API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,5 +1,5 @@
# engine/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# sqlalchemy/processors.py
# Copyright (C) 2010-2023 the SQLAlchemy authors and contributors
# engine/_py_processors.py
# Copyright (C) 2010-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
#

View File

@@ -1,3 +1,9 @@
# engine/_py_row.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import operator

View File

@@ -1,3 +1,9 @@
# engine/_py_util.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import typing
@@ -26,9 +32,9 @@ def _distill_params_20(
# Assume list is more likely than tuple
elif isinstance(params, list) or isinstance(params, tuple):
# collections_abc.MutableSequence): # avoid abc.__instancecheck__
if params and not isinstance(params[0], (tuple, Mapping)):
if params and not isinstance(params[0], Mapping):
raise exc.ArgumentError(
"List argument must consist only of tuples or dictionaries"
"List argument must consist only of dictionaries"
)
return params

View File

@@ -1,12 +1,10 @@
# engine/base.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
"""
"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`."""
from __future__ import annotations
import contextlib
@@ -70,12 +68,11 @@ if typing.TYPE_CHECKING:
from ..sql._typing import _InfoType
from ..sql.compiler import Compiled
from ..sql.ddl import ExecutableDDLElement
from ..sql.ddl import SchemaDropper
from ..sql.ddl import SchemaGenerator
from ..sql.ddl import InvokeDDLBase
from ..sql.functions import FunctionElement
from ..sql.schema import DefaultGenerator
from ..sql.schema import HasSchemaAttr
from ..sql.schema import SchemaItem
from ..sql.schema import SchemaVisitable
from ..sql.selectable import TypedReturnsRows
@@ -109,6 +106,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""
dialect: Dialect
dispatch: dispatcher[ConnectionEventsTarget]
_sqla_logger_namespace = "sqlalchemy.engine.Connection"
@@ -173,13 +171,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if self._has_events or self.engine._has_events:
self.dispatch.engine_connect(self)
@util.memoized_property
def _message_formatter(self) -> Any:
if "logging_token" in self._execution_options:
token = self._execution_options["logging_token"]
return lambda msg: "[%s] %s" % (token, msg)
else:
return None
# this can be assigned differently via
# characteristics.LoggingTokenCharacteristic
_message_formatter: Any = None
def _log_info(self, message: str, *arg: Any, **kw: Any) -> None:
fmt = self._message_formatter
@@ -205,9 +199,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
@property
def _schema_translate_map(self) -> Optional[SchemaTranslateMapType]:
schema_translate_map: Optional[
SchemaTranslateMapType
] = self._execution_options.get("schema_translate_map", None)
schema_translate_map: Optional[SchemaTranslateMapType] = (
self._execution_options.get("schema_translate_map", None)
)
return schema_translate_map
@@ -218,9 +212,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""
name = obj.schema
schema_translate_map: Optional[
SchemaTranslateMapType
] = self._execution_options.get("schema_translate_map", None)
schema_translate_map: Optional[SchemaTranslateMapType] = (
self._execution_options.get("schema_translate_map", None)
)
if (
schema_translate_map
@@ -250,13 +244,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
yield_per: int = ...,
insertmanyvalues_page_size: int = ...,
schema_translate_map: Optional[SchemaTranslateMapType] = ...,
preserve_rowcount: bool = False,
**opt: Any,
) -> Connection:
...
) -> Connection: ...
@overload
def execution_options(self, **opt: Any) -> Connection:
...
def execution_options(self, **opt: Any) -> Connection: ...
def execution_options(self, **opt: Any) -> Connection:
r"""Set non-SQL options for the connection which take effect
@@ -382,12 +375,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
:param stream_results: Available on: :class:`_engine.Connection`,
:class:`_sql.Executable`.
Indicate to the dialect that results should be
"streamed" and not pre-buffered, if possible. For backends
such as PostgreSQL, MySQL and MariaDB, this indicates the use of
a "server side cursor" as opposed to a client side cursor.
Other backends such as that of Oracle may already use server
side cursors by default.
Indicate to the dialect that results should be "streamed" and not
pre-buffered, if possible. For backends such as PostgreSQL, MySQL
and MariaDB, this indicates the use of a "server side cursor" as
opposed to a client side cursor. Other backends such as that of
Oracle Database may already use server side cursors by default.
The usage of
:paramref:`_engine.Connection.execution_options.stream_results` is
@@ -492,6 +484,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
:ref:`schema_translating`
:param preserve_rowcount: Boolean; when True, the ``cursor.rowcount``
attribute will be unconditionally memoized within the result and
made available via the :attr:`.CursorResult.rowcount` attribute.
Normally, this attribute is only preserved for UPDATE and DELETE
statements. Using this option, the DBAPIs rowcount value can
be accessed for other kinds of statements such as INSERT and SELECT,
to the degree that the DBAPI supports these statements. See
:attr:`.CursorResult.rowcount` for notes regarding the behavior
of this attribute.
.. versionadded:: 2.0.28
.. seealso::
:meth:`_engine.Engine.execution_options`
@@ -793,7 +797,6 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
with conn.begin() as trans:
conn.execute(table.insert(), {"username": "sandy"})
The returned object is an instance of :class:`_engine.RootTransaction`.
This object represents the "scope" of the transaction,
which completes when either the :meth:`_engine.Transaction.rollback`
@@ -899,7 +902,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
trans.rollback() # rollback to savepoint
# outer transaction continues
connection.execute( ... )
connection.execute(...)
If :meth:`_engine.Connection.begin_nested` is called without first
calling :meth:`_engine.Connection.begin` or
@@ -909,11 +912,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
with engine.connect() as connection: # begin() wasn't called
with connection.begin_nested(): will auto-"begin()" first
connection.execute( ... )
with connection.begin_nested(): # will auto-"begin()" first
connection.execute(...)
# savepoint is released
connection.execute( ... )
connection.execute(...)
# explicitly commit outer transaction
connection.commit()
@@ -1109,10 +1112,16 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if self._still_open_and_dbapi_connection_is_valid:
if self._echo:
if self._is_autocommit_isolation():
self._log_info(
"ROLLBACK using DBAPI connection.rollback(), "
"DBAPI should ignore due to autocommit mode"
)
if self.dialect.skip_autocommit_rollback:
self._log_info(
"ROLLBACK will be skipped by "
"skip_autocommit_rollback"
)
else:
self._log_info(
"ROLLBACK using DBAPI connection.rollback(); "
"set skip_autocommit_rollback to prevent fully"
)
else:
self._log_info("ROLLBACK")
try:
@@ -1128,7 +1137,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if self._is_autocommit_isolation():
self._log_info(
"COMMIT using DBAPI connection.commit(), "
"DBAPI should ignore due to autocommit mode"
"has no effect due to autocommit mode"
)
else:
self._log_info("COMMIT")
@@ -1262,8 +1271,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> Optional[_T]:
...
) -> Optional[_T]: ...
@overload
def scalar(
@@ -1272,8 +1280,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
parameters: Optional[_CoreSingleExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> Any:
...
) -> Any: ...
def scalar(
self,
@@ -1311,8 +1318,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[_T]:
...
) -> ScalarResult[_T]: ...
@overload
def scalars(
@@ -1321,8 +1327,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> ScalarResult[Any]:
...
) -> ScalarResult[Any]: ...
def scalars(
self,
@@ -1356,8 +1361,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> CursorResult[_T]:
...
) -> CursorResult[_T]: ...
@overload
def execute(
@@ -1366,8 +1370,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
parameters: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: Optional[CoreExecuteOptionsParameter] = None,
) -> CursorResult[Any]:
...
) -> CursorResult[Any]: ...
def execute(
self,
@@ -1498,7 +1501,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
) -> CursorResult[Any]:
"""Execute a schema.DDL object."""
execution_options = ddl._execution_options.merge_with(
exec_opts = ddl._execution_options.merge_with(
self._execution_options, execution_options
)
@@ -1512,12 +1515,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
event_multiparams,
event_params,
) = self._invoke_before_exec_event(
ddl, distilled_parameters, execution_options
ddl, distilled_parameters, exec_opts
)
else:
event_multiparams = event_params = None
exec_opts = self._execution_options.merge_with(execution_options)
schema_translate_map = exec_opts.get("schema_translate_map", None)
dialect = self.dialect
@@ -1530,7 +1532,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
dialect.execution_ctx_cls._init_ddl,
compiled,
None,
execution_options,
exec_opts,
compiled,
)
if self._has_events or self.engine._has_events:
@@ -1539,7 +1541,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
ddl,
event_multiparams,
event_params,
execution_options,
exec_opts,
ret,
)
return ret
@@ -1737,21 +1739,20 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
conn.exec_driver_sql(
"INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)",
[{"id":1, "value":"v1"}, {"id":2, "value":"v2"}]
[{"id": 1, "value": "v1"}, {"id": 2, "value": "v2"}],
)
Single dictionary::
conn.exec_driver_sql(
"INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)",
dict(id=1, value="v1")
dict(id=1, value="v1"),
)
Single tuple::
conn.exec_driver_sql(
"INSERT INTO table (id, value) VALUES (?, ?)",
(1, 'v1')
"INSERT INTO table (id, value) VALUES (?, ?)", (1, "v1")
)
.. note:: The :meth:`_engine.Connection.exec_driver_sql` method does
@@ -1840,10 +1841,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
context.pre_exec()
if context.execute_style is ExecuteStyle.INSERTMANYVALUES:
return self._exec_insertmany_context(
dialect,
context,
)
return self._exec_insertmany_context(dialect, context)
else:
return self._exec_single_context(
dialect, context, statement, parameters
@@ -2018,16 +2016,22 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
engine_events = self._has_events or self.engine._has_events
if self.dialect._has_events:
do_execute_dispatch: Iterable[
Any
] = self.dialect.dispatch.do_execute
do_execute_dispatch: Iterable[Any] = (
self.dialect.dispatch.do_execute
)
else:
do_execute_dispatch = ()
if self._echo:
stats = context._get_cache_stats() + " (insertmanyvalues)"
preserve_rowcount = context.execution_options.get(
"preserve_rowcount", False
)
rowcount = 0
for imv_batch in dialect._deliver_insertmanyvalues_batches(
self,
cursor,
str_statement,
effective_parameters,
@@ -2048,6 +2052,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
imv_batch.replaced_parameters,
None,
context,
is_sub_exec=True,
)
sub_stmt = imv_batch.replaced_statement
@@ -2067,15 +2072,16 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
if self._echo:
self._log_info(sql_util._long_statement(sub_stmt))
imv_stats = f""" {
imv_batch.batchnum}/{imv_batch.total_batches} ({
'ordered'
if imv_batch.rows_sorted else 'unordered'
}{
'; batch not supported'
if imv_batch.is_downgraded
else ''
})"""
imv_stats = f""" {imv_batch.batchnum}/{
imv_batch.total_batches
} ({
'ordered'
if imv_batch.rows_sorted else 'unordered'
}{
'; batch not supported'
if imv_batch.is_downgraded
else ''
})"""
if imv_batch.batchnum == 1:
stats += imv_stats
@@ -2136,9 +2142,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
context.executemany,
)
if preserve_rowcount:
rowcount += imv_batch.current_batch_size
try:
context.post_exec()
if preserve_rowcount:
context._rowcount = rowcount # type: ignore[attr-defined]
result = context._setup_result_proxy()
except BaseException as e:
@@ -2380,9 +2392,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
None,
cast(Exception, e),
dialect.loaded_dbapi.Error,
hide_parameters=engine.hide_parameters
if engine is not None
else False,
hide_parameters=(
engine.hide_parameters if engine is not None else False
),
connection_invalidated=is_disconnect,
dialect=dialect,
)
@@ -2419,9 +2431,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
break
if sqlalchemy_exception and is_disconnect != ctx.is_disconnect:
sqlalchemy_exception.connection_invalidated = (
is_disconnect
) = ctx.is_disconnect
sqlalchemy_exception.connection_invalidated = ctx.is_disconnect
if newraise:
raise newraise.with_traceback(exc_info[2]) from e
@@ -2434,8 +2444,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
element: SchemaItem,
visitorcallable: Type[InvokeDDLBase],
element: SchemaVisitable,
**kwargs: Any,
) -> None:
"""run a DDL visitor.
@@ -2444,7 +2454,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
options given to the visitor so that "checkfirst" is skipped.
"""
visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
visitorcallable(
dialect=self.dialect, connection=self, **kwargs
).traverse_single(element)
class ExceptionContextImpl(ExceptionContext):
@@ -2502,6 +2514,7 @@ class Transaction(TransactionalContext):
:class:`_engine.Connection`::
from sqlalchemy import create_engine
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test")
connection = engine.connect()
trans = connection.begin()
@@ -2990,7 +3003,7 @@ class Engine(
This applies **only** to the built-in cache that is established
via the :paramref:`_engine.create_engine.query_cache_size` parameter.
It will not impact any dictionary caches that were passed via the
:paramref:`.Connection.execution_options.query_cache` parameter.
:paramref:`.Connection.execution_options.compiled_cache` parameter.
.. versionadded:: 1.4
@@ -3029,12 +3042,10 @@ class Engine(
insertmanyvalues_page_size: int = ...,
schema_translate_map: Optional[SchemaTranslateMapType] = ...,
**opt: Any,
) -> OptionEngine:
...
) -> OptionEngine: ...
@overload
def execution_options(self, **opt: Any) -> OptionEngine:
...
def execution_options(self, **opt: Any) -> OptionEngine: ...
def execution_options(self, **opt: Any) -> OptionEngine:
"""Return a new :class:`_engine.Engine` that will provide
@@ -3081,10 +3092,10 @@ class Engine(
shards = {"default": "base", "shard_1": "db1", "shard_2": "db2"}
@event.listens_for(Engine, "before_cursor_execute")
def _switch_shard(conn, cursor, stmt,
params, context, executemany):
shard_id = conn.get_execution_options().get('shard_id', "default")
def _switch_shard(conn, cursor, stmt, params, context, executemany):
shard_id = conn.get_execution_options().get("shard_id", "default")
current_shard = conn.info.get("current_shard", None)
if current_shard != shard_id:
@@ -3210,9 +3221,7 @@ class Engine(
E.g.::
with engine.begin() as conn:
conn.execute(
text("insert into table (x, y, z) values (1, 2, 3)")
)
conn.execute(text("insert into table (x, y, z) values (1, 2, 3)"))
conn.execute(text("my_special_procedure(5)"))
Upon successful operation, the :class:`.Transaction`
@@ -3228,15 +3237,15 @@ class Engine(
:meth:`_engine.Connection.begin` - start a :class:`.Transaction`
for a particular :class:`_engine.Connection`.
"""
""" # noqa: E501
with self.connect() as conn:
with conn.begin():
yield conn
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
element: SchemaItem,
visitorcallable: Type[InvokeDDLBase],
element: SchemaVisitable,
**kwargs: Any,
) -> None:
with self.begin() as conn:

View File

@@ -1,3 +1,9 @@
# engine/characteristics.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
import abc
@@ -6,6 +12,7 @@ from typing import Any
from typing import ClassVar
if typing.TYPE_CHECKING:
from .base import Connection
from .interfaces import DBAPIConnection
from .interfaces import Dialect
@@ -38,13 +45,30 @@ class ConnectionCharacteristic(abc.ABC):
def reset_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection
) -> None:
"""Reset the characteristic on the connection to its default value."""
"""Reset the characteristic on the DBAPI connection to its default
value."""
@abc.abstractmethod
def set_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any
) -> None:
"""set characteristic on the connection to a given value."""
"""set characteristic on the DBAPI connection to a given value."""
def set_connection_characteristic(
self,
dialect: Dialect,
conn: Connection,
dbapi_conn: DBAPIConnection,
value: Any,
) -> None:
"""set characteristic on the :class:`_engine.Connection` to a given
value.
.. versionadded:: 2.0.30 - added to support elements that are local
to the :class:`_engine.Connection` itself.
"""
self.set_characteristic(dialect, dbapi_conn, value)
@abc.abstractmethod
def get_characteristic(
@@ -55,8 +79,22 @@ class ConnectionCharacteristic(abc.ABC):
"""
def get_connection_characteristic(
self, dialect: Dialect, conn: Connection, dbapi_conn: DBAPIConnection
) -> Any:
"""Given a :class:`_engine.Connection`, get the current value of the
characteristic.
.. versionadded:: 2.0.30 - added to support elements that are local
to the :class:`_engine.Connection` itself.
"""
return self.get_characteristic(dialect, dbapi_conn)
class IsolationLevelCharacteristic(ConnectionCharacteristic):
"""Manage the isolation level on a DBAPI connection"""
transactional: ClassVar[bool] = True
def reset_characteristic(
@@ -73,3 +111,45 @@ class IsolationLevelCharacteristic(ConnectionCharacteristic):
self, dialect: Dialect, dbapi_conn: DBAPIConnection
) -> Any:
return dialect.get_isolation_level(dbapi_conn)
class LoggingTokenCharacteristic(ConnectionCharacteristic):
"""Manage the 'logging_token' option of a :class:`_engine.Connection`.
.. versionadded:: 2.0.30
"""
transactional: ClassVar[bool] = False
def reset_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection
) -> None:
pass
def set_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any
) -> None:
raise NotImplementedError()
def set_connection_characteristic(
self,
dialect: Dialect,
conn: Connection,
dbapi_conn: DBAPIConnection,
value: Any,
) -> None:
if value:
conn._message_formatter = lambda msg: "[%s] %s" % (value, msg)
else:
del conn._message_formatter
def get_characteristic(
self, dialect: Dialect, dbapi_conn: DBAPIConnection
) -> Any:
raise NotImplementedError()
def get_connection_characteristic(
self, dialect: Dialect, conn: Connection, dbapi_conn: DBAPIConnection
) -> Any:
return conn._execution_options.get("logging_token", None)

View File

@@ -1,5 +1,5 @@
# engine/create.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -82,13 +82,11 @@ def create_engine(
query_cache_size: int = ...,
use_insertmanyvalues: bool = ...,
**kwargs: Any,
) -> Engine:
...
) -> Engine: ...
@overload
def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine:
...
def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: ...
@util.deprecated_params(
@@ -135,8 +133,11 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine:
and its underlying :class:`.Dialect` and :class:`_pool.Pool`
constructs::
engine = create_engine("mysql+mysqldb://scott:tiger@hostname/dbname",
pool_recycle=3600, echo=True)
engine = create_engine(
"mysql+mysqldb://scott:tiger@hostname/dbname",
pool_recycle=3600,
echo=True,
)
The string form of the URL is
``dialect[+driver]://user:password@host/dbname[?key=value..]``, where
@@ -467,6 +468,9 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine:
:ref:`pool_reset_on_return`
:ref:`dbapi_autocommit_skip_rollback` - a more modern approach
to using connections with no transactional instructions
:param pool_timeout=30: number of seconds to wait before giving
up on getting a connection from the pool. This is only used
with :class:`~sqlalchemy.pool.QueuePool`. This can be a float but is
@@ -523,6 +527,18 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine:
.. versionadded:: 1.4
:param skip_autocommit_rollback: When True, the dialect will
unconditionally skip all calls to the DBAPI ``connection.rollback()``
method if the DBAPI connection is confirmed to be in "autocommit" mode.
The availability of this feature is dialect specific; if not available,
a ``NotImplementedError`` is raised by the dialect when rollback occurs.
.. seealso::
:ref:`dbapi_autocommit_skip_rollback`
.. versionadded:: 2.0.43
:param use_insertmanyvalues: True by default, use the "insertmanyvalues"
execution style for INSERT..RETURNING statements by default.
@@ -616,6 +632,14 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine:
# assemble connection arguments
(cargs_tup, cparams) = dialect.create_connect_args(u)
cparams.update(pop_kwarg("connect_args", {}))
if "async_fallback" in cparams and util.asbool(cparams["async_fallback"]):
util.warn_deprecated(
"The async_fallback dialect argument is deprecated and will be "
"removed in SQLAlchemy 2.1.",
"2.0",
)
cargs = list(cargs_tup) # allow mutability
# look for existing pool or create
@@ -657,6 +681,17 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine:
else:
pool._dialect = dialect
if (
hasattr(pool, "_is_asyncio")
and pool._is_asyncio is not dialect.is_async
):
raise exc.ArgumentError(
f"Pool class {pool.__class__.__name__} cannot be "
f"used with {'non-' if not dialect.is_async else ''}"
"asyncio engine",
code="pcls",
)
# create engine.
if not pop_kwarg("future", True):
raise exc.ArgumentError(
@@ -816,13 +851,11 @@ def create_pool_from_url(
timeout: float = ...,
use_lifo: bool = ...,
**kwargs: Any,
) -> Pool:
...
) -> Pool: ...
@overload
def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool:
...
def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: ...
def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool:

View File

@@ -1,5 +1,5 @@
# engine/cursor.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -20,6 +20,7 @@ from typing import Any
from typing import cast
from typing import ClassVar
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
@@ -120,7 +121,7 @@ _CursorKeyMapRecType = Tuple[
List[Any], # MD_OBJECTS
str, # MD_LOOKUP_KEY
str, # MD_RENDERED_NAME
Optional["_ResultProcessorType"], # MD_PROCESSOR
Optional["_ResultProcessorType[Any]"], # MD_PROCESSOR
Optional[str], # MD_UNTRANSLATED
]
@@ -134,7 +135,7 @@ _NonAmbigCursorKeyMapRecType = Tuple[
List[Any],
str,
str,
Optional["_ResultProcessorType"],
Optional["_ResultProcessorType[Any]"],
str,
]
@@ -151,7 +152,7 @@ class CursorResultMetaData(ResultMetaData):
"_translated_indexes",
"_safe_for_cache",
"_unpickled",
"_key_to_index"
"_key_to_index",
# don't need _unique_filters support here for now. Can be added
# if a need arises.
)
@@ -225,9 +226,11 @@ class CursorResultMetaData(ResultMetaData):
{
key: (
# int index should be None for ambiguous key
value[0] + offset
if value[0] is not None and key not in keymap
else None,
(
value[0] + offset
if value[0] is not None and key not in keymap
else None
),
value[1] + offset,
*value[2:],
)
@@ -362,13 +365,11 @@ class CursorResultMetaData(ResultMetaData):
) = context.result_column_struct
num_ctx_cols = len(result_columns)
else:
result_columns = ( # type: ignore
cols_are_ordered
) = (
result_columns = cols_are_ordered = ( # type: ignore
num_ctx_cols
) = (
ad_hoc_textual
) = loose_column_name_matching = textual_ordered = False
) = ad_hoc_textual = loose_column_name_matching = (
textual_ordered
) = False
# merge cursor.description with the column info
# present in the compiled structure, if any
@@ -688,6 +689,7 @@ class CursorResultMetaData(ResultMetaData):
% (num_ctx_cols, len(cursor_description))
)
seen = set()
for (
idx,
colname,
@@ -1161,7 +1163,7 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy):
result = conn.execution_options(
stream_results=True, max_row_buffer=50
).execute(text("select * from table"))
).execute(text("select * from table"))
.. versionadded:: 1.4 ``max_row_buffer`` may now exceed 1000 rows.
@@ -1246,8 +1248,9 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy):
if size is None:
return self.fetchall(result, dbapi_cursor)
buf = list(self._rowbuffer)
lb = len(buf)
rb = self._rowbuffer
lb = len(rb)
close = False
if size > lb:
try:
new = dbapi_cursor.fetchmany(size - lb)
@@ -1255,13 +1258,15 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy):
self.handle_exception(result, dbapi_cursor, e)
else:
if not new:
result._soft_close()
# defer closing since it may clear the row buffer
close = True
else:
buf.extend(new)
rb.extend(new)
result = buf[0:size]
self._rowbuffer = collections.deque(buf[size:])
return result
res = [rb.popleft() for _ in range(min(size, len(rb)))]
if close:
result._soft_close()
return res
def fetchall(self, result, dbapi_cursor):
try:
@@ -1285,12 +1290,16 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy):
__slots__ = ("_rowbuffer", "alternate_cursor_description")
def __init__(
self, dbapi_cursor, alternate_description=None, initial_buffer=None
self,
dbapi_cursor: Optional[DBAPICursor],
alternate_description: Optional[_DBAPICursorDescription] = None,
initial_buffer: Optional[Iterable[Any]] = None,
):
self.alternate_cursor_description = alternate_description
if initial_buffer is not None:
self._rowbuffer = collections.deque(initial_buffer)
else:
assert dbapi_cursor is not None
self._rowbuffer = collections.deque(dbapi_cursor.fetchall())
def yield_per(self, result, dbapi_cursor, num):
@@ -1315,9 +1324,8 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy):
if size is None:
return self.fetchall(result, dbapi_cursor)
buf = list(self._rowbuffer)
rows = buf[0:size]
self._rowbuffer = collections.deque(buf[size:])
rb = self._rowbuffer
rows = [rb.popleft() for _ in range(min(size, len(rb)))]
if not rows:
result._soft_close()
return rows
@@ -1350,15 +1358,15 @@ class _NoResultMetaData(ResultMetaData):
self._we_dont_return_rows()
@property
def _keymap(self):
def _keymap(self): # type: ignore[override]
self._we_dont_return_rows()
@property
def _key_to_index(self):
def _key_to_index(self): # type: ignore[override]
self._we_dont_return_rows()
@property
def _processors(self):
def _processors(self): # type: ignore[override]
self._we_dont_return_rows()
@property
@@ -1438,6 +1446,7 @@ class CursorResult(Result[_T]):
metadata = self._init_metadata(context, cursor_description)
_make_row: Any
_make_row = functools.partial(
Row,
metadata,
@@ -1610,11 +1619,11 @@ class CursorResult(Result[_T]):
"""
if not self.context.compiled:
raise exc.InvalidRequestError(
"Statement is not a compiled " "expression construct."
"Statement is not a compiled expression construct."
)
elif not self.context.isinsert:
raise exc.InvalidRequestError(
"Statement is not an insert() " "expression construct."
"Statement is not an insert() expression construct."
)
elif self.context._is_explicit_returning:
raise exc.InvalidRequestError(
@@ -1681,11 +1690,11 @@ class CursorResult(Result[_T]):
"""
if not self.context.compiled:
raise exc.InvalidRequestError(
"Statement is not a compiled " "expression construct."
"Statement is not a compiled expression construct."
)
elif not self.context.isupdate:
raise exc.InvalidRequestError(
"Statement is not an update() " "expression construct."
"Statement is not an update() expression construct."
)
elif self.context.executemany:
return self.context.compiled_parameters
@@ -1703,11 +1712,11 @@ class CursorResult(Result[_T]):
"""
if not self.context.compiled:
raise exc.InvalidRequestError(
"Statement is not a compiled " "expression construct."
"Statement is not a compiled expression construct."
)
elif not self.context.isinsert:
raise exc.InvalidRequestError(
"Statement is not an insert() " "expression construct."
"Statement is not an insert() expression construct."
)
elif self.context.executemany:
return self.context.compiled_parameters
@@ -1752,11 +1761,9 @@ class CursorResult(Result[_T]):
r1 = connection.execute(
users.insert().returning(
users.c.user_name,
users.c.user_id,
sort_by_parameter_order=True
users.c.user_name, users.c.user_id, sort_by_parameter_order=True
),
user_values
user_values,
)
r2 = connection.execute(
@@ -1764,19 +1771,16 @@ class CursorResult(Result[_T]):
addresses.c.address_id,
addresses.c.address,
addresses.c.user_id,
sort_by_parameter_order=True
sort_by_parameter_order=True,
),
address_values
address_values,
)
rows = r1.splice_horizontally(r2).all()
assert (
rows ==
[
("john", 1, 1, "foo@bar.com", 1),
("jack", 2, 2, "bar@bat.com", 2),
]
)
assert rows == [
("john", 1, 1, "foo@bar.com", 1),
("jack", 2, 2, "bar@bat.com", 2),
]
.. versionadded:: 2.0
@@ -1785,7 +1789,7 @@ class CursorResult(Result[_T]):
:meth:`.CursorResult.splice_vertically`
"""
""" # noqa: E501
clone = self._generate()
total_rows = [
@@ -1920,7 +1924,7 @@ class CursorResult(Result[_T]):
if not self.context.compiled:
raise exc.InvalidRequestError(
"Statement is not a compiled " "expression construct."
"Statement is not a compiled expression construct."
)
elif not self.context.isinsert and not self.context.isupdate:
raise exc.InvalidRequestError(
@@ -1943,7 +1947,7 @@ class CursorResult(Result[_T]):
if not self.context.compiled:
raise exc.InvalidRequestError(
"Statement is not a compiled " "expression construct."
"Statement is not a compiled expression construct."
)
elif not self.context.isinsert and not self.context.isupdate:
raise exc.InvalidRequestError(
@@ -1974,8 +1978,28 @@ class CursorResult(Result[_T]):
def rowcount(self) -> int:
"""Return the 'rowcount' for this result.
The 'rowcount' reports the number of rows *matched*
by the WHERE criterion of an UPDATE or DELETE statement.
The primary purpose of 'rowcount' is to report the number of rows
matched by the WHERE criterion of an UPDATE or DELETE statement
executed once (i.e. for a single parameter set), which may then be
compared to the number of rows expected to be updated or deleted as a
means of asserting data integrity.
This attribute is transferred from the ``cursor.rowcount`` attribute
of the DBAPI before the cursor is closed, to support DBAPIs that
don't make this value available after cursor close. Some DBAPIs may
offer meaningful values for other kinds of statements, such as INSERT
and SELECT statements as well. In order to retrieve ``cursor.rowcount``
for these statements, set the
:paramref:`.Connection.execution_options.preserve_rowcount`
execution option to True, which will cause the ``cursor.rowcount``
value to be unconditionally memoized before any results are returned
or the cursor is closed, regardless of statement type.
For cases where the DBAPI does not support rowcount for a particular
kind of statement and/or execution, the returned value will be ``-1``,
which is delivered directly from the DBAPI and is part of :pep:`249`.
All DBAPIs should support rowcount for single-parameter-set
UPDATE and DELETE statements, however.
.. note::
@@ -1984,38 +2008,47 @@ class CursorResult(Result[_T]):
* This attribute returns the number of rows *matched*,
which is not necessarily the same as the number of rows
that were actually *modified* - an UPDATE statement, for example,
that were actually *modified*. For example, an UPDATE statement
may have no net change on a given row if the SET values
given are the same as those present in the row already.
Such a row would be matched but not modified.
On backends that feature both styles, such as MySQL,
rowcount is configured by default to return the match
rowcount is configured to return the match
count in all cases.
* :attr:`_engine.CursorResult.rowcount`
is *only* useful in conjunction
with an UPDATE or DELETE statement. Contrary to what the Python
DBAPI says, it does *not* reliably return the
number of rows available from the results of a SELECT statement
as DBAPIs cannot support this functionality when rows are
unbuffered.
* :attr:`_engine.CursorResult.rowcount` in the default case is
*only* useful in conjunction with an UPDATE or DELETE statement,
and only with a single set of parameters. For other kinds of
statements, SQLAlchemy will not attempt to pre-memoize the value
unless the
:paramref:`.Connection.execution_options.preserve_rowcount`
execution option is used. Note that contrary to :pep:`249`, many
DBAPIs do not support rowcount values for statements that are not
UPDATE or DELETE, particularly when rows are being returned which
are not fully pre-buffered. DBAPIs that dont support rowcount
for a particular kind of statement should return the value ``-1``
for such statements.
* :attr:`_engine.CursorResult.rowcount`
may not be fully implemented by
all dialects. In particular, most DBAPIs do not support an
aggregate rowcount result from an executemany call.
The :meth:`_engine.CursorResult.supports_sane_rowcount` and
:meth:`_engine.CursorResult.supports_sane_multi_rowcount` methods
will report from the dialect if each usage is known to be
supported.
* :attr:`_engine.CursorResult.rowcount` may not be meaningful
when executing a single statement with multiple parameter sets
(i.e. an :term:`executemany`). Most DBAPIs do not sum "rowcount"
values across multiple parameter sets and will return ``-1``
when accessed.
* Statements that use RETURNING may not return a correct
rowcount.
* SQLAlchemy's :ref:`engine_insertmanyvalues` feature does support
a correct population of :attr:`_engine.CursorResult.rowcount`
when the :paramref:`.Connection.execution_options.preserve_rowcount`
execution option is set to True.
* Statements that use RETURNING may not support rowcount, returning
a ``-1`` value instead.
.. seealso::
:ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial`
:paramref:`.Connection.execution_options.preserve_rowcount`
""" # noqa: E501
try:
return self.context.rowcount
@@ -2109,8 +2142,7 @@ class CursorResult(Result[_T]):
def merge(self, *others: Result[Any]) -> MergedResult[Any]:
merged_result = super().merge(*others)
setup_rowcounts = self.context._has_rowcount
if setup_rowcounts:
if self.context._has_rowcount:
merged_result.rowcount = sum(
cast("CursorResult[Any]", result).rowcount
for result in (self,) + others

View File

@@ -1,5 +1,5 @@
# engine/default.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -58,6 +58,7 @@ from ..sql import compiler
from ..sql import dml
from ..sql import expression
from ..sql import type_api
from ..sql import util as sql_util
from ..sql._typing import is_tuple_type
from ..sql.base import _NoArg
from ..sql.compiler import DDLCompiler
@@ -76,10 +77,13 @@ if typing.TYPE_CHECKING:
from .interfaces import _CoreSingleExecuteParams
from .interfaces import _DBAPICursorDescription
from .interfaces import _DBAPIMultiExecuteParams
from .interfaces import _DBAPISingleExecuteParams
from .interfaces import _ExecuteOptions
from .interfaces import _MutableCoreSingleExecuteParams
from .interfaces import _ParamStyle
from .interfaces import ConnectArgsType
from .interfaces import DBAPIConnection
from .interfaces import DBAPIModule
from .interfaces import IsolationLevel
from .row import Row
from .url import URL
@@ -95,8 +99,10 @@ if typing.TYPE_CHECKING:
from ..sql.elements import BindParameter
from ..sql.schema import Column
from ..sql.type_api import _BindProcessorType
from ..sql.type_api import _ResultProcessorType
from ..sql.type_api import TypeEngine
# When we're handed literal SQL, ensure it's a SELECT query
SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
@@ -167,7 +173,10 @@ class DefaultDialect(Dialect):
tuple_in_values = False
connection_characteristics = util.immutabledict(
{"isolation_level": characteristics.IsolationLevelCharacteristic()}
{
"isolation_level": characteristics.IsolationLevelCharacteristic(),
"logging_token": characteristics.LoggingTokenCharacteristic(),
}
)
engine_config_types: Mapping[str, Any] = util.immutabledict(
@@ -249,7 +258,7 @@ class DefaultDialect(Dialect):
default_schema_name: Optional[str] = None
# indicates symbol names are
# UPPERCASEd if they are case insensitive
# UPPERCASED if they are case insensitive
# within the database.
# if this is True, the methods normalize_name()
# and denormalize_name() must be provided.
@@ -298,6 +307,7 @@ class DefaultDialect(Dialect):
# Linting.NO_LINTING constant
compiler_linting: Linting = int(compiler.NO_LINTING), # type: ignore
server_side_cursors: bool = False,
skip_autocommit_rollback: bool = False,
**kwargs: Any,
):
if server_side_cursors:
@@ -322,6 +332,8 @@ class DefaultDialect(Dialect):
self.dbapi = dbapi
self.skip_autocommit_rollback = skip_autocommit_rollback
if paramstyle is not None:
self.paramstyle = paramstyle
elif self.dbapi is not None:
@@ -387,7 +399,8 @@ class DefaultDialect(Dialect):
available if the dialect in use has opted into using the
"use_insertmanyvalues" feature. If they haven't opted into that, then
this attribute is False, unless the dialect in question overrides this
and provides some other implementation (such as the Oracle dialect).
and provides some other implementation (such as the Oracle Database
dialects).
"""
return self.insert_returning and self.use_insertmanyvalues
@@ -410,7 +423,7 @@ class DefaultDialect(Dialect):
If the dialect in use hasn't opted into that, then this attribute is
False, unless the dialect in question overrides this and provides some
other implementation (such as the Oracle dialect).
other implementation (such as the Oracle Database dialects).
"""
return self.insert_returning and self.use_insertmanyvalues
@@ -419,7 +432,7 @@ class DefaultDialect(Dialect):
delete_executemany_returning = False
@util.memoized_property
def loaded_dbapi(self) -> ModuleType:
def loaded_dbapi(self) -> DBAPIModule:
if self.dbapi is None:
raise exc.InvalidRequestError(
f"Dialect {self} does not have a Python DBAPI established "
@@ -431,7 +444,7 @@ class DefaultDialect(Dialect):
def _bind_typing_render_casts(self):
return self.bind_typing is interfaces.BindTyping.RENDER_CASTS
def _ensure_has_table_connection(self, arg):
def _ensure_has_table_connection(self, arg: Connection) -> None:
if not isinstance(arg, Connection):
raise exc.ArgumentError(
"The argument passed to Dialect.has_table() should be a "
@@ -468,7 +481,7 @@ class DefaultDialect(Dialect):
return weakref.WeakKeyDictionary()
@property
def dialect_description(self):
def dialect_description(self): # type: ignore[override]
return self.name + "+" + self.driver
@property
@@ -509,7 +522,7 @@ class DefaultDialect(Dialect):
else:
return None
def initialize(self, connection):
def initialize(self, connection: Connection) -> None:
try:
self.server_version_info = self._get_server_version_info(
connection
@@ -545,7 +558,7 @@ class DefaultDialect(Dialect):
% (self.label_length, self.max_identifier_length)
)
def on_connect(self):
def on_connect(self) -> Optional[Callable[[Any], None]]:
# inherits the docstring from interfaces.Dialect.on_connect
return None
@@ -604,18 +617,18 @@ class DefaultDialect(Dialect):
) -> bool:
return schema_name in self.get_schema_names(connection, **kw)
def validate_identifier(self, ident):
def validate_identifier(self, ident: str) -> None:
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
"Identifier '%s' exceeds maximum length of %d characters"
% (ident, self.max_identifier_length)
)
def connect(self, *cargs, **cparams):
def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection:
# inherits the docstring from interfaces.Dialect.connect
return self.loaded_dbapi.connect(*cargs, **cparams)
return self.loaded_dbapi.connect(*cargs, **cparams) # type: ignore[no-any-return] # NOQA: E501
def create_connect_args(self, url):
def create_connect_args(self, url: URL) -> ConnectArgsType:
# inherits the docstring from interfaces.Dialect.create_connect_args
opts = url.translate_connect_args()
opts.update(url.query)
@@ -659,7 +672,7 @@ class DefaultDialect(Dialect):
if connection.in_transaction():
trans_objs = [
(name, obj)
for name, obj, value in characteristic_values
for name, obj, _ in characteristic_values
if obj.transactional
]
if trans_objs:
@@ -672,8 +685,10 @@ class DefaultDialect(Dialect):
)
dbapi_connection = connection.connection.dbapi_connection
for name, characteristic, value in characteristic_values:
characteristic.set_characteristic(self, dbapi_connection, value)
for _, characteristic, value in characteristic_values:
characteristic.set_connection_characteristic(
self, connection, dbapi_connection, value
)
connection.connection._connection_record.finalize_callback.append(
functools.partial(self._reset_characteristics, characteristics)
)
@@ -689,6 +704,10 @@ class DefaultDialect(Dialect):
pass
def do_rollback(self, dbapi_connection):
if self.skip_autocommit_rollback and self.detect_autocommit_setting(
dbapi_connection
):
return
dbapi_connection.rollback()
def do_commit(self, dbapi_connection):
@@ -728,8 +747,6 @@ class DefaultDialect(Dialect):
raise
def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
cursor = None
cursor = dbapi_connection.cursor()
try:
cursor.execute(self._dialect_specific_select_one)
@@ -756,11 +773,25 @@ class DefaultDialect(Dialect):
connection.execute(expression.ReleaseSavepointClause(name))
def _deliver_insertmanyvalues_batches(
self, cursor, statement, parameters, generic_setinputsizes, context
self,
connection,
cursor,
statement,
parameters,
generic_setinputsizes,
context,
):
context = cast(DefaultExecutionContext, context)
compiled = cast(SQLCompiler, context.compiled)
_composite_sentinel_proc: Sequence[
Optional[_ResultProcessorType[Any]]
] = ()
_scalar_sentinel_proc: Optional[_ResultProcessorType[Any]] = None
_sentinel_proc_initialized: bool = False
compiled_parameters = context.compiled_parameters
imv = compiled._insertmanyvalues
assert imv is not None
@@ -769,7 +800,12 @@ class DefaultDialect(Dialect):
"insertmanyvalues_page_size", self.insertmanyvalues_page_size
)
sentinel_value_resolvers = None
if compiled.schema_translate_map:
schema_translate_map = context.execution_options.get(
"schema_translate_map", {}
)
else:
schema_translate_map = None
if is_returning:
result: Optional[List[Any]] = []
@@ -777,10 +813,6 @@ class DefaultDialect(Dialect):
sort_by_parameter_order = imv.sort_by_parameter_order
if imv.num_sentinel_columns:
sentinel_value_resolvers = (
compiled._imv_sentinel_value_resolvers
)
else:
sort_by_parameter_order = False
result = None
@@ -788,14 +820,27 @@ class DefaultDialect(Dialect):
for imv_batch in compiled._deliver_insertmanyvalues_batches(
statement,
parameters,
compiled_parameters,
generic_setinputsizes,
batch_size,
sort_by_parameter_order,
schema_translate_map,
):
yield imv_batch
if is_returning:
rows = context.fetchall_for_returning(cursor)
try:
rows = context.fetchall_for_returning(cursor)
except BaseException as be:
connection._handle_dbapi_exception(
be,
sql_util._long_statement(imv_batch.replaced_statement),
imv_batch.replaced_parameters,
None,
context,
is_sub_exec=True,
)
# I would have thought "is_returning: Final[bool]"
# would have assured this but pylance thinks not
@@ -815,11 +860,46 @@ class DefaultDialect(Dialect):
# otherwise, create dictionaries to match up batches
# with parameters
assert imv.sentinel_param_keys
assert imv.sentinel_columns
_nsc = imv.num_sentinel_columns
if not _sentinel_proc_initialized:
if composite_sentinel:
_composite_sentinel_proc = [
col.type._cached_result_processor(
self, cursor_desc[1]
)
for col, cursor_desc in zip(
imv.sentinel_columns,
cursor.description[-_nsc:],
)
]
else:
_scalar_sentinel_proc = (
imv.sentinel_columns[0]
).type._cached_result_processor(
self, cursor.description[-1][1]
)
_sentinel_proc_initialized = True
rows_by_sentinel: Union[
Dict[Tuple[Any, ...], Any],
Dict[Any, Any],
]
if composite_sentinel:
_nsc = imv.num_sentinel_columns
rows_by_sentinel = {
tuple(row[-_nsc:]): row for row in rows
tuple(
(proc(val) if proc else val)
for val, proc in zip(
row[-_nsc:], _composite_sentinel_proc
)
): row
for row in rows
}
elif _scalar_sentinel_proc:
rows_by_sentinel = {
_scalar_sentinel_proc(row[-1]): row for row in rows
}
else:
rows_by_sentinel = {row[-1]: row for row in rows}
@@ -838,61 +918,10 @@ class DefaultDialect(Dialect):
)
try:
if composite_sentinel:
if sentinel_value_resolvers:
# composite sentinel (PK) with value resolvers
ordered_rows = [
rows_by_sentinel[
tuple(
_resolver(parameters[_spk]) # type: ignore # noqa: E501
if _resolver
else parameters[_spk] # type: ignore # noqa: E501
for _resolver, _spk in zip(
sentinel_value_resolvers,
imv.sentinel_param_keys,
)
)
]
for parameters in imv_batch.batch
]
else:
# composite sentinel (PK) with no value
# resolvers
ordered_rows = [
rows_by_sentinel[
tuple(
parameters[_spk] # type: ignore
for _spk in imv.sentinel_param_keys
)
]
for parameters in imv_batch.batch
]
else:
_sentinel_param_key = imv.sentinel_param_keys[0]
if (
sentinel_value_resolvers
and sentinel_value_resolvers[0]
):
# single-column sentinel with value resolver
_sentinel_value_resolver = (
sentinel_value_resolvers[0]
)
ordered_rows = [
rows_by_sentinel[
_sentinel_value_resolver(
parameters[_sentinel_param_key] # type: ignore # noqa: E501
)
]
for parameters in imv_batch.batch
]
else:
# single-column sentinel with no value resolver
ordered_rows = [
rows_by_sentinel[
parameters[_sentinel_param_key] # type: ignore # noqa: E501
]
for parameters in imv_batch.batch
]
ordered_rows = [
rows_by_sentinel[sentinel_keys]
for sentinel_keys in imv_batch.sentinel_values
]
except KeyError as ke:
# see test_insert_exec.py::
# IMVSentinelTest::test_sentinel_cant_match_keys
@@ -924,7 +953,14 @@ class DefaultDialect(Dialect):
def do_execute_no_params(self, cursor, statement, context=None):
cursor.execute(statement)
def is_disconnect(self, e, connection, cursor):
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Union[
pool.PoolProxiedConnection, interfaces.DBAPIConnection, None
],
cursor: Optional[interfaces.DBAPICursor],
) -> bool:
return False
@util.memoized_instancemethod
@@ -1024,7 +1060,7 @@ class DefaultDialect(Dialect):
name = name_upper
return name
def get_driver_connection(self, connection):
def get_driver_connection(self, connection: DBAPIConnection) -> Any:
return connection
def _overrides_default(self, method):
@@ -1196,7 +1232,7 @@ class DefaultExecutionContext(ExecutionContext):
_soft_closed = False
_has_rowcount = False
_rowcount: Optional[int] = None
# a hook for SQLite's translation of
# result column names
@@ -1453,9 +1489,11 @@ class DefaultExecutionContext(ExecutionContext):
assert positiontup is not None
for compiled_params in self.compiled_parameters:
l_param: List[Any] = [
flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
(
flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
)
for key in positiontup
]
core_positional_parameters.append(
@@ -1476,18 +1514,20 @@ class DefaultExecutionContext(ExecutionContext):
for compiled_params in self.compiled_parameters:
if escaped_names:
d_param = {
escaped_names.get(key, key): flattened_processors[key](
compiled_params[key]
escaped_names.get(key, key): (
flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
)
if key in flattened_processors
else compiled_params[key]
for key in compiled_params
}
else:
d_param = {
key: flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
key: (
flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
)
for key in compiled_params
}
@@ -1577,7 +1617,13 @@ class DefaultExecutionContext(ExecutionContext):
elif ch is CACHE_MISS:
return "generated in %.5fs" % (now - gen_time,)
elif ch is CACHING_DISABLED:
return "caching disabled %.5fs" % (now - gen_time,)
if "_cache_disable_reason" in self.execution_options:
return "caching disabled (%s) %.5fs " % (
self.execution_options["_cache_disable_reason"],
now - gen_time,
)
else:
return "caching disabled %.5fs" % (now - gen_time,)
elif ch is NO_DIALECT_SUPPORT:
return "dialect %s+%s does not support caching %.5fs" % (
self.dialect.name,
@@ -1588,7 +1634,7 @@ class DefaultExecutionContext(ExecutionContext):
return "unknown"
@property
def executemany(self):
def executemany(self): # type: ignore[override]
return self.execute_style in (
ExecuteStyle.EXECUTEMANY,
ExecuteStyle.INSERTMANYVALUES,
@@ -1630,7 +1676,12 @@ class DefaultExecutionContext(ExecutionContext):
def no_parameters(self):
return self.execution_options.get("no_parameters", False)
def _execute_scalar(self, stmt, type_, parameters=None):
def _execute_scalar(
self,
stmt: str,
type_: Optional[TypeEngine[Any]],
parameters: Optional[_DBAPISingleExecuteParams] = None,
) -> Any:
"""Execute a string statement on the current cursor, returning a
scalar result.
@@ -1704,7 +1755,7 @@ class DefaultExecutionContext(ExecutionContext):
return use_server_side
def create_cursor(self):
def create_cursor(self) -> DBAPICursor:
if (
# inlining initial preference checks for SS cursors
self.dialect.supports_server_side_cursors
@@ -1725,10 +1776,10 @@ class DefaultExecutionContext(ExecutionContext):
def fetchall_for_returning(self, cursor):
return cursor.fetchall()
def create_default_cursor(self):
def create_default_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor()
def create_server_side_cursor(self):
def create_server_side_cursor(self) -> DBAPICursor:
raise NotImplementedError()
def pre_exec(self):
@@ -1776,7 +1827,14 @@ class DefaultExecutionContext(ExecutionContext):
@util.non_memoized_property
def rowcount(self) -> int:
return self.cursor.rowcount
if self._rowcount is not None:
return self._rowcount
else:
return self.cursor.rowcount
@property
def _has_rowcount(self):
return self._rowcount is not None
def supports_sane_rowcount(self):
return self.dialect.supports_sane_rowcount
@@ -1787,9 +1845,13 @@ class DefaultExecutionContext(ExecutionContext):
def _setup_result_proxy(self):
exec_opt = self.execution_options
if self._rowcount is None and exec_opt.get("preserve_rowcount", False):
self._rowcount = self.cursor.rowcount
yp: Optional[Union[int, bool]]
if self.is_crud or self.is_text:
result = self._setup_dml_or_text_result()
yp = sr = False
yp = False
else:
yp = exec_opt.get("yield_per", None)
sr = self._is_server_side or exec_opt.get("stream_results", False)
@@ -1943,8 +2005,7 @@ class DefaultExecutionContext(ExecutionContext):
if rows:
self.returned_default_rows = rows
result.rowcount = len(rows)
self._has_rowcount = True
self._rowcount = len(rows)
if self._is_supplemental_returning:
result._rewind(rows)
@@ -1958,12 +2019,12 @@ class DefaultExecutionContext(ExecutionContext):
elif not result._metadata.returns_rows:
# no results, get rowcount
# (which requires open cursor on some drivers)
result.rowcount
self._has_rowcount = True
if self._rowcount is None:
self._rowcount = self.cursor.rowcount
result._soft_close()
elif self.isupdate or self.isdelete:
result.rowcount
self._has_rowcount = True
if self._rowcount is None:
self._rowcount = self.cursor.rowcount
return result
@util.memoized_property
@@ -2012,10 +2073,11 @@ class DefaultExecutionContext(ExecutionContext):
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
This method only called by those dialects which set
the :attr:`.Dialect.bind_typing` attribute to
:attr:`.BindTyping.SETINPUTSIZES`. cx_Oracle is the only DBAPI
that requires setinputsizes(), pyodbc offers it as an option.
This method only called by those dialects which set the
:attr:`.Dialect.bind_typing` attribute to
:attr:`.BindTyping.SETINPUTSIZES`. Python-oracledb and cx_Oracle are
the only DBAPIs that requires setinputsizes(); pyodbc offers it as an
option.
Prior to SQLAlchemy 2.0, the setinputsizes() approach was also used
for pg8000 and asyncpg, which has been changed to inline rendering
@@ -2143,17 +2205,21 @@ class DefaultExecutionContext(ExecutionContext):
if compiled.positional:
parameters = self.dialect.execute_sequence_format(
[
processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key]
(
processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key]
)
for key in compiled.positiontup or ()
]
)
else:
parameters = {
key: processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key]
key: (
processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key]
)
for key in compiled_params
}
return self._execute_scalar(

View File

@@ -1,5 +1,5 @@
# sqlalchemy/engine/events.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# engine/events.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -54,19 +54,24 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]):
from sqlalchemy import event, create_engine
def before_cursor_execute(conn, cursor, statement, parameters, context,
executemany):
def before_cursor_execute(
conn, cursor, statement, parameters, context, executemany
):
log.info("Received statement: %s", statement)
engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/test')
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test")
event.listen(engine, "before_cursor_execute", before_cursor_execute)
or with a specific :class:`_engine.Connection`::
with engine.begin() as conn:
@event.listens_for(conn, 'before_cursor_execute')
def before_cursor_execute(conn, cursor, statement, parameters,
context, executemany):
@event.listens_for(conn, "before_cursor_execute")
def before_cursor_execute(
conn, cursor, statement, parameters, context, executemany
):
log.info("Received statement: %s", statement)
When the methods are called with a `statement` parameter, such as in
@@ -84,9 +89,11 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]):
from sqlalchemy.engine import Engine
from sqlalchemy import event
@event.listens_for(Engine, "before_cursor_execute", retval=True)
def comment_sql_calls(conn, cursor, statement, parameters,
context, executemany):
def comment_sql_calls(
conn, cursor, statement, parameters, context, executemany
):
statement = statement + " -- some comment"
return statement, parameters
@@ -316,8 +323,9 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]):
returned as a two-tuple in this case::
@event.listens_for(Engine, "before_cursor_execute", retval=True)
def before_cursor_execute(conn, cursor, statement,
parameters, context, executemany):
def before_cursor_execute(
conn, cursor, statement, parameters, context, executemany
):
# do something with statement, parameters
return statement, parameters
@@ -766,9 +774,9 @@ class DialectEvents(event.Events[Dialect]):
@event.listens_for(Engine, "handle_error")
def handle_exception(context):
if isinstance(context.original_exception,
psycopg2.OperationalError) and \
"failed" in str(context.original_exception):
if isinstance(
context.original_exception, psycopg2.OperationalError
) and "failed" in str(context.original_exception):
raise MySpecialException("failed operation")
.. warning:: Because the
@@ -791,10 +799,13 @@ class DialectEvents(event.Events[Dialect]):
@event.listens_for(Engine, "handle_error", retval=True)
def handle_exception(context):
if context.chained_exception is not None and \
"special" in context.chained_exception.message:
return MySpecialException("failed",
cause=context.chained_exception)
if (
context.chained_exception is not None
and "special" in context.chained_exception.message
):
return MySpecialException(
"failed", cause=context.chained_exception
)
Handlers that return ``None`` may be used within the chain; when
a handler returns ``None``, the previous exception instance,
@@ -836,7 +847,8 @@ class DialectEvents(event.Events[Dialect]):
e = create_engine("postgresql+psycopg2://user@host/dbname")
@event.listens_for(e, 'do_connect')
@event.listens_for(e, "do_connect")
def receive_do_connect(dialect, conn_rec, cargs, cparams):
cparams["password"] = "some_password"
@@ -845,7 +857,8 @@ class DialectEvents(event.Events[Dialect]):
e = create_engine("postgresql+psycopg2://user@host/dbname")
@event.listens_for(e, 'do_connect')
@event.listens_for(e, "do_connect")
def receive_do_connect(dialect, conn_rec, cargs, cparams):
return psycopg2.connect(*cargs, **cparams)
@@ -928,7 +941,8 @@ class DialectEvents(event.Events[Dialect]):
The setinputsizes hook overall is only used for dialects which include
the flag ``use_setinputsizes=True``. Dialects which use this
include cx_Oracle, pg8000, asyncpg, and pyodbc dialects.
include python-oracledb, cx_Oracle, pg8000, asyncpg, and pyodbc
dialects.
.. note::

View File

@@ -1,5 +1,5 @@
# engine/interfaces.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -10,7 +10,6 @@
from __future__ import annotations
from enum import Enum
from types import ModuleType
from typing import Any
from typing import Awaitable
from typing import Callable
@@ -34,7 +33,7 @@ from typing import Union
from .. import util
from ..event import EventTarget
from ..pool import Pool
from ..pool import PoolProxiedConnection
from ..pool import PoolProxiedConnection as PoolProxiedConnection
from ..sql.compiler import Compiled as Compiled
from ..sql.compiler import Compiled # noqa
from ..sql.compiler import TypeCompiler as TypeCompiler
@@ -51,6 +50,7 @@ if TYPE_CHECKING:
from .base import Engine
from .cursor import CursorResult
from .url import URL
from ..connectors.asyncio import AsyncIODBAPIConnection
from ..event import _ListenerFnType
from ..event import dispatcher
from ..exc import StatementError
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
from ..sql.sqltypes import Integer
from ..sql.type_api import _TypeMemoDict
from ..sql.type_api import TypeEngine
from ..util.langhelpers import generic_fn_descriptor
ConnectArgsType = Tuple[Sequence[str], MutableMapping[str, Any]]
@@ -106,6 +107,22 @@ class ExecuteStyle(Enum):
"""
class DBAPIModule(Protocol):
class Error(Exception):
def __getattr__(self, key: str) -> Any: ...
class OperationalError(Error):
pass
class InterfaceError(Error):
pass
class IntegrityError(Error):
pass
def __getattr__(self, key: str) -> Any: ...
class DBAPIConnection(Protocol):
"""protocol representing a :pep:`249` database connection.
@@ -118,19 +135,17 @@ class DBAPIConnection(Protocol):
""" # noqa: E501
def close(self) -> None:
...
def close(self) -> None: ...
def commit(self) -> None:
...
def commit(self) -> None: ...
def cursor(self) -> DBAPICursor:
...
def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ...
def rollback(self) -> None:
...
def rollback(self) -> None: ...
autocommit: bool
def __getattr__(self, key: str) -> Any: ...
def __setattr__(self, key: str, value: Any) -> None: ...
class DBAPIType(Protocol):
@@ -174,53 +189,43 @@ class DBAPICursor(Protocol):
...
@property
def rowcount(self) -> int:
...
def rowcount(self) -> int: ...
arraysize: int
lastrowid: int
def close(self) -> None:
...
def close(self) -> None: ...
def execute(
self,
operation: Any,
parameters: Optional[_DBAPISingleExecuteParams] = None,
) -> Any:
...
) -> Any: ...
def executemany(
self,
operation: Any,
parameters: Sequence[_DBAPIMultiExecuteParams],
) -> Any:
...
parameters: _DBAPIMultiExecuteParams,
) -> Any: ...
def fetchone(self) -> Optional[Any]:
...
def fetchone(self) -> Optional[Any]: ...
def fetchmany(self, size: int = ...) -> Sequence[Any]:
...
def fetchmany(self, size: int = ...) -> Sequence[Any]: ...
def fetchall(self) -> Sequence[Any]:
...
def fetchall(self) -> Sequence[Any]: ...
def setinputsizes(self, sizes: Sequence[Any]) -> None:
...
def setinputsizes(self, sizes: Sequence[Any]) -> None: ...
def setoutputsize(self, size: Any, column: Any) -> None:
...
def setoutputsize(self, size: Any, column: Any) -> None: ...
def callproc(self, procname: str, parameters: Sequence[Any] = ...) -> Any:
...
def callproc(
self, procname: str, parameters: Sequence[Any] = ...
) -> Any: ...
def nextset(self) -> Optional[bool]:
...
def nextset(self) -> Optional[bool]: ...
def __getattr__(self, key: str) -> Any:
...
def __getattr__(self, key: str) -> Any: ...
_CoreSingleExecuteParams = Mapping[str, Any]
@@ -284,6 +289,7 @@ class _CoreKnownExecutionOptions(TypedDict, total=False):
yield_per: int
insertmanyvalues_page_size: int
schema_translate_map: Optional[SchemaTranslateMapType]
preserve_rowcount: bool
_ExecuteOptions = immutabledict[str, Any]
@@ -593,8 +599,8 @@ class BindTyping(Enum):
"""Use the pep-249 setinputsizes method.
This is only implemented for DBAPIs that support this method and for which
the SQLAlchemy dialect has the appropriate infrastructure for that
dialect set up. Current dialects include cx_Oracle as well as
the SQLAlchemy dialect has the appropriate infrastructure for that dialect
set up. Current dialects include python-oracledb, cx_Oracle as well as
optional support for SQL Server using pyodbc.
When using setinputsizes, dialects also have a means of only using the
@@ -671,7 +677,7 @@ class Dialect(EventTarget):
dialect_description: str
dbapi: Optional[ModuleType]
dbapi: Optional[DBAPIModule]
"""A reference to the DBAPI module object itself.
SQLAlchemy dialects import DBAPI modules using the classmethod
@@ -695,7 +701,7 @@ class Dialect(EventTarget):
"""
@util.non_memoized_property
def loaded_dbapi(self) -> ModuleType:
def loaded_dbapi(self) -> DBAPIModule:
"""same as .dbapi, but is never None; will raise an error if no
DBAPI was set up.
@@ -773,6 +779,14 @@ class Dialect(EventTarget):
default_isolation_level: Optional[IsolationLevel]
"""the isolation that is implicitly present on new connections"""
skip_autocommit_rollback: bool
"""Whether or not the :paramref:`.create_engine.skip_autocommit_rollback`
parameter was set.
.. versionadded:: 2.0.43
"""
# create_engine() -> isolation_level currently goes here
_on_connect_isolation_level: Optional[IsolationLevel]
@@ -792,8 +806,14 @@ class Dialect(EventTarget):
max_identifier_length: int
"""The maximum length of identifier names."""
max_index_name_length: Optional[int]
"""The maximum length of index names if different from
``max_identifier_length``."""
max_constraint_name_length: Optional[int]
"""The maximum length of constraint names if different from
``max_identifier_length``."""
supports_server_side_cursors: bool
supports_server_side_cursors: Union[generic_fn_descriptor[bool], bool]
"""indicates if the dialect supports server side cursors"""
server_side_cursors: bool
@@ -884,12 +904,12 @@ class Dialect(EventTarget):
the statement multiple times for a series of batches when large numbers
of rows are given.
The parameter is False for the default dialect, and is set to
True for SQLAlchemy internal dialects SQLite, MySQL/MariaDB, PostgreSQL,
SQL Server. It remains at False for Oracle, which provides native
"executemany with RETURNING" support and also does not support
``supports_multivalues_insert``. For MySQL/MariaDB, those MySQL
dialects that don't support RETURNING will not report
The parameter is False for the default dialect, and is set to True for
SQLAlchemy internal dialects SQLite, MySQL/MariaDB, PostgreSQL, SQL Server.
It remains at False for Oracle Database, which provides native "executemany
with RETURNING" support and also does not support
``supports_multivalues_insert``. For MySQL/MariaDB, those MySQL dialects
that don't support RETURNING will not report
``insert_executemany_returning`` as True.
.. versionadded:: 2.0
@@ -1073,11 +1093,7 @@ class Dialect(EventTarget):
To implement, establish as a series of tuples, as in::
construct_arguments = [
(schema.Index, {
"using": False,
"where": None,
"ops": None
})
(schema.Index, {"using": False, "where": None, "ops": None}),
]
If the above construct is established on the PostgreSQL dialect,
@@ -1106,7 +1122,8 @@ class Dialect(EventTarget):
established on a :class:`.Table` object which will be passed as
"reflection options" when using :paramref:`.Table.autoload_with`.
Current example is "oracle_resolve_synonyms" in the Oracle dialect.
Current example is "oracle_resolve_synonyms" in the Oracle Database
dialects.
"""
@@ -1130,7 +1147,7 @@ class Dialect(EventTarget):
supports_constraint_comments: bool
"""Indicates if the dialect supports comment DDL on constraints.
.. versionadded: 2.0
.. versionadded:: 2.0
"""
_has_events = False
@@ -1249,7 +1266,7 @@ class Dialect(EventTarget):
raise NotImplementedError()
@classmethod
def import_dbapi(cls) -> ModuleType:
def import_dbapi(cls) -> DBAPIModule:
"""Import the DBAPI module that is used by this dialect.
The Python module object returned here will be assigned as an
@@ -1266,8 +1283,7 @@ class Dialect(EventTarget):
"""
raise NotImplementedError()
@classmethod
def type_descriptor(cls, typeobj: TypeEngine[_T]) -> TypeEngine[_T]:
def type_descriptor(self, typeobj: TypeEngine[_T]) -> TypeEngine[_T]:
"""Transform a generic type to a dialect-specific type.
Dialect classes will usually use the
@@ -1299,12 +1315,9 @@ class Dialect(EventTarget):
"""
pass
if TYPE_CHECKING:
def _overrides_default(self, method_name: str) -> bool:
...
def _overrides_default(self, method_name: str) -> bool: ...
def get_columns(
self,
@@ -1330,6 +1343,7 @@ class Dialect(EventTarget):
def get_multi_columns(
self,
connection: Connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw: Any,
@@ -1378,6 +1392,7 @@ class Dialect(EventTarget):
def get_multi_pk_constraint(
self,
connection: Connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw: Any,
@@ -1424,6 +1439,7 @@ class Dialect(EventTarget):
def get_multi_foreign_keys(
self,
connection: Connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw: Any,
@@ -1583,6 +1599,7 @@ class Dialect(EventTarget):
def get_multi_indexes(
self,
connection: Connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw: Any,
@@ -1629,6 +1646,7 @@ class Dialect(EventTarget):
def get_multi_unique_constraints(
self,
connection: Connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw: Any,
@@ -1676,6 +1694,7 @@ class Dialect(EventTarget):
def get_multi_check_constraints(
self,
connection: Connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw: Any,
@@ -1718,6 +1737,7 @@ class Dialect(EventTarget):
def get_multi_table_options(
self,
connection: Connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw: Any,
@@ -1769,6 +1789,7 @@ class Dialect(EventTarget):
def get_multi_table_comment(
self,
connection: Connection,
*,
schema: Optional[str] = None,
filter_names: Optional[Collection[str]] = None,
**kw: Any,
@@ -2161,6 +2182,7 @@ class Dialect(EventTarget):
def _deliver_insertmanyvalues_batches(
self,
connection: Connection,
cursor: DBAPICursor,
statement: str,
parameters: _DBAPIMultiExecuteParams,
@@ -2214,7 +2236,7 @@ class Dialect(EventTarget):
def is_disconnect(
self,
e: Exception,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
@@ -2318,7 +2340,7 @@ class Dialect(EventTarget):
"""
return self.on_connect()
def on_connect(self) -> Optional[Callable[[Any], Any]]:
def on_connect(self) -> Optional[Callable[[Any], None]]:
"""return a callable which sets up a newly created DBAPI connection.
The callable should accept a single argument "conn" which is the
@@ -2467,6 +2489,30 @@ class Dialect(EventTarget):
raise NotImplementedError()
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
"""Detect the current autocommit setting for a DBAPI connection.
:param dbapi_connection: a DBAPI connection object
:return: True if autocommit is enabled, False if disabled
:rtype: bool
This method inspects the given DBAPI connection to determine
whether autocommit mode is currently enabled. The specific
mechanism for detecting autocommit varies by database dialect
and DBAPI driver, however it should be done **without** network
round trips.
.. note::
Not all dialects support autocommit detection. Dialects
that do not support this feature will raise
:exc:`NotImplementedError`.
"""
raise NotImplementedError(
"This dialect cannot detect autocommit on a DBAPI connection"
)
def get_default_isolation_level(
self, dbapi_conn: DBAPIConnection
) -> IsolationLevel:
@@ -2491,7 +2537,7 @@ class Dialect(EventTarget):
def get_isolation_level_values(
self, dbapi_conn: DBAPIConnection
) -> List[IsolationLevel]:
) -> Sequence[IsolationLevel]:
"""return a sequence of string isolation level names that are accepted
by this dialect.
@@ -2504,7 +2550,7 @@ class Dialect(EventTarget):
``REPEATABLE READ``. isolation level names will have underscores
converted to spaces before being passed along to the dialect.
* The names for the four standard isolation names to the extent that
they are supported by the backend should be ``READ UNCOMMITTED``
they are supported by the backend should be ``READ UNCOMMITTED``,
``READ COMMITTED``, ``REPEATABLE READ``, ``SERIALIZABLE``
* if the dialect supports an autocommit option it should be provided
using the isolation level name ``AUTOCOMMIT``.
@@ -2665,6 +2711,9 @@ class Dialect(EventTarget):
"""return a Pool class to use for a given URL"""
raise NotImplementedError()
def validate_identifier(self, ident: str) -> None:
"""Validates an identifier name, raising an exception if invalid"""
class CreateEnginePlugin:
"""A set of hooks intended to augment the construction of an
@@ -2690,11 +2739,14 @@ class CreateEnginePlugin:
from sqlalchemy.engine import CreateEnginePlugin
from sqlalchemy import event
class LogCursorEventsPlugin(CreateEnginePlugin):
def __init__(self, url, kwargs):
# consume the parameter "log_cursor_logging_name" from the
# URL query
logging_name = url.query.get("log_cursor_logging_name", "log_cursor")
logging_name = url.query.get(
"log_cursor_logging_name", "log_cursor"
)
self.log = logging.getLogger(logging_name)
@@ -2706,7 +2758,6 @@ class CreateEnginePlugin:
"attach an event listener after the new Engine is constructed"
event.listen(engine, "before_cursor_execute", self._log_event)
def _log_event(
self,
conn,
@@ -2714,19 +2765,19 @@ class CreateEnginePlugin:
statement,
parameters,
context,
executemany):
executemany,
):
self.log.info("Plugin logged cursor event: %s", statement)
Plugins are registered using entry points in a similar way as that
of dialects::
entry_points={
'sqlalchemy.plugins': [
'log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin'
entry_points = {
"sqlalchemy.plugins": [
"log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin"
]
}
A plugin that uses the above names would be invoked from a database
URL as in::
@@ -2743,15 +2794,16 @@ class CreateEnginePlugin:
in the URL::
engine = create_engine(
"mysql+pymysql://scott:tiger@localhost/test?"
"plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three")
"mysql+pymysql://scott:tiger@localhost/test?"
"plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three"
)
The plugin names may also be passed directly to :func:`_sa.create_engine`
using the :paramref:`_sa.create_engine.plugins` argument::
engine = create_engine(
"mysql+pymysql://scott:tiger@localhost/test",
plugins=["myplugin"])
"mysql+pymysql://scott:tiger@localhost/test", plugins=["myplugin"]
)
.. versionadded:: 1.2.3 plugin names can also be specified
to :func:`_sa.create_engine` as a list
@@ -2773,9 +2825,9 @@ class CreateEnginePlugin:
class MyPlugin(CreateEnginePlugin):
def __init__(self, url, kwargs):
self.my_argument_one = url.query['my_argument_one']
self.my_argument_two = url.query['my_argument_two']
self.my_argument_three = kwargs.pop('my_argument_three', None)
self.my_argument_one = url.query["my_argument_one"]
self.my_argument_two = url.query["my_argument_two"]
self.my_argument_three = kwargs.pop("my_argument_three", None)
def update_url(self, url):
return url.difference_update_query(
@@ -2788,9 +2840,9 @@ class CreateEnginePlugin:
from sqlalchemy import create_engine
engine = create_engine(
"mysql+pymysql://scott:tiger@localhost/test?"
"plugin=myplugin&my_argument_one=foo&my_argument_two=bar",
my_argument_three='bat'
"mysql+pymysql://scott:tiger@localhost/test?"
"plugin=myplugin&my_argument_one=foo&my_argument_two=bar",
my_argument_three="bat",
)
.. versionchanged:: 1.4
@@ -2809,15 +2861,15 @@ class CreateEnginePlugin:
def __init__(self, url, kwargs):
if hasattr(CreateEnginePlugin, "update_url"):
# detect the 1.4 API
self.my_argument_one = url.query['my_argument_one']
self.my_argument_two = url.query['my_argument_two']
self.my_argument_one = url.query["my_argument_one"]
self.my_argument_two = url.query["my_argument_two"]
else:
# detect the 1.3 and earlier API - mutate the
# URL directly
self.my_argument_one = url.query.pop('my_argument_one')
self.my_argument_two = url.query.pop('my_argument_two')
self.my_argument_one = url.query.pop("my_argument_one")
self.my_argument_two = url.query.pop("my_argument_two")
self.my_argument_three = kwargs.pop('my_argument_three', None)
self.my_argument_three = kwargs.pop("my_argument_three", None)
def update_url(self, url):
# this method is only called in the 1.4 version
@@ -2992,6 +3044,9 @@ class ExecutionContext:
inline SQL expression value was fired off. Applies to inserts
and updates."""
execution_options: _ExecuteOptions
"""Execution options associated with the current statement execution"""
@classmethod
def _init_ddl(
cls,
@@ -3366,7 +3421,7 @@ class AdaptedConnection:
__slots__ = ("_connection",)
_connection: Any
_connection: AsyncIODBAPIConnection
@property
def driver_connection(self) -> Any:
@@ -3385,11 +3440,14 @@ class AdaptedConnection:
engine = create_async_engine(...)
@event.listens_for(engine.sync_engine, "connect")
def register_custom_types(dbapi_connection, ...):
def register_custom_types(
dbapi_connection, # ...
):
dbapi_connection.run_async(
lambda connection: connection.set_type_codec(
'MyCustomType', encoder, decoder, ...
"MyCustomType", encoder, decoder, ...
)
)

View File

@@ -1,5 +1,5 @@
# engine/mock.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -27,10 +27,9 @@ if typing.TYPE_CHECKING:
from .interfaces import Dialect
from .url import URL
from ..sql.base import Executable
from ..sql.ddl import SchemaDropper
from ..sql.ddl import SchemaGenerator
from ..sql.ddl import InvokeDDLBase
from ..sql.schema import HasSchemaAttr
from ..sql.schema import SchemaItem
from ..sql.visitors import Visitable
class MockConnection:
@@ -53,12 +52,14 @@ class MockConnection:
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
element: SchemaItem,
visitorcallable: Type[InvokeDDLBase],
element: Visitable,
**kwargs: Any,
) -> None:
kwargs["checkfirst"] = False
visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
visitorcallable(
dialect=self.dialect, connection=self, **kwargs
).traverse_single(element)
def execute(
self,
@@ -90,10 +91,12 @@ def create_mock_engine(
from sqlalchemy import create_mock_engine
def dump(sql, *multiparams, **params):
print(sql.compile(dialect=engine.dialect))
engine = create_mock_engine('postgresql+psycopg2://', dump)
engine = create_mock_engine("postgresql+psycopg2://", dump)
metadata.create_all(engine, checkfirst=False)
:param url: A string URL which typically needs to contain only the

View File

@@ -1,5 +1,5 @@
# sqlalchemy/processors.py
# Copyright (C) 2010-2023 the SQLAlchemy authors and contributors
# engine/processors.py
# Copyright (C) 2010-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
#

View File

@@ -1,5 +1,5 @@
# engine/reflection.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -55,6 +55,7 @@ from .. import util
from ..sql import operators
from ..sql import schema as sa_schema
from ..sql.cache_key import _ad_hoc_cache_key_from_args
from ..sql.elements import quoted_name
from ..sql.elements import TextClause
from ..sql.type_api import TypeEngine
from ..sql.visitors import InternalTraversal
@@ -89,8 +90,16 @@ def cache(
exclude = {"info_cache", "unreflectable"}
key = (
fn.__name__,
tuple(a for a in args if isinstance(a, str)),
tuple((k, v) for k, v in kw.items() if k not in exclude),
tuple(
(str(a), a.quote) if isinstance(a, quoted_name) else a
for a in args
if isinstance(a, str)
),
tuple(
(k, (str(v), v.quote) if isinstance(v, quoted_name) else v)
for k, v in kw.items()
if k not in exclude
),
)
ret: _R = info_cache.get(key)
if ret is None:
@@ -184,7 +193,8 @@ class Inspector(inspection.Inspectable["Inspector"]):
or a :class:`_engine.Connection`::
from sqlalchemy import inspect, create_engine
engine = create_engine('...')
engine = create_engine("...")
insp = inspect(engine)
Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` associated
@@ -621,7 +631,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
r"""Return a list of temporary table names for the current bind.
This method is unsupported by most dialects; currently
only Oracle, PostgreSQL and SQLite implements it.
only Oracle Database, PostgreSQL and SQLite implements it.
:param \**kw: Additional keyword argument to pass to the dialect
specific implementation. See the documentation of the dialect
@@ -657,7 +667,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
given name was created.
This currently includes some options that apply to MySQL and Oracle
tables.
Database tables.
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
@@ -1483,9 +1493,9 @@ class Inspector(inspection.Inspectable["Inspector"]):
from sqlalchemy import create_engine, MetaData, Table
from sqlalchemy import inspect
engine = create_engine('...')
engine = create_engine("...")
meta = MetaData()
user_table = Table('user', meta)
user_table = Table("user", meta)
insp = inspect(engine)
insp.reflect_table(user_table, None)
@@ -1704,9 +1714,12 @@ class Inspector(inspection.Inspectable["Inspector"]):
if pk in cols_by_orig_name and pk not in exclude_columns
]
# update pk constraint name and comment
# update pk constraint name, comment and dialect_kwargs
table.primary_key.name = pk_cons.get("name")
table.primary_key.comment = pk_cons.get("comment", None)
dialect_options = pk_cons.get("dialect_options")
if dialect_options:
table.primary_key.dialect_kwargs.update(dialect_options)
# tell the PKConstraint to re-initialize
# its column collection
@@ -1843,7 +1856,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
if not expressions:
util.warn(
f"Skipping {flavor} {name!r} because key "
f"{index+1} reflected as None but no "
f"{index + 1} reflected as None but no "
"'expressions' were returned"
)
break

View File

@@ -1,5 +1,5 @@
# engine/result.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -52,11 +52,11 @@ else:
from sqlalchemy.cyextension.resultproxy import tuplegetter as tuplegetter
if typing.TYPE_CHECKING:
from ..sql.schema import Column
from ..sql.elements import SQLCoreOperations
from ..sql.type_api import _ResultProcessorType
_KeyType = Union[str, "Column[Any]"]
_KeyIndexType = Union[str, "Column[Any]", int]
_KeyType = Union[str, "SQLCoreOperations[Any]"]
_KeyIndexType = Union[_KeyType, int]
# is overridden in cursor using _CursorKeyMapRecType
_KeyMapRecType = Any
@@ -64,7 +64,7 @@ _KeyMapRecType = Any
_KeyMapType = Mapping[_KeyType, _KeyMapRecType]
_RowData = Union[Row, RowMapping, Any]
_RowData = Union[Row[Any], RowMapping, Any]
"""A generic form of "row" that accommodates for the different kinds of
"rows" that different result objects return, including row, row mapping, and
scalar values"""
@@ -82,7 +82,7 @@ across all the result types
"""
_InterimSupportsScalarsRowType = Union[Row, Any]
_InterimSupportsScalarsRowType = Union[Row[Any], Any]
_ProcessorsType = Sequence[Optional["_ResultProcessorType[Any]"]]
_TupleGetterType = Callable[[Sequence[Any]], Sequence[Any]]
@@ -116,8 +116,7 @@ class ResultMetaData:
@overload
def _key_fallback(
self, key: Any, err: Optional[Exception], raiseerr: Literal[True] = ...
) -> NoReturn:
...
) -> NoReturn: ...
@overload
def _key_fallback(
@@ -125,14 +124,12 @@ class ResultMetaData:
key: Any,
err: Optional[Exception],
raiseerr: Literal[False] = ...,
) -> None:
...
) -> None: ...
@overload
def _key_fallback(
self, key: Any, err: Optional[Exception], raiseerr: bool = ...
) -> Optional[NoReturn]:
...
) -> Optional[NoReturn]: ...
def _key_fallback(
self, key: Any, err: Optional[Exception], raiseerr: bool = True
@@ -329,9 +326,6 @@ class SimpleResultMetaData(ResultMetaData):
_tuplefilter=_tuplefilter,
)
def _contains(self, value: Any, row: Row[Any]) -> bool:
return value in row._data
def _index_for_key(self, key: Any, raiseerr: bool = True) -> int:
if int in key.__class__.__mro__:
key = self._keys[key]
@@ -728,14 +722,21 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
return manyrows
@overload
def _only_one_row(
self: ResultInternal[Row[Any]],
raise_for_second_row: bool,
raise_for_none: bool,
scalar: Literal[True],
) -> Any: ...
@overload
def _only_one_row(
self,
raise_for_second_row: bool,
raise_for_none: Literal[True],
scalar: bool,
) -> _R:
...
) -> _R: ...
@overload
def _only_one_row(
@@ -743,8 +744,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
raise_for_second_row: bool,
raise_for_none: bool,
scalar: bool,
) -> Optional[_R]:
...
) -> Optional[_R]: ...
def _only_one_row(
self,
@@ -817,7 +817,6 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
"was required"
)
else:
next_row = _NO_ROW
# if we checked for second row then that would have
# closed us :)
self._soft_close(hard=True)
@@ -1107,17 +1106,15 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
statement = select(table.c.x, table.c.y, table.c.z)
result = connection.execute(statement)
for z, y in result.columns('z', 'y'):
# ...
for z, y in result.columns("z", "y"):
...
Example of using the column objects from the statement itself::
for z, y in result.columns(
statement.selected_columns.c.z,
statement.selected_columns.c.y
statement.selected_columns.c.z, statement.selected_columns.c.y
):
# ...
...
.. versionadded:: 1.4
@@ -1132,18 +1129,15 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
return self._column_slices(col_expressions)
@overload
def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]:
...
def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: ...
@overload
def scalars(
self: Result[Tuple[_T]], index: Literal[0]
) -> ScalarResult[_T]:
...
) -> ScalarResult[_T]: ...
@overload
def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]:
...
def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: ...
def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]:
"""Return a :class:`_engine.ScalarResult` filtering object which
@@ -1352,7 +1346,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]:
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
When all rows are exhausted, returns an empty sequence.
This method is provided for backwards compatibility with
SQLAlchemy 1.x.x.
@@ -1360,7 +1354,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
To fetch rows in groups, use the :meth:`_engine.Result.partitions`
method.
:return: a list of :class:`_engine.Row` objects.
:return: a sequence of :class:`_engine.Row` objects.
.. seealso::
@@ -1371,14 +1365,14 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
return self._manyrow_getter(self, size)
def all(self) -> Sequence[Row[_TP]]:
"""Return all rows in a list.
"""Return all rows in a sequence.
Closes the result set after invocation. Subsequent invocations
will return an empty list.
will return an empty sequence.
.. versionadded:: 1.4
:return: a list of :class:`_engine.Row` objects.
:return: a sequence of :class:`_engine.Row` objects.
.. seealso::
@@ -1454,22 +1448,20 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
)
@overload
def scalar_one(self: Result[Tuple[_T]]) -> _T:
...
def scalar_one(self: Result[Tuple[_T]]) -> _T: ...
@overload
def scalar_one(self) -> Any:
...
def scalar_one(self) -> Any: ...
def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_engine.Result.scalars` and
then :meth:`_engine.Result.one`.
then :meth:`_engine.ScalarResult.one`.
.. seealso::
:meth:`_engine.Result.one`
:meth:`_engine.ScalarResult.one`
:meth:`_engine.Result.scalars`
@@ -1479,22 +1471,20 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
)
@overload
def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]:
...
def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]: ...
@overload
def scalar_one_or_none(self) -> Optional[Any]:
...
def scalar_one_or_none(self) -> Optional[Any]: ...
def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one scalar result or ``None``.
This is equivalent to calling :meth:`_engine.Result.scalars` and
then :meth:`_engine.Result.one_or_none`.
then :meth:`_engine.ScalarResult.one_or_none`.
.. seealso::
:meth:`_engine.Result.one_or_none`
:meth:`_engine.ScalarResult.one_or_none`
:meth:`_engine.Result.scalars`
@@ -1506,8 +1496,8 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
def one(self) -> Row[_TP]:
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
rows, or :class:`.MultipleResultsFound` if multiple rows
Raises :class:`_exc.NoResultFound` if the result returns no
rows, or :class:`_exc.MultipleResultsFound` if multiple rows
would be returned.
.. note:: This method returns one **row**, e.g. tuple, by default.
@@ -1537,12 +1527,10 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]):
)
@overload
def scalar(self: Result[Tuple[_T]]) -> Optional[_T]:
...
def scalar(self: Result[Tuple[_T]]) -> Optional[_T]: ...
@overload
def scalar(self) -> Any:
...
def scalar(self) -> Any: ...
def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
@@ -1776,7 +1764,7 @@ class ScalarResult(FilterResult[_R]):
return self._manyrow_getter(self, size)
def all(self) -> Sequence[_R]:
"""Return all scalar values in a list.
"""Return all scalar values in a sequence.
Equivalent to :meth:`_engine.Result.all` except that
scalar values, rather than :class:`_engine.Row` objects,
@@ -1880,7 +1868,7 @@ class TupleResult(FilterResult[_R], util.TypingOnly):
...
def all(self) -> Sequence[_R]: # noqa: A001
"""Return all scalar values in a list.
"""Return all scalar values in a sequence.
Equivalent to :meth:`_engine.Result.all` except that
tuple values, rather than :class:`_engine.Row` objects,
@@ -1889,11 +1877,9 @@ class TupleResult(FilterResult[_R], util.TypingOnly):
"""
...
def __iter__(self) -> Iterator[_R]:
...
def __iter__(self) -> Iterator[_R]: ...
def __next__(self) -> _R:
...
def __next__(self) -> _R: ...
def first(self) -> Optional[_R]:
"""Fetch the first object or ``None`` if no object is present.
@@ -1927,22 +1913,20 @@ class TupleResult(FilterResult[_R], util.TypingOnly):
...
@overload
def scalar_one(self: TupleResult[Tuple[_T]]) -> _T:
...
def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: ...
@overload
def scalar_one(self) -> Any:
...
def scalar_one(self) -> Any: ...
def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_engine.Result.scalars`
and then :meth:`_engine.Result.one`.
and then :meth:`_engine.ScalarResult.one`.
.. seealso::
:meth:`_engine.Result.one`
:meth:`_engine.ScalarResult.one`
:meth:`_engine.Result.scalars`
@@ -1950,22 +1934,22 @@ class TupleResult(FilterResult[_R], util.TypingOnly):
...
@overload
def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]:
...
def scalar_one_or_none(
self: TupleResult[Tuple[_T]],
) -> Optional[_T]: ...
@overload
def scalar_one_or_none(self) -> Optional[Any]:
...
def scalar_one_or_none(self) -> Optional[Any]: ...
def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
This is equivalent to calling :meth:`_engine.Result.scalars`
and then :meth:`_engine.Result.one_or_none`.
and then :meth:`_engine.ScalarResult.one_or_none`.
.. seealso::
:meth:`_engine.Result.one_or_none`
:meth:`_engine.ScalarResult.one_or_none`
:meth:`_engine.Result.scalars`
@@ -1973,12 +1957,10 @@ class TupleResult(FilterResult[_R], util.TypingOnly):
...
@overload
def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]:
...
def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: ...
@overload
def scalar(self) -> Any:
...
def scalar(self) -> Any: ...
def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result
@@ -2031,7 +2013,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
return self
def columns(self, *col_expressions: _KeyIndexType) -> Self:
r"""Establish the columns that should be returned in each row."""
"""Establish the columns that should be returned in each row."""
return self._column_slices(col_expressions)
def partitions(
@@ -2086,7 +2068,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]):
return self._manyrow_getter(self, size)
def all(self) -> Sequence[RowMapping]:
"""Return all scalar values in a list.
"""Return all scalar values in a sequence.
Equivalent to :meth:`_engine.Result.all` except that
:class:`_engine.RowMapping` values, rather than :class:`_engine.Row`

View File

@@ -1,5 +1,5 @@
# engine/row.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -213,15 +213,12 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]):
if TYPE_CHECKING:
@overload
def __getitem__(self, index: int) -> Any:
...
def __getitem__(self, index: int) -> Any: ...
@overload
def __getitem__(self, index: slice) -> Sequence[Any]:
...
def __getitem__(self, index: slice) -> Sequence[Any]: ...
def __getitem__(self, index: Union[int, slice]) -> Any:
...
def __getitem__(self, index: Union[int, slice]) -> Any: ...
def __lt__(self, other: Any) -> bool:
return self._op(other, operator.lt)
@@ -296,8 +293,8 @@ class ROMappingView(ABC):
def __init__(
self, mapping: Mapping["_KeyType", Any], items: Sequence[Any]
):
self._mapping = mapping
self._items = items
self._mapping = mapping # type: ignore[misc]
self._items = items # type: ignore[misc]
def __len__(self) -> int:
return len(self._items)
@@ -321,11 +318,11 @@ class ROMappingView(ABC):
class ROMappingKeysValuesView(
ROMappingView, typing.KeysView["_KeyType"], typing.ValuesView[Any]
):
__slots__ = ("_items",)
__slots__ = ("_items",) # mapping slot is provided by KeysView
class ROMappingItemsView(ROMappingView, typing.ItemsView["_KeyType", Any]):
__slots__ = ("_items",)
__slots__ = ("_items",) # mapping slot is provided by ItemsView
class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]):
@@ -343,12 +340,11 @@ class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]):
as iteration of keys, values, and items::
for row in result:
if 'a' in row._mapping:
print("Column 'a': %s" % row._mapping['a'])
if "a" in row._mapping:
print("Column 'a': %s" % row._mapping["a"])
print("Column b: %s" % row._mapping[table.c.b])
.. versionadded:: 1.4 The :class:`.RowMapping` object replaces the
mapping-like access previously provided by a database result row,
which now seeks to behave mostly like a named tuple.
@@ -359,8 +355,7 @@ class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]):
if TYPE_CHECKING:
def __getitem__(self, key: _KeyType) -> Any:
...
def __getitem__(self, key: _KeyType) -> Any: ...
else:
__getitem__ = BaseRow._get_by_key_impl_mapping

View File

@@ -1,14 +1,11 @@
# engine/strategies.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Deprecated mock engine strategy used by Alembic.
"""
"""Deprecated mock engine strategy used by Alembic."""
from __future__ import annotations

View File

@@ -1,5 +1,5 @@
# engine/url.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -32,6 +32,7 @@ from typing import Tuple
from typing import Type
from typing import Union
from urllib.parse import parse_qsl
from urllib.parse import quote
from urllib.parse import quote_plus
from urllib.parse import unquote
@@ -121,7 +122,9 @@ class URL(NamedTuple):
for keys and either strings or tuples of strings for values, e.g.::
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
>>> url = make_url(
... "postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt"
... )
>>> url.query
immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'})
@@ -170,6 +173,11 @@ class URL(NamedTuple):
:param password: database password. Is typically a string, but may
also be an object that can be stringified with ``str()``.
.. note:: The password string should **not** be URL encoded when
passed as an argument to :meth:`_engine.URL.create`; the string
should contain the password characters exactly as they would be
typed.
.. note:: A password-producing object will be stringified only
**once** per :class:`_engine.Engine` object. For dynamic password
generation per connect, see :ref:`engines_dynamic_tokens`.
@@ -247,14 +255,12 @@ class URL(NamedTuple):
@overload
def _assert_value(
val: str,
) -> str:
...
) -> str: ...
@overload
def _assert_value(
val: Sequence[str],
) -> Union[str, Tuple[str, ...]]:
...
) -> Union[str, Tuple[str, ...]]: ...
def _assert_value(
val: Union[str, Sequence[str]],
@@ -367,7 +373,9 @@ class URL(NamedTuple):
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname")
>>> url = url.update_query_string("alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
>>> url = url.update_query_string(
... "alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt"
... )
>>> str(url)
'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
@@ -403,7 +411,13 @@ class URL(NamedTuple):
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname")
>>> url = url.update_query_pairs([("alt_host", "host1"), ("alt_host", "host2"), ("ssl_cipher", "/path/to/crt")])
>>> url = url.update_query_pairs(
... [
... ("alt_host", "host1"),
... ("alt_host", "host2"),
... ("ssl_cipher", "/path/to/crt"),
... ]
... )
>>> str(url)
'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
@@ -485,7 +499,9 @@ class URL(NamedTuple):
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname")
>>> url = url.update_query_dict({"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"})
>>> url = url.update_query_dict(
... {"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"}
... )
>>> str(url)
'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
@@ -523,14 +539,14 @@ class URL(NamedTuple):
E.g.::
url = url.difference_update_query(['foo', 'bar'])
url = url.difference_update_query(["foo", "bar"])
Equivalent to using :meth:`_engine.URL.set` as follows::
url = url.set(
query={
key: url.query[key]
for key in set(url.query).difference(['foo', 'bar'])
for key in set(url.query).difference(["foo", "bar"])
}
)
@@ -579,7 +595,9 @@ class URL(NamedTuple):
>>> from sqlalchemy.engine import make_url
>>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
>>> url = make_url(
... "postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt"
... )
>>> url.query
immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'})
>>> url.normalized_query
@@ -621,17 +639,17 @@ class URL(NamedTuple):
"""
s = self.drivername + "://"
if self.username is not None:
s += _sqla_url_quote(self.username)
s += quote(self.username, safe=" +")
if self.password is not None:
s += ":" + (
"***"
if hide_password
else _sqla_url_quote(str(self.password))
else quote(str(self.password), safe=" +")
)
s += "@"
if self.host is not None:
if ":" in self.host:
s += "[%s]" % self.host
s += f"[{self.host}]"
else:
s += self.host
if self.port is not None:
@@ -642,7 +660,7 @@ class URL(NamedTuple):
keys = list(self.query)
keys.sort()
s += "?" + "&".join(
"%s=%s" % (quote_plus(k), quote_plus(element))
f"{quote_plus(k)}={quote_plus(element)}"
for k in keys
for element in util.to_list(self.query[k])
)
@@ -885,10 +903,10 @@ def _parse_url(name: str) -> URL:
components["query"] = query
if components["username"] is not None:
components["username"] = _sqla_url_unquote(components["username"])
components["username"] = unquote(components["username"])
if components["password"] is not None:
components["password"] = _sqla_url_unquote(components["password"])
components["password"] = unquote(components["password"])
ipv4host = components.pop("ipv4host")
ipv6host = components.pop("ipv6host")
@@ -902,12 +920,5 @@ def _parse_url(name: str) -> URL:
else:
raise exc.ArgumentError(
"Could not parse SQLAlchemy URL from string '%s'" % name
"Could not parse SQLAlchemy URL from given URL string"
)
def _sqla_url_quote(text: str) -> str:
return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text)
_sqla_url_unquote = unquote

View File

@@ -1,5 +1,5 @@
# engine/util.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -17,6 +17,7 @@ from .. import exc
from .. import util
from ..util._has_cy import HAS_CYEXTENSION
from ..util.typing import Protocol
from ..util.typing import Self
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_util import _distill_params_20 as _distill_params_20
@@ -113,7 +114,7 @@ class TransactionalContext:
"before emitting further commands."
)
def __enter__(self) -> TransactionalContext:
def __enter__(self) -> Self:
subject = self._get_subject()
# none for outer transaction, may be non-None for nested