This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# orm/context.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -104,7 +104,6 @@ class QueryContext:
|
||||
"top_level_context",
|
||||
"compile_state",
|
||||
"query",
|
||||
"user_passed_query",
|
||||
"params",
|
||||
"load_options",
|
||||
"bind_arguments",
|
||||
@@ -148,12 +147,7 @@ class QueryContext:
|
||||
def __init__(
|
||||
self,
|
||||
compile_state: CompileState,
|
||||
statement: Union[Select[Any], FromStatement[Any], UpdateBase],
|
||||
user_passed_query: Union[
|
||||
Select[Any],
|
||||
FromStatement[Any],
|
||||
UpdateBase,
|
||||
],
|
||||
statement: Union[Select[Any], FromStatement[Any]],
|
||||
params: _CoreSingleExecuteParams,
|
||||
session: Session,
|
||||
load_options: Union[
|
||||
@@ -168,13 +162,6 @@ class QueryContext:
|
||||
self.bind_arguments = bind_arguments or _EMPTY_DICT
|
||||
self.compile_state = compile_state
|
||||
self.query = statement
|
||||
|
||||
# the query that the end user passed to Session.execute() or similar.
|
||||
# this is usually the same as .query, except in the bulk_persistence
|
||||
# routines where a separate FromStatement is manufactured in the
|
||||
# compile stage; this allows differentiation in that case.
|
||||
self.user_passed_query = user_passed_query
|
||||
|
||||
self.session = session
|
||||
self.loaders_require_buffering = False
|
||||
self.loaders_require_uniquing = False
|
||||
@@ -182,7 +169,7 @@ class QueryContext:
|
||||
self.top_level_context = load_options._sa_top_level_orm_context
|
||||
|
||||
cached_options = compile_state.select_statement._with_options
|
||||
uncached_options = user_passed_query._with_options
|
||||
uncached_options = statement._with_options
|
||||
|
||||
# see issue #7447 , #8399 for some background
|
||||
# propagated loader options will be present on loaded InstanceState
|
||||
@@ -231,7 +218,7 @@ class AbstractORMCompileState(CompileState):
|
||||
if compiler is None:
|
||||
# this is the legacy / testing only ORM _compile_state() use case.
|
||||
# there is no need to apply criteria options for this.
|
||||
self.global_attributes = {}
|
||||
self.global_attributes = ga = {}
|
||||
assert toplevel
|
||||
return
|
||||
else:
|
||||
@@ -265,10 +252,10 @@ class AbstractORMCompileState(CompileState):
|
||||
@classmethod
|
||||
def create_for_statement(
|
||||
cls,
|
||||
statement: Executable,
|
||||
compiler: SQLCompiler,
|
||||
statement: Union[Select, FromStatement],
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> CompileState:
|
||||
) -> AbstractORMCompileState:
|
||||
"""Create a context for a statement given a :class:`.Compiler`.
|
||||
|
||||
This method is always invoked in the context of SQLCompiler.process().
|
||||
@@ -414,8 +401,8 @@ class ORMCompileState(AbstractORMCompileState):
|
||||
attributes: Dict[Any, Any]
|
||||
global_attributes: Dict[Any, Any]
|
||||
|
||||
statement: Union[Select[Any], FromStatement[Any], UpdateBase]
|
||||
select_statement: Union[Select[Any], FromStatement[Any], UpdateBase]
|
||||
statement: Union[Select[Any], FromStatement[Any]]
|
||||
select_statement: Union[Select[Any], FromStatement[Any]]
|
||||
_entities: List[_QueryEntity]
|
||||
_polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter]
|
||||
compile_options: Union[
|
||||
@@ -437,30 +424,16 @@ class ORMCompileState(AbstractORMCompileState):
|
||||
def __init__(self, *arg, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def create_for_statement(
|
||||
cls,
|
||||
statement: Executable,
|
||||
compiler: SQLCompiler,
|
||||
**kw: Any,
|
||||
) -> ORMCompileState:
|
||||
return cls._create_orm_context(
|
||||
cast("Union[Select, FromStatement]", statement),
|
||||
toplevel=not compiler.stack,
|
||||
compiler=compiler,
|
||||
**kw,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@classmethod
|
||||
def _create_orm_context(
|
||||
cls,
|
||||
statement: Union[Select, FromStatement],
|
||||
*,
|
||||
toplevel: bool,
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> ORMCompileState:
|
||||
raise NotImplementedError()
|
||||
@classmethod
|
||||
def create_for_statement(
|
||||
cls,
|
||||
statement: Union[Select, FromStatement],
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> ORMCompileState:
|
||||
...
|
||||
|
||||
def _append_dedupe_col_collection(self, obj, col_collection):
|
||||
dedupe = self.dedupe_columns
|
||||
@@ -544,14 +517,15 @@ class ORMCompileState(AbstractORMCompileState):
|
||||
and len(statement._compile_options._current_path) > 10
|
||||
and execution_options.get("compiled_cache", True) is not None
|
||||
):
|
||||
execution_options: util.immutabledict[str, Any] = (
|
||||
execution_options.union(
|
||||
{
|
||||
"compiled_cache": None,
|
||||
"_cache_disable_reason": "excess depth for "
|
||||
"ORM loader options",
|
||||
}
|
||||
)
|
||||
util.warn(
|
||||
"Loader depth for query is excessively deep; caching will "
|
||||
"be disabled for additional loaders. Consider using the "
|
||||
"recursion_depth feature for deeply nested recursive eager "
|
||||
"loaders. Use the compiled_cache=None execution option to "
|
||||
"skip this warning."
|
||||
)
|
||||
execution_options = execution_options.union(
|
||||
{"compiled_cache": None}
|
||||
)
|
||||
|
||||
bind_arguments["clause"] = statement
|
||||
@@ -606,7 +580,6 @@ class ORMCompileState(AbstractORMCompileState):
|
||||
querycontext = QueryContext(
|
||||
compile_state,
|
||||
statement,
|
||||
statement,
|
||||
params,
|
||||
session,
|
||||
load_options,
|
||||
@@ -670,8 +643,8 @@ class ORMCompileState(AbstractORMCompileState):
|
||||
)
|
||||
|
||||
|
||||
class _DMLReturningColFilter:
|
||||
"""a base for an adapter used for the DML RETURNING cases
|
||||
class DMLReturningColFilter:
|
||||
"""an adapter used for the DML RETURNING case.
|
||||
|
||||
Has a subset of the interface used by
|
||||
:class:`.ORMAdapter` and is used for :class:`._QueryEntity`
|
||||
@@ -705,21 +678,6 @@ class _DMLReturningColFilter:
|
||||
else:
|
||||
return None
|
||||
|
||||
def adapt_check_present(self, col):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class _DMLBulkInsertReturningColFilter(_DMLReturningColFilter):
|
||||
"""an adapter used for the DML RETURNING case specifically
|
||||
for ORM bulk insert (or any hypothetical DML that is splitting out a class
|
||||
hierarchy among multiple DML statements....ORM bulk insert is the only
|
||||
example right now)
|
||||
|
||||
its main job is to limit the columns in a RETURNING to only a specific
|
||||
mapped table in a hierarchy.
|
||||
|
||||
"""
|
||||
|
||||
def adapt_check_present(self, col):
|
||||
mapper = self.mapper
|
||||
prop = mapper._columntoproperty.get(col, None)
|
||||
@@ -728,30 +686,6 @@ class _DMLBulkInsertReturningColFilter(_DMLReturningColFilter):
|
||||
return mapper.local_table.c.corresponding_column(col)
|
||||
|
||||
|
||||
class _DMLUpdateDeleteReturningColFilter(_DMLReturningColFilter):
|
||||
"""an adapter used for the DML RETURNING case specifically
|
||||
for ORM enabled UPDATE/DELETE
|
||||
|
||||
its main job is to limit the columns in a RETURNING to include
|
||||
only direct persisted columns from the immediate selectable, not
|
||||
expressions like column_property(), or to also allow columns from other
|
||||
mappers for the UPDATE..FROM use case.
|
||||
|
||||
"""
|
||||
|
||||
def adapt_check_present(self, col):
|
||||
mapper = self.mapper
|
||||
prop = mapper._columntoproperty.get(col, None)
|
||||
if prop is not None:
|
||||
# if the col is from the immediate mapper, only return a persisted
|
||||
# column, not any kind of column_property expression
|
||||
return mapper.persist_selectable.c.corresponding_column(col)
|
||||
|
||||
# if the col is from some other mapper, just return it, assume the
|
||||
# user knows what they are doing
|
||||
return col
|
||||
|
||||
|
||||
@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
|
||||
class ORMFromStatementCompileState(ORMCompileState):
|
||||
_from_obj_alias = None
|
||||
@@ -770,16 +704,12 @@ class ORMFromStatementCompileState(ORMCompileState):
|
||||
eager_joins = _EMPTY_DICT
|
||||
|
||||
@classmethod
|
||||
def _create_orm_context(
|
||||
def create_for_statement(
|
||||
cls,
|
||||
statement: Union[Select, FromStatement],
|
||||
*,
|
||||
toplevel: bool,
|
||||
statement_container: Union[Select, FromStatement],
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> ORMFromStatementCompileState:
|
||||
statement_container = statement
|
||||
|
||||
assert isinstance(statement_container, FromStatement)
|
||||
|
||||
if compiler is not None and compiler.stack:
|
||||
@@ -821,11 +751,9 @@ class ORMFromStatementCompileState(ORMCompileState):
|
||||
self.statement = statement
|
||||
|
||||
self._label_convention = self._column_naming_convention(
|
||||
(
|
||||
statement._label_style
|
||||
if not statement._is_textual and not statement.is_dml
|
||||
else LABEL_STYLE_NONE
|
||||
),
|
||||
statement._label_style
|
||||
if not statement._is_textual and not statement.is_dml
|
||||
else LABEL_STYLE_NONE,
|
||||
self.use_legacy_query_style,
|
||||
)
|
||||
|
||||
@@ -871,9 +799,9 @@ class ORMFromStatementCompileState(ORMCompileState):
|
||||
for entity in self._entities:
|
||||
entity.setup_compile_state(self)
|
||||
|
||||
compiler._ordered_columns = compiler._textual_ordered_columns = (
|
||||
False
|
||||
)
|
||||
compiler._ordered_columns = (
|
||||
compiler._textual_ordered_columns
|
||||
) = False
|
||||
|
||||
# enable looser result column matching. this is shown to be
|
||||
# needed by test_query.py::TextTest
|
||||
@@ -910,24 +838,14 @@ class ORMFromStatementCompileState(ORMCompileState):
|
||||
return None
|
||||
|
||||
def setup_dml_returning_compile_state(self, dml_mapper):
|
||||
"""used by BulkORMInsert, Update, Delete to set up a handler
|
||||
"""used by BulkORMInsert (and Update / Delete?) to set up a handler
|
||||
for RETURNING to return ORM objects and expressions
|
||||
|
||||
"""
|
||||
target_mapper = self.statement._propagate_attrs.get(
|
||||
"plugin_subject", None
|
||||
)
|
||||
|
||||
if self.statement.is_insert:
|
||||
adapter = _DMLBulkInsertReturningColFilter(
|
||||
target_mapper, dml_mapper
|
||||
)
|
||||
elif self.statement.is_update or self.statement.is_delete:
|
||||
adapter = _DMLUpdateDeleteReturningColFilter(
|
||||
target_mapper, dml_mapper
|
||||
)
|
||||
else:
|
||||
adapter = None
|
||||
adapter = DMLReturningColFilter(target_mapper, dml_mapper)
|
||||
|
||||
if self.compile_options._is_star and (len(self._entities) != 1):
|
||||
raise sa_exc.CompileError(
|
||||
@@ -970,8 +888,6 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
|
||||
("_compile_options", InternalTraversal.dp_has_cache_key)
|
||||
]
|
||||
|
||||
is_from_statement = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entities: Iterable[_ColumnsClauseArgument[Any]],
|
||||
@@ -989,10 +905,6 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
|
||||
]
|
||||
self.element = element
|
||||
self.is_dml = element.is_dml
|
||||
self.is_select = element.is_select
|
||||
self.is_delete = element.is_delete
|
||||
self.is_insert = element.is_insert
|
||||
self.is_update = element.is_update
|
||||
self._label_style = (
|
||||
element._label_style if is_select_base(element) else None
|
||||
)
|
||||
@@ -1086,17 +998,21 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
|
||||
_having_criteria = ()
|
||||
|
||||
@classmethod
|
||||
def _create_orm_context(
|
||||
def create_for_statement(
|
||||
cls,
|
||||
statement: Union[Select, FromStatement],
|
||||
*,
|
||||
toplevel: bool,
|
||||
compiler: Optional[SQLCompiler],
|
||||
**kw: Any,
|
||||
) -> ORMSelectCompileState:
|
||||
"""compiler hook, we arrive here from compiler.visit_select() only."""
|
||||
|
||||
self = cls.__new__(cls)
|
||||
|
||||
if compiler is not None:
|
||||
toplevel = not compiler.stack
|
||||
else:
|
||||
toplevel = True
|
||||
|
||||
select_statement = statement
|
||||
|
||||
# if we are a select() that was never a legacy Query, we won't
|
||||
@@ -1452,15 +1368,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
|
||||
def get_columns_clause_froms(cls, statement):
|
||||
return cls._normalize_froms(
|
||||
itertools.chain.from_iterable(
|
||||
(
|
||||
element._from_objects
|
||||
if "parententity" not in element._annotations
|
||||
else [
|
||||
element._annotations[
|
||||
"parententity"
|
||||
].__clause_element__()
|
||||
]
|
||||
)
|
||||
element._from_objects
|
||||
if "parententity" not in element._annotations
|
||||
else [
|
||||
element._annotations["parententity"].__clause_element__()
|
||||
]
|
||||
for element in statement._raw_columns
|
||||
)
|
||||
)
|
||||
@@ -1589,11 +1501,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
|
||||
# the original expressions outside of the label references
|
||||
# in order to have them render.
|
||||
unwrapped_order_by = [
|
||||
(
|
||||
elem.element
|
||||
if isinstance(elem, sql.elements._label_reference)
|
||||
else elem
|
||||
)
|
||||
elem.element
|
||||
if isinstance(elem, sql.elements._label_reference)
|
||||
else elem
|
||||
for elem in self.order_by
|
||||
]
|
||||
|
||||
@@ -1635,10 +1545,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
|
||||
)
|
||||
statement._label_style = self.label_style
|
||||
|
||||
# Oracle Database however does not allow FOR UPDATE on the subquery,
|
||||
# and the Oracle Database dialects ignore it, plus for PostgreSQL,
|
||||
# MySQL we expect that all elements of the row are locked, so also put
|
||||
# it on the outside (except in the case of PG when OF is used)
|
||||
# Oracle however does not allow FOR UPDATE on the subquery,
|
||||
# and the Oracle dialect ignores it, plus for PostgreSQL, MySQL
|
||||
# we expect that all elements of the row are locked, so also put it
|
||||
# on the outside (except in the case of PG when OF is used)
|
||||
if (
|
||||
self._for_update_arg is not None
|
||||
and self._for_update_arg.of is None
|
||||
@@ -1864,6 +1774,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
|
||||
"selectable/table as join target"
|
||||
)
|
||||
|
||||
of_type = None
|
||||
|
||||
if isinstance(onclause, interfaces.PropComparator):
|
||||
# descriptor/property given (or determined); this tells us
|
||||
# explicitly what the expected "left" side of the join is.
|
||||
@@ -2510,12 +2422,9 @@ def _column_descriptions(
|
||||
"type": ent.type,
|
||||
"aliased": getattr(insp_ent, "is_aliased_class", False),
|
||||
"expr": ent.expr,
|
||||
"entity": (
|
||||
getattr(insp_ent, "entity", None)
|
||||
if ent.entity_zero is not None
|
||||
and not insp_ent.is_clause_element
|
||||
else None
|
||||
),
|
||||
"entity": getattr(insp_ent, "entity", None)
|
||||
if ent.entity_zero is not None and not insp_ent.is_clause_element
|
||||
else None,
|
||||
}
|
||||
for ent, insp_ent in [
|
||||
(_ent, _ent.entity_zero) for _ent in ctx._entities
|
||||
@@ -2525,7 +2434,7 @@ def _column_descriptions(
|
||||
|
||||
|
||||
def _legacy_filter_by_entity_zero(
|
||||
query_or_augmented_select: Union[Query[Any], Select[Any]],
|
||||
query_or_augmented_select: Union[Query[Any], Select[Any]]
|
||||
) -> Optional[_InternalEntityType[Any]]:
|
||||
self = query_or_augmented_select
|
||||
if self._setup_joins:
|
||||
@@ -2540,7 +2449,7 @@ def _legacy_filter_by_entity_zero(
|
||||
|
||||
|
||||
def _entity_from_pre_ent_zero(
|
||||
query_or_augmented_select: Union[Query[Any], Select[Any]],
|
||||
query_or_augmented_select: Union[Query[Any], Select[Any]]
|
||||
) -> Optional[_InternalEntityType[Any]]:
|
||||
self = query_or_augmented_select
|
||||
if not self._raw_columns:
|
||||
@@ -2598,7 +2507,7 @@ class _QueryEntity:
|
||||
def setup_dml_returning_compile_state(
|
||||
self,
|
||||
compile_state: ORMCompileState,
|
||||
adapter: Optional[_DMLReturningColFilter],
|
||||
adapter: DMLReturningColFilter,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -2800,7 +2709,7 @@ class _MapperEntity(_QueryEntity):
|
||||
def setup_dml_returning_compile_state(
|
||||
self,
|
||||
compile_state: ORMCompileState,
|
||||
adapter: Optional[_DMLReturningColFilter],
|
||||
adapter: DMLReturningColFilter,
|
||||
) -> None:
|
||||
loading._setup_entity_query(
|
||||
compile_state,
|
||||
@@ -2956,13 +2865,6 @@ class _BundleEntity(_QueryEntity):
|
||||
for ent in self._entities:
|
||||
ent.setup_compile_state(compile_state)
|
||||
|
||||
def setup_dml_returning_compile_state(
|
||||
self,
|
||||
compile_state: ORMCompileState,
|
||||
adapter: Optional[_DMLReturningColFilter],
|
||||
) -> None:
|
||||
return self.setup_compile_state(compile_state)
|
||||
|
||||
def row_processor(self, context, result):
|
||||
procs, labels, extra = zip(
|
||||
*[ent.row_processor(context, result) for ent in self._entities]
|
||||
@@ -3126,10 +3028,7 @@ class _RawColumnEntity(_ColumnEntity):
|
||||
if not is_current_entities or column._is_text_clause:
|
||||
self._label_name = None
|
||||
else:
|
||||
if parent_bundle:
|
||||
self._label_name = column._proxy_key
|
||||
else:
|
||||
self._label_name = compile_state._label_convention(column)
|
||||
self._label_name = compile_state._label_convention(column)
|
||||
|
||||
if parent_bundle:
|
||||
parent_bundle._entities.append(self)
|
||||
@@ -3149,7 +3048,7 @@ class _RawColumnEntity(_ColumnEntity):
|
||||
def setup_dml_returning_compile_state(
|
||||
self,
|
||||
compile_state: ORMCompileState,
|
||||
adapter: Optional[_DMLReturningColFilter],
|
||||
adapter: DMLReturningColFilter,
|
||||
) -> None:
|
||||
return self.setup_compile_state(compile_state)
|
||||
|
||||
@@ -3223,12 +3122,9 @@ class _ORMColumnEntity(_ColumnEntity):
|
||||
self.raw_column_index = raw_column_index
|
||||
|
||||
if is_current_entities:
|
||||
if parent_bundle:
|
||||
self._label_name = orm_key if orm_key else column._proxy_key
|
||||
else:
|
||||
self._label_name = compile_state._label_convention(
|
||||
column, col_name=orm_key
|
||||
)
|
||||
self._label_name = compile_state._label_convention(
|
||||
column, col_name=orm_key
|
||||
)
|
||||
else:
|
||||
self._label_name = None
|
||||
|
||||
@@ -3266,13 +3162,10 @@ class _ORMColumnEntity(_ColumnEntity):
|
||||
def setup_dml_returning_compile_state(
|
||||
self,
|
||||
compile_state: ORMCompileState,
|
||||
adapter: Optional[_DMLReturningColFilter],
|
||||
adapter: DMLReturningColFilter,
|
||||
) -> None:
|
||||
|
||||
self._fetch_column = column = self.column
|
||||
if adapter:
|
||||
column = adapter(column, False)
|
||||
|
||||
self._fetch_column = self.column
|
||||
column = adapter(self.column, False)
|
||||
if column is not None:
|
||||
compile_state.dedupe_columns.add(column)
|
||||
compile_state.primary_columns.append(column)
|
||||
|
||||
Reference in New Issue
Block a user