API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,5 +1,5 @@
# ext/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# ext/associationproxy.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -98,6 +98,8 @@ def association_proxy(
default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
compare: Union[_NoArg, bool] = _NoArg.NO_ARG,
kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002
dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] = _NoArg.NO_ARG,
) -> AssociationProxy[Any]:
r"""Return a Python property implementing a view of a target
attribute which references an attribute on members of the
@@ -198,6 +200,19 @@ def association_proxy(
.. versionadded:: 2.0.0b4
:param hash: Specific to
:ref:`orm_declarative_native_dataclasses`, controls if this field
is included when generating the ``__hash__()`` method for the mapped
class.
.. versionadded:: 2.0.36
:param dataclass_metadata: Specific to
:ref:`orm_declarative_native_dataclasses`, supplies metadata
to be attached to the generated dataclass field.
.. versionadded:: 2.0.42
:param info: optional, will be assigned to
:attr:`.AssociationProxy.info` if present.
@@ -237,7 +252,14 @@ def association_proxy(
cascade_scalar_deletes=cascade_scalar_deletes,
create_on_none_assignment=create_on_none_assignment,
attribute_options=_AttributeOptions(
init, repr, default, default_factory, compare, kw_only
init,
repr,
default,
default_factory,
compare,
kw_only,
hash,
dataclass_metadata,
),
)
@@ -254,45 +276,39 @@ class AssociationProxyExtensionType(InspectionAttrExtensionType):
class _GetterProtocol(Protocol[_T_co]):
def __call__(self, instance: Any) -> _T_co:
...
def __call__(self, instance: Any) -> _T_co: ...
# mypy 0.990 we are no longer allowed to make this Protocol[_T_con]
class _SetterProtocol(Protocol):
...
class _SetterProtocol(Protocol): ...
class _PlainSetterProtocol(_SetterProtocol, Protocol[_T_con]):
def __call__(self, instance: Any, value: _T_con) -> None:
...
def __call__(self, instance: Any, value: _T_con) -> None: ...
class _DictSetterProtocol(_SetterProtocol, Protocol[_T_con]):
def __call__(self, instance: Any, key: Any, value: _T_con) -> None:
...
def __call__(self, instance: Any, key: Any, value: _T_con) -> None: ...
# mypy 0.990 we are no longer allowed to make this Protocol[_T_con]
class _CreatorProtocol(Protocol):
...
class _CreatorProtocol(Protocol): ...
class _PlainCreatorProtocol(_CreatorProtocol, Protocol[_T_con]):
def __call__(self, value: _T_con) -> Any:
...
def __call__(self, value: _T_con) -> Any: ...
class _KeyCreatorProtocol(_CreatorProtocol, Protocol[_T_con]):
def __call__(self, key: Any, value: Optional[_T_con]) -> Any:
...
def __call__(self, key: Any, value: Optional[_T_con]) -> Any: ...
class _LazyCollectionProtocol(Protocol[_T]):
def __call__(
self,
) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]:
...
) -> Union[
MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]
]: ...
class _GetSetFactoryProtocol(Protocol):
@@ -300,8 +316,7 @@ class _GetSetFactoryProtocol(Protocol):
self,
collection_class: Optional[Type[Any]],
assoc_instance: AssociationProxyInstance[Any],
) -> Tuple[_GetterProtocol[Any], _SetterProtocol]:
...
) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ...
class _ProxyFactoryProtocol(Protocol):
@@ -311,15 +326,13 @@ class _ProxyFactoryProtocol(Protocol):
creator: _CreatorProtocol,
value_attr: str,
parent: AssociationProxyInstance[Any],
) -> Any:
...
) -> Any: ...
class _ProxyBulkSetProtocol(Protocol):
def __call__(
self, proxy: _AssociationCollection[Any], collection: Iterable[Any]
) -> None:
...
) -> None: ...
class _AssociationProxyProtocol(Protocol[_T]):
@@ -337,18 +350,15 @@ class _AssociationProxyProtocol(Protocol[_T]):
proxy_bulk_set: Optional[_ProxyBulkSetProtocol]
@util.ro_memoized_property
def info(self) -> _InfoType:
...
def info(self) -> _InfoType: ...
def for_class(
self, class_: Type[Any], obj: Optional[object] = None
) -> AssociationProxyInstance[_T]:
...
) -> AssociationProxyInstance[_T]: ...
def _default_getset(
self, collection_class: Any
) -> Tuple[_GetterProtocol[Any], _SetterProtocol]:
...
) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ...
class AssociationProxy(
@@ -419,18 +429,17 @@ class AssociationProxy(
self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS
@overload
def __get__(self, instance: Literal[None], owner: Literal[None]) -> Self:
...
def __get__(
self, instance: Literal[None], owner: Literal[None]
) -> Self: ...
@overload
def __get__(
self, instance: Literal[None], owner: Any
) -> AssociationProxyInstance[_T]:
...
) -> AssociationProxyInstance[_T]: ...
@overload
def __get__(self, instance: object, owner: Any) -> _T:
...
def __get__(self, instance: object, owner: Any) -> _T: ...
def __get__(
self, instance: object, owner: Any
@@ -463,7 +472,7 @@ class AssociationProxy(
class User(Base):
# ...
keywords = association_proxy('kws', 'keyword')
keywords = association_proxy("kws", "keyword")
If we access this :class:`.AssociationProxy` from
:attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the
@@ -783,9 +792,9 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
:attr:`.AssociationProxyInstance.remote_attr` attributes separately::
stmt = (
select(Parent).
join(Parent.proxied.local_attr).
join(Parent.proxied.remote_attr)
select(Parent)
.join(Parent.proxied.local_attr)
.join(Parent.proxied.remote_attr)
)
A future release may seek to provide a more succinct join pattern
@@ -861,12 +870,10 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
return self.parent.info
@overload
def get(self: _Self, obj: Literal[None]) -> _Self:
...
def get(self: _Self, obj: Literal[None]) -> _Self: ...
@overload
def get(self, obj: Any) -> _T:
...
def get(self, obj: Any) -> _T: ...
def get(
self, obj: Any
@@ -1089,7 +1096,7 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
and (not self._target_is_object or self._value_is_scalar)
):
raise exc.InvalidRequestError(
"'any()' not implemented for scalar " "attributes. Use has()."
"'any()' not implemented for scalar attributes. Use has()."
)
return self._criterion_exists(
criterion=criterion, is_has=False, **kwargs
@@ -1113,7 +1120,7 @@ class AssociationProxyInstance(SQLORMOperations[_T]):
or (self._target_is_object and not self._value_is_scalar)
):
raise exc.InvalidRequestError(
"'has()' not implemented for collections. " "Use any()."
"'has()' not implemented for collections. Use any()."
)
return self._criterion_exists(
criterion=criterion, is_has=True, **kwargs
@@ -1432,12 +1439,10 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]):
self.setter(object_, value)
@overload
def __getitem__(self, index: int) -> _T:
...
def __getitem__(self, index: int) -> _T: ...
@overload
def __getitem__(self, index: slice) -> MutableSequence[_T]:
...
def __getitem__(self, index: slice) -> MutableSequence[_T]: ...
def __getitem__(
self, index: Union[int, slice]
@@ -1448,12 +1453,10 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]):
return [self._get(member) for member in self.col[index]]
@overload
def __setitem__(self, index: int, value: _T) -> None:
...
def __setitem__(self, index: int, value: _T) -> None: ...
@overload
def __setitem__(self, index: slice, value: Iterable[_T]) -> None:
...
def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ...
def __setitem__(
self, index: Union[int, slice], value: Union[_T, Iterable[_T]]
@@ -1492,12 +1495,10 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]):
self._set(self.col[i], item)
@overload
def __delitem__(self, index: int) -> None:
...
def __delitem__(self, index: int) -> None: ...
@overload
def __delitem__(self, index: slice) -> None:
...
def __delitem__(self, index: slice) -> None: ...
def __delitem__(self, index: Union[slice, int]) -> None:
del self.col[index]
@@ -1624,8 +1625,9 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]):
if typing.TYPE_CHECKING:
# TODO: no idea how to do this without separate "stub"
def index(self, value: Any, start: int = ..., stop: int = ...) -> int:
...
def index(
self, value: Any, start: int = ..., stop: int = ...
) -> int: ...
else:
@@ -1701,12 +1703,10 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]):
return repr(dict(self))
@overload
def get(self, __key: _KT) -> Optional[_VT]:
...
def get(self, __key: _KT) -> Optional[_VT]: ...
@overload
def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]:
...
def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ...
def get(
self, key: _KT, default: Optional[Union[_VT, _T]] = None
@@ -1738,12 +1738,12 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]):
return ValuesView(self)
@overload
def pop(self, __key: _KT) -> _VT:
...
def pop(self, __key: _KT) -> _VT: ...
@overload
def pop(self, __key: _KT, default: Union[_VT, _T] = ...) -> Union[_VT, _T]:
...
def pop(
self, __key: _KT, default: Union[_VT, _T] = ...
) -> Union[_VT, _T]: ...
def pop(self, __key: _KT, *arg: Any, **kw: Any) -> Union[_VT, _T]:
member = self.col.pop(__key, *arg, **kw)
@@ -1756,16 +1756,15 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]):
@overload
def update(
self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT
) -> None:
...
) -> None: ...
@overload
def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None:
...
def update(
self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
@overload
def update(self, **kwargs: _VT) -> None:
...
def update(self, **kwargs: _VT) -> None: ...
def update(self, *a: Any, **kw: Any) -> None:
up: Dict[_KT, _VT] = {}

View File

@@ -1,5 +1,5 @@
# ext/asyncio/__init__.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# ext/asyncio/base.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -44,12 +44,10 @@ class ReversibleProxy(Generic[_PT]):
__slots__ = ("__weakref__",)
@overload
def _assign_proxied(self, target: _PT) -> _PT:
...
def _assign_proxied(self, target: _PT) -> _PT: ...
@overload
def _assign_proxied(self, target: None) -> None:
...
def _assign_proxied(self, target: None) -> None: ...
def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
if target is not None:
@@ -73,28 +71,26 @@ class ReversibleProxy(Generic[_PT]):
cls._proxy_objects.pop(ref, None)
@classmethod
def _regenerate_proxy_for_target(cls, target: _PT) -> Self:
def _regenerate_proxy_for_target(
cls, target: _PT, **additional_kw: Any
) -> Self:
raise NotImplementedError()
@overload
@classmethod
def _retrieve_proxy_for_target(
cls,
target: _PT,
regenerate: Literal[True] = ...,
) -> Self:
...
cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any
) -> Self: ...
@overload
@classmethod
def _retrieve_proxy_for_target(
cls, target: _PT, regenerate: bool = True
) -> Optional[Self]:
...
cls, target: _PT, regenerate: bool = True, **additional_kw: Any
) -> Optional[Self]: ...
@classmethod
def _retrieve_proxy_for_target(
cls, target: _PT, regenerate: bool = True
cls, target: _PT, regenerate: bool = True, **additional_kw: Any
) -> Optional[Self]:
try:
proxy_ref = cls._proxy_objects[weakref.ref(target)]
@@ -106,7 +102,7 @@ class ReversibleProxy(Generic[_PT]):
return proxy # type: ignore
if regenerate:
return cls._regenerate_proxy_for_target(target)
return cls._regenerate_proxy_for_target(target, **additional_kw)
else:
return None
@@ -182,7 +178,7 @@ class GeneratorStartableContext(StartableContext[_T_co]):
# tell if we get the same exception back
value = typ()
try:
await util.athrow(self.gen, typ, value, traceback)
await self.gen.athrow(value)
except StopAsyncIteration as exc:
# Suppress StopIteration *unless* it's the same exception that
# was passed to throw(). This prevents a StopIteration
@@ -219,7 +215,7 @@ class GeneratorStartableContext(StartableContext[_T_co]):
def asyncstartablecontext(
func: Callable[..., AsyncIterator[_T_co]]
func: Callable[..., AsyncIterator[_T_co]],
) -> Callable[..., GeneratorStartableContext[_T_co]]:
"""@asyncstartablecontext decorator.
@@ -228,7 +224,9 @@ def asyncstartablecontext(
``@contextlib.asynccontextmanager`` supports, and the usage pattern
is different as well.
Typical usage::
Typical usage:
.. sourcecode:: text
@asyncstartablecontext
async def some_async_generator(<arguments>):

View File

@@ -1,5 +1,5 @@
# ext/asyncio/engine.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -41,6 +41,8 @@ from ...engine.base import NestedTransaction
from ...engine.base import Transaction
from ...exc import ArgumentError
from ...util.concurrency import greenlet_spawn
from ...util.typing import Concatenate
from ...util.typing import ParamSpec
if TYPE_CHECKING:
from ...engine.cursor import CursorResult
@@ -61,6 +63,7 @@ if TYPE_CHECKING:
from ...sql.base import Executable
from ...sql.selectable import TypedReturnsRows
_P = ParamSpec("_P")
_T = TypeVar("_T", bound=Any)
@@ -195,6 +198,7 @@ class AsyncConnection(
method of :class:`_asyncio.AsyncEngine`::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
async with engine.connect() as conn:
@@ -251,7 +255,7 @@ class AsyncConnection(
@classmethod
def _regenerate_proxy_for_target(
cls, target: Connection
cls, target: Connection, **additional_kw: Any # noqa: U100
) -> AsyncConnection:
return AsyncConnection(
AsyncEngine._retrieve_proxy_for_target(target.engine), target
@@ -414,13 +418,12 @@ class AsyncConnection(
yield_per: int = ...,
insertmanyvalues_page_size: int = ...,
schema_translate_map: Optional[SchemaTranslateMapType] = ...,
preserve_rowcount: bool = False,
**opt: Any,
) -> AsyncConnection:
...
) -> AsyncConnection: ...
@overload
async def execution_options(self, **opt: Any) -> AsyncConnection:
...
async def execution_options(self, **opt: Any) -> AsyncConnection: ...
async def execution_options(self, **opt: Any) -> AsyncConnection:
r"""Set non-SQL options for the connection which take effect
@@ -518,8 +521,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> GeneratorStartableContext[AsyncResult[_T]]:
...
) -> GeneratorStartableContext[AsyncResult[_T]]: ...
@overload
def stream(
@@ -528,8 +530,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> GeneratorStartableContext[AsyncResult[Any]]:
...
) -> GeneratorStartableContext[AsyncResult[Any]]: ...
@asyncstartablecontext
async def stream(
@@ -544,7 +545,7 @@ class AsyncConnection(
E.g.::
result = await conn.stream(stmt):
result = await conn.stream(stmt)
async for row in result:
print(f"{row}")
@@ -573,6 +574,11 @@ class AsyncConnection(
:meth:`.AsyncConnection.stream_scalars`
"""
if not self.dialect.supports_server_side_cursors:
raise exc.InvalidRequestError(
"Cant use `stream` or `stream_scalars` with the current "
"dialect since it does not support server side cursors."
)
result = await greenlet_spawn(
self._proxied.execute,
@@ -600,8 +606,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> CursorResult[_T]:
...
) -> CursorResult[_T]: ...
@overload
async def execute(
@@ -610,8 +615,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> CursorResult[Any]:
...
) -> CursorResult[Any]: ...
async def execute(
self,
@@ -667,8 +671,7 @@ class AsyncConnection(
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar(
@@ -677,8 +680,7 @@ class AsyncConnection(
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> Any:
...
) -> Any: ...
async def scalar(
self,
@@ -709,8 +711,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[_T]:
...
) -> ScalarResult[_T]: ...
@overload
async def scalars(
@@ -719,8 +720,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
...
) -> ScalarResult[Any]: ...
async def scalars(
self,
@@ -752,8 +752,7 @@ class AsyncConnection(
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> GeneratorStartableContext[AsyncScalarResult[_T]]:
...
) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ...
@overload
def stream_scalars(
@@ -762,8 +761,7 @@ class AsyncConnection(
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> GeneratorStartableContext[AsyncScalarResult[Any]]:
...
) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ...
@asyncstartablecontext
async def stream_scalars(
@@ -819,9 +817,12 @@ class AsyncConnection(
yield result.scalars()
async def run_sync(
self, fn: Callable[..., _T], *arg: Any, **kw: Any
self,
fn: Callable[Concatenate[Connection, _P], _T],
*arg: _P.args,
**kw: _P.kwargs,
) -> _T:
"""Invoke the given synchronous (i.e. not async) callable,
'''Invoke the given synchronous (i.e. not async) callable,
passing a synchronous-style :class:`_engine.Connection` as the first
argument.
@@ -831,26 +832,26 @@ class AsyncConnection(
E.g.::
def do_something_with_core(conn: Connection, arg1: int, arg2: str) -> str:
'''A synchronous function that does not require awaiting
"""A synchronous function that does not require awaiting
:param conn: a Core SQLAlchemy Connection, used synchronously
:return: an optional return value is supported
'''
conn.execute(
some_table.insert().values(int_col=arg1, str_col=arg2)
)
"""
conn.execute(some_table.insert().values(int_col=arg1, str_col=arg2))
return "success"
async def do_something_async(async_engine: AsyncEngine) -> None:
'''an async function that uses awaiting'''
"""an async function that uses awaiting"""
async with async_engine.begin() as async_conn:
# run do_something_with_core() with a sync-style
# Connection, proxied into an awaitable
return_code = await async_conn.run_sync(do_something_with_core, 5, "strval")
return_code = await async_conn.run_sync(
do_something_with_core, 5, "strval"
)
print(return_code)
This method maintains the asyncio event loop all the way through
@@ -881,9 +882,11 @@ class AsyncConnection(
:ref:`session_run_sync`
""" # noqa: E501
''' # noqa: E501
return await greenlet_spawn(fn, self._proxied, *arg, **kw)
return await greenlet_spawn(
fn, self._proxied, *arg, _require_await=False, **kw
)
def __await__(self) -> Generator[Any, None, AsyncConnection]:
return self.start().__await__()
@@ -928,7 +931,7 @@ class AsyncConnection(
return self._proxied.invalidated
@property
def dialect(self) -> Any:
def dialect(self) -> Dialect:
r"""Proxy for the :attr:`_engine.Connection.dialect` attribute
on behalf of the :class:`_asyncio.AsyncConnection` class.
@@ -937,7 +940,7 @@ class AsyncConnection(
return self._proxied.dialect
@dialect.setter
def dialect(self, attr: Any) -> None:
def dialect(self, attr: Dialect) -> None:
self._proxied.dialect = attr
@property
@@ -998,6 +1001,7 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
:func:`_asyncio.create_async_engine` function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
.. versionadded:: 1.4
@@ -1037,7 +1041,9 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
return self.sync_engine
@classmethod
def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine:
def _regenerate_proxy_for_target(
cls, target: Engine, **additional_kw: Any # noqa: U100
) -> AsyncEngine:
return AsyncEngine(target)
@contextlib.asynccontextmanager
@@ -1054,7 +1060,6 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
)
await conn.execute(text("my_special_procedure(5)"))
"""
conn = self.connect()
@@ -1100,12 +1105,10 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
insertmanyvalues_page_size: int = ...,
schema_translate_map: Optional[SchemaTranslateMapType] = ...,
**opt: Any,
) -> AsyncEngine:
...
) -> AsyncEngine: ...
@overload
def execution_options(self, **opt: Any) -> AsyncEngine:
...
def execution_options(self, **opt: Any) -> AsyncEngine: ...
def execution_options(self, **opt: Any) -> AsyncEngine:
"""Return a new :class:`_asyncio.AsyncEngine` that will provide
@@ -1160,7 +1163,7 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
This applies **only** to the built-in cache that is established
via the :paramref:`_engine.create_engine.query_cache_size` parameter.
It will not impact any dictionary caches that were passed via the
:paramref:`.Connection.execution_options.query_cache` parameter.
:paramref:`.Connection.execution_options.compiled_cache` parameter.
.. versionadded:: 1.4
@@ -1343,7 +1346,7 @@ class AsyncTransaction(
@classmethod
def _regenerate_proxy_for_target(
cls, target: Transaction
cls, target: Transaction, **additional_kw: Any # noqa: U100
) -> AsyncTransaction:
sync_connection = target.connection
sync_transaction = target
@@ -1418,19 +1421,17 @@ class AsyncTransaction(
@overload
def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine:
...
def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: ...
@overload
def _get_sync_engine_or_connection(
async_engine: AsyncConnection,
) -> Connection:
...
) -> Connection: ...
def _get_sync_engine_or_connection(
async_engine: Union[AsyncEngine, AsyncConnection]
async_engine: Union[AsyncEngine, AsyncConnection],
) -> Union[Engine, Connection]:
if isinstance(async_engine, AsyncConnection):
return async_engine._proxied

View File

@@ -1,5 +1,5 @@
# ext/asyncio/exc.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# ext/asyncio/result.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -93,6 +93,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
self._metadata = real_result._metadata
self._unique_filter_state = real_result._unique_filter_state
self._source_supports_scalars = real_result._source_supports_scalars
self._post_creational_filter = None
# BaseCursorResult pre-generates the "_row_getter". Use that
@@ -324,22 +325,20 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
return await greenlet_spawn(self._only_one_row, True, False, False)
@overload
async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T:
...
async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: ...
@overload
async def scalar_one(self) -> Any:
...
async def scalar_one(self) -> Any: ...
async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
then :meth:`_asyncio.AsyncResult.one`.
then :meth:`_asyncio.AsyncScalarResult.one`.
.. seealso::
:meth:`_asyncio.AsyncResult.one`
:meth:`_asyncio.AsyncScalarResult.one`
:meth:`_asyncio.AsyncResult.scalars`
@@ -349,22 +348,20 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
@overload
async def scalar_one_or_none(
self: AsyncResult[Tuple[_T]],
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar_one_or_none(self) -> Optional[Any]:
...
async def scalar_one_or_none(self) -> Optional[Any]: ...
async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one scalar result or ``None``.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
then :meth:`_asyncio.AsyncResult.one_or_none`.
then :meth:`_asyncio.AsyncScalarResult.one_or_none`.
.. seealso::
:meth:`_asyncio.AsyncResult.one_or_none`
:meth:`_asyncio.AsyncScalarResult.one_or_none`
:meth:`_asyncio.AsyncResult.scalars`
@@ -403,12 +400,10 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
return await greenlet_spawn(self._only_one_row, True, True, False)
@overload
async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]:
...
async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: ...
@overload
async def scalar(self) -> Any:
...
async def scalar(self) -> Any: ...
async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
@@ -452,16 +447,13 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
@overload
def scalars(
self: AsyncResult[Tuple[_T]], index: Literal[0]
) -> AsyncScalarResult[_T]:
...
) -> AsyncScalarResult[_T]: ...
@overload
def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]:
...
def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: ...
@overload
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
...
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: ...
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
@@ -833,11 +825,9 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
"""
...
async def __aiter__(self) -> AsyncIterator[_R]:
...
def __aiter__(self) -> AsyncIterator[_R]: ...
async def __anext__(self) -> _R:
...
async def __anext__(self) -> _R: ...
async def first(self) -> Optional[_R]:
"""Fetch the first object or ``None`` if no object is present.
@@ -871,22 +861,20 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
...
@overload
async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T:
...
async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: ...
@overload
async def scalar_one(self) -> Any:
...
async def scalar_one(self) -> Any: ...
async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_engine.Result.scalars`
and then :meth:`_engine.Result.one`.
and then :meth:`_engine.AsyncScalarResult.one`.
.. seealso::
:meth:`_engine.Result.one`
:meth:`_engine.AsyncScalarResult.one`
:meth:`_engine.Result.scalars`
@@ -896,22 +884,20 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
@overload
async def scalar_one_or_none(
self: AsyncTupleResult[Tuple[_T]],
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar_one_or_none(self) -> Optional[Any]:
...
async def scalar_one_or_none(self) -> Optional[Any]: ...
async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
This is equivalent to calling :meth:`_engine.Result.scalars`
and then :meth:`_engine.Result.one_or_none`.
and then :meth:`_engine.AsyncScalarResult.one_or_none`.
.. seealso::
:meth:`_engine.Result.one_or_none`
:meth:`_engine.AsyncScalarResult.one_or_none`
:meth:`_engine.Result.scalars`
@@ -919,12 +905,12 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
...
@overload
async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]:
...
async def scalar(
self: AsyncTupleResult[Tuple[_T]],
) -> Optional[_T]: ...
@overload
async def scalar(self) -> Any:
...
async def scalar(self) -> Any: ...
async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result

View File

@@ -1,5 +1,5 @@
# ext/asyncio/scoping.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -364,7 +364,7 @@ class async_scoped_session(Generic[_AS]):
object is entered::
async with async_session.begin():
# .. ORM transaction is begun
... # ORM transaction is begun
Note that database IO will not normally occur when the session-level
transaction is begun, as database transactions begin on an
@@ -536,8 +536,7 @@ class async_scoped_session(Generic[_AS]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[_T]:
...
) -> Result[_T]: ...
@overload
async def execute(
@@ -549,8 +548,7 @@ class async_scoped_session(Generic[_AS]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> CursorResult[Any]:
...
) -> CursorResult[Any]: ...
@overload
async def execute(
@@ -562,8 +560,7 @@ class async_scoped_session(Generic[_AS]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[Any]:
...
) -> Result[Any]: ...
async def execute(
self,
@@ -811,28 +808,28 @@ class async_scoped_session(Generic[_AS]):
# construct async engines w/ async drivers
engines = {
'leader':create_async_engine("sqlite+aiosqlite:///leader.db"),
'other':create_async_engine("sqlite+aiosqlite:///other.db"),
'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"),
'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"),
"leader": create_async_engine("sqlite+aiosqlite:///leader.db"),
"other": create_async_engine("sqlite+aiosqlite:///other.db"),
"follower1": create_async_engine("sqlite+aiosqlite:///follower1.db"),
"follower2": create_async_engine("sqlite+aiosqlite:///follower2.db"),
}
class RoutingSession(Session):
def get_bind(self, mapper=None, clause=None, **kw):
# within get_bind(), return sync engines
if mapper and issubclass(mapper.class_, MyOtherClass):
return engines['other'].sync_engine
return engines["other"].sync_engine
elif self._flushing or isinstance(clause, (Update, Delete)):
return engines['leader'].sync_engine
return engines["leader"].sync_engine
else:
return engines[
random.choice(['follower1','follower2'])
random.choice(["follower1", "follower2"])
].sync_engine
# apply to AsyncSession using sync_session_class
AsyncSessionMaker = async_sessionmaker(
sync_session_class=RoutingSession
)
AsyncSessionMaker = async_sessionmaker(sync_session_class=RoutingSession)
The :meth:`_orm.Session.get_bind` method is called in a non-asyncio,
implicitly non-blocking context in the same manner as ORM event hooks
@@ -867,7 +864,7 @@ class async_scoped_session(Generic[_AS]):
This method retrieves the history for each instrumented
attribute on the instance and performs a comparison of the current
value to its previously committed value, if any.
value to its previously flushed or committed value, if any.
It is in effect a more expensive and accurate
version of checking for the given instance in the
@@ -1015,8 +1012,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar(
@@ -1027,8 +1023,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Any:
...
) -> Any: ...
async def scalar(
self,
@@ -1070,8 +1065,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[_T]:
...
) -> ScalarResult[_T]: ...
@overload
async def scalars(
@@ -1082,8 +1076,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[Any]:
...
) -> ScalarResult[Any]: ...
async def scalars(
self,
@@ -1182,8 +1175,7 @@ class async_scoped_session(Generic[_AS]):
Proxied for the :class:`_asyncio.AsyncSession` class on
behalf of the :class:`_asyncio.scoping.async_scoped_session` class.
Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects
no rows.
Raises :class:`_exc.NoResultFound` if the query selects no rows.
..versionadded: 2.0.22
@@ -1213,8 +1205,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncResult[_T]:
...
) -> AsyncResult[_T]: ...
@overload
async def stream(
@@ -1225,8 +1216,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncResult[Any]:
...
) -> AsyncResult[Any]: ...
async def stream(
self,
@@ -1265,8 +1255,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncScalarResult[_T]:
...
) -> AsyncScalarResult[_T]: ...
@overload
async def stream_scalars(
@@ -1277,8 +1266,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncScalarResult[Any]:
...
) -> AsyncScalarResult[Any]: ...
async def stream_scalars(
self,

View File

@@ -1,5 +1,5 @@
# ext/asyncio/session.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -38,6 +38,9 @@ from ...orm import Session
from ...orm import SessionTransaction
from ...orm import state as _instance_state
from ...util.concurrency import greenlet_spawn
from ...util.typing import Concatenate
from ...util.typing import ParamSpec
if TYPE_CHECKING:
from .engine import AsyncConnection
@@ -71,6 +74,7 @@ if TYPE_CHECKING:
_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
_P = ParamSpec("_P")
_T = TypeVar("_T", bound=Any)
@@ -332,9 +336,12 @@ class AsyncSession(ReversibleProxy[Session]):
)
async def run_sync(
self, fn: Callable[..., _T], *arg: Any, **kw: Any
self,
fn: Callable[Concatenate[Session, _P], _T],
*arg: _P.args,
**kw: _P.kwargs,
) -> _T:
"""Invoke the given synchronous (i.e. not async) callable,
'''Invoke the given synchronous (i.e. not async) callable,
passing a synchronous-style :class:`_orm.Session` as the first
argument.
@@ -344,25 +351,27 @@ class AsyncSession(ReversibleProxy[Session]):
E.g.::
def some_business_method(session: Session, param: str) -> str:
'''A synchronous function that does not require awaiting
"""A synchronous function that does not require awaiting
:param session: a SQLAlchemy Session, used synchronously
:return: an optional return value is supported
'''
"""
session.add(MyObject(param=param))
session.flush()
return "success"
async def do_something_async(async_engine: AsyncEngine) -> None:
'''an async function that uses awaiting'''
"""an async function that uses awaiting"""
with AsyncSession(async_engine) as async_session:
# run some_business_method() with a sync-style
# Session, proxied into an awaitable
return_code = await async_session.run_sync(some_business_method, param="param1")
return_code = await async_session.run_sync(
some_business_method, param="param1"
)
print(return_code)
This method maintains the asyncio event loop all the way through
@@ -384,9 +393,11 @@ class AsyncSession(ReversibleProxy[Session]):
:meth:`.AsyncConnection.run_sync`
:ref:`session_run_sync`
""" # noqa: E501
''' # noqa: E501
return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
return await greenlet_spawn(
fn, self.sync_session, *arg, _require_await=False, **kw
)
@overload
async def execute(
@@ -398,8 +409,7 @@ class AsyncSession(ReversibleProxy[Session]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[_T]:
...
) -> Result[_T]: ...
@overload
async def execute(
@@ -411,8 +421,7 @@ class AsyncSession(ReversibleProxy[Session]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> CursorResult[Any]:
...
) -> CursorResult[Any]: ...
@overload
async def execute(
@@ -424,8 +433,7 @@ class AsyncSession(ReversibleProxy[Session]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[Any]:
...
) -> Result[Any]: ...
async def execute(
self,
@@ -471,8 +479,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar(
@@ -483,8 +490,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Any:
...
) -> Any: ...
async def scalar(
self,
@@ -528,8 +534,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[_T]:
...
) -> ScalarResult[_T]: ...
@overload
async def scalars(
@@ -540,8 +545,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[Any]:
...
) -> ScalarResult[Any]: ...
async def scalars(
self,
@@ -624,8 +628,7 @@ class AsyncSession(ReversibleProxy[Session]):
"""Return an instance based on the given primary key identifier,
or raise an exception if not found.
Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects
no rows.
Raises :class:`_exc.NoResultFound` if the query selects no rows.
..versionadded: 2.0.22
@@ -655,8 +658,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncResult[_T]:
...
) -> AsyncResult[_T]: ...
@overload
async def stream(
@@ -667,8 +669,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncResult[Any]:
...
) -> AsyncResult[Any]: ...
async def stream(
self,
@@ -710,8 +711,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncScalarResult[_T]:
...
) -> AsyncScalarResult[_T]: ...
@overload
async def stream_scalars(
@@ -722,8 +722,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncScalarResult[Any]:
...
) -> AsyncScalarResult[Any]: ...
async def stream_scalars(
self,
@@ -812,7 +811,9 @@ class AsyncSession(ReversibleProxy[Session]):
"""
trans = self.sync_session.get_transaction()
if trans is not None:
return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
return AsyncSessionTransaction._retrieve_proxy_for_target(
trans, async_session=self
)
else:
return None
@@ -828,7 +829,9 @@ class AsyncSession(ReversibleProxy[Session]):
trans = self.sync_session.get_nested_transaction()
if trans is not None:
return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
return AsyncSessionTransaction._retrieve_proxy_for_target(
trans, async_session=self
)
else:
return None
@@ -879,28 +882,28 @@ class AsyncSession(ReversibleProxy[Session]):
# construct async engines w/ async drivers
engines = {
'leader':create_async_engine("sqlite+aiosqlite:///leader.db"),
'other':create_async_engine("sqlite+aiosqlite:///other.db"),
'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"),
'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"),
"leader": create_async_engine("sqlite+aiosqlite:///leader.db"),
"other": create_async_engine("sqlite+aiosqlite:///other.db"),
"follower1": create_async_engine("sqlite+aiosqlite:///follower1.db"),
"follower2": create_async_engine("sqlite+aiosqlite:///follower2.db"),
}
class RoutingSession(Session):
def get_bind(self, mapper=None, clause=None, **kw):
# within get_bind(), return sync engines
if mapper and issubclass(mapper.class_, MyOtherClass):
return engines['other'].sync_engine
return engines["other"].sync_engine
elif self._flushing or isinstance(clause, (Update, Delete)):
return engines['leader'].sync_engine
return engines["leader"].sync_engine
else:
return engines[
random.choice(['follower1','follower2'])
random.choice(["follower1", "follower2"])
].sync_engine
# apply to AsyncSession using sync_session_class
AsyncSessionMaker = async_sessionmaker(
sync_session_class=RoutingSession
)
AsyncSessionMaker = async_sessionmaker(sync_session_class=RoutingSession)
The :meth:`_orm.Session.get_bind` method is called in a non-asyncio,
implicitly non-blocking context in the same manner as ORM event hooks
@@ -956,7 +959,7 @@ class AsyncSession(ReversibleProxy[Session]):
object is entered::
async with async_session.begin():
# .. ORM transaction is begun
... # ORM transaction is begun
Note that database IO will not normally occur when the session-level
transaction is begun, as database transactions begin on an
@@ -1309,7 +1312,7 @@ class AsyncSession(ReversibleProxy[Session]):
This method retrieves the history for each instrumented
attribute on the instance and performs a comparison of the current
value to its previously committed value, if any.
value to its previously flushed or committed value, if any.
It is in effect a more expensive and accurate
version of checking for the given instance in the
@@ -1633,16 +1636,22 @@ class async_sessionmaker(Generic[_AS]):
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
async def run_some_sql(async_session: async_sessionmaker[AsyncSession]) -> None:
async def run_some_sql(
async_session: async_sessionmaker[AsyncSession],
) -> None:
async with async_session() as session:
session.add(SomeObject(data="object"))
session.add(SomeOtherObject(name="other object"))
await session.commit()
async def main() -> None:
# an AsyncEngine, which the AsyncSession will use for connection
# resources
engine = create_async_engine('postgresql+asyncpg://scott:tiger@localhost/')
engine = create_async_engine(
"postgresql+asyncpg://scott:tiger@localhost/"
)
# create a reusable factory for new AsyncSession instances
async_session = async_sessionmaker(engine)
@@ -1686,8 +1695,7 @@ class async_sessionmaker(Generic[_AS]):
expire_on_commit: bool = ...,
info: Optional[_InfoType] = ...,
**kw: Any,
):
...
): ...
@overload
def __init__(
@@ -1698,8 +1706,7 @@ class async_sessionmaker(Generic[_AS]):
expire_on_commit: bool = ...,
info: Optional[_InfoType] = ...,
**kw: Any,
):
...
): ...
def __init__(
self,
@@ -1743,7 +1750,6 @@ class async_sessionmaker(Generic[_AS]):
# commits transaction, closes session
"""
session = self()
@@ -1776,7 +1782,7 @@ class async_sessionmaker(Generic[_AS]):
AsyncSession = async_sessionmaker(some_engine)
AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://'))
AsyncSession.configure(bind=create_async_engine("sqlite+aiosqlite://"))
""" # noqa E501
self.kw.update(new_kw)
@@ -1862,12 +1868,27 @@ class AsyncSessionTransaction(
await greenlet_spawn(self._sync_transaction().commit)
@classmethod
def _regenerate_proxy_for_target( # type: ignore[override]
cls,
target: SessionTransaction,
async_session: AsyncSession,
**additional_kw: Any, # noqa: U100
) -> AsyncSessionTransaction:
sync_transaction = target
nested = target.nested
obj = cls.__new__(cls)
obj.session = async_session
obj.sync_transaction = obj._assign_proxied(sync_transaction)
obj.nested = nested
return obj
async def start(
self, is_ctxmanager: bool = False
) -> AsyncSessionTransaction:
self.sync_transaction = self._assign_proxied(
await greenlet_spawn(
self.session.sync_session.begin_nested # type: ignore
self.session.sync_session.begin_nested
if self.nested
else self.session.sync_session.begin
)

View File

@@ -1,5 +1,5 @@
# ext/automap.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -11,7 +11,7 @@ schema, typically though not necessarily one which is reflected.
It is hoped that the :class:`.AutomapBase` system provides a quick
and modernized solution to the problem that the very famous
`SQLSoup <https://sqlsoup.readthedocs.io/en/latest/>`_
`SQLSoup <https://pypi.org/project/sqlsoup/>`_
also tries to solve, that of generating a quick and rudimentary object
model from an existing database on the fly. By addressing the issue strictly
at the mapper configuration level, and integrating fully with existing
@@ -64,7 +64,7 @@ asking it to reflect the schema and produce mappings::
# collection-based relationships are by default named
# "<classname>_collection"
u1 = session.query(User).first()
print (u1.address_collection)
print(u1.address_collection)
Above, calling :meth:`.AutomapBase.prepare` while passing along the
:paramref:`.AutomapBase.prepare.reflect` parameter indicates that the
@@ -101,6 +101,7 @@ explicit table declaration::
from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey
from sqlalchemy.ext.automap import automap_base
engine = create_engine("sqlite:///mydatabase.db")
# produce our own MetaData object
@@ -108,13 +109,15 @@ explicit table declaration::
# we can reflect it ourselves from a database, using options
# such as 'only' to limit what tables we look at...
metadata.reflect(engine, only=['user', 'address'])
metadata.reflect(engine, only=["user", "address"])
# ... or just define our own Table objects with it (or combine both)
Table('user_order', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', ForeignKey('user.id'))
)
Table(
"user_order",
metadata,
Column("id", Integer, primary_key=True),
Column("user_id", ForeignKey("user.id")),
)
# we can then produce a set of mappings from this MetaData.
Base = automap_base(metadata=metadata)
@@ -123,8 +126,9 @@ explicit table declaration::
Base.prepare()
# mapped classes are ready
User, Address, Order = Base.classes.user, Base.classes.address,\
Base.classes.user_order
User = Base.classes.user
Address = Base.classes.address
Order = Base.classes.user_order
.. _automap_by_module:
@@ -177,18 +181,23 @@ the schema name ``default`` is used if no schema is present::
Base.metadata.create_all(e)
def module_name_for_table(cls, tablename, table):
if table.schema is not None:
return f"mymodule.{table.schema}"
else:
return f"mymodule.default"
Base = automap_base()
Base.prepare(e, modulename_for_table=module_name_for_table)
Base.prepare(e, schema="test_schema", modulename_for_table=module_name_for_table)
Base.prepare(e, schema="test_schema_2", modulename_for_table=module_name_for_table)
Base.prepare(
e, schema="test_schema", modulename_for_table=module_name_for_table
)
Base.prepare(
e, schema="test_schema_2", modulename_for_table=module_name_for_table
)
The same named-classes are organized into a hierarchical collection available
at :attr:`.AutomapBase.by_module`. This collection is traversed using the
@@ -251,12 +260,13 @@ established based on the table name we use. If our schema contains tables
# automap base
Base = automap_base()
# pre-declare User for the 'user' table
class User(Base):
__tablename__ = 'user'
__tablename__ = "user"
# override schema elements like Columns
user_name = Column('name', String)
user_name = Column("name", String)
# override relationships too, if desired.
# we must use the same name that automap would use for the
@@ -264,6 +274,7 @@ established based on the table name we use. If our schema contains tables
# generate for "address"
address_collection = relationship("address", collection_class=set)
# reflect
engine = create_engine("sqlite:///mydatabase.db")
Base.prepare(autoload_with=engine)
@@ -274,11 +285,11 @@ established based on the table name we use. If our schema contains tables
Address = Base.classes.address
u1 = session.query(User).first()
print (u1.address_collection)
print(u1.address_collection)
# the backref is still there:
a1 = session.query(Address).first()
print (a1.user)
print(a1.user)
Above, one of the more intricate details is that we illustrated overriding
one of the :func:`_orm.relationship` objects that automap would have created.
@@ -305,35 +316,49 @@ scheme for class names and a "pluralizer" for collection names using the
import re
import inflect
def camelize_classname(base, tablename, table):
"Produce a 'camelized' class name, e.g. "
"Produce a 'camelized' class name, e.g."
"'words_and_underscores' -> 'WordsAndUnderscores'"
return str(tablename[0].upper() + \
re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:]))
return str(
tablename[0].upper()
+ re.sub(
r"_([a-z])",
lambda m: m.group(1).upper(),
tablename[1:],
)
)
_pluralizer = inflect.engine()
def pluralize_collection(base, local_cls, referred_cls, constraint):
"Produce an 'uncamelized', 'pluralized' class name, e.g. "
"Produce an 'uncamelized', 'pluralized' class name, e.g."
"'SomeTerm' -> 'some_terms'"
referred_name = referred_cls.__name__
uncamelized = re.sub(r'[A-Z]',
lambda m: "_%s" % m.group(0).lower(),
referred_name)[1:]
uncamelized = re.sub(
r"[A-Z]",
lambda m: "_%s" % m.group(0).lower(),
referred_name,
)[1:]
pluralized = _pluralizer.plural(uncamelized)
return pluralized
from sqlalchemy.ext.automap import automap_base
Base = automap_base()
engine = create_engine("sqlite:///mydatabase.db")
Base.prepare(autoload_with=engine,
classname_for_table=camelize_classname,
name_for_collection_relationship=pluralize_collection
)
Base.prepare(
autoload_with=engine,
classname_for_table=camelize_classname,
name_for_collection_relationship=pluralize_collection,
)
From the above mapping, we would now have classes ``User`` and ``Address``,
where the collection from ``User`` to ``Address`` is called
@@ -422,16 +447,21 @@ Below is an illustration of how to send
options along to all one-to-many relationships::
from sqlalchemy.ext.automap import generate_relationship
from sqlalchemy.orm import interfaces
def _gen_relationship(base, direction, return_fn,
attrname, local_cls, referred_cls, **kw):
def _gen_relationship(
base, direction, return_fn, attrname, local_cls, referred_cls, **kw
):
if direction is interfaces.ONETOMANY:
kw['cascade'] = 'all, delete-orphan'
kw['passive_deletes'] = True
kw["cascade"] = "all, delete-orphan"
kw["passive_deletes"] = True
# make use of the built-in function to actually return
# the result.
return generate_relationship(base, direction, return_fn,
attrname, local_cls, referred_cls, **kw)
return generate_relationship(
base, direction, return_fn, attrname, local_cls, referred_cls, **kw
)
from sqlalchemy.ext.automap import automap_base
from sqlalchemy import create_engine
@@ -440,8 +470,7 @@ options along to all one-to-many relationships::
Base = automap_base()
engine = create_engine("sqlite:///mydatabase.db")
Base.prepare(autoload_with=engine,
generate_relationship=_gen_relationship)
Base.prepare(autoload_with=engine, generate_relationship=_gen_relationship)
Many-to-Many relationships
--------------------------
@@ -482,18 +511,20 @@ two classes that are in an inheritance relationship. That is, with two
classes given as follows::
class Employee(Base):
__tablename__ = 'employee'
__tablename__ = "employee"
id = Column(Integer, primary_key=True)
type = Column(String(50))
__mapper_args__ = {
'polymorphic_identity':'employee', 'polymorphic_on': type
"polymorphic_identity": "employee",
"polymorphic_on": type,
}
class Engineer(Employee):
__tablename__ = 'engineer'
id = Column(Integer, ForeignKey('employee.id'), primary_key=True)
__tablename__ = "engineer"
id = Column(Integer, ForeignKey("employee.id"), primary_key=True)
__mapper_args__ = {
'polymorphic_identity':'engineer',
"polymorphic_identity": "engineer",
}
The foreign key from ``Engineer`` to ``Employee`` is used not for a
@@ -508,25 +539,28 @@ we want as well as the ``inherit_condition``, as these are not things
SQLAlchemy can guess::
class Employee(Base):
__tablename__ = 'employee'
__tablename__ = "employee"
id = Column(Integer, primary_key=True)
type = Column(String(50))
__mapper_args__ = {
'polymorphic_identity':'employee', 'polymorphic_on':type
"polymorphic_identity": "employee",
"polymorphic_on": type,
}
class Engineer(Employee):
__tablename__ = 'engineer'
id = Column(Integer, ForeignKey('employee.id'), primary_key=True)
favorite_employee_id = Column(Integer, ForeignKey('employee.id'))
favorite_employee = relationship(Employee,
foreign_keys=favorite_employee_id)
class Engineer(Employee):
__tablename__ = "engineer"
id = Column(Integer, ForeignKey("employee.id"), primary_key=True)
favorite_employee_id = Column(Integer, ForeignKey("employee.id"))
favorite_employee = relationship(
Employee, foreign_keys=favorite_employee_id
)
__mapper_args__ = {
'polymorphic_identity':'engineer',
'inherit_condition': id == Employee.id
"polymorphic_identity": "engineer",
"inherit_condition": id == Employee.id,
}
Handling Simple Naming Conflicts
@@ -559,20 +593,24 @@ and will emit an error on mapping.
We can resolve this conflict by using an underscore as follows::
def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
def name_for_scalar_relationship(
base, local_cls, referred_cls, constraint
):
name = referred_cls.__name__.lower()
local_table = local_cls.__table__
if name in local_table.columns:
newname = name + "_"
warnings.warn(
"Already detected name %s present. using %s" %
(name, newname))
"Already detected name %s present. using %s" % (name, newname)
)
return newname
return name
Base.prepare(autoload_with=engine,
name_for_scalar_relationship=name_for_scalar_relationship)
Base.prepare(
autoload_with=engine,
name_for_scalar_relationship=name_for_scalar_relationship,
)
Alternatively, we can change the name on the column side. The columns
that are mapped can be modified using the technique described at
@@ -581,13 +619,14 @@ to a new name::
Base = automap_base()
class TableB(Base):
__tablename__ = 'table_b'
_table_a = Column('table_a', ForeignKey('table_a.id'))
__tablename__ = "table_b"
_table_a = Column("table_a", ForeignKey("table_a.id"))
Base.prepare(autoload_with=engine)
Using Automap with Explicit Declarations
========================================
@@ -603,26 +642,29 @@ defines table metadata::
Base = automap_base()
class User(Base):
__tablename__ = 'user'
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String)
class Address(Base):
__tablename__ = 'address'
__tablename__ = "address"
id = Column(Integer, primary_key=True)
email = Column(String)
user_id = Column(ForeignKey('user.id'))
user_id = Column(ForeignKey("user.id"))
# produce relationships
Base.prepare()
# mapping is complete, with "address_collection" and
# "user" relationships
a1 = Address(email='u1')
a2 = Address(email='u2')
a1 = Address(email="u1")
a2 = Address(email="u2")
u1 = User(address_collection=[a1, a2])
assert a1.user is u1
@@ -651,7 +693,8 @@ be applied as::
@event.listens_for(Base.metadata, "column_reflect")
def column_reflect(inspector, table, column_info):
# set column.key = "attr_<lower_case_name>"
column_info['key'] = "attr_%s" % column_info['name'].lower()
column_info["key"] = "attr_%s" % column_info["name"].lower()
# run reflection
Base.prepare(autoload_with=engine)
@@ -715,8 +758,9 @@ _VT = TypeVar("_VT", bound=Any)
class PythonNameForTableType(Protocol):
def __call__(self, base: Type[Any], tablename: str, table: Table) -> str:
...
def __call__(
self, base: Type[Any], tablename: str, table: Table
) -> str: ...
def classname_for_table(
@@ -763,8 +807,7 @@ class NameForScalarRelationshipType(Protocol):
local_cls: Type[Any],
referred_cls: Type[Any],
constraint: ForeignKeyConstraint,
) -> str:
...
) -> str: ...
def name_for_scalar_relationship(
@@ -804,8 +847,7 @@ class NameForCollectionRelationshipType(Protocol):
local_cls: Type[Any],
referred_cls: Type[Any],
constraint: ForeignKeyConstraint,
) -> str:
...
) -> str: ...
def name_for_collection_relationship(
@@ -850,8 +892,7 @@ class GenerateRelationshipType(Protocol):
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> Relationship[Any]:
...
) -> Relationship[Any]: ...
@overload
def __call__(
@@ -863,8 +904,7 @@ class GenerateRelationshipType(Protocol):
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> ORMBackrefArgument:
...
) -> ORMBackrefArgument: ...
def __call__(
self,
@@ -877,8 +917,7 @@ class GenerateRelationshipType(Protocol):
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> Union[ORMBackrefArgument, Relationship[Any]]:
...
) -> Union[ORMBackrefArgument, Relationship[Any]]: ...
@overload
@@ -890,8 +929,7 @@ def generate_relationship(
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> Relationship[Any]:
...
) -> Relationship[Any]: ...
@overload
@@ -903,8 +941,7 @@ def generate_relationship(
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> ORMBackrefArgument:
...
) -> ORMBackrefArgument: ...
def generate_relationship(
@@ -1008,6 +1045,12 @@ class AutomapBase:
User, Address = Base.classes.User, Base.classes.Address
For class names that overlap with a method name of
:class:`.util.Properties`, such as ``items()``, the getitem form
is also supported::
Item = Base.classes["items"]
"""
by_module: ClassVar[ByModuleProperties]

View File

@@ -1,5 +1,5 @@
# sqlalchemy/ext/baked.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# ext/baked.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -258,23 +258,19 @@ class BakedQuery:
is passed to the lambda::
sub_bq = self.bakery(lambda s: s.query(User.name))
sub_bq += lambda q: q.filter(
User.id == Address.user_id).correlate(Address)
sub_bq += lambda q: q.filter(User.id == Address.user_id).correlate(Address)
main_bq = self.bakery(lambda s: s.query(Address))
main_bq += lambda q: q.filter(
sub_bq.to_query(q).exists())
main_bq += lambda q: q.filter(sub_bq.to_query(q).exists())
In the case where the subquery is used in the first callable against
a :class:`.Session`, the :class:`.Session` is also accepted::
sub_bq = self.bakery(lambda s: s.query(User.name))
sub_bq += lambda q: q.filter(
User.id == Address.user_id).correlate(Address)
sub_bq += lambda q: q.filter(User.id == Address.user_id).correlate(Address)
main_bq = self.bakery(
lambda s: s.query(
Address.id, sub_bq.to_query(q).scalar_subquery())
lambda s: s.query(Address.id, sub_bq.to_query(q).scalar_subquery())
)
:param query_or_session: a :class:`_query.Query` object or a class
@@ -285,7 +281,7 @@ class BakedQuery:
.. versionadded:: 1.3
"""
""" # noqa: E501
if isinstance(query_or_session, Session):
session = query_or_session

View File

@@ -1,10 +1,9 @@
# ext/compiler.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""Provides an API for creation of custom ClauseElements and compilers.
@@ -18,9 +17,11 @@ more callables defining its compilation::
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnClause
class MyColumn(ColumnClause):
inherit_cache = True
@compiles(MyColumn)
def compile_mycolumn(element, compiler, **kw):
return "[%s]" % element.name
@@ -32,10 +33,12 @@ when the object is compiled to a string::
from sqlalchemy import select
s = select(MyColumn('x'), MyColumn('y'))
s = select(MyColumn("x"), MyColumn("y"))
print(str(s))
Produces::
Produces:
.. sourcecode:: sql
SELECT [x], [y]
@@ -47,6 +50,7 @@ invoked for the dialect in use::
from sqlalchemy.schema import DDLElement
class AlterColumn(DDLElement):
inherit_cache = False
@@ -54,14 +58,18 @@ invoked for the dialect in use::
self.column = column
self.cmd = cmd
@compiles(AlterColumn)
def visit_alter_column(element, compiler, **kw):
return "ALTER COLUMN %s ..." % element.column.name
@compiles(AlterColumn, 'postgresql')
@compiles(AlterColumn, "postgresql")
def visit_alter_column(element, compiler, **kw):
return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name,
element.column.name)
return "ALTER TABLE %s ALTER COLUMN %s ..." % (
element.table.name,
element.column.name,
)
The second ``visit_alter_table`` will be invoked when any ``postgresql``
dialect is used.
@@ -81,6 +89,7 @@ method which can be used for compilation of embedded attributes::
from sqlalchemy.sql.expression import Executable, ClauseElement
class InsertFromSelect(Executable, ClauseElement):
inherit_cache = False
@@ -88,20 +97,27 @@ method which can be used for compilation of embedded attributes::
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True, **kw),
compiler.process(element.select, **kw)
compiler.process(element.select, **kw),
)
insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5))
insert = InsertFromSelect(t1, select(t1).where(t1.c.x > 5))
print(insert)
Produces::
Produces (formatted for readability):
"INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z
FROM mytable WHERE mytable.x > :x_1)"
.. sourcecode:: sql
INSERT INTO mytable (
SELECT mytable.x, mytable.y, mytable.z
FROM mytable
WHERE mytable.x > :x_1
)
.. note::
@@ -121,11 +137,10 @@ below where we generate a CHECK constraint that embeds a SQL expression::
@compiles(MyConstraint)
def compile_my_constraint(constraint, ddlcompiler, **kw):
kw['literal_binds'] = True
kw["literal_binds"] = True
return "CONSTRAINT %s CHECK (%s)" % (
constraint.name,
ddlcompiler.sql_compiler.process(
constraint.expression, **kw)
ddlcompiler.sql_compiler.process(constraint.expression, **kw),
)
Above, we add an additional flag to the process step as called by
@@ -153,6 +168,7 @@ an endless loop. Such as, to add "prefix" to all insert statements::
from sqlalchemy.sql.expression import Insert
@compiles(Insert)
def prefix_inserts(insert, compiler, **kw):
return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
@@ -168,17 +184,16 @@ Changing Compilation of Types
``compiler`` works for types, too, such as below where we implement the
MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
@compiles(String, 'mssql')
@compiles(VARCHAR, 'mssql')
@compiles(String, "mssql")
@compiles(VARCHAR, "mssql")
def compile_varchar(element, compiler, **kw):
if element.length == 'max':
if element.length == "max":
return "VARCHAR('max')"
else:
return compiler.visit_VARCHAR(element, **kw)
foo = Table('foo', metadata,
Column('data', VARCHAR('max'))
)
foo = Table("foo", metadata, Column("data", VARCHAR("max")))
Subclassing Guidelines
======================
@@ -216,18 +231,23 @@ A synopsis is as follows:
from sqlalchemy.sql.expression import FunctionElement
class coalesce(FunctionElement):
name = 'coalesce'
name = "coalesce"
inherit_cache = True
@compiles(coalesce)
def compile(element, compiler, **kw):
return "coalesce(%s)" % compiler.process(element.clauses, **kw)
@compiles(coalesce, 'oracle')
@compiles(coalesce, "oracle")
def compile(element, compiler, **kw):
if len(element.clauses) > 2:
raise TypeError("coalesce only supports two arguments on Oracle")
raise TypeError(
"coalesce only supports two arguments on " "Oracle Database"
)
return "nvl(%s)" % compiler.process(element.clauses, **kw)
* :class:`.ExecutableDDLElement` - The root of all DDL expressions,
@@ -281,6 +301,7 @@ for example to the "synopsis" example indicated previously::
class MyColumn(ColumnClause):
inherit_cache = True
@compiles(MyColumn)
def compile_mycolumn(element, compiler, **kw):
return "[%s]" % element.name
@@ -319,11 +340,12 @@ caching::
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s (%s)" % (
compiler.process(element.table, asfrom=True, **kw),
compiler.process(element.select, **kw)
compiler.process(element.select, **kw),
)
While it is also possible that the above ``InsertFromSelect`` could be made to
@@ -359,28 +381,32 @@ For PostgreSQL and Microsoft SQL Server::
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import DateTime
class utcnow(expression.FunctionElement):
type = DateTime()
inherit_cache = True
@compiles(utcnow, 'postgresql')
@compiles(utcnow, "postgresql")
def pg_utcnow(element, compiler, **kw):
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
@compiles(utcnow, 'mssql')
@compiles(utcnow, "mssql")
def ms_utcnow(element, compiler, **kw):
return "GETUTCDATE()"
Example usage::
from sqlalchemy import (
Table, Column, Integer, String, DateTime, MetaData
)
from sqlalchemy import Table, Column, Integer, String, DateTime, MetaData
metadata = MetaData()
event = Table("event", metadata,
event = Table(
"event",
metadata,
Column("id", Integer, primary_key=True),
Column("description", String(50), nullable=False),
Column("timestamp", DateTime, server_default=utcnow())
Column("timestamp", DateTime, server_default=utcnow()),
)
"GREATEST" function
@@ -395,30 +421,30 @@ accommodates two arguments::
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import Numeric
class greatest(expression.FunctionElement):
type = Numeric()
name = 'greatest'
name = "greatest"
inherit_cache = True
@compiles(greatest)
def default_greatest(element, compiler, **kw):
return compiler.visit_function(element)
@compiles(greatest, 'sqlite')
@compiles(greatest, 'mssql')
@compiles(greatest, 'oracle')
@compiles(greatest, "sqlite")
@compiles(greatest, "mssql")
@compiles(greatest, "oracle")
def case_greatest(element, compiler, **kw):
arg1, arg2 = list(element.clauses)
return compiler.process(case((arg1 > arg2, arg1), else_=arg2), **kw)
Example usage::
Session.query(Account).\
filter(
greatest(
Account.checking_balance,
Account.savings_balance) > 10000
)
Session.query(Account).filter(
greatest(Account.checking_balance, Account.savings_balance) > 10000
)
"false" expression
------------------
@@ -429,16 +455,19 @@ don't have a "false" constant::
from sqlalchemy.sql import expression
from sqlalchemy.ext.compiler import compiles
class sql_false(expression.ColumnElement):
inherit_cache = True
@compiles(sql_false)
def default_false(element, compiler, **kw):
return "false"
@compiles(sql_false, 'mssql')
@compiles(sql_false, 'mysql')
@compiles(sql_false, 'oracle')
@compiles(sql_false, "mssql")
@compiles(sql_false, "mysql")
@compiles(sql_false, "oracle")
def int_false(element, compiler, **kw):
return "0"
@@ -448,19 +477,33 @@ Example usage::
exp = union_all(
select(users.c.name, sql_false().label("enrolled")),
select(customers.c.name, customers.c.enrolled)
select(customers.c.name, customers.c.enrolled),
)
"""
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Dict
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from .. import exc
from ..sql import sqltypes
if TYPE_CHECKING:
from ..sql.compiler import SQLCompiler
def compiles(class_, *specs):
_F = TypeVar("_F", bound=Callable[..., Any])
def compiles(class_: Type[Any], *specs: str) -> Callable[[_F], _F]:
"""Register a function as a compiler for a
given :class:`_expression.ClauseElement` type."""
def decorate(fn):
def decorate(fn: _F) -> _F:
# get an existing @compiles handler
existing = class_.__dict__.get("_compiler_dispatcher", None)
@@ -473,7 +516,9 @@ def compiles(class_, *specs):
if existing_dispatch:
def _wrap_existing_dispatch(element, compiler, **kw):
def _wrap_existing_dispatch(
element: Any, compiler: SQLCompiler, **kw: Any
) -> Any:
try:
return existing_dispatch(element, compiler, **kw)
except exc.UnsupportedCompilationError as uce:
@@ -505,7 +550,7 @@ def compiles(class_, *specs):
return decorate
def deregister(class_):
def deregister(class_: Type[Any]) -> None:
"""Remove all custom compilers associated with a given
:class:`_expression.ClauseElement` type.
@@ -517,10 +562,10 @@ def deregister(class_):
class _dispatcher:
def __init__(self):
self.specs = {}
def __init__(self) -> None:
self.specs: Dict[str, Callable[..., Any]] = {}
def __call__(self, element, compiler, **kw):
def __call__(self, element: Any, compiler: SQLCompiler, **kw: Any) -> Any:
# TODO: yes, this could also switch off of DBAPI in use.
fn = self.specs.get(compiler.dialect.name, None)
if not fn:

View File

@@ -1,5 +1,5 @@
# ext/declarative/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# ext/declarative/extensions.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -50,23 +50,26 @@ class ConcreteBase:
from sqlalchemy.ext.declarative import ConcreteBase
class Employee(ConcreteBase, Base):
__tablename__ = 'employee'
__tablename__ = "employee"
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
__mapper_args__ = {
'polymorphic_identity':'employee',
'concrete':True}
"polymorphic_identity": "employee",
"concrete": True,
}
class Manager(Employee):
__tablename__ = 'manager'
__tablename__ = "manager"
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
'polymorphic_identity':'manager',
'concrete':True}
"polymorphic_identity": "manager",
"concrete": True,
}
The name of the discriminator column used by :func:`.polymorphic_union`
defaults to the name ``type``. To suit the use case of a mapping where an
@@ -75,7 +78,7 @@ class ConcreteBase:
``_concrete_discriminator_name`` attribute::
class Employee(ConcreteBase, Base):
_concrete_discriminator_name = '_concrete_discriminator'
_concrete_discriminator_name = "_concrete_discriminator"
.. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name``
attribute to :class:`_declarative.ConcreteBase` so that the
@@ -168,23 +171,27 @@ class AbstractConcreteBase(ConcreteBase):
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.ext.declarative import AbstractConcreteBase
class Base(DeclarativeBase):
pass
class Employee(AbstractConcreteBase, Base):
pass
class Manager(Employee):
__tablename__ = 'manager'
__tablename__ = "manager"
employee_id = Column(Integer, primary_key=True)
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
'polymorphic_identity':'manager',
'concrete':True
"polymorphic_identity": "manager",
"concrete": True,
}
Base.registry.configure()
The abstract base class is handled by declarative in a special way;
@@ -200,10 +207,12 @@ class AbstractConcreteBase(ConcreteBase):
from sqlalchemy.ext.declarative import AbstractConcreteBase
class Company(Base):
__tablename__ = 'company'
__tablename__ = "company"
id = Column(Integer, primary_key=True)
class Employee(AbstractConcreteBase, Base):
strict_attrs = True
@@ -211,31 +220,31 @@ class AbstractConcreteBase(ConcreteBase):
@declared_attr
def company_id(cls):
return Column(ForeignKey('company.id'))
return Column(ForeignKey("company.id"))
@declared_attr
def company(cls):
return relationship("Company")
class Manager(Employee):
__tablename__ = 'manager'
__tablename__ = "manager"
name = Column(String(50))
manager_data = Column(String(40))
__mapper_args__ = {
'polymorphic_identity':'manager',
'concrete':True
"polymorphic_identity": "manager",
"concrete": True,
}
Base.registry.configure()
When we make use of our mappings however, both ``Manager`` and
``Employee`` will have an independently usable ``.company`` attribute::
session.execute(
select(Employee).filter(Employee.company.has(id=5))
)
session.execute(select(Employee).filter(Employee.company.has(id=5)))
:param strict_attrs: when specified on the base class, "strict" attribute
mode is enabled which attempts to limit ORM mapped attributes on the
@@ -366,10 +375,12 @@ class DeferredReflection:
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import DeferredReflection
Base = declarative_base()
class MyClass(DeferredReflection, Base):
__tablename__ = 'mytable'
__tablename__ = "mytable"
Above, ``MyClass`` is not yet mapped. After a series of
classes have been defined in the above fashion, all tables
@@ -391,17 +402,22 @@ class DeferredReflection:
class ReflectedOne(DeferredReflection, Base):
__abstract__ = True
class ReflectedTwo(DeferredReflection, Base):
__abstract__ = True
class MyClass(ReflectedOne):
__tablename__ = 'mytable'
__tablename__ = "mytable"
class MyOtherClass(ReflectedOne):
__tablename__ = 'myothertable'
__tablename__ = "myothertable"
class YetAnotherClass(ReflectedTwo):
__tablename__ = 'yetanothertable'
__tablename__ = "yetanothertable"
# ... etc.

View File

@@ -1,5 +1,5 @@
# ext/horizontal_shard.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -83,8 +83,7 @@ class ShardChooser(Protocol):
mapper: Optional[Mapper[_T]],
instance: Any,
clause: Optional[ClauseElement],
) -> Any:
...
) -> Any: ...
class IdentityChooser(Protocol):
@@ -97,8 +96,7 @@ class IdentityChooser(Protocol):
execution_options: OrmExecuteOptionsParameter,
bind_arguments: _BindArguments,
**kw: Any,
) -> Any:
...
) -> Any: ...
class ShardedQuery(Query[_T]):
@@ -127,12 +125,9 @@ class ShardedQuery(Query[_T]):
The shard_id can be passed for a 2.0 style execution to the
bind_arguments dictionary of :meth:`.Session.execute`::
results = session.execute(
stmt,
bind_arguments={"shard_id": "my_shard"}
)
results = session.execute(stmt, bind_arguments={"shard_id": "my_shard"})
"""
""" # noqa: E501
return self.execution_options(_sa_shard_id=shard_id)
@@ -323,7 +318,7 @@ class ShardedSession(Session):
state.identity_token = shard_id
return shard_id
def connection_callable( # type: ignore [override]
def connection_callable(
self,
mapper: Optional[Mapper[_T]] = None,
instance: Optional[Any] = None,
@@ -384,9 +379,9 @@ class set_shard_id(ORMOption):
the :meth:`_sql.Executable.options` method of any executable statement::
stmt = (
select(MyObject).
where(MyObject.name == 'some name').
options(set_shard_id("shard1"))
select(MyObject)
.where(MyObject.name == "some name")
.options(set_shard_id("shard1"))
)
Above, the statement when invoked will limit to the "shard1" shard

View File

@@ -1,5 +1,5 @@
# ext/hybrid.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -34,8 +34,9 @@ may receive the class directly, depending on context::
class Base(DeclarativeBase):
pass
class Interval(Base):
__tablename__ = 'interval'
__tablename__ = "interval"
id: Mapped[int] = mapped_column(primary_key=True)
start: Mapped[int]
@@ -57,7 +58,6 @@ may receive the class directly, depending on context::
def intersects(self, other: Interval) -> bool:
return self.contains(other.start) | self.contains(other.end)
Above, the ``length`` property returns the difference between the
``end`` and ``start`` attributes. With an instance of ``Interval``,
this subtraction occurs in Python, using normal Python descriptor
@@ -150,6 +150,7 @@ the absolute value function::
from sqlalchemy import func
from sqlalchemy import type_coerce
class Interval(Base):
# ...
@@ -214,6 +215,7 @@ example below that illustrates the use of :meth:`.hybrid_property.setter` and
# correct use, however is not accepted by pep-484 tooling
class Interval(Base):
# ...
@@ -256,6 +258,7 @@ a single decorator under one name::
# correct use which is also accepted by pep-484 tooling
class Interval(Base):
# ...
@@ -330,6 +333,7 @@ expression is used as the column that's the target of the SET. If our
``Interval.start``, this could be substituted directly::
from sqlalchemy import update
stmt = update(Interval).values({Interval.start_point: 10})
However, when using a composite hybrid like ``Interval.length``, this
@@ -340,6 +344,7 @@ A handler that works similarly to our setter would be::
from typing import List, Tuple, Any
class Interval(Base):
# ...
@@ -352,10 +357,10 @@ A handler that works similarly to our setter would be::
self.end = self.start + value
@length.inplace.update_expression
def _length_update_expression(cls, value: Any) -> List[Tuple[Any, Any]]:
return [
(cls.end, cls.start + value)
]
def _length_update_expression(
cls, value: Any
) -> List[Tuple[Any, Any]]:
return [(cls.end, cls.start + value)]
Above, if we use ``Interval.length`` in an UPDATE expression, we get
a hybrid SET expression:
@@ -412,15 +417,16 @@ mapping which relates a ``User`` to a ``SavingsAccount``::
class SavingsAccount(Base):
__tablename__ = 'account'
__tablename__ = "account"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey('user.id'))
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
balance: Mapped[Decimal] = mapped_column(Numeric(15, 5))
owner: Mapped[User] = relationship(back_populates="accounts")
class User(Base):
__tablename__ = 'user'
__tablename__ = "user"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(100))
@@ -448,7 +454,10 @@ mapping which relates a ``User`` to a ``SavingsAccount``::
@balance.inplace.expression
@classmethod
def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]:
return cast("SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance)
return cast(
"SQLColumnExpression[Optional[Decimal]]",
SavingsAccount.balance,
)
The above hybrid property ``balance`` works with the first
``SavingsAccount`` entry in the list of accounts for this user. The
@@ -471,8 +480,11 @@ be used in an appropriate context such that an appropriate join to
.. sourcecode:: pycon+sql
>>> from sqlalchemy import select
>>> print(select(User, User.balance).
... join(User.accounts).filter(User.balance > 5000))
>>> print(
... select(User, User.balance)
... .join(User.accounts)
... .filter(User.balance > 5000)
... )
{printsql}SELECT "user".id AS user_id, "user".name AS user_name,
account.balance AS account_balance
FROM "user" JOIN account ON "user".id = account.user_id
@@ -487,8 +499,11 @@ would use an outer join:
>>> from sqlalchemy import select
>>> from sqlalchemy import or_
>>> print (select(User, User.balance).outerjoin(User.accounts).
... filter(or_(User.balance < 5000, User.balance == None)))
>>> print(
... select(User, User.balance)
... .outerjoin(User.accounts)
... .filter(or_(User.balance < 5000, User.balance == None))
... )
{printsql}SELECT "user".id AS user_id, "user".name AS user_name,
account.balance AS account_balance
FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id
@@ -528,15 +543,16 @@ we can adjust our ``SavingsAccount`` example to aggregate the balances for
class SavingsAccount(Base):
__tablename__ = 'account'
__tablename__ = "account"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey('user.id'))
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
balance: Mapped[Decimal] = mapped_column(Numeric(15, 5))
owner: Mapped[User] = relationship(back_populates="accounts")
class User(Base):
__tablename__ = 'user'
__tablename__ = "user"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(100))
@@ -546,7 +562,9 @@ we can adjust our ``SavingsAccount`` example to aggregate the balances for
@hybrid_property
def balance(self) -> Decimal:
return sum((acc.balance for acc in self.accounts), start=Decimal("0"))
return sum(
(acc.balance for acc in self.accounts), start=Decimal("0")
)
@balance.inplace.expression
@classmethod
@@ -557,7 +575,6 @@ we can adjust our ``SavingsAccount`` example to aggregate the balances for
.label("total_balance")
)
The above recipe will give us the ``balance`` column which renders
a correlated SELECT:
@@ -604,6 +621,7 @@ named ``word_insensitive``::
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
class Base(DeclarativeBase):
pass
@@ -612,8 +630,9 @@ named ``word_insensitive``::
def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
return func.lower(self.__clause_element__()) == func.lower(other)
class SearchWord(Base):
__tablename__ = 'searchword'
__tablename__ = "searchword"
id: Mapped[int] = mapped_column(primary_key=True)
word: Mapped[str]
@@ -675,6 +694,7 @@ how the standard Python ``@property`` object works::
def _name_setter(self, value: str) -> None:
self.first_name = value
class FirstNameLastName(FirstNameOnly):
# ...
@@ -684,11 +704,11 @@ how the standard Python ``@property`` object works::
# of FirstNameOnly.name that is local to FirstNameLastName
@FirstNameOnly.name.getter
def name(self) -> str:
return self.first_name + ' ' + self.last_name
return self.first_name + " " + self.last_name
@name.inplace.setter
def _name_setter(self, value: str) -> None:
self.first_name, self.last_name = value.split(' ', 1)
self.first_name, self.last_name = value.split(" ", 1)
Above, the ``FirstNameLastName`` class refers to the hybrid from
``FirstNameOnly.name`` to repurpose its getter and setter for the subclass.
@@ -709,8 +729,7 @@ reference the instrumented attribute back to the hybrid object::
@FirstNameOnly.name.overrides.expression
@classmethod
def name(cls):
return func.concat(cls.first_name, ' ', cls.last_name)
return func.concat(cls.first_name, " ", cls.last_name)
Hybrid Value Objects
--------------------
@@ -751,7 +770,7 @@ Replacing the previous ``CaseInsensitiveComparator`` class with a new
def __str__(self):
return self.word
key = 'word'
key = "word"
"Label to apply to Query tuple results"
Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may
@@ -762,7 +781,7 @@ SQL side or Python side. Our ``SearchWord`` class can now deliver the
``CaseInsensitiveWord`` object unconditionally from a single hybrid call::
class SearchWord(Base):
__tablename__ = 'searchword'
__tablename__ = "searchword"
id: Mapped[int] = mapped_column(primary_key=True)
word: Mapped[str]
@@ -904,13 +923,11 @@ class HybridExtensionType(InspectionAttrExtensionType):
class _HybridGetterType(Protocol[_T_co]):
def __call__(s, self: Any) -> _T_co:
...
def __call__(s, self: Any) -> _T_co: ...
class _HybridSetterType(Protocol[_T_con]):
def __call__(s, self: Any, value: _T_con) -> None:
...
def __call__(s, self: Any, value: _T_con) -> None: ...
class _HybridUpdaterType(Protocol[_T_con]):
@@ -918,25 +935,21 @@ class _HybridUpdaterType(Protocol[_T_con]):
s,
cls: Any,
value: Union[_T_con, _ColumnExpressionArgument[_T_con]],
) -> List[Tuple[_DMLColumnArgument, Any]]:
...
) -> List[Tuple[_DMLColumnArgument, Any]]: ...
class _HybridDeleterType(Protocol[_T_co]):
def __call__(s, self: Any) -> None:
...
def __call__(s, self: Any) -> None: ...
class _HybridExprCallableType(Protocol[_T_co]):
def __call__(
s, cls: Any
) -> Union[_HasClauseElement, SQLColumnExpression[_T_co]]:
...
) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]: ...
class _HybridComparatorCallableType(Protocol[_T]):
def __call__(self, cls: Any) -> Comparator[_T]:
...
def __call__(self, cls: Any) -> Comparator[_T]: ...
class _HybridClassLevelAccessor(QueryableAttribute[_T]):
@@ -947,23 +960,24 @@ class _HybridClassLevelAccessor(QueryableAttribute[_T]):
if TYPE_CHECKING:
def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]:
...
def getter(
self, fget: _HybridGetterType[_T]
) -> hybrid_property[_T]: ...
def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]:
...
def setter(
self, fset: _HybridSetterType[_T]
) -> hybrid_property[_T]: ...
def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]:
...
def deleter(
self, fdel: _HybridDeleterType[_T]
) -> hybrid_property[_T]: ...
@property
def overrides(self) -> hybrid_property[_T]:
...
def overrides(self) -> hybrid_property[_T]: ...
def update_expression(
self, meth: _HybridUpdaterType[_T]
) -> hybrid_property[_T]:
...
) -> hybrid_property[_T]: ...
class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]):
@@ -988,6 +1002,7 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]):
from sqlalchemy.ext.hybrid import hybrid_method
class SomeClass:
@hybrid_method
def value(self, x, y):
@@ -1025,14 +1040,12 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]):
@overload
def __get__(
self, instance: Literal[None], owner: Type[object]
) -> Callable[_P, SQLCoreOperations[_R]]:
...
) -> Callable[_P, SQLCoreOperations[_R]]: ...
@overload
def __get__(
self, instance: object, owner: Type[object]
) -> Callable[_P, _R]:
...
) -> Callable[_P, _R]: ...
def __get__(
self, instance: Optional[object], owner: Type[object]
@@ -1087,6 +1100,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
from sqlalchemy.ext.hybrid import hybrid_property
class SomeClass:
@hybrid_property
def value(self):
@@ -1103,21 +1117,18 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
self.expr = _unwrap_classmethod(expr)
self.custom_comparator = _unwrap_classmethod(custom_comparator)
self.update_expr = _unwrap_classmethod(update_expr)
util.update_wrapper(self, fget)
util.update_wrapper(self, fget) # type: ignore[arg-type]
@overload
def __get__(self, instance: Any, owner: Literal[None]) -> Self:
...
def __get__(self, instance: Any, owner: Literal[None]) -> Self: ...
@overload
def __get__(
self, instance: Literal[None], owner: Type[object]
) -> _HybridClassLevelAccessor[_T]:
...
) -> _HybridClassLevelAccessor[_T]: ...
@overload
def __get__(self, instance: object, owner: Type[object]) -> _T:
...
def __get__(self, instance: object, owner: Type[object]) -> _T: ...
def __get__(
self, instance: Optional[object], owner: Optional[Type[object]]
@@ -1168,6 +1179,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
def foobar(self):
return self._foobar
class SubClass(SuperClass):
# ...
@@ -1377,10 +1389,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
@fullname.update_expression
def fullname(cls, value):
fname, lname = value.split(" ", 1)
return [
(cls.first_name, fname),
(cls.last_name, lname)
]
return [(cls.first_name, fname), (cls.last_name, lname)]
.. versionadded:: 1.2
@@ -1447,7 +1456,7 @@ class Comparator(interfaces.PropComparator[_T]):
classes for usage with hybrids."""
def __init__(
self, expression: Union[_HasClauseElement, SQLColumnExpression[_T]]
self, expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]]
):
self.expression = expression
@@ -1482,7 +1491,7 @@ class ExprComparator(Comparator[_T]):
def __init__(
self,
cls: Type[Any],
expression: Union[_HasClauseElement, SQLColumnExpression[_T]],
expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]],
hybrid: hybrid_property[_T],
):
self.cls = cls

View File

@@ -1,5 +1,5 @@
# ext/index.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# ext/indexable.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -36,19 +36,19 @@ as a dedicated attribute which behaves like a standalone column::
Base = declarative_base()
class Person(Base):
__tablename__ = 'person'
__tablename__ = "person"
id = Column(Integer, primary_key=True)
data = Column(JSON)
name = index_property('data', 'name')
name = index_property("data", "name")
Above, the ``name`` attribute now behaves like a mapped column. We
can compose a new ``Person`` and set the value of ``name``::
>>> person = Person(name='Alchemist')
>>> person = Person(name="Alchemist")
The value is now accessible::
@@ -59,11 +59,11 @@ Behind the scenes, the JSON field was initialized to a new blank dictionary
and the field was set::
>>> person.data
{"name": "Alchemist'}
{'name': 'Alchemist'}
The field is mutable in place::
>>> person.name = 'Renamed'
>>> person.name = "Renamed"
>>> person.name
'Renamed'
>>> person.data
@@ -87,18 +87,17 @@ A missing key will produce ``AttributeError``::
>>> person = Person()
>>> person.name
...
AttributeError: 'name'
Unless you set a default value::
>>> class Person(Base):
>>> __tablename__ = 'person'
>>>
>>> id = Column(Integer, primary_key=True)
>>> data = Column(JSON)
>>>
>>> name = index_property('data', 'name', default=None) # See default
... __tablename__ = "person"
...
... id = Column(Integer, primary_key=True)
... data = Column(JSON)
...
... name = index_property("data", "name", default=None) # See default
>>> person = Person()
>>> print(person.name)
@@ -111,11 +110,11 @@ an indexed SQL criteria::
>>> from sqlalchemy.orm import Session
>>> session = Session()
>>> query = session.query(Person).filter(Person.name == 'Alchemist')
>>> query = session.query(Person).filter(Person.name == "Alchemist")
The above query is equivalent to::
>>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist')
>>> query = session.query(Person).filter(Person.data["name"] == "Alchemist")
Multiple :class:`.index_property` objects can be chained to produce
multiple levels of indexing::
@@ -126,22 +125,25 @@ multiple levels of indexing::
Base = declarative_base()
class Person(Base):
__tablename__ = 'person'
__tablename__ = "person"
id = Column(Integer, primary_key=True)
data = Column(JSON)
birthday = index_property('data', 'birthday')
year = index_property('birthday', 'year')
month = index_property('birthday', 'month')
day = index_property('birthday', 'day')
birthday = index_property("data", "birthday")
year = index_property("birthday", "year")
month = index_property("birthday", "month")
day = index_property("birthday", "day")
Above, a query such as::
q = session.query(Person).filter(Person.year == '1980')
q = session.query(Person).filter(Person.year == "1980")
On a PostgreSQL backend, the above query will render as::
On a PostgreSQL backend, the above query will render as:
.. sourcecode:: sql
SELECT person.id, person.data
FROM person
@@ -198,13 +200,14 @@ version of :class:`_postgresql.JSON`::
Base = declarative_base()
class Person(Base):
__tablename__ = 'person'
__tablename__ = "person"
id = Column(Integer, primary_key=True)
data = Column(JSON)
age = pg_json_property('data', 'age', Integer)
age = pg_json_property("data", "age", Integer)
The ``age`` attribute at the instance level works as before; however
when rendering SQL, PostgreSQL's ``->>`` operator will be used
@@ -212,7 +215,9 @@ for indexed access, instead of the usual index operator of ``->``::
>>> query = session.query(Person).filter(Person.age < 20)
The above query will render::
The above query will render:
.. sourcecode:: sql
SELECT person.id, person.data
FROM person

View File

@@ -1,5 +1,5 @@
# ext/instrumentation.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -214,9 +214,9 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
)(instance)
orm_instrumentation._instrumentation_factory = (
_instrumentation_factory
) = ExtendedInstrumentationRegistry()
orm_instrumentation._instrumentation_factory = _instrumentation_factory = (
ExtendedInstrumentationRegistry()
)
orm_instrumentation.instrumentation_finders = instrumentation_finders
@@ -436,17 +436,15 @@ def _install_lookups(lookups):
instance_dict = lookups["instance_dict"]
manager_of_class = lookups["manager_of_class"]
opt_manager_of_class = lookups["opt_manager_of_class"]
orm_base.instance_state = (
attributes.instance_state
) = orm_instrumentation.instance_state = instance_state
orm_base.instance_dict = (
attributes.instance_dict
) = orm_instrumentation.instance_dict = instance_dict
orm_base.manager_of_class = (
attributes.manager_of_class
) = orm_instrumentation.manager_of_class = manager_of_class
orm_base.opt_manager_of_class = (
orm_util.opt_manager_of_class
) = (
orm_base.instance_state = attributes.instance_state = (
orm_instrumentation.instance_state
) = instance_state
orm_base.instance_dict = attributes.instance_dict = (
orm_instrumentation.instance_dict
) = instance_dict
orm_base.manager_of_class = attributes.manager_of_class = (
orm_instrumentation.manager_of_class
) = manager_of_class
orm_base.opt_manager_of_class = orm_util.opt_manager_of_class = (
attributes.opt_manager_of_class
) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class

View File

@@ -1,5 +1,5 @@
# ext/mutable.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -21,6 +21,7 @@ JSON strings before being persisted::
from sqlalchemy.types import TypeDecorator, VARCHAR
import json
class JSONEncodedDict(TypeDecorator):
"Represents an immutable structure as a json-encoded string."
@@ -48,6 +49,7 @@ the :class:`.Mutable` mixin to a plain Python dictionary::
from sqlalchemy.ext.mutable import Mutable
class MutableDict(Mutable, dict):
@classmethod
def coerce(cls, key, value):
@@ -101,9 +103,11 @@ attribute. Such as, with classical table metadata::
from sqlalchemy import Table, Column, Integer
my_data = Table('my_data', metadata,
Column('id', Integer, primary_key=True),
Column('data', MutableDict.as_mutable(JSONEncodedDict))
my_data = Table(
"my_data",
metadata,
Column("id", Integer, primary_key=True),
Column("data", MutableDict.as_mutable(JSONEncodedDict)),
)
Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict``
@@ -115,13 +119,17 @@ mapping against the ``my_data`` table::
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
class Base(DeclarativeBase):
pass
class MyDataClass(Base):
__tablename__ = 'my_data'
__tablename__ = "my_data"
id: Mapped[int] = mapped_column(primary_key=True)
data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict))
data: Mapped[dict[str, str]] = mapped_column(
MutableDict.as_mutable(JSONEncodedDict)
)
The ``MyDataClass.data`` member will now be notified of in place changes
to its value.
@@ -132,11 +140,11 @@ will flag the attribute as "dirty" on the parent object::
>>> from sqlalchemy.orm import Session
>>> sess = Session(some_engine)
>>> m1 = MyDataClass(data={'value1':'foo'})
>>> m1 = MyDataClass(data={"value1": "foo"})
>>> sess.add(m1)
>>> sess.commit()
>>> m1.data['value1'] = 'bar'
>>> m1.data["value1"] = "bar"
>>> assert m1 in sess.dirty
True
@@ -153,15 +161,16 @@ the need to declare it individually::
MutableDict.associate_with(JSONEncodedDict)
class Base(DeclarativeBase):
pass
class MyDataClass(Base):
__tablename__ = 'my_data'
__tablename__ = "my_data"
id: Mapped[int] = mapped_column(primary_key=True)
data: Mapped[dict[str, str]] = mapped_column(JSONEncodedDict)
Supporting Pickling
--------------------
@@ -180,7 +189,7 @@ stream::
class MyMutableType(Mutable):
def __getstate__(self):
d = self.__dict__.copy()
d.pop('_parents', None)
d.pop("_parents", None)
return d
With our dictionary example, we need to return the contents of the dict itself
@@ -213,13 +222,18 @@ from within the mutable extension::
from sqlalchemy.orm import mapped_column
from sqlalchemy import event
class Base(DeclarativeBase):
pass
class MyDataClass(Base):
__tablename__ = 'my_data'
__tablename__ = "my_data"
id: Mapped[int] = mapped_column(primary_key=True)
data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict))
data: Mapped[dict[str, str]] = mapped_column(
MutableDict.as_mutable(JSONEncodedDict)
)
@event.listens_for(MyDataClass.data, "modified")
def modified_json(instance, initiator):
@@ -247,6 +261,7 @@ class introduced in :ref:`mapper_composite` to include
import dataclasses
from sqlalchemy.ext.mutable import MutableComposite
@dataclasses.dataclass
class Point(MutableComposite):
x: int
@@ -261,7 +276,6 @@ class introduced in :ref:`mapper_composite` to include
# alert all parents to the change
self.changed()
The :class:`.MutableComposite` class makes use of class mapping events to
automatically establish listeners for any usage of :func:`_orm.composite` that
specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex``
@@ -271,6 +285,7 @@ objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
from sqlalchemy.orm import DeclarativeBase, Mapped
from sqlalchemy.orm import composite, mapped_column
class Base(DeclarativeBase):
pass
@@ -280,8 +295,12 @@ objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
id: Mapped[int] = mapped_column(primary_key=True)
start: Mapped[Point] = composite(mapped_column("x1"), mapped_column("y1"))
end: Mapped[Point] = composite(mapped_column("x2"), mapped_column("y2"))
start: Mapped[Point] = composite(
mapped_column("x1"), mapped_column("y1")
)
end: Mapped[Point] = composite(
mapped_column("x2"), mapped_column("y2")
)
def __repr__(self):
return f"Vertex(start={self.start}, end={self.end})"
@@ -378,6 +397,7 @@ from weakref import WeakKeyDictionary
from .. import event
from .. import inspect
from .. import types
from .. import util
from ..orm import Mapper
from ..orm._typing import _ExternalEntityType
from ..orm._typing import _O
@@ -390,6 +410,7 @@ from ..orm.context import QueryContext
from ..orm.decl_api import DeclarativeAttributeIntercept
from ..orm.state import InstanceState
from ..orm.unitofwork import UOWTransaction
from ..sql._typing import _TypeEngineArgument
from ..sql.base import SchemaEventTarget
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
@@ -503,6 +524,7 @@ class MutableBase:
if val is not None:
if coerce:
val = cls.coerce(key, val)
assert val is not None
state.dict[key] = val
val._parents[state] = key
@@ -637,7 +659,7 @@ class Mutable(MutableBase):
event.listen(Mapper, "mapper_configured", listen_for_type)
@classmethod
def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]:
def as_mutable(cls, sqltype: _TypeEngineArgument[_T]) -> TypeEngine[_T]:
"""Associate a SQL type with this mutable Python type.
This establishes listeners that will detect ORM mappings against
@@ -646,9 +668,11 @@ class Mutable(MutableBase):
The type is returned, unconditionally as an instance, so that
:meth:`.as_mutable` can be used inline::
Table('mytable', metadata,
Column('id', Integer, primary_key=True),
Column('data', MyMutableType.as_mutable(PickleType))
Table(
"mytable",
metadata,
Column("id", Integer, primary_key=True),
Column("data", MyMutableType.as_mutable(PickleType)),
)
Note that the returned type is always an instance, even if a class
@@ -799,15 +823,12 @@ class MutableDict(Mutable, Dict[_KT, _VT]):
@overload
def setdefault(
self: MutableDict[_KT, Optional[_T]], key: _KT, value: None = None
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
def setdefault(self, key: _KT, value: _VT) -> _VT:
...
def setdefault(self, key: _KT, value: _VT) -> _VT: ...
def setdefault(self, key: _KT, value: object = None) -> object:
...
def setdefault(self, key: _KT, value: object = None) -> object: ...
else:
@@ -828,17 +849,14 @@ class MutableDict(Mutable, Dict[_KT, _VT]):
if TYPE_CHECKING:
@overload
def pop(self, __key: _KT) -> _VT:
...
def pop(self, __key: _KT) -> _VT: ...
@overload
def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T:
...
def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T: ...
def pop(
self, __key: _KT, __default: _VT | _T | None = None
) -> _VT | _T:
...
) -> _VT | _T: ...
else:
@@ -909,10 +927,10 @@ class MutableList(Mutable, List[_T]):
self[:] = state
def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]:
return not isinstance(value, Iterable)
return not util.is_non_string_iterable(value)
def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]:
return isinstance(value, Iterable)
return util.is_non_string_iterable(value)
def __setitem__(
self, index: SupportsIndex | slice, value: _T | Iterable[_T]

View File

@@ -0,0 +1,6 @@
# ext/mypy/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

View File

@@ -1,5 +1,5 @@
# ext/mypy/apply.py
# Copyright (C) 2021 the SQLAlchemy authors and contributors
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -161,9 +161,9 @@ def re_apply_declarative_assignments(
# update the SQLAlchemyAttribute with the better
# information
mapped_attr_lookup[
stmt.lvalues[0].name
].type = python_type_for_type
mapped_attr_lookup[stmt.lvalues[0].name].type = (
python_type_for_type
)
update_cls_metadata = True
@@ -199,11 +199,15 @@ def apply_type_to_mapped_statement(
To one that describes the final Python behavior to Mypy::
... format: off
class User(Base):
# ...
attrname : Mapped[Optional[int]] = <meaningless temp node>
... format: on
"""
left_node = lvalue.node
assert isinstance(left_node, Var)
@@ -223,9 +227,11 @@ def apply_type_to_mapped_statement(
lvalue.is_inferred_def = False
left_node.type = api.named_type(
NAMED_TYPE_SQLA_MAPPED,
[AnyType(TypeOfAny.special_form)]
if python_type_for_type is None
else [python_type_for_type],
(
[AnyType(TypeOfAny.special_form)]
if python_type_for_type is None
else [python_type_for_type]
),
)
# so to have it skip the right side totally, we can do this:

View File

@@ -1,5 +1,5 @@
# ext/mypy/decl_class.py
# Copyright (C) 2021 the SQLAlchemy authors and contributors
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -58,9 +58,9 @@ def scan_declarative_assignments_and_apply_types(
elif cls.fullname.startswith("builtins"):
return None
mapped_attributes: Optional[
List[util.SQLAlchemyAttribute]
] = util.get_mapped_attributes(info, api)
mapped_attributes: Optional[List[util.SQLAlchemyAttribute]] = (
util.get_mapped_attributes(info, api)
)
# used by assign.add_additional_orm_attributes among others
util.establish_as_sqlalchemy(info)

View File

@@ -1,5 +1,5 @@
# ext/mypy/infer.py
# Copyright (C) 2021 the SQLAlchemy authors and contributors
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -385,9 +385,9 @@ def _infer_type_from_decl_column(
class MyClass:
# ...
a : Mapped[int]
a: Mapped[int]
b : Mapped[str]
b: Mapped[str]
c: Mapped[int]

View File

@@ -1,5 +1,5 @@
# ext/mypy/names.py
# Copyright (C) 2021 the SQLAlchemy authors and contributors
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -58,6 +58,14 @@ NAMED_TYPE_BUILTINS_STR = "builtins.str"
NAMED_TYPE_BUILTINS_LIST = "builtins.list"
NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped"
_RelFullNames = {
"sqlalchemy.orm.relationships.Relationship",
"sqlalchemy.orm.relationships.RelationshipProperty",
"sqlalchemy.orm.relationships._RelationshipDeclared",
"sqlalchemy.orm.Relationship",
"sqlalchemy.orm.RelationshipProperty",
}
_lookup: Dict[str, Tuple[int, Set[str]]] = {
"Column": (
COLUMN,
@@ -66,24 +74,9 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = {
"sqlalchemy.sql.Column",
},
),
"Relationship": (
RELATIONSHIP,
{
"sqlalchemy.orm.relationships.Relationship",
"sqlalchemy.orm.relationships.RelationshipProperty",
"sqlalchemy.orm.Relationship",
"sqlalchemy.orm.RelationshipProperty",
},
),
"RelationshipProperty": (
RELATIONSHIP,
{
"sqlalchemy.orm.relationships.Relationship",
"sqlalchemy.orm.relationships.RelationshipProperty",
"sqlalchemy.orm.Relationship",
"sqlalchemy.orm.RelationshipProperty",
},
),
"Relationship": (RELATIONSHIP, _RelFullNames),
"RelationshipProperty": (RELATIONSHIP, _RelFullNames),
"_RelationshipDeclared": (RELATIONSHIP, _RelFullNames),
"registry": (
REGISTRY,
{
@@ -304,7 +297,7 @@ def type_id_for_callee(callee: Expression) -> Optional[int]:
def type_id_for_named_node(
node: Union[NameExpr, MemberExpr, SymbolNode]
node: Union[NameExpr, MemberExpr, SymbolNode],
) -> Optional[int]:
type_id, fullnames = _lookup.get(node.name, (None, None))

View File

@@ -1,5 +1,5 @@
# ext/mypy/plugin.py
# Copyright (C) 2021-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# ext/mypy/util.py
# Copyright (C) 2021-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -80,7 +80,7 @@ class SQLAlchemyAttribute:
"name": self.name,
"line": self.line,
"column": self.column,
"type": self.type.serialize(),
"type": serialize_type(self.type),
}
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
@@ -212,8 +212,7 @@ def add_global(
@overload
def get_callexpr_kwarg(
callexpr: CallExpr, name: str, *, expr_types: None = ...
) -> Optional[Union[CallExpr, NameExpr]]:
...
) -> Optional[Union[CallExpr, NameExpr]]: ...
@overload
@@ -222,8 +221,7 @@ def get_callexpr_kwarg(
name: str,
*,
expr_types: Tuple[TypingType[_TArgType], ...],
) -> Optional[_TArgType]:
...
) -> Optional[_TArgType]: ...
def get_callexpr_kwarg(
@@ -315,9 +313,11 @@ def unbound_to_instance(
return Instance(
bound_type,
[
unbound_to_instance(api, arg)
if isinstance(arg, UnboundType)
else arg
(
unbound_to_instance(api, arg)
if isinstance(arg, UnboundType)
else arg
)
for arg in typ.args
],
)
@@ -336,3 +336,22 @@ def info_for_cls(
return sym.node
return cls.info
def serialize_type(typ: Type) -> Union[str, JsonDict]:
try:
return typ.serialize()
except Exception:
pass
if hasattr(typ, "args"):
typ.args = tuple(
(
a.resolve_string_annotation()
if hasattr(a, "resolve_string_annotation")
else a
)
for a in typ.args
)
elif hasattr(typ, "resolve_string_annotation"):
typ = typ.resolve_string_annotation()
return typ.serialize()

View File

@@ -1,10 +1,9 @@
# ext/orderinglist.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""A custom list that manages index/position information for contained
elements.
@@ -26,18 +25,20 @@ displayed in order based on the value of the ``position`` column in the
Base = declarative_base()
class Slide(Base):
__tablename__ = 'slide'
__tablename__ = "slide"
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position")
class Bullet(Base):
__tablename__ = 'bullet'
__tablename__ = "bullet"
id = Column(Integer, primary_key=True)
slide_id = Column(Integer, ForeignKey('slide.id'))
slide_id = Column(Integer, ForeignKey("slide.id"))
position = Column(Integer)
text = Column(String)
@@ -57,19 +58,24 @@ constructed using the :func:`.ordering_list` factory::
Base = declarative_base()
class Slide(Base):
__tablename__ = 'slide'
__tablename__ = "slide"
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position",
collection_class=ordering_list('position'))
bullets = relationship(
"Bullet",
order_by="Bullet.position",
collection_class=ordering_list("position"),
)
class Bullet(Base):
__tablename__ = 'bullet'
__tablename__ = "bullet"
id = Column(Integer, primary_key=True)
slide_id = Column(Integer, ForeignKey('slide.id'))
slide_id = Column(Integer, ForeignKey("slide.id"))
position = Column(Integer)
text = Column(String)
@@ -122,17 +128,24 @@ start numbering at 1 or some other integer, provide ``count_from=1``.
"""
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Type
from typing import TypeVar
from typing import Union
from ..orm.collections import collection
from ..orm.collections import collection_adapter
from ..util.typing import SupportsIndex
_T = TypeVar("_T")
OrderingFunc = Callable[[int, Sequence[_T]], int]
OrderingFunc = Callable[[int, Sequence[_T]], object]
__all__ = ["ordering_list"]
@@ -141,9 +154,9 @@ __all__ = ["ordering_list"]
def ordering_list(
attr: str,
count_from: Optional[int] = None,
ordering_func: Optional[OrderingFunc] = None,
ordering_func: Optional[OrderingFunc[_T]] = None,
reorder_on_append: bool = False,
) -> Callable[[], OrderingList]:
) -> Callable[[], OrderingList[_T]]:
"""Prepares an :class:`OrderingList` factory for use in mapper definitions.
Returns an object suitable for use as an argument to a Mapper
@@ -151,14 +164,18 @@ def ordering_list(
from sqlalchemy.ext.orderinglist import ordering_list
class Slide(Base):
__tablename__ = 'slide'
__tablename__ = "slide"
id = Column(Integer, primary_key=True)
name = Column(String)
bullets = relationship("Bullet", order_by="Bullet.position",
collection_class=ordering_list('position'))
bullets = relationship(
"Bullet",
order_by="Bullet.position",
collection_class=ordering_list("position"),
)
:param attr:
Name of the mapped attribute to use for storage and retrieval of
@@ -185,22 +202,22 @@ def ordering_list(
# Ordering utility functions
def count_from_0(index, collection):
def count_from_0(index: int, collection: object) -> int:
"""Numbering function: consecutive integers starting at 0."""
return index
def count_from_1(index, collection):
def count_from_1(index: int, collection: object) -> int:
"""Numbering function: consecutive integers starting at 1."""
return index + 1
def count_from_n_factory(start):
def count_from_n_factory(start: int) -> OrderingFunc[Any]:
"""Numbering function: consecutive integers starting at arbitrary start."""
def f(index, collection):
def f(index: int, collection: object) -> int:
return index + start
try:
@@ -210,7 +227,7 @@ def count_from_n_factory(start):
return f
def _unsugar_count_from(**kw):
def _unsugar_count_from(**kw: Any) -> Dict[str, Any]:
"""Builds counting functions from keyword arguments.
Keyword argument filter, prepares a simple ``ordering_func`` from a
@@ -238,13 +255,13 @@ class OrderingList(List[_T]):
"""
ordering_attr: str
ordering_func: OrderingFunc
ordering_func: OrderingFunc[_T]
reorder_on_append: bool
def __init__(
self,
ordering_attr: Optional[str] = None,
ordering_func: Optional[OrderingFunc] = None,
ordering_attr: str,
ordering_func: Optional[OrderingFunc[_T]] = None,
reorder_on_append: bool = False,
):
"""A custom list that manages position information for its children.
@@ -304,10 +321,10 @@ class OrderingList(List[_T]):
# More complex serialization schemes (multi column, e.g.) are possible by
# subclassing and reimplementing these two methods.
def _get_order_value(self, entity):
def _get_order_value(self, entity: _T) -> Any:
return getattr(entity, self.ordering_attr)
def _set_order_value(self, entity, value):
def _set_order_value(self, entity: _T, value: Any) -> None:
setattr(entity, self.ordering_attr, value)
def reorder(self) -> None:
@@ -323,7 +340,9 @@ class OrderingList(List[_T]):
# As of 0.5, _reorder is no longer semi-private
_reorder = reorder
def _order_entity(self, index, entity, reorder=True):
def _order_entity(
self, index: int, entity: _T, reorder: bool = True
) -> None:
have = self._get_order_value(entity)
# Don't disturb existing ordering if reorder is False
@@ -334,34 +353,44 @@ class OrderingList(List[_T]):
if have != should_be:
self._set_order_value(entity, should_be)
def append(self, entity):
def append(self, entity: _T) -> None:
super().append(entity)
self._order_entity(len(self) - 1, entity, self.reorder_on_append)
def _raw_append(self, entity):
def _raw_append(self, entity: _T) -> None:
"""Append without any ordering behavior."""
super().append(entity)
_raw_append = collection.adds(1)(_raw_append)
def insert(self, index, entity):
def insert(self, index: SupportsIndex, entity: _T) -> None:
super().insert(index, entity)
self._reorder()
def remove(self, entity):
def remove(self, entity: _T) -> None:
super().remove(entity)
adapter = collection_adapter(self)
if adapter and adapter._referenced_by_owner:
self._reorder()
def pop(self, index=-1):
def pop(self, index: SupportsIndex = -1) -> _T:
entity = super().pop(index)
self._reorder()
return entity
def __setitem__(self, index, entity):
@overload
def __setitem__(self, index: SupportsIndex, entity: _T) -> None: ...
@overload
def __setitem__(self, index: slice, entity: Iterable[_T]) -> None: ...
def __setitem__(
self,
index: Union[SupportsIndex, slice],
entity: Union[_T, Iterable[_T]],
) -> None:
if isinstance(index, slice):
step = index.step or 1
start = index.start or 0
@@ -370,26 +399,18 @@ class OrderingList(List[_T]):
stop = index.stop or len(self)
if stop < 0:
stop += len(self)
entities = list(entity) # type: ignore[arg-type]
for i in range(start, stop, step):
self.__setitem__(i, entity[i])
self.__setitem__(i, entities[i])
else:
self._order_entity(index, entity, True)
super().__setitem__(index, entity)
self._order_entity(int(index), entity, True) # type: ignore[arg-type] # noqa: E501
super().__setitem__(index, entity) # type: ignore[assignment]
def __delitem__(self, index):
def __delitem__(self, index: Union[SupportsIndex, slice]) -> None:
super().__delitem__(index)
self._reorder()
def __setslice__(self, start, end, values):
super().__setslice__(start, end, values)
self._reorder()
def __delslice__(self, start, end):
super().__delslice__(start, end)
self._reorder()
def __reduce__(self):
def __reduce__(self) -> Any:
return _reconstitute, (self.__class__, self.__dict__, list(self))
for func_name, func in list(locals().items()):
@@ -403,7 +424,9 @@ class OrderingList(List[_T]):
del func_name, func
def _reconstitute(cls, dict_, items):
def _reconstitute(
cls: Type[OrderingList[_T]], dict_: Dict[str, Any], items: List[_T]
) -> OrderingList[_T]:
"""Reconstitute an :class:`.OrderingList`.
This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for

View File

@@ -1,5 +1,5 @@
# ext/serializer.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -28,13 +28,17 @@ when it is deserialized.
Usage is nearly the same as that of the standard Python pickle module::
from sqlalchemy.ext.serializer import loads, dumps
metadata = MetaData(bind=some_engine)
Session = scoped_session(sessionmaker())
# ... define mappers
query = Session.query(MyClass).
filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
query = (
Session.query(MyClass)
.filter(MyClass.somedata == "foo")
.order_by(MyClass.sortkey)
)
# pickle the query
serialized = dumps(query)
@@ -42,7 +46,7 @@ Usage is nearly the same as that of the standard Python pickle module::
# unpickle. Pass in metadata + scoped_session
query2 = loads(serialized, metadata, Session)
print query2.all()
print(query2.all())
Similar restrictions as when using raw pickle apply; mapped classes must be
themselves be pickleable, meaning they are importable from a module-level
@@ -82,10 +86,9 @@ from ..util import b64encode
__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
def Serializer(*args, **kw):
pickler = pickle.Pickler(*args, **kw)
class Serializer(pickle.Pickler):
def persistent_id(obj):
def persistent_id(self, obj):
# print "serializing:", repr(obj)
if isinstance(obj, Mapper) and not obj.non_primary:
id_ = "mapper:" + b64encode(pickle.dumps(obj.class_))
@@ -113,9 +116,6 @@ def Serializer(*args, **kw):
return None
return id_
pickler.persistent_id = persistent_id
return pickler
our_ids = re.compile(
r"(mapperprop|mapper|mapper_selectable|table|column|"
@@ -123,20 +123,23 @@ our_ids = re.compile(
)
def Deserializer(file, metadata=None, scoped_session=None, engine=None):
unpickler = pickle.Unpickler(file)
class Deserializer(pickle.Unpickler):
def get_engine():
if engine:
return engine
elif scoped_session and scoped_session().bind:
return scoped_session().bind
elif metadata and metadata.bind:
return metadata.bind
def __init__(self, file, metadata=None, scoped_session=None, engine=None):
super().__init__(file)
self.metadata = metadata
self.scoped_session = scoped_session
self.engine = engine
def get_engine(self):
if self.engine:
return self.engine
elif self.scoped_session and self.scoped_session().bind:
return self.scoped_session().bind
else:
return None
def persistent_load(id_):
def persistent_load(self, id_):
m = our_ids.match(str(id_))
if not m:
return None
@@ -157,20 +160,17 @@ def Deserializer(file, metadata=None, scoped_session=None, engine=None):
cls = pickle.loads(b64decode(mapper))
return class_mapper(cls).attrs[keyname]
elif type_ == "table":
return metadata.tables[args]
return self.metadata.tables[args]
elif type_ == "column":
table, colname = args.split(":")
return metadata.tables[table].c[colname]
return self.metadata.tables[table].c[colname]
elif type_ == "session":
return scoped_session()
return self.scoped_session()
elif type_ == "engine":
return get_engine()
return self.get_engine()
else:
raise Exception("Unknown token: %s" % type_)
unpickler.persistent_load = persistent_load
return unpickler
def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL):
buf = BytesIO()