This commit is contained in:
@@ -1,14 +1,12 @@
|
||||
# sql/util.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: allow-untyped-defs, allow-untyped-calls
|
||||
|
||||
"""High level utilities which build upon other modules here.
|
||||
|
||||
"""
|
||||
"""High level utilities which build upon other modules here."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
@@ -106,7 +104,7 @@ def join_condition(
|
||||
|
||||
would produce an expression along the lines of::
|
||||
|
||||
tablea.c.id==tableb.c.tablea_id
|
||||
tablea.c.id == tableb.c.tablea_id
|
||||
|
||||
The join is determined based on the foreign key relationships
|
||||
between the two selectables. If there are multiple ways
|
||||
@@ -268,7 +266,7 @@ def visit_binary_product(
|
||||
|
||||
The function is of the form::
|
||||
|
||||
def my_fn(binary, left, right)
|
||||
def my_fn(binary, left, right): ...
|
||||
|
||||
For each binary expression located which has a
|
||||
comparison operator, the product of "left" and
|
||||
@@ -277,12 +275,11 @@ def visit_binary_product(
|
||||
|
||||
Hence an expression like::
|
||||
|
||||
and_(
|
||||
(a + b) == q + func.sum(e + f),
|
||||
j == r
|
||||
)
|
||||
and_((a + b) == q + func.sum(e + f), j == r)
|
||||
|
||||
would have the traversal::
|
||||
would have the traversal:
|
||||
|
||||
.. sourcecode:: text
|
||||
|
||||
a <eq> q
|
||||
a <eq> e
|
||||
@@ -350,9 +347,9 @@ def find_tables(
|
||||
] = _visitors["lateral"] = tables.append
|
||||
|
||||
if include_crud:
|
||||
_visitors["insert"] = _visitors["update"] = _visitors[
|
||||
"delete"
|
||||
] = lambda ent: tables.append(ent.table)
|
||||
_visitors["insert"] = _visitors["update"] = _visitors["delete"] = (
|
||||
lambda ent: tables.append(ent.table)
|
||||
)
|
||||
|
||||
if check_columns:
|
||||
|
||||
@@ -367,7 +364,7 @@ def find_tables(
|
||||
return tables
|
||||
|
||||
|
||||
def unwrap_order_by(clause):
|
||||
def unwrap_order_by(clause: Any) -> Any:
|
||||
"""Break up an 'order by' expression into individual column-expressions,
|
||||
without DESC/ASC/NULLS FIRST/NULLS LAST"""
|
||||
|
||||
@@ -481,7 +478,7 @@ def surface_selectables(clause):
|
||||
stack.append(elem.element)
|
||||
|
||||
|
||||
def surface_selectables_only(clause):
|
||||
def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]:
|
||||
stack = [clause]
|
||||
while stack:
|
||||
elem = stack.pop()
|
||||
@@ -528,9 +525,7 @@ def bind_values(clause):
|
||||
|
||||
E.g.::
|
||||
|
||||
>>> expr = and_(
|
||||
... table.c.foo==5, table.c.foo==7
|
||||
... )
|
||||
>>> expr = and_(table.c.foo == 5, table.c.foo == 7)
|
||||
>>> bind_values(expr)
|
||||
[5, 7]
|
||||
"""
|
||||
@@ -878,8 +873,7 @@ def reduce_columns(
|
||||
columns: Iterable[ColumnElement[Any]],
|
||||
*clauses: Optional[ClauseElement],
|
||||
**kw: bool,
|
||||
) -> Sequence[ColumnElement[Any]]:
|
||||
...
|
||||
) -> Sequence[ColumnElement[Any]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -887,8 +881,7 @@ def reduce_columns(
|
||||
columns: _SelectIterable,
|
||||
*clauses: Optional[ClauseElement],
|
||||
**kw: bool,
|
||||
) -> Sequence[Union[ColumnElement[Any], TextClause]]:
|
||||
...
|
||||
) -> Sequence[Union[ColumnElement[Any], TextClause]]: ...
|
||||
|
||||
|
||||
def reduce_columns(
|
||||
@@ -1043,20 +1036,24 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
|
||||
|
||||
E.g.::
|
||||
|
||||
table1 = Table('sometable', metadata,
|
||||
Column('col1', Integer),
|
||||
Column('col2', Integer)
|
||||
)
|
||||
table2 = Table('someothertable', metadata,
|
||||
Column('col1', Integer),
|
||||
Column('col2', Integer)
|
||||
)
|
||||
table1 = Table(
|
||||
"sometable",
|
||||
metadata,
|
||||
Column("col1", Integer),
|
||||
Column("col2", Integer),
|
||||
)
|
||||
table2 = Table(
|
||||
"someothertable",
|
||||
metadata,
|
||||
Column("col1", Integer),
|
||||
Column("col2", Integer),
|
||||
)
|
||||
|
||||
condition = table1.c.col1 == table2.c.col1
|
||||
|
||||
make an alias of table1::
|
||||
|
||||
s = table1.alias('foo')
|
||||
s = table1.alias("foo")
|
||||
|
||||
calling ``ClauseAdapter(s).traverse(condition)`` converts
|
||||
condition to read::
|
||||
@@ -1099,8 +1096,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def traverse(self, obj: Literal[None]) -> None:
|
||||
...
|
||||
def traverse(self, obj: Literal[None]) -> None: ...
|
||||
|
||||
# note this specializes the ReplacingExternalTraversal.traverse()
|
||||
# method to state
|
||||
@@ -1111,13 +1107,11 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
|
||||
# FromClause but Mypy is not accepting those as compatible with
|
||||
# the base ReplacingExternalTraversal
|
||||
@overload
|
||||
def traverse(self, obj: _ET) -> _ET:
|
||||
...
|
||||
def traverse(self, obj: _ET) -> _ET: ...
|
||||
|
||||
def traverse(
|
||||
self, obj: Optional[ExternallyTraversible]
|
||||
) -> Optional[ExternallyTraversible]:
|
||||
...
|
||||
) -> Optional[ExternallyTraversible]: ...
|
||||
|
||||
def _corresponding_column(
|
||||
self, col, require_embedded, _seen=util.EMPTY_SET
|
||||
@@ -1177,7 +1171,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
|
||||
# we are an alias of a table and we are not derived from an
|
||||
# alias of a table (which nonetheless may be the same table
|
||||
# as ours) so, same thing
|
||||
return col # type: ignore
|
||||
return col
|
||||
else:
|
||||
# other cases where we are a selectable and the element
|
||||
# is another join or selectable that contains a table which our
|
||||
@@ -1219,23 +1213,18 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
|
||||
|
||||
class _ColumnLookup(Protocol):
|
||||
@overload
|
||||
def __getitem__(self, key: None) -> None:
|
||||
...
|
||||
def __getitem__(self, key: None) -> None: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]:
|
||||
...
|
||||
def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]:
|
||||
...
|
||||
def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: _ET) -> _ET:
|
||||
...
|
||||
def __getitem__(self, key: _ET) -> _ET: ...
|
||||
|
||||
def __getitem__(self, key: Any) -> Any:
|
||||
...
|
||||
def __getitem__(self, key: Any) -> Any: ...
|
||||
|
||||
|
||||
class ColumnAdapter(ClauseAdapter):
|
||||
@@ -1333,12 +1322,10 @@ class ColumnAdapter(ClauseAdapter):
|
||||
return ac
|
||||
|
||||
@overload
|
||||
def traverse(self, obj: Literal[None]) -> None:
|
||||
...
|
||||
def traverse(self, obj: Literal[None]) -> None: ...
|
||||
|
||||
@overload
|
||||
def traverse(self, obj: _ET) -> _ET:
|
||||
...
|
||||
def traverse(self, obj: _ET) -> _ET: ...
|
||||
|
||||
def traverse(
|
||||
self, obj: Optional[ExternallyTraversible]
|
||||
@@ -1353,8 +1340,7 @@ class ColumnAdapter(ClauseAdapter):
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def visitor_iterator(self) -> Iterator[ColumnAdapter]:
|
||||
...
|
||||
def visitor_iterator(self) -> Iterator[ColumnAdapter]: ...
|
||||
|
||||
adapt_clause = traverse
|
||||
adapt_list = ClauseAdapter.copy_and_process
|
||||
|
||||
Reference in New Issue
Block a user