API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,5 +1,5 @@
# util/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -9,7 +9,6 @@
from collections import defaultdict as defaultdict
from functools import partial as partial
from functools import update_wrapper as update_wrapper
from typing import TYPE_CHECKING
from . import preloaded as preloaded
from ._collections import coerce_generator_arg as coerce_generator_arg
@@ -49,7 +48,6 @@ from ._collections import WeakPopulateDict as WeakPopulateDict
from ._collections import WeakSequence as WeakSequence
from .compat import anext_ as anext_
from .compat import arm as arm
from .compat import athrow as athrow
from .compat import b as b
from .compat import b64decode as b64decode
from .compat import b64encode as b64encode
@@ -66,6 +64,8 @@ from .compat import osx as osx
from .compat import py310 as py310
from .compat import py311 as py311
from .compat import py312 as py312
from .compat import py313 as py313
from .compat import py314 as py314
from .compat import py38 as py38
from .compat import py39 as py39
from .compat import pypy as pypy
@@ -157,3 +157,4 @@ from .langhelpers import warn_exception as warn_exception
from .langhelpers import warn_limited as warn_limited
from .langhelpers import wrap_callable as wrap_callable
from .preloaded import preload_module as preload_module
from .typing import is_non_string_iterable as is_non_string_iterable

View File

@@ -1,5 +1,5 @@
# util/_collections.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -9,7 +9,6 @@
"""Collection classes and helpers."""
from __future__ import annotations
import collections.abc as collections_abc
import operator
import threading
import types
@@ -17,6 +16,7 @@ import typing
from typing import Any
from typing import Callable
from typing import cast
from typing import Container
from typing import Dict
from typing import FrozenSet
from typing import Generic
@@ -36,6 +36,7 @@ from typing import ValuesView
import weakref
from ._has_cy import HAS_CYEXTENSION
from .typing import is_non_string_iterable
from .typing import Literal
from .typing import Protocol
@@ -79,8 +80,8 @@ def merge_lists_w_ordering(a: List[Any], b: List[Any]) -> List[Any]:
Example::
>>> a = ['__tablename__', 'id', 'x', 'created_at']
>>> b = ['id', 'name', 'data', 'y', 'created_at']
>>> a = ["__tablename__", "id", "x", "created_at"]
>>> b = ["id", "name", "data", "y", "created_at"]
>>> merge_lists_w_ordering(a, b)
['__tablename__', 'id', 'name', 'data', 'y', 'x', 'created_at']
@@ -227,12 +228,10 @@ class Properties(Generic[_T]):
self._data.update(value)
@overload
def get(self, key: str) -> Optional[_T]:
...
def get(self, key: str) -> Optional[_T]: ...
@overload
def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]:
...
def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ...
def get(
self, key: str, default: Optional[Union[_DT, _T]] = None
@@ -419,9 +418,7 @@ def coerce_generator_arg(arg: Any) -> List[Any]:
def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
if x is None:
return default # type: ignore
if not isinstance(x, collections_abc.Iterable) or isinstance(
x, (str, bytes)
):
if not is_non_string_iterable(x):
return [x]
elif isinstance(x, list):
return x
@@ -429,15 +426,14 @@ def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
return list(x)
def has_intersection(set_, iterable):
def has_intersection(set_: Container[Any], iterable: Iterable[Any]) -> bool:
r"""return True if any items of set\_ are present in iterable.
Goes through special effort to ensure __hash__ is not called
on items in iterable that don't support it.
"""
# TODO: optimize, write in C, etc.
return bool(set_.intersection([i for i in iterable if i.__hash__]))
return any(i in set_ for i in iterable if i.__hash__)
def to_set(x):
@@ -458,7 +454,9 @@ def to_column_set(x: Any) -> Set[Any]:
return x
def update_copy(d, _new=None, **kw):
def update_copy(
d: Dict[Any, Any], _new: Optional[Dict[Any, Any]] = None, **kw: Any
) -> Dict[Any, Any]:
"""Copy the given dict and update with the given values."""
d = d.copy()
@@ -522,12 +520,10 @@ class LRUCache(typing.MutableMapping[_KT, _VT]):
return self._counter
@overload
def get(self, key: _KT) -> Optional[_VT]:
...
def get(self, key: _KT) -> Optional[_VT]: ...
@overload
def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]:
...
def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ...
def get(
self, key: _KT, default: Optional[Union[_VT, _T]] = None
@@ -589,13 +585,11 @@ class LRUCache(typing.MutableMapping[_KT, _VT]):
class _CreateFuncType(Protocol[_T_co]):
def __call__(self) -> _T_co:
...
def __call__(self) -> _T_co: ...
class _ScopeFuncType(Protocol):
def __call__(self) -> Any:
...
def __call__(self) -> Any: ...
class ScopedRegistry(Generic[_T]):

View File

@@ -1,5 +1,5 @@
# util/_concurrency_py3k.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -19,10 +19,14 @@ from typing import Coroutine
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .langhelpers import memoized_property
from .. import exc
from ..util import py311
from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
from ..util.typing import TypeGuard
_T = TypeVar("_T")
@@ -33,8 +37,7 @@ if typing.TYPE_CHECKING:
dead: bool
gr_context: Optional[Context]
def __init__(self, fn: Callable[..., Any], driver: greenlet):
...
def __init__(self, fn: Callable[..., Any], driver: greenlet): ...
def throw(self, *arg: Any) -> Any:
return None
@@ -42,8 +45,7 @@ if typing.TYPE_CHECKING:
def switch(self, value: Any) -> Any:
return None
def getcurrent() -> greenlet:
...
def getcurrent() -> greenlet: ...
else:
from greenlet import getcurrent
@@ -72,9 +74,10 @@ def is_exit_exception(e: BaseException) -> bool:
class _AsyncIoGreenlet(greenlet):
dead: bool
__sqlalchemy_greenlet_provider__ = True
def __init__(self, fn: Callable[..., Any], driver: greenlet):
greenlet.__init__(self, fn, driver)
self.driver = driver
if _has_gr_context:
self.gr_context = driver.gr_context
@@ -85,8 +88,7 @@ if TYPE_CHECKING:
def iscoroutine(
awaitable: Awaitable[_T_co],
) -> TypeGuard[Coroutine[Any, Any, _T_co]]:
...
) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ...
else:
iscoroutine = asyncio.iscoroutine
@@ -99,6 +101,11 @@ def _safe_cancel_awaitable(awaitable: Awaitable[Any]) -> None:
awaitable.close()
def in_greenlet() -> bool:
current = getcurrent()
return getattr(current, "__sqlalchemy_greenlet_provider__", False)
def await_only(awaitable: Awaitable[_T]) -> _T:
"""Awaits an async function in a sync method.
@@ -110,7 +117,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T:
"""
# this is called in the context greenlet while running fn
current = getcurrent()
if not isinstance(current, _AsyncIoGreenlet):
if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
_safe_cancel_awaitable(awaitable)
raise exc.MissingGreenlet(
@@ -122,7 +129,7 @@ def await_only(awaitable: Awaitable[_T]) -> _T:
# a coroutine to run. Once the awaitable is done, the driver greenlet
# switches back to this greenlet with the result of awaitable that is
# then returned to the caller (or raised as error)
return current.driver.switch(awaitable) # type: ignore[no-any-return]
return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
def await_fallback(awaitable: Awaitable[_T]) -> _T:
@@ -133,11 +140,16 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
:param awaitable: The coroutine to call.
.. deprecated:: 2.0.24 The ``await_fallback()`` function will be removed
in SQLAlchemy 2.1. Use :func:`_util.await_only` instead, running the
function / program / etc. within a top-level greenlet that is set up
using :func:`_util.greenlet_spawn`.
"""
# this is called in the context greenlet while running fn
current = getcurrent()
if not isinstance(current, _AsyncIoGreenlet):
if not getattr(current, "__sqlalchemy_greenlet_provider__", False):
loop = get_event_loop()
if loop.is_running():
_safe_cancel_awaitable(awaitable)
@@ -149,7 +161,7 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
)
return loop.run_until_complete(awaitable)
return current.driver.switch(awaitable) # type: ignore[no-any-return]
return current.parent.switch(awaitable) # type: ignore[no-any-return,attr-defined] # noqa: E501
async def greenlet_spawn(
@@ -175,24 +187,21 @@ async def greenlet_spawn(
# coroutine to wait. If the context is dead the function has
# returned, and its result can be returned.
switch_occurred = False
try:
result = context.switch(*args, **kwargs)
while not context.dead:
switch_occurred = True
try:
# wait for a coroutine from await_only and then return its
# result back to it.
value = await result
except BaseException:
# this allows an exception to be raised within
# the moderated greenlet so that it can continue
# its expected flow.
result = context.throw(*sys.exc_info())
else:
result = context.switch(value)
finally:
# clean up to avoid cycle resolution by gc
del context.driver
result = context.switch(*args, **kwargs)
while not context.dead:
switch_occurred = True
try:
# wait for a coroutine from await_only and then return its
# result back to it.
value = await result
except BaseException:
# this allows an exception to be raised within
# the moderated greenlet so that it can continue
# its expected flow.
result = context.throw(*sys.exc_info())
else:
result = context.switch(value)
if _require_await and not switch_occurred:
raise exc.AwaitRequired(
"The current operation required an async execution but none was "
@@ -218,34 +227,6 @@ class AsyncAdaptedLock:
self.mutex.release()
def _util_async_run_coroutine_function(
fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
) -> Any:
"""for test suite/ util only"""
loop = get_event_loop()
if loop.is_running():
raise Exception(
"for async run coroutine we expect that no greenlet or event "
"loop is running when we start out"
)
return loop.run_until_complete(fn(*args, **kwargs))
def _util_async_run(
fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
) -> Any:
"""for test suite/ util only"""
loop = get_event_loop()
if not loop.is_running():
return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs))
else:
# allow for a wrapped test function to call another
assert isinstance(getcurrent(), _AsyncIoGreenlet)
return fn(*args, **kwargs)
def get_event_loop() -> asyncio.AbstractEventLoop:
"""vendor asyncio.get_event_loop() for python 3.7 and above.
@@ -258,3 +239,50 @@ def get_event_loop() -> asyncio.AbstractEventLoop:
# avoid "During handling of the above exception, another exception..."
pass
return asyncio.get_event_loop_policy().get_event_loop()
if not TYPE_CHECKING and py311:
_Runner = asyncio.Runner
else:
class _Runner:
"""Runner implementation for test only"""
_loop: Union[None, asyncio.AbstractEventLoop, Literal[False]]
def __init__(self) -> None:
self._loop = None
def __enter__(self) -> Self:
self._lazy_init()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def close(self) -> None:
if self._loop:
try:
self._loop.run_until_complete(
self._loop.shutdown_asyncgens()
)
finally:
self._loop.close()
self._loop = False
def get_loop(self) -> asyncio.AbstractEventLoop:
"""Return embedded event loop."""
self._lazy_init()
assert self._loop
return self._loop
def run(self, coro: Coroutine[Any, Any, _T]) -> _T:
self._lazy_init()
assert self._loop
return self._loop.run_until_complete(coro)
def _lazy_init(self) -> None:
if self._loop is False:
raise RuntimeError("Runner is closed")
if self._loop is None:
self._loop = asyncio.new_event_loop()

View File

@@ -1,4 +1,5 @@
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# util/_has_cy.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# util/_py_collections.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -59,11 +59,9 @@ class ReadOnlyContainer:
class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]):
if TYPE_CHECKING:
def __new__(cls, *args: Any) -> Self:
...
def __new__(cls, *args: Any) -> Self: ...
def __init__(cls, *args: Any):
...
def __init__(cls, *args: Any): ...
def _readonly(self, *arg: Any, **kw: Any) -> NoReturn:
self._immutable()
@@ -106,7 +104,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
new = ImmutableDictBase.__new__(self.__class__)
dict.__init__(new, self)
dict.update(new, __d) # type: ignore
dict.update(new, __d)
return new
def _union_w_kw(
@@ -119,8 +117,8 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
new = ImmutableDictBase.__new__(self.__class__)
dict.__init__(new, self)
if __d:
dict.update(new, __d) # type: ignore
dict.update(new, kw) # type: ignore
dict.update(new, __d)
dict.update(new, kw)
return new
def merge_with(
@@ -132,7 +130,7 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
if new is None:
new = ImmutableDictBase.__new__(self.__class__)
dict.__init__(new, self)
dict.update(new, d) # type: ignore
dict.update(new, d)
if new is None:
return self
@@ -148,12 +146,16 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
def __or__( # type: ignore[override]
self, __value: Mapping[_KT, _VT]
) -> immutabledict[_KT, _VT]:
return immutabledict(super().__or__(__value))
return immutabledict(
super().__or__(__value), # type: ignore[call-overload]
)
def __ror__( # type: ignore[override]
self, __value: Mapping[_KT, _VT]
) -> immutabledict[_KT, _VT]:
return immutabledict(super().__ror__(__value))
return immutabledict(
super().__ror__(__value), # type: ignore[call-overload]
)
class OrderedSet(Set[_T]):

View File

@@ -1,5 +1,5 @@
# util/compat.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -19,8 +19,6 @@ import platform
import sys
import typing
from typing import Any
from typing import AsyncGenerator
from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import Iterable
@@ -34,6 +32,9 @@ from typing import Type
from typing import TypeVar
py314b1 = sys.version_info >= (3, 14, 0, "beta", 1)
py314 = sys.version_info >= (3, 14)
py313 = sys.version_info >= (3, 13)
py312 = sys.version_info >= (3, 12)
py311 = sys.version_info >= (3, 11)
py310 = sys.version_info >= (3, 10)
@@ -60,7 +61,7 @@ class FullArgSpec(typing.NamedTuple):
varkw: Optional[str]
defaults: Optional[Tuple[Any, ...]]
kwonlyargs: List[str]
kwonlydefaults: Dict[str, Any]
kwonlydefaults: Optional[Dict[str, Any]]
annotations: Dict[str, Any]
@@ -102,24 +103,6 @@ def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec:
)
if py312:
# we are 95% certain this form of athrow works in former Python
# versions, however we are unable to get confirmation;
# see https://github.com/python/cpython/issues/105269 where have
# been unable to get a straight answer so far
def athrow( # noqa
gen: AsyncGenerator[_T_co, Any], typ: Any, value: Any, traceback: Any
) -> Awaitable[_T_co]:
return gen.athrow(value)
else:
def athrow( # noqa
gen: AsyncGenerator[_T_co, Any], typ: Any, value: Any, traceback: Any
) -> Awaitable[_T_co]:
return gen.athrow(typ, value, traceback)
if py39:
# python stubs don't have a public type for this. not worth
# making a protocol
@@ -173,7 +156,7 @@ else:
def importlib_metadata_get(group):
ep = importlib_metadata.entry_points()
if not typing.TYPE_CHECKING and hasattr(ep, "select"):
if typing.TYPE_CHECKING or hasattr(ep, "select"):
return ep.select(group=group)
else:
return ep.get(group, ())

View File

@@ -1,5 +1,5 @@
# util/concurrency.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -10,11 +10,15 @@ from __future__ import annotations
import asyncio # noqa
import typing
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import TypeVar
have_greenlet = False
greenlet_error = None
try:
import greenlet # type: ignore # noqa: F401
import greenlet # type: ignore[import-untyped,unused-ignore] # noqa: F401,E501
except ImportError as e:
greenlet_error = str(e)
pass
@@ -22,15 +26,47 @@ else:
have_greenlet = True
from ._concurrency_py3k import await_only as await_only
from ._concurrency_py3k import await_fallback as await_fallback
from ._concurrency_py3k import in_greenlet as in_greenlet
from ._concurrency_py3k import greenlet_spawn as greenlet_spawn
from ._concurrency_py3k import is_exit_exception as is_exit_exception
from ._concurrency_py3k import AsyncAdaptedLock as AsyncAdaptedLock
from ._concurrency_py3k import (
_util_async_run as _util_async_run,
) # noqa: F401
from ._concurrency_py3k import (
_util_async_run_coroutine_function as _util_async_run_coroutine_function, # noqa: F401, E501
)
from ._concurrency_py3k import _Runner
_T = TypeVar("_T")
class _AsyncUtil:
"""Asyncio util for test suite/ util only"""
def __init__(self) -> None:
if have_greenlet:
self.runner = _Runner()
def run(
self,
fn: Callable[..., Coroutine[Any, Any, _T]],
*args: Any,
**kwargs: Any,
) -> _T:
"""Run coroutine on the loop"""
return self.runner.run(fn(*args, **kwargs))
def run_in_greenlet(
self, fn: Callable[..., _T], *args: Any, **kwargs: Any
) -> _T:
"""Run sync function in greenlet. Support nested calls"""
if have_greenlet:
if self.runner.get_loop().is_running():
return fn(*args, **kwargs)
else:
return self.runner.run(greenlet_spawn(fn, *args, **kwargs))
else:
return fn(*args, **kwargs)
def close(self) -> None:
if have_greenlet:
self.runner.close()
if not typing.TYPE_CHECKING and not have_greenlet:
@@ -56,6 +92,9 @@ if not typing.TYPE_CHECKING and not have_greenlet:
def await_fallback(thing): # type: ignore # noqa: F811
return thing
def in_greenlet(): # type: ignore # noqa: F811
_not_implemented()
def greenlet_spawn(fn, *args, **kw): # type: ignore # noqa: F811
_not_implemented()

View File

@@ -1,5 +1,5 @@
# util/deprecations.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -205,10 +205,10 @@ def deprecated_params(**specs: Tuple[str, str]) -> Callable[[_F], _F]:
weak_identity_map=(
"0.7",
"the :paramref:`.Session.weak_identity_map parameter "
"is deprecated."
"is deprecated.",
)
)
def some_function(**kwargs): ...
"""

View File

@@ -1,5 +1,5 @@
# util/langhelpers.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -60,7 +60,85 @@ _HP = TypeVar("_HP", bound="hybridproperty[Any]")
_HM = TypeVar("_HM", bound="hybridmethod[Any]")
if compat.py310:
if compat.py314:
# vendor a minimal form of get_annotations per
# https://github.com/python/cpython/issues/133684#issuecomment-2863841891
from annotationlib import call_annotate_function # type: ignore
from annotationlib import Format
def _get_and_call_annotate(obj, format): # noqa: A002
annotate = getattr(obj, "__annotate__", None)
if annotate is not None:
ann = call_annotate_function(annotate, format, owner=obj)
if not isinstance(ann, dict):
raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
return ann
return None
# this is ported from py3.13.0a7
_BASE_GET_ANNOTATIONS = type.__dict__["__annotations__"].__get__ # type: ignore # noqa: E501
def _get_dunder_annotations(obj):
if isinstance(obj, type):
try:
ann = _BASE_GET_ANNOTATIONS(obj)
except AttributeError:
# For static types, the descriptor raises AttributeError.
return {}
else:
ann = getattr(obj, "__annotations__", None)
if ann is None:
return {}
if not isinstance(ann, dict):
raise ValueError(
f"{obj!r}.__annotations__ is neither a dict nor None"
)
return dict(ann)
def _vendored_get_annotations(
obj: Any, *, format: Format # noqa: A002
) -> Mapping[str, Any]:
"""A sparse implementation of annotationlib.get_annotations()"""
try:
ann = _get_dunder_annotations(obj)
except Exception:
pass
else:
if ann is not None:
return dict(ann)
# But if __annotations__ threw a NameError, we try calling __annotate__
ann = _get_and_call_annotate(obj, format)
if ann is None:
# If that didn't work either, we have a very weird object:
# evaluating
# __annotations__ threw NameError and there is no __annotate__.
# In that case,
# we fall back to trying __annotations__ again.
ann = _get_dunder_annotations(obj)
if ann is None:
if isinstance(obj, type) or callable(obj):
return {}
raise TypeError(f"{obj!r} does not have annotations")
if not ann:
return {}
return dict(ann)
def get_annotations(obj: Any) -> Mapping[str, Any]:
# FORWARDREF has the effect of giving us ForwardRefs and not
# actually trying to evaluate the annotations. We need this so
# that the annotations act as much like
# "from __future__ import annotations" as possible, which is going
# away in future python as a separate mode
return _vendored_get_annotations(obj, format=Format.FORWARDREF)
elif compat.py310:
def get_annotations(obj: Any) -> Mapping[str, Any]:
return inspect.get_annotations(obj)
@@ -174,10 +252,11 @@ def string_or_unprintable(element: Any) -> str:
return "unprintable element %r" % element
def clsname_as_plain_name(cls: Type[Any]) -> str:
return " ".join(
n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", cls.__name__)
)
def clsname_as_plain_name(
cls: Type[Any], use_name: Optional[str] = None
) -> str:
name = use_name or cls.__name__
return " ".join(n.lower() for n in re.findall(r"([A-Z][a-z]+|SQL)", name))
def method_is_overridden(
@@ -249,10 +328,30 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
raise Exception("not a decoratable function")
spec = compat.inspect_getfullargspec(fn)
env: Dict[str, Any] = {}
# Python 3.14 defer creating __annotations__ until its used.
# We do not want to create __annotations__ now.
annofunc = getattr(fn, "__annotate__", None)
if annofunc is not None:
fn.__annotate__ = None # type: ignore[union-attr]
try:
spec = compat.inspect_getfullargspec(fn)
finally:
fn.__annotate__ = annofunc # type: ignore[union-attr]
else:
spec = compat.inspect_getfullargspec(fn)
spec = _update_argspec_defaults_into_env(spec, env)
# Do not generate code for annotations.
# update_wrapper() copies the annotation from fn to decorated.
# We use dummy defaults for code generation to avoid having
# copy of large globals for compiling.
# We copy __defaults__ and __kwdefaults__ from fn to decorated.
empty_defaults = (None,) * len(spec.defaults or ())
empty_kwdefaults = dict.fromkeys(spec.kwonlydefaults or ())
spec = spec._replace(
annotations={},
defaults=empty_defaults,
kwonlydefaults=empty_kwdefaults,
)
names = (
tuple(cast("Tuple[str, ...]", spec[0]))
@@ -297,41 +396,21 @@ def decorator(target: Callable[..., Any]) -> Callable[[_Fn], _Fn]:
% metadata
)
mod = sys.modules[fn.__module__]
env.update(vars(mod))
env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
env: Dict[str, Any] = {
targ_name: target,
fn_name: fn,
"__name__": fn.__module__,
}
decorated = cast(
types.FunctionType,
_exec_code_in_env(code, env, fn.__name__),
)
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
decorated.__defaults__ = fn.__defaults__
decorated.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
return update_wrapper(decorated, fn) # type: ignore[return-value]
decorated.__wrapped__ = fn # type: ignore
return cast(_Fn, update_wrapper(decorated, fn))
return update_wrapper(decorate, target)
def _update_argspec_defaults_into_env(spec, env):
"""given a FullArgSpec, convert defaults to be symbol names in an env."""
if spec.defaults:
new_defaults = []
i = 0
for arg in spec.defaults:
if type(arg).__module__ not in ("builtins", "__builtin__"):
name = "x%d" % i
env[name] = arg
new_defaults.append(name)
i += 1
else:
new_defaults.append(arg)
elem = list(spec)
elem[3] = tuple(new_defaults)
return compat.FullArgSpec(*elem)
else:
return spec
return update_wrapper(decorate, target) # type: ignore[return-value]
def _exec_code_in_env(
@@ -384,6 +463,9 @@ class PluginLoader:
self.impls[name] = load
def deregister(self, name: str) -> None:
del self.impls[name]
def _inspect_func_args(fn):
try:
@@ -411,15 +493,13 @@ def get_cls_kwargs(
*,
_set: Optional[Set[str]] = None,
raiseerr: Literal[True] = ...,
) -> Set[str]:
...
) -> Set[str]: ...
@overload
def get_cls_kwargs(
cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
) -> Optional[Set[str]]:
...
) -> Optional[Set[str]]: ...
def get_cls_kwargs(
@@ -663,7 +743,9 @@ def format_argspec_init(method, grouped=True):
"""format_argspec_plus with considerations for typical __init__ methods
Wraps format_argspec_plus with error handling strategies for typical
__init__ cases::
__init__ cases:
.. sourcecode:: text
object.__init__ -> (self)
other unreflectable (usually C) -> (self, *args, **kwargs)
@@ -718,7 +800,9 @@ def create_proxy_methods(
def getargspec_init(method):
"""inspect.getargspec with considerations for typical __init__ methods
Wraps inspect.getargspec with error handling for typical __init__ cases::
Wraps inspect.getargspec with error handling for typical __init__ cases:
.. sourcecode:: text
object.__init__ -> (self)
other unreflectable (usually C) -> (self, *args, **kwargs)
@@ -1092,23 +1176,19 @@ class generic_fn_descriptor(Generic[_T_co]):
self.__name__ = fget.__name__
@overload
def __get__(self: _GFD, obj: None, cls: Any) -> _GFD:
...
def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ...
@overload
def __get__(self, obj: object, cls: Any) -> _T_co:
...
def __get__(self, obj: object, cls: Any) -> _T_co: ...
def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]:
raise NotImplementedError()
if TYPE_CHECKING:
def __set__(self, instance: Any, value: Any) -> None:
...
def __set__(self, instance: Any, value: Any) -> None: ...
def __delete__(self, instance: Any) -> None:
...
def __delete__(self, instance: Any) -> None: ...
def _reset(self, obj: Any) -> None:
raise NotImplementedError()
@@ -1247,12 +1327,10 @@ class HasMemoized:
self.__name__ = fget.__name__
@overload
def __get__(self: _MA, obj: None, cls: Any) -> _MA:
...
def __get__(self: _MA, obj: None, cls: Any) -> _MA: ...
@overload
def __get__(self, obj: Any, cls: Any) -> _T:
...
def __get__(self, obj: Any, cls: Any) -> _T: ...
def __get__(self, obj, cls):
if obj is None:
@@ -1598,9 +1676,9 @@ class hybridmethod(Generic[_T]):
class symbol(int):
"""A constant symbol.
>>> symbol('foo') is symbol('foo')
>>> symbol("foo") is symbol("foo")
True
>>> symbol('foo')
>>> symbol("foo")
<symbol 'foo>
A slight refinement of the MAGICCOOKIE=object() pattern. The primary
@@ -1666,6 +1744,8 @@ class _IntFlagMeta(type):
items: List[symbol]
cls._items = items = []
for k, v in dict_.items():
if re.match(r"^__.*__$", k):
continue
if isinstance(v, int):
sym = symbol(k, canonical=v)
elif not k.startswith("_"):
@@ -1959,12 +2039,15 @@ NoneType = type(None)
def attrsetter(attrname):
code = "def set(obj, value):" " obj.%s = value" % attrname
code = "def set(obj, value): obj.%s = value" % attrname
env = locals().copy()
exec(code, env)
return env["set"]
_dunders = re.compile("^__.+__$")
class TypingOnly:
"""A mixin class that marks a class as 'typing only', meaning it has
absolutely no methods, attributes, or runtime functionality whatsoever.
@@ -1975,15 +2058,9 @@ class TypingOnly:
def __init_subclass__(cls) -> None:
if TypingOnly in cls.__bases__:
remaining = set(cls.__dict__).difference(
{
"__module__",
"__doc__",
"__slots__",
"__orig_bases__",
"__annotations__",
}
)
remaining = {
name for name in cls.__dict__ if not _dunders.match(name)
}
if remaining:
raise AssertionError(
f"Class {cls} directly inherits TypingOnly but has "
@@ -2216,3 +2293,11 @@ def has_compiled_ext(raise_=False):
)
else:
return False
class _Missing(enum.Enum):
Missing = enum.auto()
Missing = _Missing.Missing
MissingOr = Union[_T, Literal[_Missing.Missing]]

View File

@@ -1,5 +1,5 @@
# util/_preloaded.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# util/preloaded.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# util/queue.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -57,8 +57,7 @@ class QueueCommon(Generic[_T]):
maxsize: int
use_lifo: bool
def __init__(self, maxsize: int = 0, use_lifo: bool = False):
...
def __init__(self, maxsize: int = 0, use_lifo: bool = False): ...
def empty(self) -> bool:
raise NotImplementedError()
@@ -242,8 +241,7 @@ class AsyncAdaptedQueue(QueueCommon[_T]):
if typing.TYPE_CHECKING:
@staticmethod
def await_(coroutine: Awaitable[Any]) -> _T:
...
def await_(coroutine: Awaitable[Any]) -> _T: ...
else:
await_ = staticmethod(await_only)

View File

@@ -1,5 +1,5 @@
# util/tool_support.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -27,6 +27,7 @@ from typing import Any
from typing import Dict
from typing import Iterator
from typing import Optional
from typing import Union
from . import compat
@@ -121,7 +122,7 @@ class code_writer_cmd:
sys.stderr.write(" ".join(text))
def write_output_file_from_text(
self, text: str, destination_path: str
self, text: str, destination_path: Union[str, Path]
) -> None:
if self.args.check:
self._run_diff(destination_path, source=text)
@@ -129,7 +130,9 @@ class code_writer_cmd:
print(text)
else:
self.write_status(f"Writing {destination_path}...")
Path(destination_path).write_text(text)
Path(destination_path).write_text(
text, encoding="utf-8", newline="\n"
)
self.write_status("done\n")
def write_output_file_from_tempfile(
@@ -149,24 +152,24 @@ class code_writer_cmd:
def _run_diff(
self,
destination_path: str,
destination_path: Union[str, Path],
*,
source: Optional[str] = None,
source_file: Optional[str] = None,
) -> None:
if source_file:
with open(source_file) as tf:
with open(source_file, encoding="utf-8") as tf:
source_lines = list(tf)
elif source is not None:
source_lines = source.splitlines(keepends=True)
else:
assert False, "source or source_file is required"
with open(destination_path) as dp:
with open(destination_path, encoding="utf-8") as dp:
d = difflib.unified_diff(
list(dp),
source_lines,
fromfile=destination_path,
fromfile=Path(destination_path).as_posix(),
tofile="<proposed changes>",
n=3,
lineterm="\n",

View File

@@ -1,5 +1,5 @@
# util/topological.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -112,7 +112,7 @@ def find_cycles(
todo.remove(node)
break
else:
node = stack.pop()
stack.pop()
return output

View File

@@ -1,5 +1,5 @@
# util/typing.py
# Copyright (C) 2022 the SQLAlchemy authors and contributors
# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -9,12 +9,13 @@
from __future__ import annotations
import builtins
from collections import deque
import collections.abc as collections_abc
import re
import sys
import typing
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import ForwardRef
from typing import Generic
@@ -31,6 +32,8 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import typing_extensions
from . import compat
if True: # zimports removes the tailing comments
@@ -52,7 +55,9 @@ if True: # zimports removes the tailing comments
from typing_extensions import TypedDict as TypedDict # 3.8
from typing_extensions import TypeGuard as TypeGuard # 3.10
from typing_extensions import Self as Self # 3.11
from typing_extensions import TypeAliasType as TypeAliasType # 3.12
from typing_extensions import Never as Never # 3.11
from typing_extensions import LiteralString as LiteralString # 3.11
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT")
@@ -61,7 +66,6 @@ _KT_contra = TypeVar("_KT_contra", contravariant=True)
_VT = TypeVar("_VT")
_VT_co = TypeVar("_VT_co", covariant=True)
if compat.py310:
# why they took until py310 to put this in stdlib is beyond me,
# I've been wanting it since py27
@@ -69,18 +73,17 @@ if compat.py310:
else:
NoneType = type(None) # type: ignore
NoneFwd = ForwardRef("None")
typing_get_args = get_args
typing_get_origin = get_origin
def is_fwd_none(typ: Any) -> bool:
return isinstance(typ, ForwardRef) and typ.__forward_arg__ == "None"
_AnnotationScanType = Union[
Type[Any], str, ForwardRef, NewType, "GenericProtocol[Any]"
Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
]
class ArgsTypeProcotol(Protocol):
class ArgsTypeProtocol(Protocol):
"""protocol for types that have ``__args__``
there's no public interface for this AFAIK
@@ -111,11 +114,9 @@ class GenericProtocol(Protocol[_T]):
# copied from TypeShed, required in order to implement
# MutableMapping.update()
class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
def keys(self) -> Iterable[_KT]:
...
def keys(self) -> Iterable[_KT]: ...
def __getitem__(self, __k: _KT) -> _VT_co:
...
def __getitem__(self, __k: _KT) -> _VT_co: ...
# work around https://github.com/microsoft/pyright/issues/3025
@@ -155,7 +156,7 @@ def de_stringify_annotation(
annotation = str_cleanup_fn(annotation, originating_module)
annotation = eval_expression(
annotation, originating_module, locals_=locals_
annotation, originating_module, locals_=locals_, in_class=cls
)
if (
@@ -189,9 +190,51 @@ def de_stringify_annotation(
)
return _copy_generic_annotation_with(annotation, elements)
return annotation # type: ignore
def fixup_container_fwd_refs(
type_: _AnnotationScanType,
) -> _AnnotationScanType:
"""Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
and similar for list, set
"""
if (
is_generic(type_)
and get_origin(type_)
in (
dict,
set,
list,
collections_abc.MutableSet,
collections_abc.MutableMapping,
collections_abc.MutableSequence,
collections_abc.Mapping,
collections_abc.Sequence,
)
# fight, kick and scream to struggle to tell the difference between
# dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
# behave the same yet there is NO WAY to distinguish between which type
# it is using public attributes
and not re.match(
"typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
)
):
# compat with py3.10 and earlier
return get_origin(type_).__class_getitem__( # type: ignore
tuple(
[
ForwardRef(elem) if isinstance(elem, str) else elem
for elem in get_args(type_)
]
)
)
return type_
def _copy_generic_annotation_with(
annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
) -> Type[_T]:
@@ -208,6 +251,7 @@ def eval_expression(
module_name: str,
*,
locals_: Optional[Mapping[str, Any]] = None,
in_class: Optional[Type[Any]] = None,
) -> Any:
try:
base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
@@ -218,7 +262,18 @@ def eval_expression(
) from ke
try:
annotation = eval(expression, base_globals, locals_)
if in_class is not None:
cls_namespace = dict(in_class.__dict__)
cls_namespace.setdefault(in_class.__name__, in_class)
# see #10899. We want the locals/globals to take precedence
# over the class namespace in this context, even though this
# is not the usual way variables would resolve.
cls_namespace.update(base_globals)
annotation = eval(expression, cls_namespace, locals_)
else:
annotation = eval(expression, base_globals, locals_)
except Exception as err:
raise NameError(
f"Could not de-stringify annotation {expression!r}"
@@ -270,34 +325,18 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
return getattr(obj, "__name__", name)
def de_stringify_union_elements(
cls: Type[Any],
annotation: ArgsTypeProcotol,
originating_module: str,
locals_: Mapping[str, Any],
*,
str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
) -> Type[Any]:
return make_union_type(
*[
de_stringify_annotation(
cls,
anno,
originating_module,
{},
str_cleanup_fn=str_cleanup_fn,
)
for anno in annotation.__args__
]
def is_pep593(type_: Optional[Any]) -> bool:
return type_ is not None and get_origin(type_) in _type_tuples.Annotated
def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
return isinstance(obj, collections_abc.Iterable) and not isinstance(
obj, (str, bytes)
)
def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
return type_ is not None and typing_get_origin(type_) is Annotated
def is_literal(type_: _AnnotationScanType) -> bool:
return get_origin(type_) is Literal
def is_literal(type_: Any) -> bool:
return get_origin(type_) in _type_tuples.Literal
def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
@@ -305,46 +344,99 @@ def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
# doesn't work in 3.8, 3.7 as it passes a closure, not an
# object instance
# return isinstance(type_, NewType)
# isinstance(type, type_instances.NewType)
def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
# NOTE: a generic TAT does not instance check as TypeAliasType outside of
# python 3.10. For sqlalchemy use cases it's fine to consider it a TAT
# though.
# NOTE: things seems to work also without this additional check
if is_generic(type_):
return is_pep695(type_.__origin__)
return isinstance(type_, _type_instances.TypeAliasType)
def flatten_newtype(type_: NewType) -> Type[Any]:
super_type = type_.__supertype__
while is_newtype(super_type):
super_type = super_type.__supertype__
return super_type
return super_type # type: ignore[return-value]
def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
"""Extracts the value from a TypeAliasType, recursively exploring unions
and inner TypeAliasType to flatten them into a single set.
Forward references are not evaluated, so no recursive exploration happens
into them.
"""
_seen = set()
def recursive_value(inner_type):
if inner_type in _seen:
# recursion are not supported (at least it's flagged as
# an error by pyright). Just avoid infinite loop
return inner_type
_seen.add(inner_type)
if not is_pep695(inner_type):
return inner_type
value = inner_type.__value__
if not is_union(value):
return value
return [recursive_value(t) for t in value.__args__]
res = recursive_value(type_)
if isinstance(res, list):
types = set()
stack = deque(res)
while stack:
t = stack.popleft()
if isinstance(t, list):
stack.extend(t)
else:
types.add(None if t is NoneType or is_fwd_none(t) else t)
return types
else:
return {res}
def is_fwd_ref(
type_: _AnnotationScanType, check_generic: bool = False
type_: _AnnotationScanType,
check_generic: bool = False,
check_for_plain_string: bool = False,
) -> TypeGuard[ForwardRef]:
if isinstance(type_, ForwardRef):
if check_for_plain_string and isinstance(type_, str):
return True
elif isinstance(type_, _type_instances.ForwardRef):
return True
elif check_generic and is_generic(type_):
return any(is_fwd_ref(arg, True) for arg in type_.__args__)
return any(
is_fwd_ref(
arg, True, check_for_plain_string=check_for_plain_string
)
for arg in type_.__args__
)
else:
return False
@overload
def de_optionalize_union_types(type_: str) -> str:
...
def de_optionalize_union_types(type_: str) -> str: ...
@overload
def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]:
...
def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ...
@overload
def de_optionalize_union_types(
type_: _AnnotationScanType,
) -> _AnnotationScanType:
...
) -> _AnnotationScanType: ...
def de_optionalize_union_types(
@@ -353,16 +445,33 @@ def de_optionalize_union_types(
"""Given a type, filter out ``Union`` types that include ``NoneType``
to not include the ``NoneType``.
Contains extra logic to work on non-flattened unions, unions that contain
``None`` (seen in py38, 37)
"""
if is_fwd_ref(type_):
return de_optionalize_fwd_ref_union_types(type_)
return _de_optionalize_fwd_ref_union_types(type_, False)
elif is_optional(type_):
typ = set(type_.__args__)
elif is_union(type_) and includes_none(type_):
if compat.py39:
typ = set(type_.__args__)
else:
# py38, 37 - unions are not automatically flattened, can contain
# None rather than NoneType
stack_of_unions = deque([type_])
typ = set()
while stack_of_unions:
u_typ = stack_of_unions.popleft()
for elem in u_typ.__args__:
if is_union(elem):
stack_of_unions.append(elem)
else:
typ.add(elem)
typ.discard(NoneType)
typ.discard(NoneFwd)
typ.discard(None) # type: ignore
typ = {t for t in typ if t is not NoneType and not is_fwd_none(t)}
return make_union_type(*typ)
@@ -370,9 +479,21 @@ def de_optionalize_union_types(
return type_
def de_optionalize_fwd_ref_union_types(
type_: ForwardRef,
) -> _AnnotationScanType:
@overload
def _de_optionalize_fwd_ref_union_types(
type_: ForwardRef, return_has_none: Literal[True]
) -> bool: ...
@overload
def _de_optionalize_fwd_ref_union_types(
type_: ForwardRef, return_has_none: Literal[False]
) -> _AnnotationScanType: ...
def _de_optionalize_fwd_ref_union_types(
type_: ForwardRef, return_has_none: bool
) -> Union[_AnnotationScanType, bool]:
"""return the non-optional type for Optional[], Union[None, ...], x|None,
etc. without de-stringifying forward refs.
@@ -384,68 +505,94 @@ def de_optionalize_fwd_ref_union_types(
mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
if mm:
if mm.group(1) == "Optional":
return ForwardRef(mm.group(2))
elif mm.group(1) == "Union":
elements = re.split(r",\s*", mm.group(2))
return make_union_type(
*[ForwardRef(elem) for elem in elements if elem != "None"]
)
g1 = mm.group(1).split(".")[-1]
if g1 == "Optional":
return True if return_has_none else ForwardRef(mm.group(2))
elif g1 == "Union":
if "[" in mm.group(2):
# cases like "Union[Dict[str, int], int, None]"
elements: list[str] = []
current: list[str] = []
ignore_comma = 0
for char in mm.group(2):
if char == "[":
ignore_comma += 1
elif char == "]":
ignore_comma -= 1
elif ignore_comma == 0 and char == ",":
elements.append("".join(current).strip())
current.clear()
continue
current.append(char)
else:
elements = re.split(r",\s*", mm.group(2))
parts = [ForwardRef(elem) for elem in elements if elem != "None"]
if return_has_none:
return len(elements) != len(parts)
else:
return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501
else:
return type_
return False if return_has_none else type_
pipe_tokens = re.split(r"\s*\|\s*", annotation)
if "None" in pipe_tokens:
return ForwardRef("|".join(p for p in pipe_tokens if p != "None"))
has_none = "None" in pipe_tokens
if return_has_none:
return has_none
if has_none:
anno_str = "|".join(p for p in pipe_tokens if p != "None")
return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501
return type_
def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
"""Make a Union type.
"""Make a Union type."""
This is needed by :func:`.de_optionalize_union_types` which removes
``NoneType`` from a ``Union``.
return Union[types] # type: ignore
def includes_none(type_: Any) -> bool:
"""Returns if the type annotation ``type_`` allows ``None``.
This function supports:
* forward refs
* unions
* pep593 - Annotated
* pep695 - TypeAliasType (does not support looking into
fw reference of other pep695)
* NewType
* plain types like ``int``, ``None``, etc
"""
return cast(Any, Union).__getitem__(types) # type: ignore
def expand_unions(
type_: Type[Any], include_union: bool = False, discard_none: bool = False
) -> Tuple[Type[Any], ...]:
"""Return a type as a tuple of individual types, expanding for
``Union`` types."""
if is_fwd_ref(type_):
return _de_optionalize_fwd_ref_union_types(type_, True)
if is_union(type_):
typ = set(type_.__args__)
if discard_none:
typ.discard(NoneType)
if include_union:
return (type_,) + tuple(typ) # type: ignore
else:
return tuple(typ) # type: ignore
else:
return (type_,)
return any(includes_none(t) for t in get_args(type_))
if is_pep593(type_):
return includes_none(get_args(type_)[0])
if is_pep695(type_):
return any(includes_none(t) for t in pep695_values(type_))
if is_newtype(type_):
return includes_none(type_.__supertype__)
try:
return type_ in (NoneType, None) or is_fwd_none(type_)
except TypeError:
# if type_ is Column, mapped_column(), etc. the use of "in"
# resolves to ``__eq__()`` which then gives us an expression object
# that can't resolve to boolean. just catch it all via exception
return False
def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
return is_origin_of(
type_,
"Optional",
"Union",
"UnionType",
def is_a_type(type_: Any) -> bool:
return (
isinstance(type_, type)
or hasattr(type_, "__origin__")
or type_.__module__ in ("typing", "typing_extensions")
or type(type_).__mro__[0].__module__ in ("typing", "typing_extensions")
)
def is_optional_union(type_: Any) -> bool:
return is_optional(type_) and NoneType in typing_get_args(type_)
def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
return is_origin_of(type_, "Union")
def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]:
return is_origin_of(type_, "Union", "UnionType")
def is_origin_of_cls(
@@ -454,7 +601,7 @@ def is_origin_of_cls(
"""return True if the given type has an __origin__ that shares a base
with the given class"""
origin = typing_get_origin(type_)
origin = get_origin(type_)
if origin is None:
return False
@@ -467,7 +614,7 @@ def is_origin_of(
"""return True if the given type has an __origin__ with the given name
and optional module."""
origin = typing_get_origin(type_)
origin = get_origin(type_)
if origin is None:
return False
@@ -488,14 +635,11 @@ def _get_type_name(type_: Type[Any]) -> str:
class DescriptorProto(Protocol):
def __get__(self, instance: object, owner: Any) -> Any:
...
def __get__(self, instance: object, owner: Any) -> Any: ...
def __set__(self, instance: Any, value: Any) -> None:
...
def __set__(self, instance: Any, value: Any) -> None: ...
def __delete__(self, instance: Any) -> None:
...
def __delete__(self, instance: Any) -> None: ...
_DESC = TypeVar("_DESC", bound=DescriptorProto)
@@ -514,14 +658,11 @@ class DescriptorReference(Generic[_DESC]):
if TYPE_CHECKING:
def __get__(self, instance: object, owner: Any) -> _DESC:
...
def __get__(self, instance: object, owner: Any) -> _DESC: ...
def __set__(self, instance: Any, value: _DESC) -> None:
...
def __set__(self, instance: Any, value: _DESC) -> None: ...
def __delete__(self, instance: Any) -> None:
...
def __delete__(self, instance: Any) -> None: ...
_DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True)
@@ -537,14 +678,11 @@ class RODescriptorReference(Generic[_DESC_co]):
if TYPE_CHECKING:
def __get__(self, instance: object, owner: Any) -> _DESC_co:
...
def __get__(self, instance: object, owner: Any) -> _DESC_co: ...
def __set__(self, instance: Any, value: Any) -> NoReturn:
...
def __set__(self, instance: Any, value: Any) -> NoReturn: ...
def __delete__(self, instance: Any) -> NoReturn:
...
def __delete__(self, instance: Any) -> NoReturn: ...
_FN = TypeVar("_FN", bound=Optional[Callable[..., Any]])
@@ -561,14 +699,35 @@ class CallableReference(Generic[_FN]):
if TYPE_CHECKING:
def __get__(self, instance: object, owner: Any) -> _FN:
...
def __get__(self, instance: object, owner: Any) -> _FN: ...
def __set__(self, instance: Any, value: _FN) -> None:
...
def __set__(self, instance: Any, value: _FN) -> None: ...
def __delete__(self, instance: Any) -> None:
...
def __delete__(self, instance: Any) -> None: ...
# $def ro_descriptor_reference(fn: Callable[])
class _TypingInstances:
def __getattr__(self, key: str) -> tuple[type, ...]:
types = tuple(
{
t
for t in [
getattr(typing, key, None),
getattr(typing_extensions, key, None),
]
if t is not None
}
)
if not types:
raise AttributeError(key)
self.__dict__[key] = types
return types
_type_tuples = _TypingInstances()
if TYPE_CHECKING:
_type_instances = typing_extensions
else:
_type_instances = _type_tuples
LITERAL_TYPES = _type_tuples.Literal