API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,5 +1,5 @@
# engine/default.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -58,6 +58,7 @@ from ..sql import compiler
from ..sql import dml
from ..sql import expression
from ..sql import type_api
from ..sql import util as sql_util
from ..sql._typing import is_tuple_type
from ..sql.base import _NoArg
from ..sql.compiler import DDLCompiler
@@ -76,10 +77,13 @@ if typing.TYPE_CHECKING:
from .interfaces import _CoreSingleExecuteParams
from .interfaces import _DBAPICursorDescription
from .interfaces import _DBAPIMultiExecuteParams
from .interfaces import _DBAPISingleExecuteParams
from .interfaces import _ExecuteOptions
from .interfaces import _MutableCoreSingleExecuteParams
from .interfaces import _ParamStyle
from .interfaces import ConnectArgsType
from .interfaces import DBAPIConnection
from .interfaces import DBAPIModule
from .interfaces import IsolationLevel
from .row import Row
from .url import URL
@@ -95,8 +99,10 @@ if typing.TYPE_CHECKING:
from ..sql.elements import BindParameter
from ..sql.schema import Column
from ..sql.type_api import _BindProcessorType
from ..sql.type_api import _ResultProcessorType
from ..sql.type_api import TypeEngine
# When we're handed literal SQL, ensure it's a SELECT query
SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
@@ -167,7 +173,10 @@ class DefaultDialect(Dialect):
tuple_in_values = False
connection_characteristics = util.immutabledict(
{"isolation_level": characteristics.IsolationLevelCharacteristic()}
{
"isolation_level": characteristics.IsolationLevelCharacteristic(),
"logging_token": characteristics.LoggingTokenCharacteristic(),
}
)
engine_config_types: Mapping[str, Any] = util.immutabledict(
@@ -249,7 +258,7 @@ class DefaultDialect(Dialect):
default_schema_name: Optional[str] = None
# indicates symbol names are
# UPPERCASEd if they are case insensitive
# UPPERCASED if they are case insensitive
# within the database.
# if this is True, the methods normalize_name()
# and denormalize_name() must be provided.
@@ -298,6 +307,7 @@ class DefaultDialect(Dialect):
# Linting.NO_LINTING constant
compiler_linting: Linting = int(compiler.NO_LINTING), # type: ignore
server_side_cursors: bool = False,
skip_autocommit_rollback: bool = False,
**kwargs: Any,
):
if server_side_cursors:
@@ -322,6 +332,8 @@ class DefaultDialect(Dialect):
self.dbapi = dbapi
self.skip_autocommit_rollback = skip_autocommit_rollback
if paramstyle is not None:
self.paramstyle = paramstyle
elif self.dbapi is not None:
@@ -387,7 +399,8 @@ class DefaultDialect(Dialect):
available if the dialect in use has opted into using the
"use_insertmanyvalues" feature. If they haven't opted into that, then
this attribute is False, unless the dialect in question overrides this
and provides some other implementation (such as the Oracle dialect).
and provides some other implementation (such as the Oracle Database
dialects).
"""
return self.insert_returning and self.use_insertmanyvalues
@@ -410,7 +423,7 @@ class DefaultDialect(Dialect):
If the dialect in use hasn't opted into that, then this attribute is
False, unless the dialect in question overrides this and provides some
other implementation (such as the Oracle dialect).
other implementation (such as the Oracle Database dialects).
"""
return self.insert_returning and self.use_insertmanyvalues
@@ -419,7 +432,7 @@ class DefaultDialect(Dialect):
delete_executemany_returning = False
@util.memoized_property
def loaded_dbapi(self) -> ModuleType:
def loaded_dbapi(self) -> DBAPIModule:
if self.dbapi is None:
raise exc.InvalidRequestError(
f"Dialect {self} does not have a Python DBAPI established "
@@ -431,7 +444,7 @@ class DefaultDialect(Dialect):
def _bind_typing_render_casts(self):
return self.bind_typing is interfaces.BindTyping.RENDER_CASTS
def _ensure_has_table_connection(self, arg):
def _ensure_has_table_connection(self, arg: Connection) -> None:
if not isinstance(arg, Connection):
raise exc.ArgumentError(
"The argument passed to Dialect.has_table() should be a "
@@ -468,7 +481,7 @@ class DefaultDialect(Dialect):
return weakref.WeakKeyDictionary()
@property
def dialect_description(self):
def dialect_description(self): # type: ignore[override]
return self.name + "+" + self.driver
@property
@@ -509,7 +522,7 @@ class DefaultDialect(Dialect):
else:
return None
def initialize(self, connection):
def initialize(self, connection: Connection) -> None:
try:
self.server_version_info = self._get_server_version_info(
connection
@@ -545,7 +558,7 @@ class DefaultDialect(Dialect):
% (self.label_length, self.max_identifier_length)
)
def on_connect(self):
def on_connect(self) -> Optional[Callable[[Any], None]]:
# inherits the docstring from interfaces.Dialect.on_connect
return None
@@ -604,18 +617,18 @@ class DefaultDialect(Dialect):
) -> bool:
return schema_name in self.get_schema_names(connection, **kw)
def validate_identifier(self, ident):
def validate_identifier(self, ident: str) -> None:
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
"Identifier '%s' exceeds maximum length of %d characters"
% (ident, self.max_identifier_length)
)
def connect(self, *cargs, **cparams):
def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection:
# inherits the docstring from interfaces.Dialect.connect
return self.loaded_dbapi.connect(*cargs, **cparams)
return self.loaded_dbapi.connect(*cargs, **cparams) # type: ignore[no-any-return] # NOQA: E501
def create_connect_args(self, url):
def create_connect_args(self, url: URL) -> ConnectArgsType:
# inherits the docstring from interfaces.Dialect.create_connect_args
opts = url.translate_connect_args()
opts.update(url.query)
@@ -659,7 +672,7 @@ class DefaultDialect(Dialect):
if connection.in_transaction():
trans_objs = [
(name, obj)
for name, obj, value in characteristic_values
for name, obj, _ in characteristic_values
if obj.transactional
]
if trans_objs:
@@ -672,8 +685,10 @@ class DefaultDialect(Dialect):
)
dbapi_connection = connection.connection.dbapi_connection
for name, characteristic, value in characteristic_values:
characteristic.set_characteristic(self, dbapi_connection, value)
for _, characteristic, value in characteristic_values:
characteristic.set_connection_characteristic(
self, connection, dbapi_connection, value
)
connection.connection._connection_record.finalize_callback.append(
functools.partial(self._reset_characteristics, characteristics)
)
@@ -689,6 +704,10 @@ class DefaultDialect(Dialect):
pass
def do_rollback(self, dbapi_connection):
if self.skip_autocommit_rollback and self.detect_autocommit_setting(
dbapi_connection
):
return
dbapi_connection.rollback()
def do_commit(self, dbapi_connection):
@@ -728,8 +747,6 @@ class DefaultDialect(Dialect):
raise
def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
cursor = None
cursor = dbapi_connection.cursor()
try:
cursor.execute(self._dialect_specific_select_one)
@@ -756,11 +773,25 @@ class DefaultDialect(Dialect):
connection.execute(expression.ReleaseSavepointClause(name))
def _deliver_insertmanyvalues_batches(
self, cursor, statement, parameters, generic_setinputsizes, context
self,
connection,
cursor,
statement,
parameters,
generic_setinputsizes,
context,
):
context = cast(DefaultExecutionContext, context)
compiled = cast(SQLCompiler, context.compiled)
_composite_sentinel_proc: Sequence[
Optional[_ResultProcessorType[Any]]
] = ()
_scalar_sentinel_proc: Optional[_ResultProcessorType[Any]] = None
_sentinel_proc_initialized: bool = False
compiled_parameters = context.compiled_parameters
imv = compiled._insertmanyvalues
assert imv is not None
@@ -769,7 +800,12 @@ class DefaultDialect(Dialect):
"insertmanyvalues_page_size", self.insertmanyvalues_page_size
)
sentinel_value_resolvers = None
if compiled.schema_translate_map:
schema_translate_map = context.execution_options.get(
"schema_translate_map", {}
)
else:
schema_translate_map = None
if is_returning:
result: Optional[List[Any]] = []
@@ -777,10 +813,6 @@ class DefaultDialect(Dialect):
sort_by_parameter_order = imv.sort_by_parameter_order
if imv.num_sentinel_columns:
sentinel_value_resolvers = (
compiled._imv_sentinel_value_resolvers
)
else:
sort_by_parameter_order = False
result = None
@@ -788,14 +820,27 @@ class DefaultDialect(Dialect):
for imv_batch in compiled._deliver_insertmanyvalues_batches(
statement,
parameters,
compiled_parameters,
generic_setinputsizes,
batch_size,
sort_by_parameter_order,
schema_translate_map,
):
yield imv_batch
if is_returning:
rows = context.fetchall_for_returning(cursor)
try:
rows = context.fetchall_for_returning(cursor)
except BaseException as be:
connection._handle_dbapi_exception(
be,
sql_util._long_statement(imv_batch.replaced_statement),
imv_batch.replaced_parameters,
None,
context,
is_sub_exec=True,
)
# I would have thought "is_returning: Final[bool]"
# would have assured this but pylance thinks not
@@ -815,11 +860,46 @@ class DefaultDialect(Dialect):
# otherwise, create dictionaries to match up batches
# with parameters
assert imv.sentinel_param_keys
assert imv.sentinel_columns
_nsc = imv.num_sentinel_columns
if not _sentinel_proc_initialized:
if composite_sentinel:
_composite_sentinel_proc = [
col.type._cached_result_processor(
self, cursor_desc[1]
)
for col, cursor_desc in zip(
imv.sentinel_columns,
cursor.description[-_nsc:],
)
]
else:
_scalar_sentinel_proc = (
imv.sentinel_columns[0]
).type._cached_result_processor(
self, cursor.description[-1][1]
)
_sentinel_proc_initialized = True
rows_by_sentinel: Union[
Dict[Tuple[Any, ...], Any],
Dict[Any, Any],
]
if composite_sentinel:
_nsc = imv.num_sentinel_columns
rows_by_sentinel = {
tuple(row[-_nsc:]): row for row in rows
tuple(
(proc(val) if proc else val)
for val, proc in zip(
row[-_nsc:], _composite_sentinel_proc
)
): row
for row in rows
}
elif _scalar_sentinel_proc:
rows_by_sentinel = {
_scalar_sentinel_proc(row[-1]): row for row in rows
}
else:
rows_by_sentinel = {row[-1]: row for row in rows}
@@ -838,61 +918,10 @@ class DefaultDialect(Dialect):
)
try:
if composite_sentinel:
if sentinel_value_resolvers:
# composite sentinel (PK) with value resolvers
ordered_rows = [
rows_by_sentinel[
tuple(
_resolver(parameters[_spk]) # type: ignore # noqa: E501
if _resolver
else parameters[_spk] # type: ignore # noqa: E501
for _resolver, _spk in zip(
sentinel_value_resolvers,
imv.sentinel_param_keys,
)
)
]
for parameters in imv_batch.batch
]
else:
# composite sentinel (PK) with no value
# resolvers
ordered_rows = [
rows_by_sentinel[
tuple(
parameters[_spk] # type: ignore
for _spk in imv.sentinel_param_keys
)
]
for parameters in imv_batch.batch
]
else:
_sentinel_param_key = imv.sentinel_param_keys[0]
if (
sentinel_value_resolvers
and sentinel_value_resolvers[0]
):
# single-column sentinel with value resolver
_sentinel_value_resolver = (
sentinel_value_resolvers[0]
)
ordered_rows = [
rows_by_sentinel[
_sentinel_value_resolver(
parameters[_sentinel_param_key] # type: ignore # noqa: E501
)
]
for parameters in imv_batch.batch
]
else:
# single-column sentinel with no value resolver
ordered_rows = [
rows_by_sentinel[
parameters[_sentinel_param_key] # type: ignore # noqa: E501
]
for parameters in imv_batch.batch
]
ordered_rows = [
rows_by_sentinel[sentinel_keys]
for sentinel_keys in imv_batch.sentinel_values
]
except KeyError as ke:
# see test_insert_exec.py::
# IMVSentinelTest::test_sentinel_cant_match_keys
@@ -924,7 +953,14 @@ class DefaultDialect(Dialect):
def do_execute_no_params(self, cursor, statement, context=None):
cursor.execute(statement)
def is_disconnect(self, e, connection, cursor):
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Union[
pool.PoolProxiedConnection, interfaces.DBAPIConnection, None
],
cursor: Optional[interfaces.DBAPICursor],
) -> bool:
return False
@util.memoized_instancemethod
@@ -1024,7 +1060,7 @@ class DefaultDialect(Dialect):
name = name_upper
return name
def get_driver_connection(self, connection):
def get_driver_connection(self, connection: DBAPIConnection) -> Any:
return connection
def _overrides_default(self, method):
@@ -1196,7 +1232,7 @@ class DefaultExecutionContext(ExecutionContext):
_soft_closed = False
_has_rowcount = False
_rowcount: Optional[int] = None
# a hook for SQLite's translation of
# result column names
@@ -1453,9 +1489,11 @@ class DefaultExecutionContext(ExecutionContext):
assert positiontup is not None
for compiled_params in self.compiled_parameters:
l_param: List[Any] = [
flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
(
flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
)
for key in positiontup
]
core_positional_parameters.append(
@@ -1476,18 +1514,20 @@ class DefaultExecutionContext(ExecutionContext):
for compiled_params in self.compiled_parameters:
if escaped_names:
d_param = {
escaped_names.get(key, key): flattened_processors[key](
compiled_params[key]
escaped_names.get(key, key): (
flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
)
if key in flattened_processors
else compiled_params[key]
for key in compiled_params
}
else:
d_param = {
key: flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
key: (
flattened_processors[key](compiled_params[key])
if key in flattened_processors
else compiled_params[key]
)
for key in compiled_params
}
@@ -1577,7 +1617,13 @@ class DefaultExecutionContext(ExecutionContext):
elif ch is CACHE_MISS:
return "generated in %.5fs" % (now - gen_time,)
elif ch is CACHING_DISABLED:
return "caching disabled %.5fs" % (now - gen_time,)
if "_cache_disable_reason" in self.execution_options:
return "caching disabled (%s) %.5fs " % (
self.execution_options["_cache_disable_reason"],
now - gen_time,
)
else:
return "caching disabled %.5fs" % (now - gen_time,)
elif ch is NO_DIALECT_SUPPORT:
return "dialect %s+%s does not support caching %.5fs" % (
self.dialect.name,
@@ -1588,7 +1634,7 @@ class DefaultExecutionContext(ExecutionContext):
return "unknown"
@property
def executemany(self):
def executemany(self): # type: ignore[override]
return self.execute_style in (
ExecuteStyle.EXECUTEMANY,
ExecuteStyle.INSERTMANYVALUES,
@@ -1630,7 +1676,12 @@ class DefaultExecutionContext(ExecutionContext):
def no_parameters(self):
return self.execution_options.get("no_parameters", False)
def _execute_scalar(self, stmt, type_, parameters=None):
def _execute_scalar(
self,
stmt: str,
type_: Optional[TypeEngine[Any]],
parameters: Optional[_DBAPISingleExecuteParams] = None,
) -> Any:
"""Execute a string statement on the current cursor, returning a
scalar result.
@@ -1704,7 +1755,7 @@ class DefaultExecutionContext(ExecutionContext):
return use_server_side
def create_cursor(self):
def create_cursor(self) -> DBAPICursor:
if (
# inlining initial preference checks for SS cursors
self.dialect.supports_server_side_cursors
@@ -1725,10 +1776,10 @@ class DefaultExecutionContext(ExecutionContext):
def fetchall_for_returning(self, cursor):
return cursor.fetchall()
def create_default_cursor(self):
def create_default_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor()
def create_server_side_cursor(self):
def create_server_side_cursor(self) -> DBAPICursor:
raise NotImplementedError()
def pre_exec(self):
@@ -1776,7 +1827,14 @@ class DefaultExecutionContext(ExecutionContext):
@util.non_memoized_property
def rowcount(self) -> int:
return self.cursor.rowcount
if self._rowcount is not None:
return self._rowcount
else:
return self.cursor.rowcount
@property
def _has_rowcount(self):
return self._rowcount is not None
def supports_sane_rowcount(self):
return self.dialect.supports_sane_rowcount
@@ -1787,9 +1845,13 @@ class DefaultExecutionContext(ExecutionContext):
def _setup_result_proxy(self):
exec_opt = self.execution_options
if self._rowcount is None and exec_opt.get("preserve_rowcount", False):
self._rowcount = self.cursor.rowcount
yp: Optional[Union[int, bool]]
if self.is_crud or self.is_text:
result = self._setup_dml_or_text_result()
yp = sr = False
yp = False
else:
yp = exec_opt.get("yield_per", None)
sr = self._is_server_side or exec_opt.get("stream_results", False)
@@ -1943,8 +2005,7 @@ class DefaultExecutionContext(ExecutionContext):
if rows:
self.returned_default_rows = rows
result.rowcount = len(rows)
self._has_rowcount = True
self._rowcount = len(rows)
if self._is_supplemental_returning:
result._rewind(rows)
@@ -1958,12 +2019,12 @@ class DefaultExecutionContext(ExecutionContext):
elif not result._metadata.returns_rows:
# no results, get rowcount
# (which requires open cursor on some drivers)
result.rowcount
self._has_rowcount = True
if self._rowcount is None:
self._rowcount = self.cursor.rowcount
result._soft_close()
elif self.isupdate or self.isdelete:
result.rowcount
self._has_rowcount = True
if self._rowcount is None:
self._rowcount = self.cursor.rowcount
return result
@util.memoized_property
@@ -2012,10 +2073,11 @@ class DefaultExecutionContext(ExecutionContext):
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
This method only called by those dialects which set
the :attr:`.Dialect.bind_typing` attribute to
:attr:`.BindTyping.SETINPUTSIZES`. cx_Oracle is the only DBAPI
that requires setinputsizes(), pyodbc offers it as an option.
This method only called by those dialects which set the
:attr:`.Dialect.bind_typing` attribute to
:attr:`.BindTyping.SETINPUTSIZES`. Python-oracledb and cx_Oracle are
the only DBAPIs that requires setinputsizes(); pyodbc offers it as an
option.
Prior to SQLAlchemy 2.0, the setinputsizes() approach was also used
for pg8000 and asyncpg, which has been changed to inline rendering
@@ -2143,17 +2205,21 @@ class DefaultExecutionContext(ExecutionContext):
if compiled.positional:
parameters = self.dialect.execute_sequence_format(
[
processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key]
(
processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key]
)
for key in compiled.positiontup or ()
]
)
else:
parameters = {
key: processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key]
key: (
processors[key](compiled_params[key]) # type: ignore
if key in processors
else compiled_params[key]
)
for key in compiled_params
}
return self._execute_scalar(