API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,5 +1,5 @@
# orm/context.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -104,6 +104,7 @@ class QueryContext:
"top_level_context",
"compile_state",
"query",
"user_passed_query",
"params",
"load_options",
"bind_arguments",
@@ -147,7 +148,12 @@ class QueryContext:
def __init__(
self,
compile_state: CompileState,
statement: Union[Select[Any], FromStatement[Any]],
statement: Union[Select[Any], FromStatement[Any], UpdateBase],
user_passed_query: Union[
Select[Any],
FromStatement[Any],
UpdateBase,
],
params: _CoreSingleExecuteParams,
session: Session,
load_options: Union[
@@ -162,6 +168,13 @@ class QueryContext:
self.bind_arguments = bind_arguments or _EMPTY_DICT
self.compile_state = compile_state
self.query = statement
# the query that the end user passed to Session.execute() or similar.
# this is usually the same as .query, except in the bulk_persistence
# routines where a separate FromStatement is manufactured in the
# compile stage; this allows differentiation in that case.
self.user_passed_query = user_passed_query
self.session = session
self.loaders_require_buffering = False
self.loaders_require_uniquing = False
@@ -169,7 +182,7 @@ class QueryContext:
self.top_level_context = load_options._sa_top_level_orm_context
cached_options = compile_state.select_statement._with_options
uncached_options = statement._with_options
uncached_options = user_passed_query._with_options
# see issue #7447 , #8399 for some background
# propagated loader options will be present on loaded InstanceState
@@ -218,7 +231,7 @@ class AbstractORMCompileState(CompileState):
if compiler is None:
# this is the legacy / testing only ORM _compile_state() use case.
# there is no need to apply criteria options for this.
self.global_attributes = ga = {}
self.global_attributes = {}
assert toplevel
return
else:
@@ -252,10 +265,10 @@ class AbstractORMCompileState(CompileState):
@classmethod
def create_for_statement(
cls,
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
statement: Executable,
compiler: SQLCompiler,
**kw: Any,
) -> AbstractORMCompileState:
) -> CompileState:
"""Create a context for a statement given a :class:`.Compiler`.
This method is always invoked in the context of SQLCompiler.process().
@@ -401,8 +414,8 @@ class ORMCompileState(AbstractORMCompileState):
attributes: Dict[Any, Any]
global_attributes: Dict[Any, Any]
statement: Union[Select[Any], FromStatement[Any]]
select_statement: Union[Select[Any], FromStatement[Any]]
statement: Union[Select[Any], FromStatement[Any], UpdateBase]
select_statement: Union[Select[Any], FromStatement[Any], UpdateBase]
_entities: List[_QueryEntity]
_polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter]
compile_options: Union[
@@ -424,16 +437,30 @@ class ORMCompileState(AbstractORMCompileState):
def __init__(self, *arg, **kw):
raise NotImplementedError()
if TYPE_CHECKING:
@classmethod
def create_for_statement(
cls,
statement: Executable,
compiler: SQLCompiler,
**kw: Any,
) -> ORMCompileState:
return cls._create_orm_context(
cast("Union[Select, FromStatement]", statement),
toplevel=not compiler.stack,
compiler=compiler,
**kw,
)
@classmethod
def create_for_statement(
cls,
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMCompileState:
...
@classmethod
def _create_orm_context(
cls,
statement: Union[Select, FromStatement],
*,
toplevel: bool,
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMCompileState:
raise NotImplementedError()
def _append_dedupe_col_collection(self, obj, col_collection):
dedupe = self.dedupe_columns
@@ -517,15 +544,14 @@ class ORMCompileState(AbstractORMCompileState):
and len(statement._compile_options._current_path) > 10
and execution_options.get("compiled_cache", True) is not None
):
util.warn(
"Loader depth for query is excessively deep; caching will "
"be disabled for additional loaders. Consider using the "
"recursion_depth feature for deeply nested recursive eager "
"loaders. Use the compiled_cache=None execution option to "
"skip this warning."
)
execution_options = execution_options.union(
{"compiled_cache": None}
execution_options: util.immutabledict[str, Any] = (
execution_options.union(
{
"compiled_cache": None,
"_cache_disable_reason": "excess depth for "
"ORM loader options",
}
)
)
bind_arguments["clause"] = statement
@@ -580,6 +606,7 @@ class ORMCompileState(AbstractORMCompileState):
querycontext = QueryContext(
compile_state,
statement,
statement,
params,
session,
load_options,
@@ -643,8 +670,8 @@ class ORMCompileState(AbstractORMCompileState):
)
class DMLReturningColFilter:
"""an adapter used for the DML RETURNING case.
class _DMLReturningColFilter:
"""a base for an adapter used for the DML RETURNING cases
Has a subset of the interface used by
:class:`.ORMAdapter` and is used for :class:`._QueryEntity`
@@ -678,6 +705,21 @@ class DMLReturningColFilter:
else:
return None
def adapt_check_present(self, col):
raise NotImplementedError()
class _DMLBulkInsertReturningColFilter(_DMLReturningColFilter):
"""an adapter used for the DML RETURNING case specifically
for ORM bulk insert (or any hypothetical DML that is splitting out a class
hierarchy among multiple DML statements....ORM bulk insert is the only
example right now)
its main job is to limit the columns in a RETURNING to only a specific
mapped table in a hierarchy.
"""
def adapt_check_present(self, col):
mapper = self.mapper
prop = mapper._columntoproperty.get(col, None)
@@ -686,6 +728,30 @@ class DMLReturningColFilter:
return mapper.local_table.c.corresponding_column(col)
class _DMLUpdateDeleteReturningColFilter(_DMLReturningColFilter):
"""an adapter used for the DML RETURNING case specifically
for ORM enabled UPDATE/DELETE
its main job is to limit the columns in a RETURNING to include
only direct persisted columns from the immediate selectable, not
expressions like column_property(), or to also allow columns from other
mappers for the UPDATE..FROM use case.
"""
def adapt_check_present(self, col):
mapper = self.mapper
prop = mapper._columntoproperty.get(col, None)
if prop is not None:
# if the col is from the immediate mapper, only return a persisted
# column, not any kind of column_property expression
return mapper.persist_selectable.c.corresponding_column(col)
# if the col is from some other mapper, just return it, assume the
# user knows what they are doing
return col
@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
class ORMFromStatementCompileState(ORMCompileState):
_from_obj_alias = None
@@ -704,12 +770,16 @@ class ORMFromStatementCompileState(ORMCompileState):
eager_joins = _EMPTY_DICT
@classmethod
def create_for_statement(
def _create_orm_context(
cls,
statement_container: Union[Select, FromStatement],
statement: Union[Select, FromStatement],
*,
toplevel: bool,
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMFromStatementCompileState:
statement_container = statement
assert isinstance(statement_container, FromStatement)
if compiler is not None and compiler.stack:
@@ -751,9 +821,11 @@ class ORMFromStatementCompileState(ORMCompileState):
self.statement = statement
self._label_convention = self._column_naming_convention(
statement._label_style
if not statement._is_textual and not statement.is_dml
else LABEL_STYLE_NONE,
(
statement._label_style
if not statement._is_textual and not statement.is_dml
else LABEL_STYLE_NONE
),
self.use_legacy_query_style,
)
@@ -799,9 +871,9 @@ class ORMFromStatementCompileState(ORMCompileState):
for entity in self._entities:
entity.setup_compile_state(self)
compiler._ordered_columns = (
compiler._textual_ordered_columns
) = False
compiler._ordered_columns = compiler._textual_ordered_columns = (
False
)
# enable looser result column matching. this is shown to be
# needed by test_query.py::TextTest
@@ -838,14 +910,24 @@ class ORMFromStatementCompileState(ORMCompileState):
return None
def setup_dml_returning_compile_state(self, dml_mapper):
"""used by BulkORMInsert (and Update / Delete?) to set up a handler
"""used by BulkORMInsert, Update, Delete to set up a handler
for RETURNING to return ORM objects and expressions
"""
target_mapper = self.statement._propagate_attrs.get(
"plugin_subject", None
)
adapter = DMLReturningColFilter(target_mapper, dml_mapper)
if self.statement.is_insert:
adapter = _DMLBulkInsertReturningColFilter(
target_mapper, dml_mapper
)
elif self.statement.is_update or self.statement.is_delete:
adapter = _DMLUpdateDeleteReturningColFilter(
target_mapper, dml_mapper
)
else:
adapter = None
if self.compile_options._is_star and (len(self._entities) != 1):
raise sa_exc.CompileError(
@@ -888,6 +970,8 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
("_compile_options", InternalTraversal.dp_has_cache_key)
]
is_from_statement = True
def __init__(
self,
entities: Iterable[_ColumnsClauseArgument[Any]],
@@ -905,6 +989,10 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
]
self.element = element
self.is_dml = element.is_dml
self.is_select = element.is_select
self.is_delete = element.is_delete
self.is_insert = element.is_insert
self.is_update = element.is_update
self._label_style = (
element._label_style if is_select_base(element) else None
)
@@ -998,21 +1086,17 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
_having_criteria = ()
@classmethod
def create_for_statement(
def _create_orm_context(
cls,
statement: Union[Select, FromStatement],
*,
toplevel: bool,
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMSelectCompileState:
"""compiler hook, we arrive here from compiler.visit_select() only."""
self = cls.__new__(cls)
if compiler is not None:
toplevel = not compiler.stack
else:
toplevel = True
select_statement = statement
# if we are a select() that was never a legacy Query, we won't
@@ -1368,11 +1452,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def get_columns_clause_froms(cls, statement):
return cls._normalize_froms(
itertools.chain.from_iterable(
element._from_objects
if "parententity" not in element._annotations
else [
element._annotations["parententity"].__clause_element__()
]
(
element._from_objects
if "parententity" not in element._annotations
else [
element._annotations[
"parententity"
].__clause_element__()
]
)
for element in statement._raw_columns
)
)
@@ -1501,9 +1589,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# the original expressions outside of the label references
# in order to have them render.
unwrapped_order_by = [
elem.element
if isinstance(elem, sql.elements._label_reference)
else elem
(
elem.element
if isinstance(elem, sql.elements._label_reference)
else elem
)
for elem in self.order_by
]
@@ -1545,10 +1635,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
)
statement._label_style = self.label_style
# Oracle however does not allow FOR UPDATE on the subquery,
# and the Oracle dialect ignores it, plus for PostgreSQL, MySQL
# we expect that all elements of the row are locked, so also put it
# on the outside (except in the case of PG when OF is used)
# Oracle Database however does not allow FOR UPDATE on the subquery,
# and the Oracle Database dialects ignore it, plus for PostgreSQL,
# MySQL we expect that all elements of the row are locked, so also put
# it on the outside (except in the case of PG when OF is used)
if (
self._for_update_arg is not None
and self._for_update_arg.of is None
@@ -1774,8 +1864,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
"selectable/table as join target"
)
of_type = None
if isinstance(onclause, interfaces.PropComparator):
# descriptor/property given (or determined); this tells us
# explicitly what the expected "left" side of the join is.
@@ -2422,9 +2510,12 @@ def _column_descriptions(
"type": ent.type,
"aliased": getattr(insp_ent, "is_aliased_class", False),
"expr": ent.expr,
"entity": getattr(insp_ent, "entity", None)
if ent.entity_zero is not None and not insp_ent.is_clause_element
else None,
"entity": (
getattr(insp_ent, "entity", None)
if ent.entity_zero is not None
and not insp_ent.is_clause_element
else None
),
}
for ent, insp_ent in [
(_ent, _ent.entity_zero) for _ent in ctx._entities
@@ -2434,7 +2525,7 @@ def _column_descriptions(
def _legacy_filter_by_entity_zero(
query_or_augmented_select: Union[Query[Any], Select[Any]]
query_or_augmented_select: Union[Query[Any], Select[Any]],
) -> Optional[_InternalEntityType[Any]]:
self = query_or_augmented_select
if self._setup_joins:
@@ -2449,7 +2540,7 @@ def _legacy_filter_by_entity_zero(
def _entity_from_pre_ent_zero(
query_or_augmented_select: Union[Query[Any], Select[Any]]
query_or_augmented_select: Union[Query[Any], Select[Any]],
) -> Optional[_InternalEntityType[Any]]:
self = query_or_augmented_select
if not self._raw_columns:
@@ -2507,7 +2598,7 @@ class _QueryEntity:
def setup_dml_returning_compile_state(
self,
compile_state: ORMCompileState,
adapter: DMLReturningColFilter,
adapter: Optional[_DMLReturningColFilter],
) -> None:
raise NotImplementedError()
@@ -2709,7 +2800,7 @@ class _MapperEntity(_QueryEntity):
def setup_dml_returning_compile_state(
self,
compile_state: ORMCompileState,
adapter: DMLReturningColFilter,
adapter: Optional[_DMLReturningColFilter],
) -> None:
loading._setup_entity_query(
compile_state,
@@ -2865,6 +2956,13 @@ class _BundleEntity(_QueryEntity):
for ent in self._entities:
ent.setup_compile_state(compile_state)
def setup_dml_returning_compile_state(
self,
compile_state: ORMCompileState,
adapter: Optional[_DMLReturningColFilter],
) -> None:
return self.setup_compile_state(compile_state)
def row_processor(self, context, result):
procs, labels, extra = zip(
*[ent.row_processor(context, result) for ent in self._entities]
@@ -3028,7 +3126,10 @@ class _RawColumnEntity(_ColumnEntity):
if not is_current_entities or column._is_text_clause:
self._label_name = None
else:
self._label_name = compile_state._label_convention(column)
if parent_bundle:
self._label_name = column._proxy_key
else:
self._label_name = compile_state._label_convention(column)
if parent_bundle:
parent_bundle._entities.append(self)
@@ -3048,7 +3149,7 @@ class _RawColumnEntity(_ColumnEntity):
def setup_dml_returning_compile_state(
self,
compile_state: ORMCompileState,
adapter: DMLReturningColFilter,
adapter: Optional[_DMLReturningColFilter],
) -> None:
return self.setup_compile_state(compile_state)
@@ -3122,9 +3223,12 @@ class _ORMColumnEntity(_ColumnEntity):
self.raw_column_index = raw_column_index
if is_current_entities:
self._label_name = compile_state._label_convention(
column, col_name=orm_key
)
if parent_bundle:
self._label_name = orm_key if orm_key else column._proxy_key
else:
self._label_name = compile_state._label_convention(
column, col_name=orm_key
)
else:
self._label_name = None
@@ -3162,10 +3266,13 @@ class _ORMColumnEntity(_ColumnEntity):
def setup_dml_returning_compile_state(
self,
compile_state: ORMCompileState,
adapter: DMLReturningColFilter,
adapter: Optional[_DMLReturningColFilter],
) -> None:
self._fetch_column = self.column
column = adapter(self.column, False)
self._fetch_column = column = self.column
if adapter:
column = adapter(column, False)
if column is not None:
compile_state.dedupe_columns.add(column)
compile_state.primary_columns.append(column)