API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,5 +1,5 @@
# orm/util.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -35,6 +35,7 @@ import weakref
from . import attributes # noqa
from . import exc
from . import exc as orm_exc
from ._typing import _O
from ._typing import insp_is_aliased_class
from ._typing import insp_is_mapper
@@ -42,6 +43,7 @@ from ._typing import prop_is_relationship
from .base import _class_to_mapper as _class_to_mapper
from .base import _MappedAnnotationBase
from .base import _never_set as _never_set # noqa: F401
from .base import _none_only_set as _none_only_set # noqa: F401
from .base import _none_set as _none_set # noqa: F401
from .base import attribute_str as attribute_str # noqa: F401
from .base import class_mapper as class_mapper
@@ -85,14 +87,12 @@ from ..sql.elements import KeyedColumnElement
from ..sql.selectable import FromClause
from ..util.langhelpers import MemoizedSlots
from ..util.typing import de_stringify_annotation as _de_stringify_annotation
from ..util.typing import (
de_stringify_union_elements as _de_stringify_union_elements,
)
from ..util.typing import eval_name_only as _eval_name_only
from ..util.typing import fixup_container_fwd_refs
from ..util.typing import get_origin
from ..util.typing import is_origin_of_cls
from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import typing_get_origin
if typing.TYPE_CHECKING:
from ._typing import _EntityType
@@ -121,7 +121,6 @@ if typing.TYPE_CHECKING:
from ..sql.selectable import Selectable
from ..sql.visitors import anon_map
from ..util.typing import _AnnotationScanType
from ..util.typing import ArgsTypeProcotol
_T = TypeVar("_T", bound=Any)
@@ -138,7 +137,6 @@ all_cascades = frozenset(
)
)
_de_stringify_partial = functools.partial(
functools.partial,
locals_=util.immutabledict(
@@ -163,8 +161,7 @@ class _DeStringifyAnnotation(Protocol):
*,
str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
include_generic: bool = False,
) -> Type[Any]:
...
) -> Type[Any]: ...
de_stringify_annotation = cast(
@@ -172,27 +169,8 @@ de_stringify_annotation = cast(
)
class _DeStringifyUnionElements(Protocol):
def __call__(
self,
cls: Type[Any],
annotation: ArgsTypeProcotol,
originating_module: str,
*,
str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
) -> Type[Any]:
...
de_stringify_union_elements = cast(
_DeStringifyUnionElements,
_de_stringify_partial(_de_stringify_union_elements),
)
class _EvalNameOnly(Protocol):
def __call__(self, name: str, module_name: str) -> Any:
...
def __call__(self, name: str, module_name: str) -> Any: ...
eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only))
@@ -250,7 +228,7 @@ class CascadeOptions(FrozenSet[str]):
values.clear()
values.discard("all")
self = super().__new__(cls, values) # type: ignore
self = super().__new__(cls, values)
self.save_update = "save-update" in values
self.delete = "delete" in values
self.refresh_expire = "refresh-expire" in values
@@ -259,9 +237,7 @@ class CascadeOptions(FrozenSet[str]):
self.delete_orphan = "delete-orphan" in values
if self.delete_orphan and not self.delete:
util.warn(
"The 'delete-orphan' cascade " "option requires 'delete'."
)
util.warn("The 'delete-orphan' cascade option requires 'delete'.")
return self
def __repr__(self):
@@ -478,9 +454,7 @@ def identity_key(
E.g.::
>>> row = engine.execute(\
text("select * from table where a=1 and b=2")\
).first()
>>> row = engine.execute(text("select * from table where a=1 and b=2")).first()
>>> identity_key(MyClass, row=row)
(<class '__main__.MyClass'>, (1, 2), None)
@@ -491,7 +465,7 @@ def identity_key(
.. versionadded:: 1.2 added identity_token
"""
""" # noqa: E501
if class_ is not None:
mapper = class_mapper(class_)
if row is None:
@@ -669,9 +643,9 @@ class AliasedClass(
# find all pairs of users with the same name
user_alias = aliased(User)
session.query(User, user_alias).\
join((user_alias, User.id > user_alias.id)).\
filter(User.name == user_alias.name)
session.query(User, user_alias).join(
(user_alias, User.id > user_alias.id)
).filter(User.name == user_alias.name)
:class:`.AliasedClass` is also capable of mapping an existing mapped
class to an entirely new selectable, provided this selectable is column-
@@ -695,6 +669,7 @@ class AliasedClass(
using :func:`_sa.inspect`::
from sqlalchemy import inspect
my_alias = aliased(MyClass)
insp = inspect(my_alias)
@@ -755,12 +730,16 @@ class AliasedClass(
insp,
alias,
name,
with_polymorphic_mappers
if with_polymorphic_mappers
else mapper.with_polymorphic_mappers,
with_polymorphic_discriminator
if with_polymorphic_discriminator is not None
else mapper.polymorphic_on,
(
with_polymorphic_mappers
if with_polymorphic_mappers
else mapper.with_polymorphic_mappers
),
(
with_polymorphic_discriminator
if with_polymorphic_discriminator is not None
else mapper.polymorphic_on
),
base_alias,
use_mapper_path,
adapt_on_names,
@@ -971,9 +950,9 @@ class AliasedInsp(
self._weak_entity = weakref.ref(entity)
self.mapper = mapper
self.selectable = (
self.persist_selectable
) = self.local_table = selectable
self.selectable = self.persist_selectable = self.local_table = (
selectable
)
self.name = name
self.polymorphic_on = polymorphic_on
self._base_alias = weakref.ref(_base_alias or self)
@@ -1068,6 +1047,7 @@ class AliasedInsp(
aliased: bool = False,
innerjoin: bool = False,
adapt_on_names: bool = False,
name: Optional[str] = None,
_use_mapper_path: bool = False,
) -> AliasedClass[_O]:
primary_mapper = _class_to_mapper(base)
@@ -1088,6 +1068,7 @@ class AliasedInsp(
return AliasedClass(
base,
selectable,
name=name,
with_polymorphic_mappers=mappers,
adapt_on_names=adapt_on_names,
with_polymorphic_discriminator=polymorphic_on,
@@ -1229,8 +1210,7 @@ class AliasedInsp(
self,
obj: _CE,
key: Optional[str] = None,
) -> _CE:
...
) -> _CE: ...
else:
_orm_adapt_element = _adapt_element
@@ -1380,7 +1360,10 @@ class LoaderCriteriaOption(CriteriaOption):
def __init__(
self,
entity_or_base: _EntityType[Any],
where_criteria: _ColumnExpressionArgument[bool],
where_criteria: Union[
_ColumnExpressionArgument[bool],
Callable[[Any], _ColumnExpressionArgument[bool]],
],
loader_only: bool = False,
include_aliases: bool = False,
propagate_to_loaders: bool = True,
@@ -1539,7 +1522,7 @@ GenericAlias = type(List[Any])
def _inspect_generic_alias(
class_: Type[_O],
) -> Optional[Mapper[_O]]:
origin = cast("Type[_O]", typing_get_origin(class_))
origin = cast("Type[_O]", get_origin(class_))
return _inspect_mc(origin)
@@ -1583,7 +1566,7 @@ class Bundle(
_propagate_attrs: _PropagateAttrsType = util.immutabledict()
proxy_set = util.EMPTY_SET # type: ignore
proxy_set = util.EMPTY_SET
exprs: List[_ColumnsClauseElement]
@@ -1596,8 +1579,7 @@ class Bundle(
bn = Bundle("mybundle", MyClass.x, MyClass.y)
for row in session.query(bn).filter(
bn.c.x == 5).filter(bn.c.y == 4):
for row in session.query(bn).filter(bn.c.x == 5).filter(bn.c.y == 4):
print(row.mybundle.x, row.mybundle.y)
:param name: name of the bundle.
@@ -1606,7 +1588,7 @@ class Bundle(
can be returned as a "single entity" outside of any enclosing tuple
in the same manner as a mapped entity.
"""
""" # noqa: E501
self.name = self._label = name
coerced_exprs = [
coercions.expect(
@@ -1661,24 +1643,24 @@ class Bundle(
Nesting of bundles is also supported::
b1 = Bundle("b1",
Bundle('b2', MyClass.a, MyClass.b),
Bundle('b3', MyClass.x, MyClass.y)
)
b1 = Bundle(
"b1",
Bundle("b2", MyClass.a, MyClass.b),
Bundle("b3", MyClass.x, MyClass.y),
)
q = sess.query(b1).filter(
b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9)
q = sess.query(b1).filter(b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9)
.. seealso::
:attr:`.Bundle.c`
"""
""" # noqa: E501
c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]
"""An alias for :attr:`.Bundle.columns`."""
def _clone(self):
def _clone(self, **kw):
cloned = self.__class__.__new__(self.__class__)
cloned.__dict__.update(self.__dict__)
return cloned
@@ -1739,25 +1721,24 @@ class Bundle(
from sqlalchemy.orm import Bundle
class DictBundle(Bundle):
def create_row_processor(self, query, procs, labels):
'Override create_row_processor to return values as
dictionaries'
"Override create_row_processor to return values as dictionaries"
def proc(row):
return dict(
zip(labels, (proc(row) for proc in procs))
)
return dict(zip(labels, (proc(row) for proc in procs)))
return proc
A result from the above :class:`_orm.Bundle` will return dictionary
values::
bn = DictBundle('mybundle', MyClass.data1, MyClass.data2)
for row in session.execute(select(bn)).where(bn.c.data1 == 'd1'):
print(row.mybundle['data1'], row.mybundle['data2'])
bn = DictBundle("mybundle", MyClass.data1, MyClass.data2)
for row in session.execute(select(bn)).where(bn.c.data1 == "d1"):
print(row.mybundle["data1"], row.mybundle["data2"])
"""
""" # noqa: E501
keyed_tuple = result_tuple(labels, [() for l in labels])
def proc(row: Row[Any]) -> Any:
@@ -1940,7 +1921,7 @@ class _ORMJoin(expression.Join):
self.onclause,
isouter=self.isouter,
_left_memo=self._left_memo,
_right_memo=other._left_memo,
_right_memo=other._left_memo._path_registry,
)
return _ORMJoin(
@@ -1983,7 +1964,6 @@ def with_parent(
stmt = select(Address).where(with_parent(some_user, User.addresses))
The SQL rendered is the same as that rendered when a lazy loader
would fire off from the given parent on that attribute, meaning
that the appropriate state is taken from the parent object in
@@ -1996,9 +1976,7 @@ def with_parent(
a1 = aliased(Address)
a2 = aliased(Address)
stmt = select(a1, a2).where(
with_parent(u1, User.addresses.of_type(a2))
)
stmt = select(a1, a2).where(with_parent(u1, User.addresses.of_type(a2)))
The above use is equivalent to using the
:func:`_orm.with_parent.from_entity` argument::
@@ -2023,7 +2001,7 @@ def with_parent(
.. versionadded:: 1.2
"""
""" # noqa: E501
prop_t: RelationshipProperty[Any]
if isinstance(prop, str):
@@ -2117,14 +2095,13 @@ def _entity_corresponds_to_use_path_impl(
someoption(A).someoption(C.d) # -> fn(A, C) -> False
a1 = aliased(A)
someoption(a1).someoption(A.b) # -> fn(a1, A) -> False
someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True
someoption(a1).someoption(A.b) # -> fn(a1, A) -> False
someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True
wp = with_polymorphic(A, [A1, A2])
someoption(wp).someoption(A1.foo) # -> fn(wp, A1) -> False
someoption(wp).someoption(wp.A1.foo) # -> fn(wp, wp.A1) -> True
"""
if insp_is_aliased_class(given):
return (
@@ -2151,7 +2128,7 @@ def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool:
mapper
)
elif given.with_polymorphic_mappers:
return mapper in given.with_polymorphic_mappers
return mapper in given.with_polymorphic_mappers or given.isa(mapper)
else:
return given.isa(mapper)
@@ -2233,7 +2210,7 @@ def _cleanup_mapped_str_annotation(
inner: Optional[Match[str]]
mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
mm = re.match(r"^([^ \|]+?)\[(.+)\]$", annotation)
if not mm:
return annotation
@@ -2273,7 +2250,7 @@ def _cleanup_mapped_str_annotation(
while True:
stack.append(real_symbol if mm is inner else inner.group(1))
g2 = inner.group(2)
inner = re.match(r"^(.+?)\[(.+)\]$", g2)
inner = re.match(r"^([^ \|]+?)\[(.+)\]$", g2)
if inner is None:
stack.append(g2)
break
@@ -2295,8 +2272,10 @@ def _cleanup_mapped_str_annotation(
# ['Mapped', "'Optional[Dict[str, str]]'"]
not re.match(r"""^["'].*["']$""", stack[-1])
# avoid further generics like Dict[] such as
# ['Mapped', 'dict[str, str] | None']
and not re.match(r".*\[.*\]", stack[-1])
# ['Mapped', 'dict[str, str] | None'],
# ['Mapped', 'list[int] | list[str]'],
# ['Mapped', 'Union[list[int], list[str]]'],
and not re.search(r"[\[\]]", stack[-1])
):
stripchars = "\"' "
stack[-1] = ", ".join(
@@ -2318,7 +2297,7 @@ def _extract_mapped_subtype(
is_dataclass_field: bool,
expect_mapped: bool = True,
raiseerr: bool = True,
) -> Optional[Tuple[Union[type, str], Optional[type]]]:
) -> Optional[Tuple[Union[_AnnotationScanType, str], Optional[type]]]:
"""given an annotation, figure out if it's ``Mapped[something]`` and if
so, return the ``something`` part.
@@ -2328,7 +2307,7 @@ def _extract_mapped_subtype(
if raw_annotation is None:
if required:
raise sa_exc.ArgumentError(
raise orm_exc.MappedAnnotationError(
f"Python typing annotation is required for attribute "
f'"{cls.__name__}.{key}" when primary argument(s) for '
f'"{attr_cls.__name__}" construct are None or not present'
@@ -2336,6 +2315,11 @@ def _extract_mapped_subtype(
return None
try:
# destringify the "outside" of the annotation. note we are not
# adding include_generic so it will *not* dig into generic contents,
# which will remain as ForwardRef or plain str under future annotations
# mode. The full destringify happens later when mapped_column goes
# to do a full lookup in the registry type_annotations_map.
annotated = de_stringify_annotation(
cls,
raw_annotation,
@@ -2343,14 +2327,14 @@ def _extract_mapped_subtype(
str_cleanup_fn=_cleanup_mapped_str_annotation,
)
except _CleanupError as ce:
raise sa_exc.ArgumentError(
raise orm_exc.MappedAnnotationError(
f"Could not interpret annotation {raw_annotation}. "
"Check that it uses names that are correctly imported at the "
"module level. See chained stack trace for more hints."
) from ce
except NameError as ne:
if raiseerr and "Mapped[" in raw_annotation: # type: ignore
raise sa_exc.ArgumentError(
raise orm_exc.MappedAnnotationError(
f"Could not interpret annotation {raw_annotation}. "
"Check that it uses names that are correctly imported at the "
"module level. See chained stack trace for more hints."
@@ -2379,7 +2363,7 @@ def _extract_mapped_subtype(
):
return None
raise sa_exc.ArgumentError(
raise orm_exc.MappedAnnotationError(
f'Type annotation for "{cls.__name__}.{key}" '
"can't be correctly interpreted for "
"Annotated Declarative Table form. ORM annotations "
@@ -2400,8 +2384,20 @@ def _extract_mapped_subtype(
return annotated, None
if len(annotated.__args__) != 1:
raise sa_exc.ArgumentError(
raise orm_exc.MappedAnnotationError(
"Expected sub-type for Mapped[] annotation"
)
return annotated.__args__[0], annotated.__origin__
return (
# fix dict/list/set args to be ForwardRef, see #11814
fixup_container_fwd_refs(annotated.__args__[0]),
annotated.__origin__,
)
def _mapper_property_as_plain_name(prop: Type[Any]) -> str:
if hasattr(prop, "_mapper_property_name"):
name = prop._mapper_property_name()
else:
name = None
return util.clsname_as_plain_name(prop, name)