API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,5 +1,5 @@
# ext/asyncio/__init__.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# ext/asyncio/base.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -44,12 +44,10 @@ class ReversibleProxy(Generic[_PT]):
__slots__ = ("__weakref__",)
@overload
def _assign_proxied(self, target: _PT) -> _PT:
...
def _assign_proxied(self, target: _PT) -> _PT: ...
@overload
def _assign_proxied(self, target: None) -> None:
...
def _assign_proxied(self, target: None) -> None: ...
def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
if target is not None:
@@ -73,28 +71,26 @@ class ReversibleProxy(Generic[_PT]):
cls._proxy_objects.pop(ref, None)
@classmethod
def _regenerate_proxy_for_target(cls, target: _PT) -> Self:
def _regenerate_proxy_for_target(
cls, target: _PT, **additional_kw: Any
) -> Self:
raise NotImplementedError()
@overload
@classmethod
def _retrieve_proxy_for_target(
cls,
target: _PT,
regenerate: Literal[True] = ...,
) -> Self:
...
cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any
) -> Self: ...
@overload
@classmethod
def _retrieve_proxy_for_target(
cls, target: _PT, regenerate: bool = True
) -> Optional[Self]:
...
cls, target: _PT, regenerate: bool = True, **additional_kw: Any
) -> Optional[Self]: ...
@classmethod
def _retrieve_proxy_for_target(
cls, target: _PT, regenerate: bool = True
cls, target: _PT, regenerate: bool = True, **additional_kw: Any
) -> Optional[Self]:
try:
proxy_ref = cls._proxy_objects[weakref.ref(target)]
@@ -106,7 +102,7 @@ class ReversibleProxy(Generic[_PT]):
return proxy # type: ignore
if regenerate:
return cls._regenerate_proxy_for_target(target)
return cls._regenerate_proxy_for_target(target, **additional_kw)
else:
return None
@@ -182,7 +178,7 @@ class GeneratorStartableContext(StartableContext[_T_co]):
# tell if we get the same exception back
value = typ()
try:
await util.athrow(self.gen, typ, value, traceback)
await self.gen.athrow(value)
except StopAsyncIteration as exc:
# Suppress StopIteration *unless* it's the same exception that
# was passed to throw(). This prevents a StopIteration
@@ -219,7 +215,7 @@ class GeneratorStartableContext(StartableContext[_T_co]):
def asyncstartablecontext(
func: Callable[..., AsyncIterator[_T_co]]
func: Callable[..., AsyncIterator[_T_co]],
) -> Callable[..., GeneratorStartableContext[_T_co]]:
"""@asyncstartablecontext decorator.
@@ -228,7 +224,9 @@ def asyncstartablecontext(
``@contextlib.asynccontextmanager`` supports, and the usage pattern
is different as well.
Typical usage::
Typical usage:
.. sourcecode:: text
@asyncstartablecontext
async def some_async_generator(<arguments>):

View File

@@ -1,5 +1,5 @@
# ext/asyncio/engine.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -41,6 +41,8 @@ from ...engine.base import NestedTransaction
from ...engine.base import Transaction
from ...exc import ArgumentError
from ...util.concurrency import greenlet_spawn
from ...util.typing import Concatenate
from ...util.typing import ParamSpec
if TYPE_CHECKING:
from ...engine.cursor import CursorResult
@@ -61,6 +63,7 @@ if TYPE_CHECKING:
from ...sql.base import Executable
from ...sql.selectable import TypedReturnsRows
_P = ParamSpec("_P")
_T = TypeVar("_T", bound=Any)
@@ -195,6 +198,7 @@ class AsyncConnection(
method of :class:`_asyncio.AsyncEngine`::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
async with engine.connect() as conn:
@@ -251,7 +255,7 @@ class AsyncConnection(
@classmethod
def _regenerate_proxy_for_target(
cls, target: Connection
cls, target: Connection, **additional_kw: Any # noqa: U100
) -> AsyncConnection:
return AsyncConnection(
AsyncEngine._retrieve_proxy_for_target(target.engine), target
@@ -414,13 +418,12 @@ class AsyncConnection(
yield_per: int = ...,
insertmanyvalues_page_size: int = ...,
schema_translate_map: Optional[SchemaTranslateMapType] = ...,
preserve_rowcount: bool = False,
**opt: Any,
) -> AsyncConnection:
...
) -> AsyncConnection: ...
@overload
async def execution_options(self, **opt: Any) -> AsyncConnection:
...
async def execution_options(self, **opt: Any) -> AsyncConnection: ...
async def execution_options(self, **opt: Any) -> AsyncConnection:
r"""Set non-SQL options for the connection which take effect
@@ -518,8 +521,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> GeneratorStartableContext[AsyncResult[_T]]:
...
) -> GeneratorStartableContext[AsyncResult[_T]]: ...
@overload
def stream(
@@ -528,8 +530,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> GeneratorStartableContext[AsyncResult[Any]]:
...
) -> GeneratorStartableContext[AsyncResult[Any]]: ...
@asyncstartablecontext
async def stream(
@@ -544,7 +545,7 @@ class AsyncConnection(
E.g.::
result = await conn.stream(stmt):
result = await conn.stream(stmt)
async for row in result:
print(f"{row}")
@@ -573,6 +574,11 @@ class AsyncConnection(
:meth:`.AsyncConnection.stream_scalars`
"""
if not self.dialect.supports_server_side_cursors:
raise exc.InvalidRequestError(
"Cant use `stream` or `stream_scalars` with the current "
"dialect since it does not support server side cursors."
)
result = await greenlet_spawn(
self._proxied.execute,
@@ -600,8 +606,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> CursorResult[_T]:
...
) -> CursorResult[_T]: ...
@overload
async def execute(
@@ -610,8 +615,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> CursorResult[Any]:
...
) -> CursorResult[Any]: ...
async def execute(
self,
@@ -667,8 +671,7 @@ class AsyncConnection(
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar(
@@ -677,8 +680,7 @@ class AsyncConnection(
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> Any:
...
) -> Any: ...
async def scalar(
self,
@@ -709,8 +711,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[_T]:
...
) -> ScalarResult[_T]: ...
@overload
async def scalars(
@@ -719,8 +720,7 @@ class AsyncConnection(
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
...
) -> ScalarResult[Any]: ...
async def scalars(
self,
@@ -752,8 +752,7 @@ class AsyncConnection(
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> GeneratorStartableContext[AsyncScalarResult[_T]]:
...
) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ...
@overload
def stream_scalars(
@@ -762,8 +761,7 @@ class AsyncConnection(
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> GeneratorStartableContext[AsyncScalarResult[Any]]:
...
) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ...
@asyncstartablecontext
async def stream_scalars(
@@ -819,9 +817,12 @@ class AsyncConnection(
yield result.scalars()
async def run_sync(
self, fn: Callable[..., _T], *arg: Any, **kw: Any
self,
fn: Callable[Concatenate[Connection, _P], _T],
*arg: _P.args,
**kw: _P.kwargs,
) -> _T:
"""Invoke the given synchronous (i.e. not async) callable,
'''Invoke the given synchronous (i.e. not async) callable,
passing a synchronous-style :class:`_engine.Connection` as the first
argument.
@@ -831,26 +832,26 @@ class AsyncConnection(
E.g.::
def do_something_with_core(conn: Connection, arg1: int, arg2: str) -> str:
'''A synchronous function that does not require awaiting
"""A synchronous function that does not require awaiting
:param conn: a Core SQLAlchemy Connection, used synchronously
:return: an optional return value is supported
'''
conn.execute(
some_table.insert().values(int_col=arg1, str_col=arg2)
)
"""
conn.execute(some_table.insert().values(int_col=arg1, str_col=arg2))
return "success"
async def do_something_async(async_engine: AsyncEngine) -> None:
'''an async function that uses awaiting'''
"""an async function that uses awaiting"""
async with async_engine.begin() as async_conn:
# run do_something_with_core() with a sync-style
# Connection, proxied into an awaitable
return_code = await async_conn.run_sync(do_something_with_core, 5, "strval")
return_code = await async_conn.run_sync(
do_something_with_core, 5, "strval"
)
print(return_code)
This method maintains the asyncio event loop all the way through
@@ -881,9 +882,11 @@ class AsyncConnection(
:ref:`session_run_sync`
""" # noqa: E501
''' # noqa: E501
return await greenlet_spawn(fn, self._proxied, *arg, **kw)
return await greenlet_spawn(
fn, self._proxied, *arg, _require_await=False, **kw
)
def __await__(self) -> Generator[Any, None, AsyncConnection]:
return self.start().__await__()
@@ -928,7 +931,7 @@ class AsyncConnection(
return self._proxied.invalidated
@property
def dialect(self) -> Any:
def dialect(self) -> Dialect:
r"""Proxy for the :attr:`_engine.Connection.dialect` attribute
on behalf of the :class:`_asyncio.AsyncConnection` class.
@@ -937,7 +940,7 @@ class AsyncConnection(
return self._proxied.dialect
@dialect.setter
def dialect(self, attr: Any) -> None:
def dialect(self, attr: Dialect) -> None:
self._proxied.dialect = attr
@property
@@ -998,6 +1001,7 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
:func:`_asyncio.create_async_engine` function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
.. versionadded:: 1.4
@@ -1037,7 +1041,9 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
return self.sync_engine
@classmethod
def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine:
def _regenerate_proxy_for_target(
cls, target: Engine, **additional_kw: Any # noqa: U100
) -> AsyncEngine:
return AsyncEngine(target)
@contextlib.asynccontextmanager
@@ -1054,7 +1060,6 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
)
await conn.execute(text("my_special_procedure(5)"))
"""
conn = self.connect()
@@ -1100,12 +1105,10 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
insertmanyvalues_page_size: int = ...,
schema_translate_map: Optional[SchemaTranslateMapType] = ...,
**opt: Any,
) -> AsyncEngine:
...
) -> AsyncEngine: ...
@overload
def execution_options(self, **opt: Any) -> AsyncEngine:
...
def execution_options(self, **opt: Any) -> AsyncEngine: ...
def execution_options(self, **opt: Any) -> AsyncEngine:
"""Return a new :class:`_asyncio.AsyncEngine` that will provide
@@ -1160,7 +1163,7 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
This applies **only** to the built-in cache that is established
via the :paramref:`_engine.create_engine.query_cache_size` parameter.
It will not impact any dictionary caches that were passed via the
:paramref:`.Connection.execution_options.query_cache` parameter.
:paramref:`.Connection.execution_options.compiled_cache` parameter.
.. versionadded:: 1.4
@@ -1343,7 +1346,7 @@ class AsyncTransaction(
@classmethod
def _regenerate_proxy_for_target(
cls, target: Transaction
cls, target: Transaction, **additional_kw: Any # noqa: U100
) -> AsyncTransaction:
sync_connection = target.connection
sync_transaction = target
@@ -1418,19 +1421,17 @@ class AsyncTransaction(
@overload
def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine:
...
def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: ...
@overload
def _get_sync_engine_or_connection(
async_engine: AsyncConnection,
) -> Connection:
...
) -> Connection: ...
def _get_sync_engine_or_connection(
async_engine: Union[AsyncEngine, AsyncConnection]
async_engine: Union[AsyncEngine, AsyncConnection],
) -> Union[Engine, Connection]:
if isinstance(async_engine, AsyncConnection):
return async_engine._proxied

View File

@@ -1,5 +1,5 @@
# ext/asyncio/exc.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# ext/asyncio/result.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -93,6 +93,7 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
self._metadata = real_result._metadata
self._unique_filter_state = real_result._unique_filter_state
self._source_supports_scalars = real_result._source_supports_scalars
self._post_creational_filter = None
# BaseCursorResult pre-generates the "_row_getter". Use that
@@ -324,22 +325,20 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
return await greenlet_spawn(self._only_one_row, True, False, False)
@overload
async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T:
...
async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: ...
@overload
async def scalar_one(self) -> Any:
...
async def scalar_one(self) -> Any: ...
async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
then :meth:`_asyncio.AsyncResult.one`.
then :meth:`_asyncio.AsyncScalarResult.one`.
.. seealso::
:meth:`_asyncio.AsyncResult.one`
:meth:`_asyncio.AsyncScalarResult.one`
:meth:`_asyncio.AsyncResult.scalars`
@@ -349,22 +348,20 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
@overload
async def scalar_one_or_none(
self: AsyncResult[Tuple[_T]],
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar_one_or_none(self) -> Optional[Any]:
...
async def scalar_one_or_none(self) -> Optional[Any]: ...
async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one scalar result or ``None``.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
then :meth:`_asyncio.AsyncResult.one_or_none`.
then :meth:`_asyncio.AsyncScalarResult.one_or_none`.
.. seealso::
:meth:`_asyncio.AsyncResult.one_or_none`
:meth:`_asyncio.AsyncScalarResult.one_or_none`
:meth:`_asyncio.AsyncResult.scalars`
@@ -403,12 +400,10 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
return await greenlet_spawn(self._only_one_row, True, True, False)
@overload
async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]:
...
async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: ...
@overload
async def scalar(self) -> Any:
...
async def scalar(self) -> Any: ...
async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
@@ -452,16 +447,13 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]):
@overload
def scalars(
self: AsyncResult[Tuple[_T]], index: Literal[0]
) -> AsyncScalarResult[_T]:
...
) -> AsyncScalarResult[_T]: ...
@overload
def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]:
...
def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: ...
@overload
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
...
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: ...
def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
@@ -833,11 +825,9 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
"""
...
async def __aiter__(self) -> AsyncIterator[_R]:
...
def __aiter__(self) -> AsyncIterator[_R]: ...
async def __anext__(self) -> _R:
...
async def __anext__(self) -> _R: ...
async def first(self) -> Optional[_R]:
"""Fetch the first object or ``None`` if no object is present.
@@ -871,22 +861,20 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
...
@overload
async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T:
...
async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: ...
@overload
async def scalar_one(self) -> Any:
...
async def scalar_one(self) -> Any: ...
async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_engine.Result.scalars`
and then :meth:`_engine.Result.one`.
and then :meth:`_engine.AsyncScalarResult.one`.
.. seealso::
:meth:`_engine.Result.one`
:meth:`_engine.AsyncScalarResult.one`
:meth:`_engine.Result.scalars`
@@ -896,22 +884,20 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
@overload
async def scalar_one_or_none(
self: AsyncTupleResult[Tuple[_T]],
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar_one_or_none(self) -> Optional[Any]:
...
async def scalar_one_or_none(self) -> Optional[Any]: ...
async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
This is equivalent to calling :meth:`_engine.Result.scalars`
and then :meth:`_engine.Result.one_or_none`.
and then :meth:`_engine.AsyncScalarResult.one_or_none`.
.. seealso::
:meth:`_engine.Result.one_or_none`
:meth:`_engine.AsyncScalarResult.one_or_none`
:meth:`_engine.Result.scalars`
@@ -919,12 +905,12 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly):
...
@overload
async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]:
...
async def scalar(
self: AsyncTupleResult[Tuple[_T]],
) -> Optional[_T]: ...
@overload
async def scalar(self) -> Any:
...
async def scalar(self) -> Any: ...
async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result

View File

@@ -1,5 +1,5 @@
# ext/asyncio/scoping.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -364,7 +364,7 @@ class async_scoped_session(Generic[_AS]):
object is entered::
async with async_session.begin():
# .. ORM transaction is begun
... # ORM transaction is begun
Note that database IO will not normally occur when the session-level
transaction is begun, as database transactions begin on an
@@ -536,8 +536,7 @@ class async_scoped_session(Generic[_AS]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[_T]:
...
) -> Result[_T]: ...
@overload
async def execute(
@@ -549,8 +548,7 @@ class async_scoped_session(Generic[_AS]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> CursorResult[Any]:
...
) -> CursorResult[Any]: ...
@overload
async def execute(
@@ -562,8 +560,7 @@ class async_scoped_session(Generic[_AS]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[Any]:
...
) -> Result[Any]: ...
async def execute(
self,
@@ -811,28 +808,28 @@ class async_scoped_session(Generic[_AS]):
# construct async engines w/ async drivers
engines = {
'leader':create_async_engine("sqlite+aiosqlite:///leader.db"),
'other':create_async_engine("sqlite+aiosqlite:///other.db"),
'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"),
'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"),
"leader": create_async_engine("sqlite+aiosqlite:///leader.db"),
"other": create_async_engine("sqlite+aiosqlite:///other.db"),
"follower1": create_async_engine("sqlite+aiosqlite:///follower1.db"),
"follower2": create_async_engine("sqlite+aiosqlite:///follower2.db"),
}
class RoutingSession(Session):
def get_bind(self, mapper=None, clause=None, **kw):
# within get_bind(), return sync engines
if mapper and issubclass(mapper.class_, MyOtherClass):
return engines['other'].sync_engine
return engines["other"].sync_engine
elif self._flushing or isinstance(clause, (Update, Delete)):
return engines['leader'].sync_engine
return engines["leader"].sync_engine
else:
return engines[
random.choice(['follower1','follower2'])
random.choice(["follower1", "follower2"])
].sync_engine
# apply to AsyncSession using sync_session_class
AsyncSessionMaker = async_sessionmaker(
sync_session_class=RoutingSession
)
AsyncSessionMaker = async_sessionmaker(sync_session_class=RoutingSession)
The :meth:`_orm.Session.get_bind` method is called in a non-asyncio,
implicitly non-blocking context in the same manner as ORM event hooks
@@ -867,7 +864,7 @@ class async_scoped_session(Generic[_AS]):
This method retrieves the history for each instrumented
attribute on the instance and performs a comparison of the current
value to its previously committed value, if any.
value to its previously flushed or committed value, if any.
It is in effect a more expensive and accurate
version of checking for the given instance in the
@@ -1015,8 +1012,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar(
@@ -1027,8 +1023,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Any:
...
) -> Any: ...
async def scalar(
self,
@@ -1070,8 +1065,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[_T]:
...
) -> ScalarResult[_T]: ...
@overload
async def scalars(
@@ -1082,8 +1076,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[Any]:
...
) -> ScalarResult[Any]: ...
async def scalars(
self,
@@ -1182,8 +1175,7 @@ class async_scoped_session(Generic[_AS]):
Proxied for the :class:`_asyncio.AsyncSession` class on
behalf of the :class:`_asyncio.scoping.async_scoped_session` class.
Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects
no rows.
Raises :class:`_exc.NoResultFound` if the query selects no rows.
..versionadded: 2.0.22
@@ -1213,8 +1205,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncResult[_T]:
...
) -> AsyncResult[_T]: ...
@overload
async def stream(
@@ -1225,8 +1216,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncResult[Any]:
...
) -> AsyncResult[Any]: ...
async def stream(
self,
@@ -1265,8 +1255,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncScalarResult[_T]:
...
) -> AsyncScalarResult[_T]: ...
@overload
async def stream_scalars(
@@ -1277,8 +1266,7 @@ class async_scoped_session(Generic[_AS]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncScalarResult[Any]:
...
) -> AsyncScalarResult[Any]: ...
async def stream_scalars(
self,

View File

@@ -1,5 +1,5 @@
# ext/asyncio/session.py
# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -38,6 +38,9 @@ from ...orm import Session
from ...orm import SessionTransaction
from ...orm import state as _instance_state
from ...util.concurrency import greenlet_spawn
from ...util.typing import Concatenate
from ...util.typing import ParamSpec
if TYPE_CHECKING:
from .engine import AsyncConnection
@@ -71,6 +74,7 @@ if TYPE_CHECKING:
_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
_P = ParamSpec("_P")
_T = TypeVar("_T", bound=Any)
@@ -332,9 +336,12 @@ class AsyncSession(ReversibleProxy[Session]):
)
async def run_sync(
self, fn: Callable[..., _T], *arg: Any, **kw: Any
self,
fn: Callable[Concatenate[Session, _P], _T],
*arg: _P.args,
**kw: _P.kwargs,
) -> _T:
"""Invoke the given synchronous (i.e. not async) callable,
'''Invoke the given synchronous (i.e. not async) callable,
passing a synchronous-style :class:`_orm.Session` as the first
argument.
@@ -344,25 +351,27 @@ class AsyncSession(ReversibleProxy[Session]):
E.g.::
def some_business_method(session: Session, param: str) -> str:
'''A synchronous function that does not require awaiting
"""A synchronous function that does not require awaiting
:param session: a SQLAlchemy Session, used synchronously
:return: an optional return value is supported
'''
"""
session.add(MyObject(param=param))
session.flush()
return "success"
async def do_something_async(async_engine: AsyncEngine) -> None:
'''an async function that uses awaiting'''
"""an async function that uses awaiting"""
with AsyncSession(async_engine) as async_session:
# run some_business_method() with a sync-style
# Session, proxied into an awaitable
return_code = await async_session.run_sync(some_business_method, param="param1")
return_code = await async_session.run_sync(
some_business_method, param="param1"
)
print(return_code)
This method maintains the asyncio event loop all the way through
@@ -384,9 +393,11 @@ class AsyncSession(ReversibleProxy[Session]):
:meth:`.AsyncConnection.run_sync`
:ref:`session_run_sync`
""" # noqa: E501
''' # noqa: E501
return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
return await greenlet_spawn(
fn, self.sync_session, *arg, _require_await=False, **kw
)
@overload
async def execute(
@@ -398,8 +409,7 @@ class AsyncSession(ReversibleProxy[Session]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[_T]:
...
) -> Result[_T]: ...
@overload
async def execute(
@@ -411,8 +421,7 @@ class AsyncSession(ReversibleProxy[Session]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> CursorResult[Any]:
...
) -> CursorResult[Any]: ...
@overload
async def execute(
@@ -424,8 +433,7 @@ class AsyncSession(ReversibleProxy[Session]):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result[Any]:
...
) -> Result[Any]: ...
async def execute(
self,
@@ -471,8 +479,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
async def scalar(
@@ -483,8 +490,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Any:
...
) -> Any: ...
async def scalar(
self,
@@ -528,8 +534,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[_T]:
...
) -> ScalarResult[_T]: ...
@overload
async def scalars(
@@ -540,8 +545,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[Any]:
...
) -> ScalarResult[Any]: ...
async def scalars(
self,
@@ -624,8 +628,7 @@ class AsyncSession(ReversibleProxy[Session]):
"""Return an instance based on the given primary key identifier,
or raise an exception if not found.
Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects
no rows.
Raises :class:`_exc.NoResultFound` if the query selects no rows.
..versionadded: 2.0.22
@@ -655,8 +658,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncResult[_T]:
...
) -> AsyncResult[_T]: ...
@overload
async def stream(
@@ -667,8 +669,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncResult[Any]:
...
) -> AsyncResult[Any]: ...
async def stream(
self,
@@ -710,8 +711,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncScalarResult[_T]:
...
) -> AsyncScalarResult[_T]: ...
@overload
async def stream_scalars(
@@ -722,8 +722,7 @@ class AsyncSession(ReversibleProxy[Session]):
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncScalarResult[Any]:
...
) -> AsyncScalarResult[Any]: ...
async def stream_scalars(
self,
@@ -812,7 +811,9 @@ class AsyncSession(ReversibleProxy[Session]):
"""
trans = self.sync_session.get_transaction()
if trans is not None:
return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
return AsyncSessionTransaction._retrieve_proxy_for_target(
trans, async_session=self
)
else:
return None
@@ -828,7 +829,9 @@ class AsyncSession(ReversibleProxy[Session]):
trans = self.sync_session.get_nested_transaction()
if trans is not None:
return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
return AsyncSessionTransaction._retrieve_proxy_for_target(
trans, async_session=self
)
else:
return None
@@ -879,28 +882,28 @@ class AsyncSession(ReversibleProxy[Session]):
# construct async engines w/ async drivers
engines = {
'leader':create_async_engine("sqlite+aiosqlite:///leader.db"),
'other':create_async_engine("sqlite+aiosqlite:///other.db"),
'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"),
'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"),
"leader": create_async_engine("sqlite+aiosqlite:///leader.db"),
"other": create_async_engine("sqlite+aiosqlite:///other.db"),
"follower1": create_async_engine("sqlite+aiosqlite:///follower1.db"),
"follower2": create_async_engine("sqlite+aiosqlite:///follower2.db"),
}
class RoutingSession(Session):
def get_bind(self, mapper=None, clause=None, **kw):
# within get_bind(), return sync engines
if mapper and issubclass(mapper.class_, MyOtherClass):
return engines['other'].sync_engine
return engines["other"].sync_engine
elif self._flushing or isinstance(clause, (Update, Delete)):
return engines['leader'].sync_engine
return engines["leader"].sync_engine
else:
return engines[
random.choice(['follower1','follower2'])
random.choice(["follower1", "follower2"])
].sync_engine
# apply to AsyncSession using sync_session_class
AsyncSessionMaker = async_sessionmaker(
sync_session_class=RoutingSession
)
AsyncSessionMaker = async_sessionmaker(sync_session_class=RoutingSession)
The :meth:`_orm.Session.get_bind` method is called in a non-asyncio,
implicitly non-blocking context in the same manner as ORM event hooks
@@ -956,7 +959,7 @@ class AsyncSession(ReversibleProxy[Session]):
object is entered::
async with async_session.begin():
# .. ORM transaction is begun
... # ORM transaction is begun
Note that database IO will not normally occur when the session-level
transaction is begun, as database transactions begin on an
@@ -1309,7 +1312,7 @@ class AsyncSession(ReversibleProxy[Session]):
This method retrieves the history for each instrumented
attribute on the instance and performs a comparison of the current
value to its previously committed value, if any.
value to its previously flushed or committed value, if any.
It is in effect a more expensive and accurate
version of checking for the given instance in the
@@ -1633,16 +1636,22 @@ class async_sessionmaker(Generic[_AS]):
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
async def run_some_sql(async_session: async_sessionmaker[AsyncSession]) -> None:
async def run_some_sql(
async_session: async_sessionmaker[AsyncSession],
) -> None:
async with async_session() as session:
session.add(SomeObject(data="object"))
session.add(SomeOtherObject(name="other object"))
await session.commit()
async def main() -> None:
# an AsyncEngine, which the AsyncSession will use for connection
# resources
engine = create_async_engine('postgresql+asyncpg://scott:tiger@localhost/')
engine = create_async_engine(
"postgresql+asyncpg://scott:tiger@localhost/"
)
# create a reusable factory for new AsyncSession instances
async_session = async_sessionmaker(engine)
@@ -1686,8 +1695,7 @@ class async_sessionmaker(Generic[_AS]):
expire_on_commit: bool = ...,
info: Optional[_InfoType] = ...,
**kw: Any,
):
...
): ...
@overload
def __init__(
@@ -1698,8 +1706,7 @@ class async_sessionmaker(Generic[_AS]):
expire_on_commit: bool = ...,
info: Optional[_InfoType] = ...,
**kw: Any,
):
...
): ...
def __init__(
self,
@@ -1743,7 +1750,6 @@ class async_sessionmaker(Generic[_AS]):
# commits transaction, closes session
"""
session = self()
@@ -1776,7 +1782,7 @@ class async_sessionmaker(Generic[_AS]):
AsyncSession = async_sessionmaker(some_engine)
AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://'))
AsyncSession.configure(bind=create_async_engine("sqlite+aiosqlite://"))
""" # noqa E501
self.kw.update(new_kw)
@@ -1862,12 +1868,27 @@ class AsyncSessionTransaction(
await greenlet_spawn(self._sync_transaction().commit)
@classmethod
def _regenerate_proxy_for_target( # type: ignore[override]
cls,
target: SessionTransaction,
async_session: AsyncSession,
**additional_kw: Any, # noqa: U100
) -> AsyncSessionTransaction:
sync_transaction = target
nested = target.nested
obj = cls.__new__(cls)
obj.session = async_session
obj.sync_transaction = obj._assign_proxied(sync_transaction)
obj.nested = nested
return obj
async def start(
self, is_ctxmanager: bool = False
) -> AsyncSessionTransaction:
self.sync_transaction = self._assign_proxied(
await greenlet_spawn(
self.session.sync_session.begin_nested # type: ignore
self.session.sync_session.begin_nested
if self.nested
else self.session.sync_session.begin
)