API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,5 +1,5 @@
# sql/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# sql/_dml_constructors.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -24,10 +24,7 @@ def insert(table: _DMLTableArgument) -> Insert:
from sqlalchemy import insert
stmt = (
insert(user_table).
values(name='username', fullname='Full Username')
)
stmt = insert(user_table).values(name="username", fullname="Full Username")
Similar functionality is available via the
:meth:`_expression.TableClause.insert` method on
@@ -78,7 +75,7 @@ def insert(table: _DMLTableArgument) -> Insert:
:ref:`tutorial_core_insert` - in the :ref:`unified_tutorial`
"""
""" # noqa: E501
return Insert(table)
@@ -90,9 +87,7 @@ def update(table: _DMLTableArgument) -> Update:
from sqlalchemy import update
stmt = (
update(user_table).
where(user_table.c.id == 5).
values(name='user #5')
update(user_table).where(user_table.c.id == 5).values(name="user #5")
)
Similar functionality is available via the
@@ -109,7 +104,7 @@ def update(table: _DMLTableArgument) -> Update:
:ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial`
"""
""" # noqa: E501
return Update(table)
@@ -120,10 +115,7 @@ def delete(table: _DMLTableArgument) -> Delete:
from sqlalchemy import delete
stmt = (
delete(user_table).
where(user_table.c.id == 5)
)
stmt = delete(user_table).where(user_table.c.id == 5)
Similar functionality is available via the
:meth:`_expression.TableClause.delete` method on

View File

@@ -1,5 +1,5 @@
# sql/_elements_constructors.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -10,7 +10,6 @@ from __future__ import annotations
import typing
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import overload
@@ -49,6 +48,7 @@ from .functions import FunctionElement
from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ._typing import _ByArgument
from ._typing import _ColumnExpressionArgument
from ._typing import _ColumnExpressionOrLiteralArgument
from ._typing import _ColumnExpressionOrStrLabelArgument
@@ -125,11 +125,8 @@ def and_( # type: ignore[empty-body]
from sqlalchemy import and_
stmt = select(users_table).where(
and_(
users_table.c.name == 'wendy',
users_table.c.enrolled == True
)
)
and_(users_table.c.name == "wendy", users_table.c.enrolled == True)
)
The :func:`.and_` conjunction is also available using the
Python ``&`` operator (though note that compound expressions
@@ -137,9 +134,8 @@ def and_( # type: ignore[empty-body]
operator precedence behavior)::
stmt = select(users_table).where(
(users_table.c.name == 'wendy') &
(users_table.c.enrolled == True)
)
(users_table.c.name == "wendy") & (users_table.c.enrolled == True)
)
The :func:`.and_` operation is also implicit in some cases;
the :meth:`_expression.Select.where`
@@ -147,9 +143,11 @@ def and_( # type: ignore[empty-body]
times against a statement, which will have the effect of each
clause being combined using :func:`.and_`::
stmt = select(users_table).\
where(users_table.c.name == 'wendy').\
where(users_table.c.enrolled == True)
stmt = (
select(users_table)
.where(users_table.c.name == "wendy")
.where(users_table.c.enrolled == True)
)
The :func:`.and_` construct must be given at least one positional
argument in order to be valid; a :func:`.and_` construct with no
@@ -159,6 +157,7 @@ def and_( # type: ignore[empty-body]
specified::
from sqlalchemy import true
criteria = and_(true(), *expressions)
The above expression will compile to SQL as the expression ``true``
@@ -190,11 +189,8 @@ if not TYPE_CHECKING:
from sqlalchemy import and_
stmt = select(users_table).where(
and_(
users_table.c.name == 'wendy',
users_table.c.enrolled == True
)
)
and_(users_table.c.name == "wendy", users_table.c.enrolled == True)
)
The :func:`.and_` conjunction is also available using the
Python ``&`` operator (though note that compound expressions
@@ -202,9 +198,8 @@ if not TYPE_CHECKING:
operator precedence behavior)::
stmt = select(users_table).where(
(users_table.c.name == 'wendy') &
(users_table.c.enrolled == True)
)
(users_table.c.name == "wendy") & (users_table.c.enrolled == True)
)
The :func:`.and_` operation is also implicit in some cases;
the :meth:`_expression.Select.where`
@@ -212,9 +207,11 @@ if not TYPE_CHECKING:
times against a statement, which will have the effect of each
clause being combined using :func:`.and_`::
stmt = select(users_table).\
where(users_table.c.name == 'wendy').\
where(users_table.c.enrolled == True)
stmt = (
select(users_table)
.where(users_table.c.name == "wendy")
.where(users_table.c.enrolled == True)
)
The :func:`.and_` construct must be given at least one positional
argument in order to be valid; a :func:`.and_` construct with no
@@ -224,6 +221,7 @@ if not TYPE_CHECKING:
specified::
from sqlalchemy import true
criteria = and_(true(), *expressions)
The above expression will compile to SQL as the expression ``true``
@@ -241,7 +239,7 @@ if not TYPE_CHECKING:
:func:`.or_`
"""
""" # noqa: E501
return BooleanClauseList.and_(*clauses)
@@ -307,9 +305,12 @@ def asc(
e.g.::
from sqlalchemy import asc
stmt = select(users_table).order_by(asc(users_table.c.name))
will produce SQL as::
will produce SQL as:
.. sourcecode:: sql
SELECT id, name FROM user ORDER BY name ASC
@@ -346,9 +347,11 @@ def collate(
e.g.::
collate(mycolumn, 'utf8_bin')
collate(mycolumn, "utf8_bin")
produces::
produces:
.. sourcecode:: sql
mycolumn COLLATE utf8_bin
@@ -373,9 +376,12 @@ def between(
E.g.::
from sqlalchemy import between
stmt = select(users_table).where(between(users_table.c.id, 5, 7))
Would produce SQL resembling::
Would produce SQL resembling:
.. sourcecode:: sql
SELECT id, name FROM user WHERE id BETWEEN :id_1 AND :id_2
@@ -436,16 +442,12 @@ def outparam(
return BindParameter(key, None, type_=type_, unique=False, isoutparam=True)
# mypy insists that BinaryExpression and _HasClauseElement protocol overlap.
# they do not. at all. bug in mypy?
@overload
def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: # type: ignore
...
def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: ...
@overload
def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]:
...
def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: ...
def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]:
@@ -497,10 +499,13 @@ def bindparam(
from sqlalchemy import bindparam
stmt = select(users_table).\
where(users_table.c.name == bindparam('username'))
stmt = select(users_table).where(
users_table.c.name == bindparam("username")
)
The above statement, when rendered, will produce SQL similar to::
The above statement, when rendered, will produce SQL similar to:
.. sourcecode:: sql
SELECT id, name FROM user WHERE name = :username
@@ -508,22 +513,25 @@ def bindparam(
would typically be applied at execution time to a method
like :meth:`_engine.Connection.execute`::
result = connection.execute(stmt, username='wendy')
result = connection.execute(stmt, {"username": "wendy"})
Explicit use of :func:`.bindparam` is also common when producing
UPDATE or DELETE statements that are to be invoked multiple times,
where the WHERE criterion of the statement is to change on each
invocation, such as::
stmt = (users_table.update().
where(user_table.c.name == bindparam('username')).
values(fullname=bindparam('fullname'))
)
stmt = (
users_table.update()
.where(user_table.c.name == bindparam("username"))
.values(fullname=bindparam("fullname"))
)
connection.execute(
stmt, [{"username": "wendy", "fullname": "Wendy Smith"},
{"username": "jack", "fullname": "Jack Jones"},
]
stmt,
[
{"username": "wendy", "fullname": "Wendy Smith"},
{"username": "jack", "fullname": "Jack Jones"},
],
)
SQLAlchemy's Core expression system makes wide use of
@@ -532,7 +540,7 @@ def bindparam(
coerced into fixed :func:`.bindparam` constructs. For example, given
a comparison operation such as::
expr = users_table.c.name == 'Wendy'
expr = users_table.c.name == "Wendy"
The above expression will produce a :class:`.BinaryExpression`
construct, where the left side is the :class:`_schema.Column` object
@@ -540,9 +548,11 @@ def bindparam(
:class:`.BindParameter` representing the literal value::
print(repr(expr.right))
BindParameter('%(4327771088 name)s', 'Wendy', type_=String())
BindParameter("%(4327771088 name)s", "Wendy", type_=String())
The expression above will render SQL such as::
The expression above will render SQL such as:
.. sourcecode:: sql
user.name = :name_1
@@ -551,10 +561,12 @@ def bindparam(
along where it is later used within statement execution. If we
invoke a statement like the following::
stmt = select(users_table).where(users_table.c.name == 'Wendy')
stmt = select(users_table).where(users_table.c.name == "Wendy")
result = connection.execute(stmt)
We would see SQL logging output as::
We would see SQL logging output as:
.. sourcecode:: sql
SELECT "user".id, "user".name
FROM "user"
@@ -572,9 +584,11 @@ def bindparam(
bound placeholders based on the arguments passed, as in::
stmt = users_table.insert()
result = connection.execute(stmt, name='Wendy')
result = connection.execute(stmt, {"name": "Wendy"})
The above will produce SQL output as::
The above will produce SQL output as:
.. sourcecode:: sql
INSERT INTO "user" (name) VALUES (%(name)s)
{'name': 'Wendy'}
@@ -647,12 +661,12 @@ def bindparam(
:param quote:
True if this parameter name requires quoting and is not
currently known as a SQLAlchemy reserved word; this currently
only applies to the Oracle backend, where bound names must
only applies to the Oracle Database backends, where bound names must
sometimes be quoted.
:param isoutparam:
if True, the parameter should be treated like a stored procedure
"OUT" parameter. This applies to backends such as Oracle which
"OUT" parameter. This applies to backends such as Oracle Database which
support OUT parameters.
:param expanding:
@@ -738,16 +752,17 @@ def case(
from sqlalchemy import case
stmt = select(users_table).\
where(
case(
(users_table.c.name == 'wendy', 'W'),
(users_table.c.name == 'jack', 'J'),
else_='E'
)
)
stmt = select(users_table).where(
case(
(users_table.c.name == "wendy", "W"),
(users_table.c.name == "jack", "J"),
else_="E",
)
)
The above statement will produce SQL resembling::
The above statement will produce SQL resembling:
.. sourcecode:: sql
SELECT id, name FROM user
WHERE CASE
@@ -765,14 +780,9 @@ def case(
compared against keyed to result expressions. The statement below is
equivalent to the preceding statement::
stmt = select(users_table).\
where(
case(
{"wendy": "W", "jack": "J"},
value=users_table.c.name,
else_='E'
)
)
stmt = select(users_table).where(
case({"wendy": "W", "jack": "J"}, value=users_table.c.name, else_="E")
)
The values which are accepted as result values in
:paramref:`.case.whens` as well as with :paramref:`.case.else_` are
@@ -787,20 +797,16 @@ def case(
from sqlalchemy import case, literal_column
case(
(
orderline.c.qty > 100,
literal_column("'greaterthan100'")
),
(
orderline.c.qty > 10,
literal_column("'greaterthan10'")
),
else_=literal_column("'lessthan10'")
(orderline.c.qty > 100, literal_column("'greaterthan100'")),
(orderline.c.qty > 10, literal_column("'greaterthan10'")),
else_=literal_column("'lessthan10'"),
)
The above will render the given constants without using bound
parameters for the result values (but still for the comparison
values), as in::
values), as in:
.. sourcecode:: sql
CASE
WHEN (orderline.qty > :qty_1) THEN 'greaterthan100'
@@ -821,8 +827,8 @@ def case(
resulting value, e.g.::
case(
(users_table.c.name == 'wendy', 'W'),
(users_table.c.name == 'jack', 'J')
(users_table.c.name == "wendy", "W"),
(users_table.c.name == "jack", "J"),
)
In the second form, it accepts a Python dictionary of comparison
@@ -830,10 +836,7 @@ def case(
:paramref:`.case.value` to be present, and values will be compared
using the ``==`` operator, e.g.::
case(
{"wendy": "W", "jack": "J"},
value=users_table.c.name
)
case({"wendy": "W", "jack": "J"}, value=users_table.c.name)
:param value: An optional SQL expression which will be used as a
fixed "comparison point" for candidate values within a dictionary
@@ -846,7 +849,7 @@ def case(
expressions evaluate to true.
"""
""" # noqa: E501
return Case(*whens, value=value, else_=else_)
@@ -864,7 +867,9 @@ def cast(
stmt = select(cast(product_table.c.unit_price, Numeric(10, 4)))
The above statement will produce SQL resembling::
The above statement will produce SQL resembling:
.. sourcecode:: sql
SELECT CAST(unit_price AS NUMERIC(10, 4)) FROM product
@@ -933,11 +938,11 @@ def try_cast(
from sqlalchemy import select, try_cast, Numeric
stmt = select(
try_cast(product_table.c.unit_price, Numeric(10, 4))
)
stmt = select(try_cast(product_table.c.unit_price, Numeric(10, 4)))
The above would render on Microsoft SQL Server as::
The above would render on Microsoft SQL Server as:
.. sourcecode:: sql
SELECT TRY_CAST (product_table.unit_price AS NUMERIC(10, 4))
FROM product_table
@@ -968,7 +973,9 @@ def column(
id, name = column("id"), column("name")
stmt = select(id, name).select_from("user")
The above statement would produce SQL like::
The above statement would produce SQL like:
.. sourcecode:: sql
SELECT id, name FROM user
@@ -1004,13 +1011,14 @@ def column(
from sqlalchemy import table, column, select
user = table("user",
column("id"),
column("name"),
column("description"),
user = table(
"user",
column("id"),
column("name"),
column("description"),
)
stmt = select(user.c.description).where(user.c.name == 'wendy')
stmt = select(user.c.description).where(user.c.name == "wendy")
A :func:`_expression.column` / :func:`.table`
construct like that illustrated
@@ -1057,7 +1065,9 @@ def desc(
stmt = select(users_table).order_by(desc(users_table.c.name))
will produce SQL as::
will produce SQL as:
.. sourcecode:: sql
SELECT id, name FROM user ORDER BY name DESC
@@ -1090,16 +1100,26 @@ def desc(
def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
"""Produce an column-expression-level unary ``DISTINCT`` clause.
This applies the ``DISTINCT`` keyword to an individual column
expression, and is typically contained within an aggregate function,
as in::
This applies the ``DISTINCT`` keyword to an **individual column
expression** (e.g. not the whole statement), and renders **specifically
in that column position**; this is used for containment within
an aggregate function, as in::
from sqlalchemy import distinct, func
stmt = select(func.count(distinct(users_table.c.name)))
The above would produce an expression resembling::
stmt = select(users_table.c.id, func.count(distinct(users_table.c.name)))
SELECT COUNT(DISTINCT name) FROM user
The above would produce an statement resembling:
.. sourcecode:: sql
SELECT user.id, count(DISTINCT user.name) FROM user
.. tip:: The :func:`_sql.distinct` function does **not** apply DISTINCT
to the full SELECT statement, instead applying a DISTINCT modifier
to **individual column expressions**. For general ``SELECT DISTINCT``
support, use the
:meth:`_sql.Select.distinct` method on :class:`_sql.Select`.
The :func:`.distinct` function is also available as a column-level
method, e.g. :meth:`_expression.ColumnElement.distinct`, as in::
@@ -1122,7 +1142,7 @@ def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
:data:`.func`
"""
""" # noqa: E501
return UnaryExpression._create_distinct(expr)
@@ -1152,6 +1172,9 @@ def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract:
:param field: The field to extract.
.. warning:: This field is used as a literal SQL string.
**DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
:param expr: A column or Python scalar expression serving as the
right side of the ``EXTRACT`` expression.
@@ -1160,9 +1183,10 @@ def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract:
from sqlalchemy import extract
from sqlalchemy import table, column
logged_table = table("user",
column("id"),
column("date_created"),
logged_table = table(
"user",
column("id"),
column("date_created"),
)
stmt = select(logged_table.c.id).where(
@@ -1174,9 +1198,9 @@ def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract:
Similarly, one can also select an extracted component::
stmt = select(
extract("YEAR", logged_table.c.date_created)
).where(logged_table.c.id == 1)
stmt = select(extract("YEAR", logged_table.c.date_created)).where(
logged_table.c.id == 1
)
The implementation of ``EXTRACT`` may vary across database backends.
Users are reminded to consult their database documentation.
@@ -1235,7 +1259,8 @@ def funcfilter(
E.g.::
from sqlalchemy import funcfilter
funcfilter(func.count(1), MyClass.name == 'some name')
funcfilter(func.count(1), MyClass.name == "some name")
Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')".
@@ -1292,10 +1317,11 @@ def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
from sqlalchemy import desc, nulls_first
stmt = select(users_table).order_by(
nulls_first(desc(users_table.c.name)))
stmt = select(users_table).order_by(nulls_first(desc(users_table.c.name)))
The SQL expression from the above would resemble::
The SQL expression from the above would resemble:
.. sourcecode:: sql
SELECT id, name FROM user ORDER BY name DESC NULLS FIRST
@@ -1306,7 +1332,8 @@ def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
function version, as in::
stmt = select(users_table).order_by(
users_table.c.name.desc().nulls_first())
users_table.c.name.desc().nulls_first()
)
.. versionchanged:: 1.4 :func:`.nulls_first` is renamed from
:func:`.nullsfirst` in previous releases.
@@ -1322,7 +1349,7 @@ def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
:meth:`_expression.Select.order_by`
"""
""" # noqa: E501
return UnaryExpression._create_nulls_first(column)
@@ -1336,10 +1363,11 @@ def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
from sqlalchemy import desc, nulls_last
stmt = select(users_table).order_by(
nulls_last(desc(users_table.c.name)))
stmt = select(users_table).order_by(nulls_last(desc(users_table.c.name)))
The SQL expression from the above would resemble::
The SQL expression from the above would resemble:
.. sourcecode:: sql
SELECT id, name FROM user ORDER BY name DESC NULLS LAST
@@ -1349,8 +1377,7 @@ def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
rather than as its standalone
function version, as in::
stmt = select(users_table).order_by(
users_table.c.name.desc().nulls_last())
stmt = select(users_table).order_by(users_table.c.name.desc().nulls_last())
.. versionchanged:: 1.4 :func:`.nulls_last` is renamed from
:func:`.nullslast` in previous releases.
@@ -1366,7 +1393,7 @@ def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]:
:meth:`_expression.Select.order_by`
"""
""" # noqa: E501
return UnaryExpression._create_nulls_last(column)
@@ -1381,11 +1408,8 @@ def or_( # type: ignore[empty-body]
from sqlalchemy import or_
stmt = select(users_table).where(
or_(
users_table.c.name == 'wendy',
users_table.c.name == 'jack'
)
)
or_(users_table.c.name == "wendy", users_table.c.name == "jack")
)
The :func:`.or_` conjunction is also available using the
Python ``|`` operator (though note that compound expressions
@@ -1393,9 +1417,8 @@ def or_( # type: ignore[empty-body]
operator precedence behavior)::
stmt = select(users_table).where(
(users_table.c.name == 'wendy') |
(users_table.c.name == 'jack')
)
(users_table.c.name == "wendy") | (users_table.c.name == "jack")
)
The :func:`.or_` construct must be given at least one positional
argument in order to be valid; a :func:`.or_` construct with no
@@ -1405,6 +1428,7 @@ def or_( # type: ignore[empty-body]
specified::
from sqlalchemy import false
or_criteria = or_(false(), *expressions)
The above expression will compile to SQL as the expression ``false``
@@ -1436,11 +1460,8 @@ if not TYPE_CHECKING:
from sqlalchemy import or_
stmt = select(users_table).where(
or_(
users_table.c.name == 'wendy',
users_table.c.name == 'jack'
)
)
or_(users_table.c.name == "wendy", users_table.c.name == "jack")
)
The :func:`.or_` conjunction is also available using the
Python ``|`` operator (though note that compound expressions
@@ -1448,9 +1469,8 @@ if not TYPE_CHECKING:
operator precedence behavior)::
stmt = select(users_table).where(
(users_table.c.name == 'wendy') |
(users_table.c.name == 'jack')
)
(users_table.c.name == "wendy") | (users_table.c.name == "jack")
)
The :func:`.or_` construct must be given at least one positional
argument in order to be valid; a :func:`.or_` construct with no
@@ -1460,6 +1480,7 @@ if not TYPE_CHECKING:
specified::
from sqlalchemy import false
or_criteria = or_(false(), *expressions)
The above expression will compile to SQL as the expression ``false``
@@ -1477,26 +1498,17 @@ if not TYPE_CHECKING:
:func:`.and_`
"""
""" # noqa: E501
return BooleanClauseList.or_(*clauses)
def over(
element: FunctionElement[_T],
partition_by: Optional[
Union[
Iterable[_ColumnExpressionArgument[Any]],
_ColumnExpressionArgument[Any],
]
] = None,
order_by: Optional[
Union[
Iterable[_ColumnExpressionArgument[Any]],
_ColumnExpressionArgument[Any],
]
] = None,
partition_by: Optional[_ByArgument] = None,
order_by: Optional[_ByArgument] = None,
range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
) -> Over[_T]:
r"""Produce an :class:`.Over` object against a function.
@@ -1508,19 +1520,23 @@ def over(
func.row_number().over(order_by=mytable.c.some_column)
Would produce::
Would produce:
.. sourcecode:: sql
ROW_NUMBER() OVER(ORDER BY some_column)
Ranges are also possible using the :paramref:`.expression.over.range_`
and :paramref:`.expression.over.rows` parameters. These
Ranges are also possible using the :paramref:`.expression.over.range_`,
:paramref:`.expression.over.rows`, and :paramref:`.expression.over.groups`
parameters. These
mutually-exclusive parameters each accept a 2-tuple, which contains
a combination of integers and None::
func.row_number().over(
order_by=my_table.c.some_column, range_=(None, 0))
func.row_number().over(order_by=my_table.c.some_column, range_=(None, 0))
The above would produce::
The above would produce:
.. sourcecode:: sql
ROW_NUMBER() OVER(ORDER BY some_column
RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
@@ -1531,19 +1547,23 @@ def over(
* RANGE BETWEEN 5 PRECEDING AND 10 FOLLOWING::
func.row_number().over(order_by='x', range_=(-5, 10))
func.row_number().over(order_by="x", range_=(-5, 10))
* ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW::
func.row_number().over(order_by='x', rows=(None, 0))
func.row_number().over(order_by="x", rows=(None, 0))
* RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING::
func.row_number().over(order_by='x', range_=(-2, None))
func.row_number().over(order_by="x", range_=(-2, None))
* RANGE BETWEEN 1 FOLLOWING AND 3 FOLLOWING::
func.row_number().over(order_by='x', range_=(1, 3))
func.row_number().over(order_by="x", range_=(1, 3))
* GROUPS BETWEEN 1 FOLLOWING AND 3 FOLLOWING::
func.row_number().over(order_by="x", groups=(1, 3))
:param element: a :class:`.FunctionElement`, :class:`.WithinGroup`,
or other compatible construct.
@@ -1556,10 +1576,14 @@ def over(
:param range\_: optional range clause for the window. This is a
tuple value which can contain integer values or ``None``,
and will render a RANGE BETWEEN PRECEDING / FOLLOWING clause.
:param rows: optional rows clause for the window. This is a tuple
value which can contain integer values or None, and will render
a ROWS BETWEEN PRECEDING / FOLLOWING clause.
:param groups: optional groups clause for the window. This is a
tuple value which can contain integer values or ``None``,
and will render a GROUPS BETWEEN PRECEDING / FOLLOWING clause.
.. versionadded:: 2.0.40
This function is also available from the :data:`~.expression.func`
construct itself via the :meth:`.FunctionElement.over` method.
@@ -1572,8 +1596,8 @@ def over(
:func:`_expression.within_group`
"""
return Over(element, partition_by, order_by, range_, rows)
""" # noqa: E501
return Over(element, partition_by, order_by, range_, rows, groups)
@_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`")
@@ -1603,7 +1627,7 @@ def text(text: str) -> TextClause:
E.g.::
t = text("SELECT * FROM users WHERE id=:user_id")
result = connection.execute(t, user_id=12)
result = connection.execute(t, {"user_id": 12})
For SQL statements where a colon is required verbatim, as within
an inline string, use a backslash to escape::
@@ -1621,9 +1645,11 @@ def text(text: str) -> TextClause:
method allows
specification of return columns including names and types::
t = text("SELECT * FROM users WHERE id=:user_id").\
bindparams(user_id=7).\
columns(id=Integer, name=String)
t = (
text("SELECT * FROM users WHERE id=:user_id")
.bindparams(user_id=7)
.columns(id=Integer, name=String)
)
for id, name in connection.execute(t):
print(id, name)
@@ -1633,7 +1659,7 @@ def text(text: str) -> TextClause:
such as for the WHERE clause of a SELECT statement::
s = select(users.c.id, users.c.name).where(text("id=:user_id"))
result = connection.execute(s, user_id=12)
result = connection.execute(s, {"user_id": 12})
:func:`_expression.text` is also used for the construction
of a full, standalone statement using plain text.
@@ -1705,9 +1731,7 @@ def tuple_(
from sqlalchemy import tuple_
tuple_(table.c.col1, table.c.col2).in_(
[(1, 2), (5, 12), (10, 19)]
)
tuple_(table.c.col1, table.c.col2).in_([(1, 2), (5, 12), (10, 19)])
.. versionchanged:: 1.3.6 Added support for SQLite IN tuples.
@@ -1757,10 +1781,9 @@ def type_coerce(
:meth:`_expression.ColumnElement.label`::
stmt = select(
type_coerce(log_table.date_string, StringDateTime()).label('date')
type_coerce(log_table.date_string, StringDateTime()).label("date")
)
A type that features bound-value handling will also have that behavior
take effect when literal values or :func:`.bindparam` constructs are
passed to :func:`.type_coerce` as targets.
@@ -1821,11 +1844,10 @@ def within_group(
the :meth:`.FunctionElement.within_group` method, e.g.::
from sqlalchemy import within_group
stmt = select(
department.c.id,
func.percentile_cont(0.5).within_group(
department.c.salary.desc()
)
func.percentile_cont(0.5).within_group(department.c.salary.desc()),
)
The above statement would produce SQL similar to

View File

@@ -1,5 +1,5 @@
# sql/_orm_types.py
# Copyright (C) 2022 the SQLAlchemy authors and contributors
# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# sql/_py_util.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# sql/_selectable_constructors.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -12,7 +12,6 @@ from typing import Optional
from typing import overload
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import coercions
@@ -47,6 +46,7 @@ if TYPE_CHECKING:
from ._typing import _T7
from ._typing import _T8
from ._typing import _T9
from ._typing import _TP
from ._typing import _TypedColumnClauseArgument as _TCCA
from .functions import Function
from .selectable import CTE
@@ -55,9 +55,6 @@ if TYPE_CHECKING:
from .selectable import SelectBase
_T = TypeVar("_T", bound=Any)
def alias(
selectable: FromClause, name: Optional[str] = None, flat: bool = False
) -> NamedFromClause:
@@ -106,9 +103,28 @@ def cte(
)
# TODO: mypy requires the _TypedSelectable overloads in all compound select
# constructors since _SelectStatementForCompoundArgument includes
# untyped args that make it return CompoundSelect[Unpack[tuple[Never, ...]]]
# pyright does not have this issue
_TypedSelectable = Union["Select[_TP]", "CompoundSelect[_TP]"]
@overload
def except_(
*selects: _SelectStatementForCompoundArgument,
) -> CompoundSelect:
*selects: _TypedSelectable[_TP],
) -> CompoundSelect[_TP]: ...
@overload
def except_(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]: ...
def except_(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]:
r"""Return an ``EXCEPT`` of multiple selectables.
The returned object is an instance of
@@ -121,9 +137,21 @@ def except_(
return CompoundSelect._create_except(*selects)
@overload
def except_all(
*selects: _SelectStatementForCompoundArgument,
) -> CompoundSelect:
*selects: _TypedSelectable[_TP],
) -> CompoundSelect[_TP]: ...
@overload
def except_all(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]: ...
def except_all(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]:
r"""Return an ``EXCEPT ALL`` of multiple selectables.
The returned object is an instance of
@@ -155,16 +183,16 @@ def exists(
:meth:`_sql.SelectBase.exists` method::
exists_criteria = (
select(table2.c.col2).
where(table1.c.col1 == table2.c.col2).
exists()
select(table2.c.col2).where(table1.c.col1 == table2.c.col2).exists()
)
The EXISTS criteria is then used inside of an enclosing SELECT::
stmt = select(table1.c.col1).where(exists_criteria)
The above statement will then be of the form::
The above statement will then be of the form:
.. sourcecode:: sql
SELECT col1 FROM table1 WHERE EXISTS
(SELECT table2.col2 FROM table2 WHERE table2.col2 = table1.col1)
@@ -181,9 +209,21 @@ def exists(
return Exists(__argument)
@overload
def intersect(
*selects: _SelectStatementForCompoundArgument,
) -> CompoundSelect:
*selects: _TypedSelectable[_TP],
) -> CompoundSelect[_TP]: ...
@overload
def intersect(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]: ...
def intersect(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]:
r"""Return an ``INTERSECT`` of multiple selectables.
The returned object is an instance of
@@ -196,9 +236,21 @@ def intersect(
return CompoundSelect._create_intersect(*selects)
@overload
def intersect_all(
*selects: _SelectStatementForCompoundArgument,
) -> CompoundSelect:
*selects: _TypedSelectable[_TP],
) -> CompoundSelect[_TP]: ...
@overload
def intersect_all(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]: ...
def intersect_all(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]:
r"""Return an ``INTERSECT ALL`` of multiple selectables.
The returned object is an instance of
@@ -225,11 +277,14 @@ def join(
E.g.::
j = join(user_table, address_table,
user_table.c.id == address_table.c.user_id)
j = join(
user_table, address_table, user_table.c.id == address_table.c.user_id
)
stmt = select(user_table).select_from(j)
would emit SQL along the lines of::
would emit SQL along the lines of:
.. sourcecode:: sql
SELECT user.id, user.name FROM user
JOIN address ON user.id = address.user_id
@@ -263,7 +318,7 @@ def join(
:class:`_expression.Join` - the type of object produced.
"""
""" # noqa: E501
return Join(left, right, onclause, isouter, full)
@@ -330,20 +385,19 @@ def outerjoin(
@overload
def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]:
...
def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: ...
@overload
def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]:
...
def select(
__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
) -> Select[Tuple[_T0, _T1]]: ...
@overload
def select(
__ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
) -> Select[Tuple[_T0, _T1, _T2]]:
...
) -> Select[Tuple[_T0, _T1, _T2]]: ...
@overload
@@ -352,8 +406,7 @@ def select(
__ent1: _TCCA[_T1],
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
) -> Select[Tuple[_T0, _T1, _T2, _T3]]:
...
) -> Select[Tuple[_T0, _T1, _T2, _T3]]: ...
@overload
@@ -363,8 +416,7 @@ def select(
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
__ent4: _TCCA[_T4],
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]:
...
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: ...
@overload
@@ -375,8 +427,7 @@ def select(
__ent3: _TCCA[_T3],
__ent4: _TCCA[_T4],
__ent5: _TCCA[_T5],
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
...
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ...
@overload
@@ -388,8 +439,7 @@ def select(
__ent4: _TCCA[_T4],
__ent5: _TCCA[_T5],
__ent6: _TCCA[_T6],
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
...
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ...
@overload
@@ -402,8 +452,7 @@ def select(
__ent5: _TCCA[_T5],
__ent6: _TCCA[_T6],
__ent7: _TCCA[_T7],
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
...
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ...
@overload
@@ -417,8 +466,7 @@ def select(
__ent6: _TCCA[_T6],
__ent7: _TCCA[_T7],
__ent8: _TCCA[_T8],
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]:
...
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: ...
@overload
@@ -433,16 +481,16 @@ def select(
__ent7: _TCCA[_T7],
__ent8: _TCCA[_T8],
__ent9: _TCCA[_T9],
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]:
...
) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: ...
# END OVERLOADED FUNCTIONS select
@overload
def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]:
...
def select(
*entities: _ColumnsClauseArgument[Any], **__kw: Any
) -> Select[Any]: ...
def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]:
@@ -536,13 +584,14 @@ def tablesample(
from sqlalchemy import func
selectable = people.tablesample(
func.bernoulli(1),
name='alias',
seed=func.random())
func.bernoulli(1), name="alias", seed=func.random()
)
stmt = select(selectable.c.people_id)
Assuming ``people`` with a column ``people_id``, the above
statement would render as::
statement would render as:
.. sourcecode:: sql
SELECT alias.people_id FROM
people AS alias TABLESAMPLE bernoulli(:bernoulli_1)
@@ -560,9 +609,21 @@ def tablesample(
return TableSample._factory(selectable, sampling, name=name, seed=seed)
@overload
def union(
*selects: _SelectStatementForCompoundArgument,
) -> CompoundSelect:
*selects: _TypedSelectable[_TP],
) -> CompoundSelect[_TP]: ...
@overload
def union(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]: ...
def union(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]:
r"""Return a ``UNION`` of multiple selectables.
The returned object is an instance of
@@ -582,9 +643,21 @@ def union(
return CompoundSelect._create_union(*selects)
@overload
def union_all(
*selects: _SelectStatementForCompoundArgument,
) -> CompoundSelect:
*selects: _TypedSelectable[_TP],
) -> CompoundSelect[_TP]: ...
@overload
def union_all(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]: ...
def union_all(
*selects: _SelectStatementForCompoundArgument[_TP],
) -> CompoundSelect[_TP]:
r"""Return a ``UNION ALL`` of multiple selectables.
The returned object is an instance of
@@ -605,28 +678,75 @@ def values(
name: Optional[str] = None,
literal_binds: bool = False,
) -> Values:
r"""Construct a :class:`_expression.Values` construct.
r"""Construct a :class:`_expression.Values` construct representing the
SQL ``VALUES`` clause.
The column expressions and the actual data for
:class:`_expression.Values` are given in two separate steps. The
constructor receives the column expressions typically as
:func:`_expression.column` constructs,
and the data is then passed via the
:meth:`_expression.Values.data` method as a list,
which can be called multiple
times to add more data, e.g.::
The column expressions and the actual data for :class:`_expression.Values`
are given in two separate steps. The constructor receives the column
expressions typically as :func:`_expression.column` constructs, and the
data is then passed via the :meth:`_expression.Values.data` method as a
list, which can be called multiple times to add more data, e.g.::
from sqlalchemy import column
from sqlalchemy import values
from sqlalchemy import Integer
from sqlalchemy import String
value_expr = (
values(
column("id", Integer),
column("name", String),
)
.data([(1, "name1"), (2, "name2")])
.data([(3, "name3")])
)
Would represent a SQL fragment like::
VALUES(1, "name1"), (2, "name2"), (3, "name3")
The :class:`_sql.values` construct has an optional
:paramref:`_sql.values.name` field; when using this field, the
PostgreSQL-specific "named VALUES" clause may be generated::
value_expr = values(
column('id', Integer),
column('name', String),
name="my_values"
).data(
[(1, 'name1'), (2, 'name2'), (3, 'name3')]
column("id", Integer), column("name", String), name="somename"
).data([(1, "name1"), (2, "name2"), (3, "name3")])
When selecting from the above construct, the name and column names will
be listed out using a PostgreSQL-specific syntax::
>>> print(value_expr.select())
SELECT somename.id, somename.name
FROM (VALUES (:param_1, :param_2), (:param_3, :param_4),
(:param_5, :param_6)) AS somename (id, name)
For a more database-agnostic means of SELECTing named columns from a
VALUES expression, the :meth:`.Values.cte` method may be used, which
produces a named CTE with explicit column names against the VALUES
construct within; this syntax works on PostgreSQL, SQLite, and MariaDB::
value_expr = (
values(
column("id", Integer),
column("name", String),
)
.data([(1, "name1"), (2, "name2"), (3, "name3")])
.cte()
)
Rendering as::
>>> print(value_expr.select())
WITH anon_1(id, name) AS
(VALUES (:param_1, :param_2), (:param_3, :param_4), (:param_5, :param_6))
SELECT anon_1.id, anon_1.name
FROM anon_1
.. versionadded:: 2.0.42 Added the :meth:`.Values.cte` method to
:class:`.Values`
:param \*columns: column expressions, typically composed using
:func:`_expression.column` objects.
@@ -638,5 +758,6 @@ def values(
the data values inline in the SQL output, rather than using bound
parameters.
"""
""" # noqa: E501
return Values(*columns, literal_binds=literal_binds, name=name)

View File

@@ -1,5 +1,5 @@
# sql/_typing.py
# Copyright (C) 2022 the SQLAlchemy authors and contributors
# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -11,6 +11,8 @@ import operator
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import Mapping
from typing import NoReturn
from typing import Optional
@@ -51,10 +53,10 @@ if TYPE_CHECKING:
from .elements import SQLCoreOperations
from .elements import TextClause
from .lambdas import LambdaElement
from .roles import ColumnsClauseRole
from .roles import FromClauseRole
from .schema import Column
from .selectable import Alias
from .selectable import CompoundSelect
from .selectable import CTE
from .selectable import FromClause
from .selectable import Join
@@ -68,9 +70,14 @@ if TYPE_CHECKING:
from .sqltypes import TableValueType
from .sqltypes import TupleType
from .type_api import TypeEngine
from ..engine import Connection
from ..engine import Dialect
from ..engine import Engine
from ..engine.mock import MockConnection
from ..util.typing import TypeGuard
_T = TypeVar("_T", bound=Any)
_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_CE = TypeVar("_CE", bound="ColumnElement[Any]")
@@ -78,18 +85,25 @@ _CE = TypeVar("_CE", bound="ColumnElement[Any]")
_CLE = TypeVar("_CLE", bound="ClauseElement")
class _HasClauseElement(Protocol):
class _HasClauseElement(Protocol, Generic[_T_co]):
"""indicates a class that has a __clause_element__() method"""
def __clause_element__(self) -> ColumnsClauseRole:
...
def __clause_element__(self) -> roles.ExpressionElementRole[_T_co]: ...
class _CoreAdapterProto(Protocol):
"""protocol for the ClauseAdapter/ColumnAdapter.traverse() method."""
def __call__(self, obj: _CE) -> _CE:
...
def __call__(self, obj: _CE) -> _CE: ...
class _HasDialect(Protocol):
"""protocol for Engine/Connection-like objects that have dialect
attribute.
"""
@property
def dialect(self) -> Dialect: ...
# match column types that are not ORM entities
@@ -97,6 +111,7 @@ _NOT_ENTITY = TypeVar(
"_NOT_ENTITY",
int,
str,
bool,
"datetime",
"date",
"time",
@@ -106,13 +121,15 @@ _NOT_ENTITY = TypeVar(
"Decimal",
)
_StarOrOne = Literal["*", 1]
_MAYBE_ENTITY = TypeVar(
"_MAYBE_ENTITY",
roles.ColumnsClauseRole,
Literal["*", 1],
_StarOrOne,
Type[Any],
Inspectable[_HasClauseElement],
_HasClauseElement,
Inspectable[_HasClauseElement[Any]],
_HasClauseElement[Any],
)
@@ -126,7 +143,7 @@ _TextCoercedExpressionArgument = Union[
str,
"TextClause",
"ColumnElement[_T]",
_HasClauseElement,
_HasClauseElement[_T],
roles.ExpressionElementRole[_T],
]
@@ -134,10 +151,10 @@ _ColumnsClauseArgument = Union[
roles.TypedColumnsClauseRole[_T],
roles.ColumnsClauseRole,
"SQLCoreOperations[_T]",
Literal["*", 1],
_StarOrOne,
Type[_T],
Inspectable[_HasClauseElement],
_HasClauseElement,
Inspectable[_HasClauseElement[_T]],
_HasClauseElement[_T],
]
"""open-ended SELECT columns clause argument.
@@ -171,9 +188,10 @@ _T9 = TypeVar("_T9", bound=Any)
_ColumnExpressionArgument = Union[
"ColumnElement[_T]",
_HasClauseElement,
_HasClauseElement[_T],
"SQLCoreOperations[_T]",
roles.ExpressionElementRole[_T],
roles.TypedColumnsClauseRole[_T],
Callable[[], "ColumnElement[_T]"],
"LambdaElement",
]
@@ -198,6 +216,12 @@ _ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]]
_ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]]
_ByArgument = Union[
Iterable[_ColumnExpressionOrStrLabelArgument[Any]],
_ColumnExpressionOrStrLabelArgument[Any],
]
"""Used for keyword-based ``order_by`` and ``partition_by`` parameters."""
_InfoType = Dict[Any, Any]
"""the .info dictionary accepted and used throughout Core /ORM"""
@@ -205,8 +229,8 @@ _InfoType = Dict[Any, Any]
_FromClauseArgument = Union[
roles.FromClauseRole,
Type[Any],
Inspectable[_HasClauseElement],
_HasClauseElement,
Inspectable[_HasClauseElement[Any]],
_HasClauseElement[Any],
]
"""A FROM clause, like we would send to select().select_from().
@@ -227,13 +251,15 @@ come from the ORM.
"""
_SelectStatementForCompoundArgument = Union[
"SelectBase", roles.CompoundElementRole
"Select[_TP]",
"CompoundSelect[_TP]",
roles.CompoundElementRole,
]
"""SELECT statement acceptable by ``union()`` and other SQL set operations"""
_DMLColumnArgument = Union[
str,
_HasClauseElement,
_HasClauseElement[Any],
roles.DMLColumnRole,
"SQLCoreOperations[Any]",
]
@@ -264,8 +290,8 @@ _DMLTableArgument = Union[
"Alias",
"CTE",
Type[Any],
Inspectable[_HasClauseElement],
_HasClauseElement,
Inspectable[_HasClauseElement[Any]],
_HasClauseElement[Any],
]
_PropagateAttrsType = util.immutabledict[str, Any]
@@ -278,58 +304,51 @@ _LimitOffsetType = Union[int, _ColumnExpressionArgument[int], None]
_AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]]
_CreateDropBind = Union["Engine", "Connection", "MockConnection"]
if TYPE_CHECKING:
def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]:
...
def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ...
def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]:
...
def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: ...
def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]:
...
def is_named_from_clause(
t: FromClauseRole,
) -> TypeGuard[NamedFromClause]: ...
def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]:
...
def is_column_element(
c: ClauseElement,
) -> TypeGuard[ColumnElement[Any]]: ...
def is_keyed_column_element(
c: ClauseElement,
) -> TypeGuard[KeyedColumnElement[Any]]:
...
) -> TypeGuard[KeyedColumnElement[Any]]: ...
def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]:
...
def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: ...
def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]:
...
def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: ...
def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]:
...
def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: ...
def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]:
...
def is_table_value_type(
t: TypeEngine[Any],
) -> TypeGuard[TableValueType]: ...
def is_selectable(t: Any) -> TypeGuard[Selectable]:
...
def is_selectable(t: Any) -> TypeGuard[Selectable]: ...
def is_select_base(
t: Union[Executable, ReturnsRows]
) -> TypeGuard[SelectBase]:
...
t: Union[Executable, ReturnsRows],
) -> TypeGuard[SelectBase]: ...
def is_select_statement(
t: Union[Executable, ReturnsRows]
) -> TypeGuard[Select[Any]]:
...
t: Union[Executable, ReturnsRows],
) -> TypeGuard[Select[Any]]: ...
def is_table(t: FromClause) -> TypeGuard[TableClause]:
...
def is_table(t: FromClause) -> TypeGuard[TableClause]: ...
def is_subquery(t: FromClause) -> TypeGuard[Subquery]:
...
def is_subquery(t: FromClause) -> TypeGuard[Subquery]: ...
def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]:
...
def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: ...
else:
is_sql_compiler = operator.attrgetter("is_sql")
@@ -357,7 +376,7 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]:
return hasattr(s, "quote")
def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]:
def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement[Any]]:
return hasattr(s, "__clause_element__")
@@ -380,20 +399,17 @@ def _unexpected_kw(methname: str, kw: Dict[str, Any]) -> NoReturn:
@overload
def Nullable(
val: "SQLCoreOperations[_T]",
) -> "SQLCoreOperations[Optional[_T]]":
...
) -> "SQLCoreOperations[Optional[_T]]": ...
@overload
def Nullable(
val: roles.ExpressionElementRole[_T],
) -> roles.ExpressionElementRole[Optional[_T]]:
...
) -> roles.ExpressionElementRole[Optional[_T]]: ...
@overload
def Nullable(val: Type[_T]) -> Type[Optional[_T]]:
...
def Nullable(val: Type[_T]) -> Type[Optional[_T]]: ...
def Nullable(
@@ -417,25 +433,21 @@ def Nullable(
@overload
def NotNullable(
val: "SQLCoreOperations[Optional[_T]]",
) -> "SQLCoreOperations[_T]":
...
) -> "SQLCoreOperations[_T]": ...
@overload
def NotNullable(
val: roles.ExpressionElementRole[Optional[_T]],
) -> roles.ExpressionElementRole[_T]:
...
) -> roles.ExpressionElementRole[_T]: ...
@overload
def NotNullable(val: Type[Optional[_T]]) -> Type[_T]:
...
def NotNullable(val: Type[Optional[_T]]) -> Type[_T]: ...
@overload
def NotNullable(val: Optional[Type[_T]]) -> Type[_T]:
...
def NotNullable(val: Optional[Type[_T]]) -> Type[_T]: ...
def NotNullable(

View File

@@ -1,5 +1,5 @@
# sql/annotation.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -67,16 +67,14 @@ class SupportsAnnotations(ExternallyTraversible):
self,
values: Literal[None] = ...,
clone: bool = ...,
) -> Self:
...
) -> Self: ...
@overload
def _deannotate(
self,
values: Sequence[str] = ...,
clone: bool = ...,
) -> SupportsAnnotations:
...
) -> SupportsAnnotations: ...
def _deannotate(
self,
@@ -99,9 +97,11 @@ class SupportsAnnotations(ExternallyTraversible):
tuple(
(
key,
value._gen_cache_key(anon_map, [])
if isinstance(value, HasCacheKey)
else value,
(
value._gen_cache_key(anon_map, [])
if isinstance(value, HasCacheKey)
else value
),
)
for key, value in [
(key, self._annotations[key])
@@ -119,8 +119,7 @@ class SupportsWrappingAnnotations(SupportsAnnotations):
if TYPE_CHECKING:
@util.ro_non_memoized_property
def entity_namespace(self) -> _EntityNamespace:
...
def entity_namespace(self) -> _EntityNamespace: ...
def _annotate(self, values: _AnnotationDict) -> Self:
"""return a copy of this ClauseElement with annotations
@@ -141,16 +140,14 @@ class SupportsWrappingAnnotations(SupportsAnnotations):
self,
values: Literal[None] = ...,
clone: bool = ...,
) -> Self:
...
) -> Self: ...
@overload
def _deannotate(
self,
values: Sequence[str] = ...,
clone: bool = ...,
) -> SupportsAnnotations:
...
) -> SupportsAnnotations: ...
def _deannotate(
self,
@@ -214,16 +211,14 @@ class SupportsCloneAnnotations(SupportsWrappingAnnotations):
self,
values: Literal[None] = ...,
clone: bool = ...,
) -> Self:
...
) -> Self: ...
@overload
def _deannotate(
self,
values: Sequence[str] = ...,
clone: bool = ...,
) -> SupportsAnnotations:
...
) -> SupportsAnnotations: ...
def _deannotate(
self,
@@ -316,16 +311,14 @@ class Annotated(SupportsAnnotations):
self,
values: Literal[None] = ...,
clone: bool = ...,
) -> Self:
...
) -> Self: ...
@overload
def _deannotate(
self,
values: Sequence[str] = ...,
clone: bool = ...,
) -> Annotated:
...
) -> Annotated: ...
def _deannotate(
self,
@@ -395,9 +388,9 @@ class Annotated(SupportsAnnotations):
# so that the resulting objects are pickleable; additionally, other
# decisions can be made up front about the type of object being annotated
# just once per class rather than per-instance.
annotated_classes: Dict[
Type[SupportsWrappingAnnotations], Type[Annotated]
] = {}
annotated_classes: Dict[Type[SupportsWrappingAnnotations], Type[Annotated]] = (
{}
)
_SA = TypeVar("_SA", bound="SupportsAnnotations")
@@ -487,15 +480,13 @@ def _deep_annotate(
@overload
def _deep_deannotate(
element: Literal[None], values: Optional[Sequence[str]] = None
) -> Literal[None]:
...
) -> Literal[None]: ...
@overload
def _deep_deannotate(
element: _SA, values: Optional[Sequence[str]] = None
) -> _SA:
...
) -> _SA: ...
def _deep_deannotate(

View File

@@ -1,14 +1,12 @@
# sql/base.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""Foundational utilities common to many sql modules.
"""
"""Foundational utilities common to many sql modules."""
from __future__ import annotations
@@ -24,6 +22,7 @@ from typing import Callable
from typing import cast
from typing import Dict
from typing import FrozenSet
from typing import Generator
from typing import Generic
from typing import Iterable
from typing import Iterator
@@ -57,6 +56,7 @@ from .. import util
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
from ..util.typing import Final
from ..util.typing import Protocol
from ..util.typing import Self
from ..util.typing import TypeGuard
@@ -68,11 +68,12 @@ if TYPE_CHECKING:
from ._orm_types import DMLStrategyArgument
from ._orm_types import SynchronizeSessionArgument
from ._typing import _CLE
from .cache_key import CacheKey
from .compiler import SQLCompiler
from .elements import BindParameter
from .elements import ClauseList
from .elements import ColumnClause # noqa
from .elements import ColumnElement
from .elements import KeyedColumnElement
from .elements import NamedColumn
from .elements import SQLCoreOperations
from .elements import TextClause
@@ -81,6 +82,7 @@ if TYPE_CHECKING:
from .selectable import _JoinTargetElement
from .selectable import _SelectIterable
from .selectable import FromClause
from .visitors import anon_map
from ..engine import Connection
from ..engine import CursorResult
from ..engine.interfaces import _CoreMultiExecuteParams
@@ -108,7 +110,7 @@ class _NoArg(Enum):
return f"_NoArg.{self.name}"
NO_ARG = _NoArg.NO_ARG
NO_ARG: Final = _NoArg.NO_ARG
class _NoneName(Enum):
@@ -116,7 +118,7 @@ class _NoneName(Enum):
"""indicate a 'deferred' name that was ultimately the value None."""
_NONE_NAME = _NoneName.NONE_NAME
_NONE_NAME: Final = _NoneName.NONE_NAME
_T = TypeVar("_T", bound=Any)
@@ -151,18 +153,18 @@ class _DefaultDescriptionTuple(NamedTuple):
)
_never_select_column = operator.attrgetter("_omit_from_statements")
_never_select_column: operator.attrgetter[Any] = operator.attrgetter(
"_omit_from_statements"
)
class _EntityNamespace(Protocol):
def __getattr__(self, key: str) -> SQLCoreOperations[Any]:
...
def __getattr__(self, key: str) -> SQLCoreOperations[Any]: ...
class _HasEntityNamespace(Protocol):
@util.ro_non_memoized_property
def entity_namespace(self) -> _EntityNamespace:
...
def entity_namespace(self) -> _EntityNamespace: ...
def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]:
@@ -188,12 +190,12 @@ class Immutable:
__slots__ = ()
_is_immutable = True
_is_immutable: bool = True
def unique_params(self, *optionaldict, **kwargs):
def unique_params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn:
raise NotImplementedError("Immutable objects do not support copying")
def params(self, *optionaldict, **kwargs):
def params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn:
raise NotImplementedError("Immutable objects do not support copying")
def _clone(self: _Self, **kw: Any) -> _Self:
@@ -208,7 +210,7 @@ class Immutable:
class SingletonConstant(Immutable):
"""Represent SQL constants like NULL, TRUE, FALSE"""
_is_singleton_constant = True
_is_singleton_constant: bool = True
_singleton: SingletonConstant
@@ -220,7 +222,7 @@ class SingletonConstant(Immutable):
raise NotImplementedError()
@classmethod
def _create_singleton(cls):
def _create_singleton(cls) -> None:
obj = object.__new__(cls)
obj.__init__() # type: ignore
@@ -261,8 +263,7 @@ _SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType")
class _GenerativeType(compat_typing.Protocol):
def _generate(self) -> Self:
...
def _generate(self) -> Self: ...
def _generative(fn: _Fn) -> _Fn:
@@ -290,17 +291,17 @@ def _generative(fn: _Fn) -> _Fn:
def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]:
msgs = kw.pop("msgs", {})
msgs: Dict[str, str] = kw.pop("msgs", {})
defaults = kw.pop("defaults", {})
defaults: Dict[str, str] = kw.pop("defaults", {})
getters = [
getters: List[Tuple[str, operator.attrgetter[Any], Optional[str]]] = [
(name, operator.attrgetter(name), defaults.get(name, None))
for name in names
]
@util.decorator
def check(fn, *args, **kw):
def check(fn: _Fn, *args: Any, **kw: Any) -> Any:
# make pylance happy by not including "self" in the argument
# list
self = args[0]
@@ -349,12 +350,16 @@ def _cloned_intersection(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]:
The returned set is in terms of the entities present within 'a'.
"""
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
all_overlap: Set[_CLE] = set(_expand_cloned(a)).intersection(
_expand_cloned(b)
)
return {elem for elem in a if all_overlap.intersection(elem._cloned_set)}
def _cloned_difference(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]:
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
all_overlap: Set[_CLE] = set(_expand_cloned(a)).intersection(
_expand_cloned(b)
)
return {
elem for elem in a if not all_overlap.intersection(elem._cloned_set)
}
@@ -366,10 +371,12 @@ class _DialectArgView(MutableMapping[str, Any]):
"""
def __init__(self, obj):
__slots__ = ("obj",)
def __init__(self, obj: DialectKWArgs) -> None:
self.obj = obj
def _key(self, key):
def _key(self, key: str) -> Tuple[str, str]:
try:
dialect, value_key = key.split("_", 1)
except ValueError as err:
@@ -377,7 +384,7 @@ class _DialectArgView(MutableMapping[str, Any]):
else:
return dialect, value_key
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
dialect, value_key = self._key(key)
try:
@@ -387,7 +394,7 @@ class _DialectArgView(MutableMapping[str, Any]):
else:
return opt[value_key]
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
try:
dialect, value_key = self._key(key)
except KeyError as err:
@@ -397,17 +404,17 @@ class _DialectArgView(MutableMapping[str, Any]):
else:
self.obj.dialect_options[dialect][value_key] = value
def __delitem__(self, key):
def __delitem__(self, key: str) -> None:
dialect, value_key = self._key(key)
del self.obj.dialect_options[dialect][value_key]
def __len__(self):
def __len__(self) -> int:
return sum(
len(args._non_defaults)
for args in self.obj.dialect_options.values()
)
def __iter__(self):
def __iter__(self) -> Generator[str, None, None]:
return (
"%s_%s" % (dialect_name, value_name)
for dialect_name in self.obj.dialect_options
@@ -426,31 +433,31 @@ class _DialectArgDict(MutableMapping[str, Any]):
"""
def __init__(self):
self._non_defaults = {}
self._defaults = {}
def __init__(self) -> None:
self._non_defaults: Dict[str, Any] = {}
self._defaults: Dict[str, Any] = {}
def __len__(self):
def __len__(self) -> int:
return len(set(self._non_defaults).union(self._defaults))
def __iter__(self):
def __iter__(self) -> Iterator[str]:
return iter(set(self._non_defaults).union(self._defaults))
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
if key in self._non_defaults:
return self._non_defaults[key]
else:
return self._defaults[key]
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
self._non_defaults[key] = value
def __delitem__(self, key):
def __delitem__(self, key: str) -> None:
del self._non_defaults[key]
@util.preload_module("sqlalchemy.dialects")
def _kw_reg_for_dialect(dialect_name):
def _kw_reg_for_dialect(dialect_name: str) -> Optional[Dict[Any, Any]]:
dialect_cls = util.preloaded.dialects.registry.load(dialect_name)
if dialect_cls.construct_arguments is None:
return None
@@ -472,19 +479,21 @@ class DialectKWArgs:
__slots__ = ()
_dialect_kwargs_traverse_internals = [
_dialect_kwargs_traverse_internals: List[Tuple[str, Any]] = [
("dialect_options", InternalTraversal.dp_dialect_options)
]
@classmethod
def argument_for(cls, dialect_name, argument_name, default):
def argument_for(
cls, dialect_name: str, argument_name: str, default: Any
) -> None:
"""Add a new kind of dialect-specific keyword argument for this class.
E.g.::
Index.argument_for("mydialect", "length", None)
some_index = Index('a', 'b', mydialect_length=5)
some_index = Index("a", "b", mydialect_length=5)
The :meth:`.DialectKWArgs.argument_for` method is a per-argument
way adding extra arguments to the
@@ -514,7 +523,9 @@ class DialectKWArgs:
"""
construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
construct_arg_dictionary: Optional[Dict[Any, Any]] = (
DialectKWArgs._kw_registry[dialect_name]
)
if construct_arg_dictionary is None:
raise exc.ArgumentError(
"Dialect '%s' does have keyword-argument "
@@ -524,8 +535,8 @@ class DialectKWArgs:
construct_arg_dictionary[cls] = {}
construct_arg_dictionary[cls][argument_name] = default
@util.memoized_property
def dialect_kwargs(self):
@property
def dialect_kwargs(self) -> _DialectArgView:
"""A collection of keyword arguments specified as dialect-specific
options to this construct.
@@ -546,26 +557,29 @@ class DialectKWArgs:
return _DialectArgView(self)
@property
def kwargs(self):
def kwargs(self) -> _DialectArgView:
"""A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
return self.dialect_kwargs
_kw_registry = util.PopulateDict(_kw_reg_for_dialect)
_kw_registry: util.PopulateDict[str, Optional[Dict[Any, Any]]] = (
util.PopulateDict(_kw_reg_for_dialect)
)
def _kw_reg_for_dialect_cls(self, dialect_name):
@classmethod
def _kw_reg_for_dialect_cls(cls, dialect_name: str) -> _DialectArgDict:
construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
d = _DialectArgDict()
if construct_arg_dictionary is None:
d._defaults.update({"*": None})
else:
for cls in reversed(self.__class__.__mro__):
for cls in reversed(cls.__mro__):
if cls in construct_arg_dictionary:
d._defaults.update(construct_arg_dictionary[cls])
return d
@util.memoized_property
def dialect_options(self):
def dialect_options(self) -> util.PopulateDict[str, _DialectArgDict]:
"""A collection of keyword arguments specified as dialect-specific
options to this construct.
@@ -573,7 +587,7 @@ class DialectKWArgs:
and ``<argument_name>``. For example, the ``postgresql_where``
argument would be locatable as::
arg = my_object.dialect_options['postgresql']['where']
arg = my_object.dialect_options["postgresql"]["where"]
.. versionadded:: 0.9.2
@@ -583,9 +597,7 @@ class DialectKWArgs:
"""
return util.PopulateDict(
util.portable_instancemethod(self._kw_reg_for_dialect_cls)
)
return util.PopulateDict(self._kw_reg_for_dialect_cls)
def _validate_dialect_kwargs(self, kwargs: Dict[str, Any]) -> None:
# validate remaining kwargs that they all specify DB prefixes
@@ -661,7 +673,9 @@ class CompileState:
_ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
@classmethod
def create_for_statement(cls, statement, compiler, **kw):
def create_for_statement(
cls, statement: Executable, compiler: SQLCompiler, **kw: Any
) -> CompileState:
# factory construction.
if statement._propagate_attrs:
@@ -801,14 +815,11 @@ class _MetaOptions(type):
if TYPE_CHECKING:
def __getattr__(self, key: str) -> Any:
...
def __getattr__(self, key: str) -> Any: ...
def __setattr__(self, key: str, value: Any) -> None:
...
def __setattr__(self, key: str, value: Any) -> None: ...
def __delattr__(self, key: str) -> None:
...
def __delattr__(self, key: str) -> None: ...
class Options(metaclass=_MetaOptions):
@@ -830,7 +841,7 @@ class Options(metaclass=_MetaOptions):
)
super().__init_subclass__()
def __init__(self, **kw):
def __init__(self, **kw: Any) -> None:
self.__dict__.update(kw)
def __add__(self, other):
@@ -855,7 +866,7 @@ class Options(metaclass=_MetaOptions):
return False
return True
def __repr__(self):
def __repr__(self) -> str:
# TODO: fairly inefficient, used only in debugging right now.
return "%s(%s)" % (
@@ -872,7 +883,7 @@ class Options(metaclass=_MetaOptions):
return issubclass(cls, klass)
@hybridmethod
def add_to_element(self, name, value):
def add_to_element(self, name: str, value: str) -> Any:
return self + {name: getattr(self, name) + value}
@hybridmethod
@@ -886,7 +897,7 @@ class Options(metaclass=_MetaOptions):
return cls._state_dict_const
@classmethod
def safe_merge(cls, other):
def safe_merge(cls, other: "Options") -> Any:
d = other._state_dict()
# only support a merge with another object of our class
@@ -912,8 +923,12 @@ class Options(metaclass=_MetaOptions):
@classmethod
def from_execution_options(
cls, key, attrs, exec_options, statement_exec_options
):
cls,
key: str,
attrs: set[str],
exec_options: Mapping[str, Any],
statement_exec_options: Mapping[str, Any],
) -> Tuple["Options", Mapping[str, Any]]:
"""process Options argument in terms of execution options.
@@ -924,11 +939,7 @@ class Options(metaclass=_MetaOptions):
execution_options,
) = QueryContext.default_load_options.from_execution_options(
"_sa_orm_load_options",
{
"populate_existing",
"autoflush",
"yield_per"
},
{"populate_existing", "autoflush", "yield_per"},
execution_options,
statement._execution_options,
)
@@ -956,8 +967,8 @@ class Options(metaclass=_MetaOptions):
result[local] = statement_exec_options[argname]
new_options = existing_options + result
exec_options = util.immutabledict().merge_with(
exec_options, {key: new_options}
exec_options = util.immutabledict(exec_options).merge_with(
{key: new_options}
)
return new_options, exec_options
@@ -966,42 +977,43 @@ class Options(metaclass=_MetaOptions):
if TYPE_CHECKING:
def __getattr__(self, key: str) -> Any:
...
def __getattr__(self, key: str) -> Any: ...
def __setattr__(self, key: str, value: Any) -> None:
...
def __setattr__(self, key: str, value: Any) -> None: ...
def __delattr__(self, key: str) -> None:
...
def __delattr__(self, key: str) -> None: ...
class CacheableOptions(Options, HasCacheKey):
__slots__ = ()
@hybridmethod
def _gen_cache_key_inst(self, anon_map, bindparams):
def _gen_cache_key_inst(
self, anon_map: Any, bindparams: List[BindParameter[Any]]
) -> Optional[Tuple[Any]]:
return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
@_gen_cache_key_inst.classlevel
def _gen_cache_key(cls, anon_map, bindparams):
def _gen_cache_key(
cls, anon_map: "anon_map", bindparams: List[BindParameter[Any]]
) -> Tuple[CacheableOptions, Any]:
return (cls, ())
@hybridmethod
def _generate_cache_key(self):
def _generate_cache_key(self) -> Optional[CacheKey]:
return HasCacheKey._generate_cache_key_for_object(self)
class ExecutableOption(HasCopyInternals):
__slots__ = ()
_annotations = util.EMPTY_DICT
_annotations: _ImmutableExecuteOptions = util.EMPTY_DICT
__visit_name__ = "executable_option"
__visit_name__: str = "executable_option"
_is_has_cache_key = False
_is_has_cache_key: bool = False
_is_core = True
_is_core: bool = True
def _clone(self, **kw):
"""Create a shallow copy of this ExecutableOption."""
@@ -1021,7 +1033,7 @@ class Executable(roles.StatementRole):
supports_execution: bool = True
_execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
_is_default_generator = False
_is_default_generator: bool = False
_with_options: Tuple[ExecutableOption, ...] = ()
_with_context_options: Tuple[
Tuple[Callable[[CompileState], None], Any], ...
@@ -1037,12 +1049,13 @@ class Executable(roles.StatementRole):
("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
]
is_select = False
is_update = False
is_insert = False
is_text = False
is_delete = False
is_dml = False
is_select: bool = False
is_from_statement: bool = False
is_update: bool = False
is_insert: bool = False
is_text: bool = False
is_delete: bool = False
is_dml: bool = False
if TYPE_CHECKING:
__visit_name__: str
@@ -1058,27 +1071,24 @@ class Executable(roles.StatementRole):
**kw: Any,
) -> Tuple[
Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats
]:
...
]: ...
def _execute_on_connection(
self,
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: CoreExecuteOptionsParameter,
) -> CursorResult[Any]:
...
) -> CursorResult[Any]: ...
def _execute_on_scalar(
self,
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
execution_options: CoreExecuteOptionsParameter,
) -> Any:
...
) -> Any: ...
@util.ro_non_memoized_property
def _all_selected_columns(self):
def _all_selected_columns(self) -> _SelectIterable:
raise NotImplementedError()
@property
@@ -1179,13 +1189,12 @@ class Executable(roles.StatementRole):
render_nulls: bool = ...,
is_delete_using: bool = ...,
is_update_from: bool = ...,
preserve_rowcount: bool = False,
**opt: Any,
) -> Self:
...
) -> Self: ...
@overload
def execution_options(self, **opt: Any) -> Self:
...
def execution_options(self, **opt: Any) -> Self: ...
@_generative
def execution_options(self, **kw: Any) -> Self:
@@ -1237,6 +1246,7 @@ class Executable(roles.StatementRole):
from sqlalchemy import event
@event.listens_for(some_engine, "before_execute")
def _process_opt(conn, statement, multiparams, params, execution_options):
"run a SQL function before invoking a statement"
@@ -1338,10 +1348,21 @@ class SchemaEventTarget(event.EventTarget):
self.dispatch.after_parent_attach(self, parent)
class SchemaVisitor(ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
class SchemaVisitable(SchemaEventTarget, visitors.Visitable):
"""Base class for elements that are targets of a :class:`.SchemaVisitor`.
__traverse_options__ = {"schema_visitor": True}
.. versionadded:: 2.0.41
"""
class SchemaVisitor(ClauseVisitor):
"""Define the visiting for ``SchemaItem`` and more
generally ``SchemaVisitable`` objects.
"""
__traverse_options__: Dict[str, Any] = {"schema_visitor": True}
class _SentinelDefaultCharacterization(Enum):
@@ -1366,7 +1387,7 @@ class _SentinelColumnCharacterization(NamedTuple):
_COLKEY = TypeVar("_COLKEY", Union[None, str], str)
_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True)
_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]")
_COL = TypeVar("_COL", bound="ColumnElement[Any]")
class _ColumnMetrics(Generic[_COL_co]):
@@ -1376,7 +1397,7 @@ class _ColumnMetrics(Generic[_COL_co]):
def __init__(
self, collection: ColumnCollection[Any, _COL_co], col: _COL_co
):
) -> None:
self.column = col
# proxy_index being non-empty means it was initialized.
@@ -1386,10 +1407,10 @@ class _ColumnMetrics(Generic[_COL_co]):
for eps_col in col._expanded_proxy_set:
pi[eps_col].add(self)
def get_expanded_proxy_set(self):
def get_expanded_proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
return self.column._expanded_proxy_set
def dispose(self, collection):
def dispose(self, collection: ColumnCollection[_COLKEY, _COL_co]) -> None:
pi = collection._proxy_index
if not pi:
return
@@ -1488,14 +1509,14 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
mean either two columns with the same key, in which case the column
returned by key access is **arbitrary**::
>>> x1, x2 = Column('x', Integer), Column('x', Integer)
>>> x1, x2 = Column("x", Integer), Column("x", Integer)
>>> cc = ColumnCollection(columns=[(x1.name, x1), (x2.name, x2)])
>>> list(cc)
[Column('x', Integer(), table=None),
Column('x', Integer(), table=None)]
>>> cc['x'] is x1
>>> cc["x"] is x1
False
>>> cc['x'] is x2
>>> cc["x"] is x2
True
Or it can also mean the same column multiple times. These cases are
@@ -1522,7 +1543,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
"""
__slots__ = "_collection", "_index", "_colset", "_proxy_index"
__slots__ = ("_collection", "_index", "_colset", "_proxy_index")
_collection: List[Tuple[_COLKEY, _COL_co, _ColumnMetrics[_COL_co]]]
_index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]]
@@ -1591,20 +1612,17 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
return iter([col for _, col, _ in self._collection])
@overload
def __getitem__(self, key: Union[str, int]) -> _COL_co:
...
def __getitem__(self, key: Union[str, int]) -> _COL_co: ...
@overload
def __getitem__(
self, key: Tuple[Union[str, int], ...]
) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]:
...
) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ...
@overload
def __getitem__(
self, key: slice
) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]:
...
) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ...
def __getitem__(
self, key: Union[str, int, slice, Tuple[Union[str, int], ...]]
@@ -1644,7 +1662,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
else:
return True
def compare(self, other: ColumnCollection[Any, Any]) -> bool:
def compare(self, other: ColumnCollection[_COLKEY, _COL_co]) -> bool:
"""Compare this :class:`_expression.ColumnCollection` to another
based on the names of the keys"""
@@ -1657,9 +1675,15 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
def __eq__(self, other: Any) -> bool:
return self.compare(other)
@overload
def get(self, key: str, default: None = None) -> Optional[_COL_co]: ...
@overload
def get(self, key: str, default: _COL) -> Union[_COL_co, _COL]: ...
def get(
self, key: str, default: Optional[_COL_co] = None
) -> Optional[_COL_co]:
self, key: str, default: Optional[_COL] = None
) -> Optional[Union[_COL_co, _COL]]:
"""Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
based on a string key name from this
:class:`_expression.ColumnCollection`."""
@@ -1689,7 +1713,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
:class:`_sql.ColumnCollection`."""
raise NotImplementedError()
def remove(self, column: Any) -> None:
def remove(self, column: Any) -> NoReturn:
raise NotImplementedError()
def update(self, iter_: Any) -> NoReturn:
@@ -1698,7 +1722,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
raise NotImplementedError()
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
__hash__: Optional[int] = None # type: ignore
def _populate_separate_keys(
self, iter_: Iterable[Tuple[_COLKEY, _COL_co]]
@@ -1791,7 +1815,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
return ReadOnlyColumnCollection(self)
def _init_proxy_index(self):
def _init_proxy_index(self) -> None:
"""populate the "proxy index", if empty.
proxy index is added in 2.0 to provide more efficient operation
@@ -1940,16 +1964,15 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
"""
def add(
self, column: ColumnElement[Any], key: Optional[str] = None
def add( # type: ignore[override]
self, column: _NAMEDCOL, key: Optional[str] = None
) -> None:
named_column = cast(_NAMEDCOL, column)
if key is not None and named_column.key != key:
if key is not None and column.key != key:
raise exc.ArgumentError(
"DedupeColumnCollection requires columns be under "
"the same key as their .key"
)
key = named_column.key
key = column.key
if key is None:
raise exc.ArgumentError(
@@ -1959,17 +1982,17 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
if key in self._index:
existing = self._index[key][1]
if existing is named_column:
if existing is column:
return
self.replace(named_column)
self.replace(column)
# pop out memoized proxy_set as this
# operation may very well be occurring
# in a _make_proxy operation
util.memoized_property.reset(named_column, "proxy_set")
util.memoized_property.reset(column, "proxy_set")
else:
self._append_new_column(key, named_column)
self._append_new_column(key, column)
def _append_new_column(self, key: str, named_column: _NAMEDCOL) -> None:
l = len(self._collection)
@@ -2011,7 +2034,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
def extend(self, iter_: Iterable[_NAMEDCOL]) -> None:
self._populate_separate_keys((col.key, col) for col in iter_)
def remove(self, column: _NAMEDCOL) -> None:
def remove(self, column: _NAMEDCOL) -> None: # type: ignore[override]
if column not in self._colset:
raise ValueError(
"Can't remove column %r; column is not in this collection"
@@ -2044,8 +2067,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
e.g.::
t = Table('sometable', metadata, Column('col1', Integer))
t.columns.replace(Column('col1', Integer, key='columnone'))
t = Table("sometable", metadata, Column("col1", Integer))
t.columns.replace(Column("col1", Integer, key="columnone"))
will remove the original 'col1' from the collection, and add
the new column under the name 'columnname'.
@@ -2108,17 +2131,17 @@ class ReadOnlyColumnCollection(
):
__slots__ = ("_parent",)
def __init__(self, collection):
def __init__(self, collection: ColumnCollection[_COLKEY, _COL_co]):
object.__setattr__(self, "_parent", collection)
object.__setattr__(self, "_colset", collection._colset)
object.__setattr__(self, "_index", collection._index)
object.__setattr__(self, "_collection", collection._collection)
object.__setattr__(self, "_proxy_index", collection._proxy_index)
def __getstate__(self):
def __getstate__(self) -> Dict[str, _COL_co]:
return {"_parent": self._parent}
def __setstate__(self, state):
def __setstate__(self, state: Dict[str, Any]) -> None:
parent = state["_parent"]
self.__init__(parent) # type: ignore
@@ -2133,10 +2156,10 @@ class ReadOnlyColumnCollection(
class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
def contains_column(self, col):
def contains_column(self, col: ColumnClause[Any]) -> bool:
return col in self
def extend(self, cols):
def extend(self, cols: Iterable[Any]) -> None:
for col in cols:
self.add(col)
@@ -2148,12 +2171,12 @@ class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
l.append(c == local)
return elements.and_(*l)
def __hash__(self):
def __hash__(self) -> int: # type: ignore[override]
return hash(tuple(x for x in self))
def _entity_namespace(
entity: Union[_HasEntityNamespace, ExternallyTraversible]
entity: Union[_HasEntityNamespace, ExternallyTraversible],
) -> _EntityNamespace:
"""Return the nearest .entity_namespace for the given entity.

View File

@@ -1,5 +1,5 @@
# sql/cache_key.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -11,6 +11,7 @@ import enum
from itertools import zip_longest
import typing
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
@@ -36,6 +37,7 @@ from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from .elements import BindParameter
from .elements import ClauseElement
from .elements import ColumnElement
from .visitors import _TraverseInternalsType
from ..engine.interfaces import _CoreSingleExecuteParams
@@ -43,8 +45,7 @@ if typing.TYPE_CHECKING:
class _CacheKeyTraversalDispatchType(Protocol):
def __call__(
s, self: HasCacheKey, visitor: _CacheKeyTraversal
) -> CacheKey:
...
) -> _CacheKeyTraversalDispatchTypeReturn: ...
class CacheConst(enum.Enum):
@@ -75,6 +76,18 @@ class CacheTraverseTarget(enum.Enum):
ANON_NAME,
) = tuple(CacheTraverseTarget)
_CacheKeyTraversalDispatchTypeReturn = Sequence[
Tuple[
str,
Any,
Union[
Callable[..., Tuple[Any, ...]],
CacheTraverseTarget,
InternalTraversal,
],
]
]
class HasCacheKey:
"""Mixin for objects which can produce a cache key.
@@ -290,11 +303,13 @@ class HasCacheKey:
result += (
attrname,
obj["compile_state_plugin"],
obj["plugin_subject"]._gen_cache_key(
anon_map, bindparams
)
if obj["plugin_subject"]
else None,
(
obj["plugin_subject"]._gen_cache_key(
anon_map, bindparams
)
if obj["plugin_subject"]
else None
),
)
elif meth is InternalTraversal.dp_annotations_key:
# obj is here is the _annotations dict. Table uses
@@ -324,7 +339,7 @@ class HasCacheKey:
),
)
else:
result += meth(
result += meth( # type: ignore
attrname, obj, self, anon_map, bindparams
)
return result
@@ -501,7 +516,7 @@ class CacheKey(NamedTuple):
e2,
)
else:
pickup_index = stack.pop(-1)
stack.pop(-1)
break
def _diff(self, other: CacheKey) -> str:
@@ -543,18 +558,17 @@ class CacheKey(NamedTuple):
_anon_map = prefix_anon_map()
return {b.key % _anon_map: b.effective_value for b in self.bindparams}
@util.preload_module("sqlalchemy.sql.elements")
def _apply_params_to_element(
self, original_cache_key: CacheKey, target_element: ClauseElement
) -> ClauseElement:
if target_element._is_immutable:
self, original_cache_key: CacheKey, target_element: ColumnElement[Any]
) -> ColumnElement[Any]:
if target_element._is_immutable or original_cache_key is self:
return target_element
translate = {
k.key: v.value
for k, v in zip(original_cache_key.bindparams, self.bindparams)
}
return target_element.params(translate)
elements = util.preloaded.sql_elements
return elements._OverrideBinds(
target_element, self.bindparams, original_cache_key.bindparams
)
def _ad_hoc_cache_key_from_args(
@@ -606,9 +620,9 @@ class _CacheKeyTraversal(HasTraversalDispatch):
InternalTraversal.dp_memoized_select_entities
)
visit_string = (
visit_boolean
) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
visit_string = visit_boolean = visit_operator = visit_plain_obj = (
CACHE_IN_PLACE
)
visit_statement_hint_list = CACHE_IN_PLACE
visit_type = STATIC_CACHE_KEY
visit_anon_name = ANON_NAME
@@ -655,9 +669,11 @@ class _CacheKeyTraversal(HasTraversalDispatch):
) -> Tuple[Any, ...]:
return (
attrname,
obj._gen_cache_key(anon_map, bindparams)
if isinstance(obj, HasCacheKey)
else obj,
(
obj._gen_cache_key(anon_map, bindparams)
if isinstance(obj, HasCacheKey)
else obj
),
)
def visit_multi_list(
@@ -671,9 +687,11 @@ class _CacheKeyTraversal(HasTraversalDispatch):
return (
attrname,
tuple(
elem._gen_cache_key(anon_map, bindparams)
if isinstance(elem, HasCacheKey)
else elem
(
elem._gen_cache_key(anon_map, bindparams)
if isinstance(elem, HasCacheKey)
else elem
)
for elem in obj
),
)
@@ -834,12 +852,16 @@ class _CacheKeyTraversal(HasTraversalDispatch):
return tuple(
(
target._gen_cache_key(anon_map, bindparams),
onclause._gen_cache_key(anon_map, bindparams)
if onclause is not None
else None,
from_._gen_cache_key(anon_map, bindparams)
if from_ is not None
else None,
(
onclause._gen_cache_key(anon_map, bindparams)
if onclause is not None
else None
),
(
from_._gen_cache_key(anon_map, bindparams)
if from_ is not None
else None
),
tuple([(key, flags[key]) for key in sorted(flags)]),
)
for (target, onclause, from_, flags) in obj
@@ -933,9 +955,11 @@ class _CacheKeyTraversal(HasTraversalDispatch):
tuple(
(
key,
value._gen_cache_key(anon_map, bindparams)
if isinstance(value, HasCacheKey)
else value,
(
value._gen_cache_key(anon_map, bindparams)
if isinstance(value, HasCacheKey)
else value
),
)
for key, value in [(key, obj[key]) for key in sorted(obj)]
),
@@ -981,9 +1005,11 @@ class _CacheKeyTraversal(HasTraversalDispatch):
attrname,
tuple(
(
key._gen_cache_key(anon_map, bindparams)
if hasattr(key, "__clause_element__")
else key,
(
key._gen_cache_key(anon_map, bindparams)
if hasattr(key, "__clause_element__")
else key
),
value._gen_cache_key(anon_map, bindparams),
)
for key, value in obj
@@ -1004,9 +1030,11 @@ class _CacheKeyTraversal(HasTraversalDispatch):
attrname,
tuple(
(
k._gen_cache_key(anon_map, bindparams)
if hasattr(k, "__clause_element__")
else k,
(
k._gen_cache_key(anon_map, bindparams)
if hasattr(k, "__clause_element__")
else k
),
obj[k]._gen_cache_key(anon_map, bindparams),
)
for k in obj

View File

@@ -1,5 +1,5 @@
# sql/coercions.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -29,7 +29,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import operators
from . import roles
from . import visitors
from ._typing import is_from_clause
@@ -58,9 +57,9 @@ if typing.TYPE_CHECKING:
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import DQLDMLClauseElement
from .elements import NamedColumn
from .elements import SQLCoreOperations
from .elements import TextClause
from .schema import Column
from .selectable import _ColumnsClauseElement
from .selectable import _JoinTargetProtocol
@@ -76,7 +75,7 @@ _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole)
_T = TypeVar("_T", bound=Any)
def _is_literal(element):
def _is_literal(element: Any) -> bool:
"""Return whether or not the element is a "literal" in the context
of a SQL expression construct.
@@ -165,8 +164,7 @@ def expect(
role: Type[roles.TruncatedLabelRole],
element: Any,
**kw: Any,
) -> str:
...
) -> str: ...
@overload
@@ -176,8 +174,7 @@ def expect(
*,
as_key: Literal[True] = ...,
**kw: Any,
) -> str:
...
) -> str: ...
@overload
@@ -185,8 +182,7 @@ def expect(
role: Type[roles.LiteralValueRole],
element: Any,
**kw: Any,
) -> BindParameter[Any]:
...
) -> BindParameter[Any]: ...
@overload
@@ -194,8 +190,7 @@ def expect(
role: Type[roles.DDLReferredColumnRole],
element: Any,
**kw: Any,
) -> Column[Any]:
...
) -> Union[Column[Any], str]: ...
@overload
@@ -203,8 +198,7 @@ def expect(
role: Type[roles.DDLConstraintColumnRole],
element: Any,
**kw: Any,
) -> Union[Column[Any], str]:
...
) -> Union[Column[Any], str]: ...
@overload
@@ -212,8 +206,7 @@ def expect(
role: Type[roles.StatementOptionRole],
element: Any,
**kw: Any,
) -> DQLDMLClauseElement:
...
) -> Union[ColumnElement[Any], TextClause]: ...
@overload
@@ -221,8 +214,7 @@ def expect(
role: Type[roles.LabeledColumnExprRole[Any]],
element: _ColumnExpressionArgument[_T],
**kw: Any,
) -> NamedColumn[_T]:
...
) -> NamedColumn[_T]: ...
@overload
@@ -234,8 +226,7 @@ def expect(
],
element: _ColumnExpressionArgument[_T],
**kw: Any,
) -> ColumnElement[_T]:
...
) -> ColumnElement[_T]: ...
@overload
@@ -249,8 +240,7 @@ def expect(
],
element: Any,
**kw: Any,
) -> ColumnElement[Any]:
...
) -> ColumnElement[Any]: ...
@overload
@@ -258,8 +248,7 @@ def expect(
role: Type[roles.DMLTableRole],
element: _DMLTableArgument,
**kw: Any,
) -> _DMLTableElement:
...
) -> _DMLTableElement: ...
@overload
@@ -267,8 +256,7 @@ def expect(
role: Type[roles.HasCTERole],
element: HasCTE,
**kw: Any,
) -> HasCTE:
...
) -> HasCTE: ...
@overload
@@ -276,8 +264,7 @@ def expect(
role: Type[roles.SelectStatementRole],
element: SelectBase,
**kw: Any,
) -> SelectBase:
...
) -> SelectBase: ...
@overload
@@ -285,8 +272,7 @@ def expect(
role: Type[roles.FromClauseRole],
element: _FromClauseArgument,
**kw: Any,
) -> FromClause:
...
) -> FromClause: ...
@overload
@@ -296,8 +282,7 @@ def expect(
*,
explicit_subquery: Literal[True] = ...,
**kw: Any,
) -> Subquery:
...
) -> Subquery: ...
@overload
@@ -305,8 +290,7 @@ def expect(
role: Type[roles.ColumnsClauseRole],
element: _ColumnsClauseArgument[Any],
**kw: Any,
) -> _ColumnsClauseElement:
...
) -> _ColumnsClauseElement: ...
@overload
@@ -314,8 +298,7 @@ def expect(
role: Type[roles.JoinTargetRole],
element: _JoinTargetProtocol,
**kw: Any,
) -> _JoinTargetProtocol:
...
) -> _JoinTargetProtocol: ...
# catchall for not-yet-implemented overloads
@@ -324,8 +307,7 @@ def expect(
role: Type[_SR],
element: Any,
**kw: Any,
) -> Any:
...
) -> Any: ...
def expect(
@@ -510,6 +492,7 @@ class RoleImpl:
element: Any,
argname: Optional[str] = None,
resolved: Optional[Any] = None,
*,
advice: Optional[str] = None,
code: Optional[str] = None,
err: Optional[Exception] = None,
@@ -612,7 +595,7 @@ def _no_text_coercion(
class _NoTextCoercion(RoleImpl):
__slots__ = ()
def _literal_coercion(self, element, argname=None, **kw):
def _literal_coercion(self, element, *, argname=None, **kw):
if isinstance(element, str) and issubclass(
elements.TextClause, self._role_class
):
@@ -630,7 +613,7 @@ class _CoerceLiterals(RoleImpl):
def _text_coercion(self, element, argname=None):
return _no_text_coercion(element, argname)
def _literal_coercion(self, element, argname=None, **kw):
def _literal_coercion(self, element, *, argname=None, **kw):
if isinstance(element, str):
if self._coerce_star and element == "*":
return elements.ColumnClause("*", is_literal=True)
@@ -658,7 +641,8 @@ class LiteralValueImpl(RoleImpl):
self,
element,
resolved,
argname,
argname=None,
*,
type_=None,
literal_execute=False,
**kw,
@@ -676,7 +660,7 @@ class LiteralValueImpl(RoleImpl):
literal_execute=literal_execute,
)
def _literal_coercion(self, element, argname=None, type_=None, **kw):
def _literal_coercion(self, element, **kw):
return element
@@ -688,6 +672,7 @@ class _SelectIsNotFrom(RoleImpl):
element: Any,
argname: Optional[str] = None,
resolved: Optional[Any] = None,
*,
advice: Optional[str] = None,
code: Optional[str] = None,
err: Optional[Exception] = None,
@@ -762,7 +747,7 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl):
__slots__ = ()
def _literal_coercion(
self, element, name=None, type_=None, argname=None, is_crud=False, **kw
self, element, *, name=None, type_=None, is_crud=False, **kw
):
if (
element is None
@@ -804,15 +789,22 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl):
class BinaryElementImpl(ExpressionElementImpl, RoleImpl):
__slots__ = ()
def _literal_coercion(
self, element, expr, operator, bindparam_type=None, argname=None, **kw
def _literal_coercion( # type: ignore[override]
self,
element,
*,
expr,
operator,
bindparam_type=None,
argname=None,
**kw,
):
try:
return expr._bind_param(operator, element, type_=bindparam_type)
except exc.ArgumentError as err:
self._raise_for_expected(element, err=err)
def _post_coercion(self, resolved, expr, bindparam_type=None, **kw):
def _post_coercion(self, resolved, *, expr, bindparam_type=None, **kw):
if resolved.type._isnull and not expr.type._isnull:
resolved = resolved._with_binary_element_type(
bindparam_type if bindparam_type is not None else expr.type
@@ -850,31 +842,32 @@ class InElementImpl(RoleImpl):
% (elem.__class__.__name__)
)
def _literal_coercion(self, element, expr, operator, **kw):
if isinstance(element, collections_abc.Iterable) and not isinstance(
element, str
):
@util.preload_module("sqlalchemy.sql.elements")
def _literal_coercion(self, element, *, expr, operator, **kw): # type: ignore[override] # noqa: E501
if util.is_non_string_iterable(element):
non_literal_expressions: Dict[
Optional[operators.ColumnOperators],
operators.ColumnOperators,
Optional[_ColumnExpressionArgument[Any]],
_ColumnExpressionArgument[Any],
] = {}
element = list(element)
for o in element:
if not _is_literal(o):
if not isinstance(o, operators.ColumnOperators):
if not isinstance(
o, util.preloaded.sql_elements.ColumnElement
) and not hasattr(o, "__clause_element__"):
self._raise_for_expected(element, **kw)
else:
non_literal_expressions[o] = o
elif o is None:
non_literal_expressions[o] = elements.Null()
if non_literal_expressions:
return elements.ClauseList(
*[
non_literal_expressions[o]
if o in non_literal_expressions
else expr._bind_param(operator, o)
(
non_literal_expressions[o]
if o in non_literal_expressions
else expr._bind_param(operator, o)
)
for o in element
]
)
@@ -884,7 +877,7 @@ class InElementImpl(RoleImpl):
else:
self._raise_for_expected(element, **kw)
def _post_coercion(self, element, expr, operator, **kw):
def _post_coercion(self, element, *, expr, operator, **kw):
if element._is_select_base:
# for IN, we are doing scalar_subquery() coercion without
# a warning
@@ -910,12 +903,10 @@ class OnClauseImpl(_ColumnCoercions, RoleImpl):
_coerce_consts = True
def _literal_coercion(
self, element, name=None, type_=None, argname=None, is_crud=False, **kw
):
def _literal_coercion(self, element, **kw):
self._raise_for_expected(element)
def _post_coercion(self, resolved, original_element=None, **kw):
def _post_coercion(self, resolved, *, original_element=None, **kw):
# this is a hack right now as we want to use coercion on an
# ORM InstrumentedAttribute, but we want to return the object
# itself if it is one, not its clause element.
@@ -1000,7 +991,7 @@ class GroupByImpl(ByOfImpl, RoleImpl):
class DMLColumnImpl(_ReturnsStringKey, RoleImpl):
__slots__ = ()
def _post_coercion(self, element, as_key=False, **kw):
def _post_coercion(self, element, *, as_key=False, **kw):
if as_key:
return element.key
else:
@@ -1010,7 +1001,7 @@ class DMLColumnImpl(_ReturnsStringKey, RoleImpl):
class ConstExprImpl(RoleImpl):
__slots__ = ()
def _literal_coercion(self, element, argname=None, **kw):
def _literal_coercion(self, element, *, argname=None, **kw):
if element is None:
return elements.Null()
elif element is False:
@@ -1036,7 +1027,7 @@ class TruncatedLabelImpl(_StringOnly, RoleImpl):
else:
self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, argname=None, **kw):
def _literal_coercion(self, element, **kw):
"""coerce the given value to :class:`._truncated_label`.
Existing :class:`._truncated_label` and
@@ -1086,7 +1077,9 @@ class LimitOffsetImpl(RoleImpl):
else:
self._raise_for_expected(element, argname, resolved)
def _literal_coercion(self, element, name, type_, **kw):
def _literal_coercion( # type: ignore[override]
self, element, *, name, type_, **kw
):
if element is None:
return None
else:
@@ -1128,7 +1121,7 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl):
_guess_straight_column = re.compile(r"^\w\S*$", re.I)
def _raise_for_expected(
self, element, argname=None, resolved=None, advice=None, **kw
self, element, argname=None, resolved=None, *, advice=None, **kw
):
if not advice and isinstance(element, list):
advice = (
@@ -1152,9 +1145,9 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl):
% {
"column": util.ellipses_string(element),
"argname": "for argument %s" % (argname,) if argname else "",
"literal_column": "literal_column"
if guess_is_literal
else "column",
"literal_column": (
"literal_column" if guess_is_literal else "column"
),
}
)
@@ -1166,7 +1159,9 @@ class ReturnsRowsImpl(RoleImpl):
class StatementImpl(_CoerceLiterals, RoleImpl):
__slots__ = ()
def _post_coercion(self, resolved, original_element, argname=None, **kw):
def _post_coercion(
self, resolved, *, original_element, argname=None, **kw
):
if resolved is not original_element and not isinstance(
original_element, str
):
@@ -1232,7 +1227,7 @@ class JoinTargetImpl(RoleImpl):
_skip_clauseelement_for_target_match = True
def _literal_coercion(self, element, argname=None, **kw):
def _literal_coercion(self, element, *, argname=None, **kw):
self._raise_for_expected(element, argname)
def _implicit_coercions(
@@ -1240,6 +1235,7 @@ class JoinTargetImpl(RoleImpl):
element: Any,
resolved: Any,
argname: Optional[str] = None,
*,
legacy: bool = False,
**kw: Any,
) -> Any:
@@ -1273,6 +1269,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
element: Any,
resolved: Any,
argname: Optional[str] = None,
*,
explicit_subquery: bool = False,
allow_select: bool = True,
**kw: Any,
@@ -1294,7 +1291,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
else:
self._raise_for_expected(element, argname, resolved)
def _post_coercion(self, element, deannotate=False, **kw):
def _post_coercion(self, element, *, deannotate=False, **kw):
if deannotate:
return element._deannotate()
else:
@@ -1309,7 +1306,7 @@ class StrictFromClauseImpl(FromClauseImpl):
element: Any,
resolved: Any,
argname: Optional[str] = None,
explicit_subquery: bool = False,
*,
allow_select: bool = False,
**kw: Any,
) -> Any:
@@ -1329,7 +1326,7 @@ class StrictFromClauseImpl(FromClauseImpl):
class AnonymizedFromClauseImpl(StrictFromClauseImpl):
__slots__ = ()
def _post_coercion(self, element, flat=False, name=None, **kw):
def _post_coercion(self, element, *, flat=False, name=None, **kw):
assert name is None
return element._anonymous_fromclause(flat=flat)

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
# sql/crud.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -241,7 +241,7 @@ def _get_crud_params(
stmt_parameter_tuples = list(spd.items())
spd_str_key = {_column_as_key(key) for key in spd}
else:
stmt_parameter_tuples = spd = spd_str_key = None
stmt_parameter_tuples = spd_str_key = None
# if we have statement parameters - set defaults in the
# compiled params
@@ -332,6 +332,52 @@ def _get_crud_params(
.difference(check_columns)
)
if check:
if dml.isupdate(compile_state):
tables_mentioned = set(
c.table
for c, v in stmt_parameter_tuples
if isinstance(c, ColumnClause) and c.table is not None
).difference([compile_state.dml_table])
multi_not_in_from = tables_mentioned.difference(
compile_state._extra_froms
)
if tables_mentioned and (
not compile_state.is_multitable
or not compiler.render_table_with_column_in_update_from
):
if not compiler.render_table_with_column_in_update_from:
preamble = (
"Backend does not support additional "
"tables in the SET clause"
)
else:
preamble = (
"Statement is not a multi-table UPDATE statement"
)
raise exc.CompileError(
f"{preamble}; cannot "
f"""include columns from table(s) {
", ".join(f"'{t.description}'"
for t in tables_mentioned)
} in SET clause"""
)
elif multi_not_in_from:
assert compiler.render_table_with_column_in_update_from
raise exc.CompileError(
f"Multi-table UPDATE statement does not include "
"table(s) "
f"""{
", ".join(
f"'{t.description}'" for
t in multi_not_in_from)
}"""
)
raise exc.CompileError(
"Unconsumed column names: %s"
% (", ".join("%s" % (c,) for c in check))
@@ -393,9 +439,9 @@ def _create_bind_param(
process: Literal[True] = ...,
required: bool = False,
name: Optional[str] = None,
force_anonymous: bool = False,
**kw: Any,
) -> str:
...
) -> str: ...
@overload
@@ -404,8 +450,7 @@ def _create_bind_param(
col: ColumnElement[Any],
value: Any,
**kw: Any,
) -> str:
...
) -> str: ...
def _create_bind_param(
@@ -415,10 +460,14 @@ def _create_bind_param(
process: bool = True,
required: bool = False,
name: Optional[str] = None,
force_anonymous: bool = False,
**kw: Any,
) -> Union[str, elements.BindParameter[Any]]:
if name is None:
if force_anonymous:
name = None
elif name is None:
name = col.key
bindparam = elements.BindParameter(
name, value, type_=col.type, required=required
)
@@ -488,7 +537,7 @@ def _key_getters_for_crud_column(
)
def _column_as_key(
key: Union[ColumnClause[Any], str]
key: Union[ColumnClause[Any], str],
) -> Union[str, Tuple[str, str]]:
str_key = c_key_role(key)
if hasattr(key, "table") and key.table in _et:
@@ -834,6 +883,7 @@ def _append_param_parameter(
):
value = parameters.pop(col_key)
has_visiting_cte = kw.get("visiting_cte") is not None
col_value = compiler.preparer.format_column(
c, use_table=compile_state.include_table_with_column_exprs
)
@@ -859,11 +909,14 @@ def _append_param_parameter(
c,
value,
required=value is REQUIRED,
name=_col_bind_name(c)
if not _compile_state_isinsert(compile_state)
or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
name=(
_col_bind_name(c)
if not _compile_state_isinsert(compile_state)
or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c)
),
accumulate_bind_names=accumulated_bind_names,
force_anonymous=has_visiting_cte,
**kw,
)
elif value._is_bind_parameter:
@@ -884,10 +937,12 @@ def _append_param_parameter(
compiler,
c,
value,
name=_col_bind_name(c)
if not _compile_state_isinsert(compile_state)
or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
name=(
_col_bind_name(c)
if not _compile_state_isinsert(compile_state)
or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c)
),
accumulate_bind_names=accumulated_bind_names,
**kw,
)
@@ -1213,8 +1268,7 @@ def _create_insert_prefetch_bind_param(
c: ColumnElement[Any],
process: Literal[True] = ...,
**kw: Any,
) -> str:
...
) -> str: ...
@overload
@@ -1223,8 +1277,7 @@ def _create_insert_prefetch_bind_param(
c: ColumnElement[Any],
process: Literal[False],
**kw: Any,
) -> elements.BindParameter[Any]:
...
) -> elements.BindParameter[Any]: ...
def _create_insert_prefetch_bind_param(
@@ -1247,8 +1300,7 @@ def _create_update_prefetch_bind_param(
c: ColumnElement[Any],
process: Literal[True] = ...,
**kw: Any,
) -> str:
...
) -> str: ...
@overload
@@ -1257,8 +1309,7 @@ def _create_update_prefetch_bind_param(
c: ColumnElement[Any],
process: Literal[False],
**kw: Any,
) -> elements.BindParameter[Any]:
...
) -> elements.BindParameter[Any]: ...
def _create_update_prefetch_bind_param(
@@ -1288,7 +1339,7 @@ class _multiparam_column(elements.ColumnElement[Any]):
def compare(self, other, **kw):
raise NotImplementedError()
def _copy_internals(self, other, **kw):
def _copy_internals(self, **kw):
raise NotImplementedError()
def __eq__(self, other):
@@ -1364,9 +1415,28 @@ def _get_update_multitable_params(
affected_tables = set()
for t in compile_state._extra_froms:
# extra gymnastics to support the probably-shouldnt-have-supported
# case of "UPDATE table AS alias SET table.foo = bar", but it's
# supported
we_shouldnt_be_here_if_columns_found = (
not include_table
and not compile_state.dml_table.is_derived_from(t)
)
for c in t.c:
if c in normalized_params:
if we_shouldnt_be_here_if_columns_found:
raise exc.CompileError(
"Backend does not support additional tables "
"in the SET "
"clause; cannot include columns from table(s) "
f"'{t.description}' in "
"SET clause"
)
affected_tables.add(t)
check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
@@ -1392,6 +1462,7 @@ def _get_update_multitable_params(
value = compiler.process(value.self_group(), **kw)
accumulated_bind_names = ()
values.append((c, col_value, value, accumulated_bind_names))
# determine tables which are actually to be updated - process onupdate
# and server_onupdate for these
for t in affected_tables:
@@ -1437,6 +1508,7 @@ def _extend_values_for_multiparams(
values_0 = initial_values
values = [initial_values]
has_visiting_cte = kw.get("visiting_cte") is not None
mp = compile_state._multi_parameters
assert mp is not None
for i, row in enumerate(mp[1:]):
@@ -1453,7 +1525,8 @@ def _extend_values_for_multiparams(
compiler,
col,
row[key],
name="%s_m%d" % (col.key, i + 1),
name=("%s_m%d" % (col.key, i + 1)),
force_anonymous=has_visiting_cte,
**kw,
)
else:

View File

@@ -1,5 +1,5 @@
# sql/ddl.py
# Copyright (C) 2009-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2009-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -17,11 +17,14 @@ import contextlib
import typing
from typing import Any
from typing import Callable
from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence as typing_Sequence
from typing import Tuple
from typing import TypeVar
from typing import Union
from . import roles
from .base import _generative
@@ -38,10 +41,12 @@ if typing.TYPE_CHECKING:
from .compiler import Compiled
from .compiler import DDLCompiler
from .elements import BindParameter
from .schema import Column
from .schema import Constraint
from .schema import ForeignKeyConstraint
from .schema import Index
from .schema import SchemaItem
from .schema import Sequence
from .schema import Sequence as Sequence # noqa: F401
from .schema import Table
from .selectable import TableClause
from ..engine.base import Connection
@@ -50,6 +55,8 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import Dialect
from ..engine.interfaces import SchemaTranslateMapType
_SI = TypeVar("_SI", bound=Union["SchemaItem", str])
class BaseDDLElement(ClauseElement):
"""The root of DDL constructs, including those that are sub-elements
@@ -87,7 +94,7 @@ class DDLIfCallable(Protocol):
def __call__(
self,
ddl: BaseDDLElement,
target: SchemaItem,
target: Union[SchemaItem, str],
bind: Optional[Connection],
tables: Optional[List[Table]] = None,
state: Optional[Any] = None,
@@ -95,8 +102,7 @@ class DDLIfCallable(Protocol):
dialect: Dialect,
compiler: Optional[DDLCompiler] = ...,
checkfirst: bool,
) -> bool:
...
) -> bool: ...
class DDLIf(typing.NamedTuple):
@@ -107,7 +113,7 @@ class DDLIf(typing.NamedTuple):
def _should_execute(
self,
ddl: BaseDDLElement,
target: SchemaItem,
target: Union[SchemaItem, str],
bind: Optional[Connection],
compiler: Optional[DDLCompiler] = None,
**kw: Any,
@@ -156,8 +162,8 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement):
event.listen(
users,
'after_create',
AddConstraint(constraint).execute_if(dialect='postgresql')
"after_create",
AddConstraint(constraint).execute_if(dialect="postgresql"),
)
.. seealso::
@@ -173,7 +179,7 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement):
"""
_ddl_if: Optional[DDLIf] = None
target: Optional[SchemaItem] = None
target: Union[SchemaItem, str, None] = None
def _execute_on_connection(
self, connection, distilled_params, execution_options
@@ -232,20 +238,20 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement):
Used to provide a wrapper for event listening::
event.listen(
metadata,
'before_create',
DDL("my_ddl").execute_if(dialect='postgresql')
)
metadata,
"before_create",
DDL("my_ddl").execute_if(dialect="postgresql"),
)
:param dialect: May be a string or tuple of strings.
If a string, it will be compared to the name of the
executing database dialect::
DDL('something').execute_if(dialect='postgresql')
DDL("something").execute_if(dialect="postgresql")
If a tuple, specifies multiple dialect names::
DDL('something').execute_if(dialect=('postgresql', 'mysql'))
DDL("something").execute_if(dialect=("postgresql", "mysql"))
:param callable\_: A callable, which will be invoked with
three positional arguments as well as optional keyword
@@ -343,17 +349,19 @@ class DDL(ExecutableDDLElement):
from sqlalchemy import event, DDL
tbl = Table('users', metadata, Column('uid', Integer))
event.listen(tbl, 'before_create', DDL('DROP TRIGGER users_trigger'))
tbl = Table("users", metadata, Column("uid", Integer))
event.listen(tbl, "before_create", DDL("DROP TRIGGER users_trigger"))
spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE')
event.listen(tbl, 'after_create', spow.execute_if(dialect='somedb'))
spow = DDL("ALTER TABLE %(table)s SET secretpowers TRUE")
event.listen(tbl, "after_create", spow.execute_if(dialect="somedb"))
drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE')
drop_spow = DDL("ALTER TABLE users SET secretpowers FALSE")
connection.execute(drop_spow)
When operating on Table events, the following ``statement``
string substitutions are available::
string substitutions are available:
.. sourcecode:: text
%(table)s - the Table name, with any required quoting applied
%(schema)s - the schema name, with any required quoting applied
@@ -414,7 +422,7 @@ class DDL(ExecutableDDLElement):
)
class _CreateDropBase(ExecutableDDLElement):
class _CreateDropBase(ExecutableDDLElement, Generic[_SI]):
"""Base class for DDL constructs that represent CREATE and DROP or
equivalents.
@@ -424,15 +432,15 @@ class _CreateDropBase(ExecutableDDLElement):
"""
def __init__(
self,
element,
):
element: _SI
def __init__(self, element: _SI) -> None:
self.element = self.target = element
self._ddl_if = getattr(element, "_ddl_if", None)
@property
def stringify_dialect(self):
def stringify_dialect(self): # type: ignore[override]
assert not isinstance(self.element, str)
return self.element.create_drop_stringify_dialect
def _create_rule_disable(self, compiler):
@@ -446,19 +454,19 @@ class _CreateDropBase(ExecutableDDLElement):
return False
class _CreateBase(_CreateDropBase):
def __init__(self, element, if_not_exists=False):
class _CreateBase(_CreateDropBase[_SI]):
def __init__(self, element: _SI, if_not_exists: bool = False) -> None:
super().__init__(element)
self.if_not_exists = if_not_exists
class _DropBase(_CreateDropBase):
def __init__(self, element, if_exists=False):
class _DropBase(_CreateDropBase[_SI]):
def __init__(self, element: _SI, if_exists: bool = False) -> None:
super().__init__(element)
self.if_exists = if_exists
class CreateSchema(_CreateBase):
class CreateSchema(_CreateBase[str]):
"""Represent a CREATE SCHEMA statement.
The argument here is the string name of the schema.
@@ -471,15 +479,15 @@ class CreateSchema(_CreateBase):
def __init__(
self,
name,
if_not_exists=False,
):
name: str,
if_not_exists: bool = False,
) -> None:
"""Create a new :class:`.CreateSchema` construct."""
super().__init__(element=name, if_not_exists=if_not_exists)
class DropSchema(_DropBase):
class DropSchema(_DropBase[str]):
"""Represent a DROP SCHEMA statement.
The argument here is the string name of the schema.
@@ -492,17 +500,17 @@ class DropSchema(_DropBase):
def __init__(
self,
name,
cascade=False,
if_exists=False,
):
name: str,
cascade: bool = False,
if_exists: bool = False,
) -> None:
"""Create a new :class:`.DropSchema` construct."""
super().__init__(element=name, if_exists=if_exists)
self.cascade = cascade
class CreateTable(_CreateBase):
class CreateTable(_CreateBase["Table"]):
"""Represent a CREATE TABLE statement."""
__visit_name__ = "create_table"
@@ -514,7 +522,7 @@ class CreateTable(_CreateBase):
typing_Sequence[ForeignKeyConstraint]
] = None,
if_not_exists: bool = False,
):
) -> None:
"""Create a :class:`.CreateTable` construct.
:param element: a :class:`_schema.Table` that's the subject
@@ -536,7 +544,7 @@ class CreateTable(_CreateBase):
self.include_foreign_key_constraints = include_foreign_key_constraints
class _DropView(_DropBase):
class _DropView(_DropBase["Table"]):
"""Semi-public 'DROP VIEW' construct.
Used by the test suite for dialect-agnostic drops of views.
@@ -548,7 +556,9 @@ class _DropView(_DropBase):
class CreateConstraint(BaseDDLElement):
def __init__(self, element: Constraint):
element: Constraint
def __init__(self, element: Constraint) -> None:
self.element = element
@@ -569,6 +579,7 @@ class CreateColumn(BaseDDLElement):
from sqlalchemy import schema
from sqlalchemy.ext.compiler import compiles
@compiles(schema.CreateColumn)
def compile(element, compiler, **kw):
column = element.element
@@ -577,9 +588,9 @@ class CreateColumn(BaseDDLElement):
return compiler.visit_create_column(element, **kw)
text = "%s SPECIAL DIRECTIVE %s" % (
column.name,
compiler.type_compiler.process(column.type)
)
column.name,
compiler.type_compiler.process(column.type),
)
default = compiler.get_column_default_string(column)
if default is not None:
text += " DEFAULT " + default
@@ -589,8 +600,8 @@ class CreateColumn(BaseDDLElement):
if column.constraints:
text += " ".join(
compiler.process(const)
for const in column.constraints)
compiler.process(const) for const in column.constraints
)
return text
The above construct can be applied to a :class:`_schema.Table`
@@ -601,17 +612,21 @@ class CreateColumn(BaseDDLElement):
metadata = MetaData()
table = Table('mytable', MetaData(),
Column('x', Integer, info={"special":True}, primary_key=True),
Column('y', String(50)),
Column('z', String(20), info={"special":True})
)
table = Table(
"mytable",
MetaData(),
Column("x", Integer, info={"special": True}, primary_key=True),
Column("y", String(50)),
Column("z", String(20), info={"special": True}),
)
metadata.create_all(conn)
Above, the directives we've added to the :attr:`_schema.Column.info`
collection
will be detected by our custom compilation scheme::
will be detected by our custom compilation scheme:
.. sourcecode:: sql
CREATE TABLE mytable (
x SPECIAL DIRECTIVE INTEGER NOT NULL,
@@ -636,18 +651,21 @@ class CreateColumn(BaseDDLElement):
from sqlalchemy.schema import CreateColumn
@compiles(CreateColumn, "postgresql")
def skip_xmin(element, compiler, **kw):
if element.element.name == 'xmin':
if element.element.name == "xmin":
return None
else:
return compiler.visit_create_column(element, **kw)
my_table = Table('mytable', metadata,
Column('id', Integer, primary_key=True),
Column('xmin', Integer)
)
my_table = Table(
"mytable",
metadata,
Column("id", Integer, primary_key=True),
Column("xmin", Integer),
)
Above, a :class:`.CreateTable` construct will generate a ``CREATE TABLE``
which only includes the ``id`` column in the string; the ``xmin`` column
@@ -657,16 +675,18 @@ class CreateColumn(BaseDDLElement):
__visit_name__ = "create_column"
def __init__(self, element):
element: Column[Any]
def __init__(self, element: Column[Any]) -> None:
self.element = element
class DropTable(_DropBase):
class DropTable(_DropBase["Table"]):
"""Represent a DROP TABLE statement."""
__visit_name__ = "drop_table"
def __init__(self, element: Table, if_exists: bool = False):
def __init__(self, element: Table, if_exists: bool = False) -> None:
"""Create a :class:`.DropTable` construct.
:param element: a :class:`_schema.Table` that's the subject
@@ -681,30 +701,24 @@ class DropTable(_DropBase):
super().__init__(element, if_exists=if_exists)
class CreateSequence(_CreateBase):
class CreateSequence(_CreateBase["Sequence"]):
"""Represent a CREATE SEQUENCE statement."""
__visit_name__ = "create_sequence"
def __init__(self, element: Sequence, if_not_exists: bool = False):
super().__init__(element, if_not_exists=if_not_exists)
class DropSequence(_DropBase):
class DropSequence(_DropBase["Sequence"]):
"""Represent a DROP SEQUENCE statement."""
__visit_name__ = "drop_sequence"
def __init__(self, element: Sequence, if_exists: bool = False):
super().__init__(element, if_exists=if_exists)
class CreateIndex(_CreateBase):
class CreateIndex(_CreateBase["Index"]):
"""Represent a CREATE INDEX statement."""
__visit_name__ = "create_index"
def __init__(self, element, if_not_exists=False):
def __init__(self, element: Index, if_not_exists: bool = False) -> None:
"""Create a :class:`.Createindex` construct.
:param element: a :class:`_schema.Index` that's the subject
@@ -718,12 +732,12 @@ class CreateIndex(_CreateBase):
super().__init__(element, if_not_exists=if_not_exists)
class DropIndex(_DropBase):
class DropIndex(_DropBase["Index"]):
"""Represent a DROP INDEX statement."""
__visit_name__ = "drop_index"
def __init__(self, element, if_exists=False):
def __init__(self, element: Index, if_exists: bool = False) -> None:
"""Create a :class:`.DropIndex` construct.
:param element: a :class:`_schema.Index` that's the subject
@@ -737,38 +751,88 @@ class DropIndex(_DropBase):
super().__init__(element, if_exists=if_exists)
class AddConstraint(_CreateBase):
class AddConstraint(_CreateBase["Constraint"]):
"""Represent an ALTER TABLE ADD CONSTRAINT statement."""
__visit_name__ = "add_constraint"
def __init__(self, element):
def __init__(
self,
element: Constraint,
*,
isolate_from_table: bool = True,
) -> None:
"""Construct a new :class:`.AddConstraint` construct.
:param element: a :class:`.Constraint` object
:param isolate_from_table: optional boolean, defaults to True. Has
the effect of the incoming constraint being isolated from being
included in a CREATE TABLE sequence when associated with a
:class:`.Table`.
.. versionadded:: 2.0.39 - added
:paramref:`.AddConstraint.isolate_from_table`, defaulting
to True. Previously, the behavior of this parameter was implicitly
turned on in all cases.
"""
super().__init__(element)
element._create_rule = util.portable_instancemethod(
self._create_rule_disable
)
if isolate_from_table:
element._create_rule = util.portable_instancemethod(
self._create_rule_disable
)
class DropConstraint(_DropBase):
class DropConstraint(_DropBase["Constraint"]):
"""Represent an ALTER TABLE DROP CONSTRAINT statement."""
__visit_name__ = "drop_constraint"
def __init__(self, element, cascade=False, if_exists=False, **kw):
def __init__(
self,
element: Constraint,
*,
cascade: bool = False,
if_exists: bool = False,
isolate_from_table: bool = True,
**kw: Any,
) -> None:
"""Construct a new :class:`.DropConstraint` construct.
:param element: a :class:`.Constraint` object
:param cascade: optional boolean, indicates backend-specific
"CASCADE CONSTRAINT" directive should be rendered if available
:param if_exists: optional boolean, indicates backend-specific
"IF EXISTS" directive should be rendered if available
:param isolate_from_table: optional boolean, defaults to True. Has
the effect of the incoming constraint being isolated from being
included in a CREATE TABLE sequence when associated with a
:class:`.Table`.
.. versionadded:: 2.0.39 - added
:paramref:`.DropConstraint.isolate_from_table`, defaulting
to True. Previously, the behavior of this parameter was implicitly
turned on in all cases.
"""
self.cascade = cascade
super().__init__(element, if_exists=if_exists, **kw)
element._create_rule = util.portable_instancemethod(
self._create_rule_disable
)
if isolate_from_table:
element._create_rule = util.portable_instancemethod(
self._create_rule_disable
)
class SetTableComment(_CreateDropBase):
class SetTableComment(_CreateDropBase["Table"]):
"""Represent a COMMENT ON TABLE IS statement."""
__visit_name__ = "set_table_comment"
class DropTableComment(_CreateDropBase):
class DropTableComment(_CreateDropBase["Table"]):
"""Represent a COMMENT ON TABLE '' statement.
Note this varies a lot across database backends.
@@ -778,33 +842,34 @@ class DropTableComment(_CreateDropBase):
__visit_name__ = "drop_table_comment"
class SetColumnComment(_CreateDropBase):
class SetColumnComment(_CreateDropBase["Column[Any]"]):
"""Represent a COMMENT ON COLUMN IS statement."""
__visit_name__ = "set_column_comment"
class DropColumnComment(_CreateDropBase):
class DropColumnComment(_CreateDropBase["Column[Any]"]):
"""Represent a COMMENT ON COLUMN IS NULL statement."""
__visit_name__ = "drop_column_comment"
class SetConstraintComment(_CreateDropBase):
class SetConstraintComment(_CreateDropBase["Constraint"]):
"""Represent a COMMENT ON CONSTRAINT IS statement."""
__visit_name__ = "set_constraint_comment"
class DropConstraintComment(_CreateDropBase):
class DropConstraintComment(_CreateDropBase["Constraint"]):
"""Represent a COMMENT ON CONSTRAINT IS NULL statement."""
__visit_name__ = "drop_constraint_comment"
class InvokeDDLBase(SchemaVisitor):
def __init__(self, connection):
def __init__(self, connection, **kw):
self.connection = connection
assert not kw, f"Unexpected keywords: {kw.keys()}"
@contextlib.contextmanager
def with_ddl_events(self, target, **kw):
@@ -1021,10 +1086,12 @@ class SchemaDropper(InvokeDropDDLBase):
reversed(
sort_tables_and_constraints(
unsorted_tables,
filter_fn=lambda constraint: False
if not self.dialect.supports_alter
or constraint.name is None
else None,
filter_fn=lambda constraint: (
False
if not self.dialect.supports_alter
or constraint.name is None
else None
),
)
)
)

View File

@@ -1,12 +1,11 @@
# sql/default_comparator.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Default implementation of SQL comparison operations.
"""
"""Default implementation of SQL comparison operations."""
from __future__ import annotations
@@ -56,7 +55,6 @@ def _boolean_compare(
negate_op: Optional[OperatorType] = None,
reverse: bool = False,
_python_is_types: Tuple[Type[Any], ...] = (type(None), bool),
_any_all_expr: bool = False,
result_type: Optional[TypeEngine[bool]] = None,
**kwargs: Any,
) -> OperatorExpression[bool]:
@@ -90,7 +88,7 @@ def _boolean_compare(
negate=negate_op,
modifiers=kwargs,
)
elif _any_all_expr:
elif expr._is_collection_aggregate:
obj = coercions.expect(
roles.ConstExprRole, element=obj, operator=op, expr=expr
)
@@ -248,7 +246,7 @@ def _unsupported_impl(
expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any
) -> NoReturn:
raise NotImplementedError(
"Operator '%s' is not supported on " "this expression" % op.__name__
"Operator '%s' is not supported on this expression" % op.__name__
)
@@ -297,9 +295,11 @@ def _match_impl(
operator=operators.match_op,
),
result_type=type_api.MATCHTYPE,
negate_op=operators.not_match_op
if op is operators.match_op
else operators.match_op,
negate_op=(
operators.not_match_op
if op is operators.match_op
else operators.match_op
),
**kw,
)
@@ -341,9 +341,11 @@ def _between_impl(
group=False,
),
op,
negate=operators.not_between_op
if op is operators.between_op
else operators.between_op,
negate=(
operators.not_between_op
if op is operators.between_op
else operators.between_op
),
modifiers=kw,
)

View File

@@ -1,5 +1,5 @@
# sql/dml.py
# Copyright (C) 2009-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2009-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -23,6 +23,7 @@ from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
@@ -42,6 +43,7 @@ from .base import _from_objects
from .base import _generative
from .base import _select_iterables
from .base import ColumnCollection
from .base import ColumnSet
from .base import CompileState
from .base import DialectKWArgs
from .base import Executable
@@ -91,14 +93,11 @@ if TYPE_CHECKING:
from .selectable import Select
from .selectable import Selectable
def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]:
...
def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: ...
def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]:
...
def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]: ...
def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]:
...
def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]: ...
else:
isupdate = operator.attrgetter("isupdate")
@@ -137,9 +136,11 @@ class DMLState(CompileState):
@classmethod
def get_entity_description(cls, statement: UpdateBase) -> Dict[str, Any]:
return {
"name": statement.table.name
if is_named_from_clause(statement.table)
else None,
"name": (
statement.table.name
if is_named_from_clause(statement.table)
else None
),
"table": statement.table,
}
@@ -163,8 +164,7 @@ class DMLState(CompileState):
if TYPE_CHECKING:
@classmethod
def get_plugin_class(cls, statement: Executable) -> Type[DMLState]:
...
def get_plugin_class(cls, statement: Executable) -> Type[DMLState]: ...
@classmethod
def _get_multi_crud_kv_pairs(
@@ -190,13 +190,15 @@ class DMLState(CompileState):
return [
(
coercions.expect(roles.DMLColumnRole, k),
v
if not needs_to_be_cacheable
else coercions.expect(
roles.ExpressionElementRole,
v,
type_=NullType(),
is_crud=True,
(
v
if not needs_to_be_cacheable
else coercions.expect(
roles.ExpressionElementRole,
v,
type_=NullType(),
is_crud=True,
)
),
)
for k, v in kv_iterator
@@ -306,12 +308,14 @@ class InsertDMLState(DMLState):
def _process_multi_values(self, statement: ValuesBase) -> None:
for parameters in statement._multi_values:
multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [
{
c.key: value
for c, value in zip(statement.table.c, parameter_set)
}
if isinstance(parameter_set, collections_abc.Sequence)
else parameter_set
(
{
c.key: value
for c, value in zip(statement.table.c, parameter_set)
}
if isinstance(parameter_set, collections_abc.Sequence)
else parameter_set
)
for parameter_set in parameters
]
@@ -396,9 +400,9 @@ class UpdateBase(
__visit_name__ = "update_base"
_hints: util.immutabledict[
Tuple[_DMLTableElement, str], str
] = util.EMPTY_DICT
_hints: util.immutabledict[Tuple[_DMLTableElement, str], str] = (
util.EMPTY_DICT
)
named_with_column = False
_label_style: SelectLabelStyle = (
@@ -407,19 +411,25 @@ class UpdateBase(
table: _DMLTableElement
_return_defaults = False
_return_defaults_columns: Optional[
Tuple[_ColumnsClauseElement, ...]
] = None
_return_defaults_columns: Optional[Tuple[_ColumnsClauseElement, ...]] = (
None
)
_supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None
_returning: Tuple[_ColumnsClauseElement, ...] = ()
is_dml = True
def _generate_fromclause_column_proxies(
self, fromclause: FromClause
self,
fromclause: FromClause,
columns: ColumnCollection[str, KeyedColumnElement[Any]],
primary_key: ColumnSet,
foreign_keys: Set[KeyedColumnElement[Any]],
) -> None:
fromclause._columns._populate_separate_keys(
col._make_proxy(fromclause)
columns._populate_separate_keys(
col._make_proxy(
fromclause, primary_key=primary_key, foreign_keys=foreign_keys
)
for col in self._all_selected_columns
if is_column_element(col)
)
@@ -523,11 +533,11 @@ class UpdateBase(
E.g.::
stmt = table.insert().values(data='newdata').return_defaults()
stmt = table.insert().values(data="newdata").return_defaults()
result = connection.execute(stmt)
server_created_at = result.returned_defaults['created_at']
server_created_at = result.returned_defaults["created_at"]
When used against an UPDATE statement
:meth:`.UpdateBase.return_defaults` instead looks for columns that
@@ -685,6 +695,16 @@ class UpdateBase(
return self
def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
"""Return ``True`` if this :class:`.ReturnsRows` is
'derived' from the given :class:`.FromClause`.
Since these are DMLs, we dont want such statements ever being adapted
so we return False for derives.
"""
return False
@_generative
def returning(
self,
@@ -1030,7 +1050,7 @@ class ValuesBase(UpdateBase):
users.insert().values(name="some name")
users.update().where(users.c.id==5).values(name="some name")
users.update().where(users.c.id == 5).values(name="some name")
:param \*args: As an alternative to passing key/value parameters,
a dictionary, tuple, or list of dictionaries or tuples can be passed
@@ -1060,13 +1080,17 @@ class ValuesBase(UpdateBase):
this syntax is supported on backends such as SQLite, PostgreSQL,
MySQL, but not necessarily others::
users.insert().values([
{"name": "some name"},
{"name": "some other name"},
{"name": "yet another name"},
])
users.insert().values(
[
{"name": "some name"},
{"name": "some other name"},
{"name": "yet another name"},
]
)
The above form would render a multiple VALUES statement similar to::
The above form would render a multiple VALUES statement similar to:
.. sourcecode:: sql
INSERT INTO users (name) VALUES
(:name_1),
@@ -1244,7 +1268,7 @@ class Insert(ValuesBase):
e.g.::
sel = select(table1.c.a, table1.c.b).where(table1.c.c > 5)
ins = table2.insert().from_select(['a', 'b'], sel)
ins = table2.insert().from_select(["a", "b"], sel)
:param names: a sequence of string column names or
:class:`_schema.Column`
@@ -1295,8 +1319,7 @@ class Insert(ValuesBase):
@overload
def returning(
self, __ent0: _TCCA[_T0], *, sort_by_parameter_order: bool = False
) -> ReturningInsert[Tuple[_T0]]:
...
) -> ReturningInsert[Tuple[_T0]]: ...
@overload
def returning(
@@ -1305,8 +1328,7 @@ class Insert(ValuesBase):
__ent1: _TCCA[_T1],
*,
sort_by_parameter_order: bool = False,
) -> ReturningInsert[Tuple[_T0, _T1]]:
...
) -> ReturningInsert[Tuple[_T0, _T1]]: ...
@overload
def returning(
@@ -1316,8 +1338,7 @@ class Insert(ValuesBase):
__ent2: _TCCA[_T2],
*,
sort_by_parameter_order: bool = False,
) -> ReturningInsert[Tuple[_T0, _T1, _T2]]:
...
) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: ...
@overload
def returning(
@@ -1328,8 +1349,7 @@ class Insert(ValuesBase):
__ent3: _TCCA[_T3],
*,
sort_by_parameter_order: bool = False,
) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]:
...
) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: ...
@overload
def returning(
@@ -1341,8 +1361,7 @@ class Insert(ValuesBase):
__ent4: _TCCA[_T4],
*,
sort_by_parameter_order: bool = False,
) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]:
...
) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: ...
@overload
def returning(
@@ -1355,8 +1374,7 @@ class Insert(ValuesBase):
__ent5: _TCCA[_T5],
*,
sort_by_parameter_order: bool = False,
) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
...
) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ...
@overload
def returning(
@@ -1370,8 +1388,7 @@ class Insert(ValuesBase):
__ent6: _TCCA[_T6],
*,
sort_by_parameter_order: bool = False,
) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
...
) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ...
@overload
def returning(
@@ -1386,8 +1403,9 @@ class Insert(ValuesBase):
__ent7: _TCCA[_T7],
*,
sort_by_parameter_order: bool = False,
) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
...
) -> ReturningInsert[
Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]
]: ...
# END OVERLOADED FUNCTIONS self.returning
@@ -1397,16 +1415,14 @@ class Insert(ValuesBase):
*cols: _ColumnsClauseArgument[Any],
sort_by_parameter_order: bool = False,
**__kw: Any,
) -> ReturningInsert[Any]:
...
) -> ReturningInsert[Any]: ...
def returning(
self,
*cols: _ColumnsClauseArgument[Any],
sort_by_parameter_order: bool = False,
**__kw: Any,
) -> ReturningInsert[Any]:
...
) -> ReturningInsert[Any]: ...
class ReturningInsert(Insert, TypedReturnsRows[_TP]):
@@ -1541,9 +1557,7 @@ class Update(DMLWhereBase, ValuesBase):
E.g.::
stmt = table.update().ordered_values(
("name", "ed"), ("ident": "foo")
)
stmt = table.update().ordered_values(("name", "ed"), ("ident", "foo"))
.. seealso::
@@ -1556,7 +1570,7 @@ class Update(DMLWhereBase, ValuesBase):
:paramref:`_expression.update.preserve_parameter_order`
parameter, which will be removed in SQLAlchemy 2.0.
"""
""" # noqa: E501
if self._values:
raise exc.ArgumentError(
"This statement already has values present"
@@ -1596,20 +1610,19 @@ class Update(DMLWhereBase, ValuesBase):
# statically generated** by tools/generate_tuple_map_overloads.py
@overload
def returning(self, __ent0: _TCCA[_T0]) -> ReturningUpdate[Tuple[_T0]]:
...
def returning(
self, __ent0: _TCCA[_T0]
) -> ReturningUpdate[Tuple[_T0]]: ...
@overload
def returning(
self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
) -> ReturningUpdate[Tuple[_T0, _T1]]:
...
) -> ReturningUpdate[Tuple[_T0, _T1]]: ...
@overload
def returning(
self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]:
...
) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: ...
@overload
def returning(
@@ -1618,8 +1631,7 @@ class Update(DMLWhereBase, ValuesBase):
__ent1: _TCCA[_T1],
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]:
...
) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: ...
@overload
def returning(
@@ -1629,8 +1641,7 @@ class Update(DMLWhereBase, ValuesBase):
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
__ent4: _TCCA[_T4],
) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]:
...
) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: ...
@overload
def returning(
@@ -1641,8 +1652,7 @@ class Update(DMLWhereBase, ValuesBase):
__ent3: _TCCA[_T3],
__ent4: _TCCA[_T4],
__ent5: _TCCA[_T5],
) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
...
) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ...
@overload
def returning(
@@ -1654,8 +1664,7 @@ class Update(DMLWhereBase, ValuesBase):
__ent4: _TCCA[_T4],
__ent5: _TCCA[_T5],
__ent6: _TCCA[_T6],
) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
...
) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ...
@overload
def returning(
@@ -1668,21 +1677,20 @@ class Update(DMLWhereBase, ValuesBase):
__ent5: _TCCA[_T5],
__ent6: _TCCA[_T6],
__ent7: _TCCA[_T7],
) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
...
) -> ReturningUpdate[
Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]
]: ...
# END OVERLOADED FUNCTIONS self.returning
@overload
def returning(
self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
) -> ReturningUpdate[Any]:
...
) -> ReturningUpdate[Any]: ...
def returning(
self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
) -> ReturningUpdate[Any]:
...
) -> ReturningUpdate[Any]: ...
class ReturningUpdate(Update, TypedReturnsRows[_TP]):
@@ -1734,20 +1742,19 @@ class Delete(DMLWhereBase, UpdateBase):
# statically generated** by tools/generate_tuple_map_overloads.py
@overload
def returning(self, __ent0: _TCCA[_T0]) -> ReturningDelete[Tuple[_T0]]:
...
def returning(
self, __ent0: _TCCA[_T0]
) -> ReturningDelete[Tuple[_T0]]: ...
@overload
def returning(
self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
) -> ReturningDelete[Tuple[_T0, _T1]]:
...
) -> ReturningDelete[Tuple[_T0, _T1]]: ...
@overload
def returning(
self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
) -> ReturningDelete[Tuple[_T0, _T1, _T2]]:
...
) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: ...
@overload
def returning(
@@ -1756,8 +1763,7 @@ class Delete(DMLWhereBase, UpdateBase):
__ent1: _TCCA[_T1],
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]:
...
) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: ...
@overload
def returning(
@@ -1767,8 +1773,7 @@ class Delete(DMLWhereBase, UpdateBase):
__ent2: _TCCA[_T2],
__ent3: _TCCA[_T3],
__ent4: _TCCA[_T4],
) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]:
...
) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: ...
@overload
def returning(
@@ -1779,8 +1784,7 @@ class Delete(DMLWhereBase, UpdateBase):
__ent3: _TCCA[_T3],
__ent4: _TCCA[_T4],
__ent5: _TCCA[_T5],
) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
...
) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ...
@overload
def returning(
@@ -1792,8 +1796,7 @@ class Delete(DMLWhereBase, UpdateBase):
__ent4: _TCCA[_T4],
__ent5: _TCCA[_T5],
__ent6: _TCCA[_T6],
) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
...
) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ...
@overload
def returning(
@@ -1806,21 +1809,20 @@ class Delete(DMLWhereBase, UpdateBase):
__ent5: _TCCA[_T5],
__ent6: _TCCA[_T6],
__ent7: _TCCA[_T7],
) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
...
) -> ReturningDelete[
Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]
]: ...
# END OVERLOADED FUNCTIONS self.returning
@overload
def returning(
self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
) -> ReturningDelete[Any]:
...
) -> ReturningDelete[Any]: ...
def returning(
self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
) -> ReturningDelete[Any]:
...
) -> ReturningDelete[Any]: ...
class ReturningDelete(Update, TypedReturnsRows[_TP]):

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
# sqlalchemy/sql/events.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# sql/events.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -63,13 +63,14 @@ class DDLEvents(event.Events[SchemaEventTarget]):
from sqlalchemy import Table, Column, Metadata, Integer
m = MetaData()
some_table = Table('some_table', m, Column('data', Integer))
some_table = Table("some_table", m, Column("data", Integer))
@event.listens_for(some_table, "after_create")
def after_create(target, connection, **kw):
connection.execute(text(
"ALTER TABLE %s SET name=foo_%s" % (target.name, target.name)
))
connection.execute(
text("ALTER TABLE %s SET name=foo_%s" % (target.name, target.name))
)
some_engine = create_engine("postgresql://scott:tiger@host/test")
@@ -127,10 +128,11 @@ class DDLEvents(event.Events[SchemaEventTarget]):
as listener callables::
from sqlalchemy import DDL
event.listen(
some_table,
"after_create",
DDL("ALTER TABLE %(table)s SET name=foo_%(table)s")
DDL("ALTER TABLE %(table)s SET name=foo_%(table)s"),
)
**Event Propagation to MetaData Copies**
@@ -149,7 +151,7 @@ class DDLEvents(event.Events[SchemaEventTarget]):
some_table,
"after_create",
DDL("ALTER TABLE %(table)s SET name=foo_%(table)s"),
propagate=True
propagate=True,
)
new_metadata = MetaData()
@@ -169,7 +171,7 @@ class DDLEvents(event.Events[SchemaEventTarget]):
:ref:`schema_ddl_sequences`
"""
""" # noqa: E501
_target_class_doc = "SomeSchemaClassOrObject"
_dispatch_target = SchemaEventTarget
@@ -358,16 +360,17 @@ class DDLEvents(event.Events[SchemaEventTarget]):
metadata = MetaData()
@event.listens_for(metadata, 'column_reflect')
@event.listens_for(metadata, "column_reflect")
def receive_column_reflect(inspector, table, column_info):
# receives for all Table objects that are reflected
# under this MetaData
...
# will use the above event hook
my_table = Table("my_table", metadata, autoload_with=some_engine)
.. versionadded:: 1.4.0b2 The :meth:`_events.DDLEvents.column_reflect`
hook may now be applied to a :class:`_schema.MetaData` object as
well as the :class:`_schema.MetaData` class itself where it will
@@ -379,9 +382,11 @@ class DDLEvents(event.Events[SchemaEventTarget]):
from sqlalchemy import Table
@event.listens_for(Table, 'column_reflect')
@event.listens_for(Table, "column_reflect")
def receive_column_reflect(inspector, table, column_info):
# receives for all Table objects that are reflected
...
It can also be applied to a specific :class:`_schema.Table` at the
point that one is being reflected using the
@@ -390,9 +395,7 @@ class DDLEvents(event.Events[SchemaEventTarget]):
t1 = Table(
"my_table",
autoload_with=some_engine,
listeners=[
('column_reflect', receive_column_reflect)
]
listeners=[("column_reflect", receive_column_reflect)],
)
The dictionary of column information as returned by the

View File

@@ -1,14 +1,11 @@
# sql/expression.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Defines the public namespace for SQL expression constructs.
"""
"""Defines the public namespace for SQL expression constructs."""
from __future__ import annotations

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
# sql/lambdas.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -256,10 +256,7 @@ class LambdaElement(elements.ClauseElement):
self.closure_cache_key = cache_key
try:
rec = lambda_cache[tracker_key + cache_key]
except KeyError:
rec = None
rec = lambda_cache.get(tracker_key + cache_key)
else:
cache_key = _cache_key.NO_CACHE
rec = None
@@ -278,7 +275,7 @@ class LambdaElement(elements.ClauseElement):
rec = AnalyzedFunction(
tracker, self, apply_propagate_attrs, fn
)
rec.closure_bindparams = bindparams
rec.closure_bindparams = list(bindparams)
lambda_cache[key] = rec
else:
rec = lambda_cache[key]
@@ -303,7 +300,9 @@ class LambdaElement(elements.ClauseElement):
while lambda_element is not None:
rec = lambda_element._rec
if rec.bindparam_trackers:
tracker_instrumented_fn = rec.tracker_instrumented_fn
tracker_instrumented_fn = (
rec.tracker_instrumented_fn # type:ignore [union-attr] # noqa: E501
)
for tracker in rec.bindparam_trackers:
tracker(
lambda_element.fn,
@@ -407,9 +406,9 @@ class LambdaElement(elements.ClauseElement):
while parent is not None:
assert parent.closure_cache_key is not CacheConst.NO_CACHE
parent_closure_cache_key: Tuple[
Any, ...
] = parent.closure_cache_key
parent_closure_cache_key: Tuple[Any, ...] = (
parent.closure_cache_key
)
cache_key = (
(parent.fn.__code__,) + parent_closure_cache_key + cache_key
@@ -437,7 +436,7 @@ class DeferredLambdaElement(LambdaElement):
def __init__(
self,
fn: _LambdaType,
fn: _AnyLambdaType,
role: Type[roles.SQLRole],
opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
lambda_args: Tuple[Any, ...] = (),
@@ -518,7 +517,6 @@ class StatementLambdaElement(
stmt += lambda s: s.where(table.c.col == parameter)
.. versionadded:: 1.4
.. seealso::
@@ -535,8 +533,7 @@ class StatementLambdaElement(
role: Type[SQLRole],
opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
apply_propagate_attrs: Optional[ClauseElement] = None,
):
...
): ...
def __add__(
self, other: _StmtLambdaElementType[Any]
@@ -559,9 +556,7 @@ class StatementLambdaElement(
... stmt = lambda_stmt(
... lambda: select(table.c.x, table.c.y),
... )
... stmt = stmt.add_criteria(
... lambda: table.c.x > parameter
... )
... stmt = stmt.add_criteria(lambda: table.c.x > parameter)
... return stmt
The :meth:`_sql.StatementLambdaElement.add_criteria` method is
@@ -572,18 +567,15 @@ class StatementLambdaElement(
>>> def my_stmt(self, foo):
... stmt = lambda_stmt(
... lambda: select(func.max(foo.x, foo.y)),
... track_closure_variables=False
... )
... stmt = stmt.add_criteria(
... lambda: self.where_criteria,
... track_on=[self]
... track_closure_variables=False,
... )
... stmt = stmt.add_criteria(lambda: self.where_criteria, track_on=[self])
... return stmt
See :func:`_sql.lambda_stmt` for a description of the parameters
accepted.
"""
""" # noqa: E501
opts = self.opts + dict(
enable_tracking=enable_tracking,
@@ -612,7 +604,7 @@ class StatementLambdaElement(
return self._rec_expected_expr
@property
def _with_options(self):
def _with_options(self): # type: ignore[override]
return self._proxied._with_options
@property
@@ -620,7 +612,7 @@ class StatementLambdaElement(
return self._proxied._effective_plugin_target
@property
def _execution_options(self):
def _execution_options(self): # type: ignore[override]
return self._proxied._execution_options
@property
@@ -628,27 +620,27 @@ class StatementLambdaElement(
return self._proxied._all_selected_columns
@property
def is_select(self):
def is_select(self): # type: ignore[override]
return self._proxied.is_select
@property
def is_update(self):
def is_update(self): # type: ignore[override]
return self._proxied.is_update
@property
def is_insert(self):
def is_insert(self): # type: ignore[override]
return self._proxied.is_insert
@property
def is_text(self):
def is_text(self): # type: ignore[override]
return self._proxied.is_text
@property
def is_delete(self):
def is_delete(self): # type: ignore[override]
return self._proxied.is_delete
@property
def is_dml(self):
def is_dml(self): # type: ignore[override]
return self._proxied.is_dml
def spoil(self) -> NullLambdaStatement:
@@ -737,9 +729,9 @@ class AnalyzedCode:
"closure_trackers",
"build_py_wrappers",
)
_fns: weakref.WeakKeyDictionary[
CodeType, AnalyzedCode
] = weakref.WeakKeyDictionary()
_fns: weakref.WeakKeyDictionary[CodeType, AnalyzedCode] = (
weakref.WeakKeyDictionary()
)
_generation_mutex = threading.RLock()
@@ -1180,16 +1172,16 @@ class AnalyzedFunction:
closure_pywrappers.append(bind)
else:
value = fn.__globals__[name]
new_globals[name] = bind = PyWrapper(fn, name, value)
new_globals[name] = PyWrapper(fn, name, value)
# rewrite the original fn. things that look like they will
# become bound parameters are wrapped in a PyWrapper.
self.tracker_instrumented_fn = (
tracker_instrumented_fn
) = self._rewrite_code_obj(
fn,
[new_closure[name] for name in fn.__code__.co_freevars],
new_globals,
self.tracker_instrumented_fn = tracker_instrumented_fn = (
self._rewrite_code_obj(
fn,
[new_closure[name] for name in fn.__code__.co_freevars],
new_globals,
)
)
# now invoke the function. This will give us a new SQL

View File

@@ -1,15 +1,12 @@
# sqlalchemy/naming.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# sql/naming.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""Establish constraint and index naming conventions.
"""
"""Establish constraint and index naming conventions."""
from __future__ import annotations

View File

@@ -1,5 +1,5 @@
# sql/operators.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -77,8 +77,7 @@ class OperatorType(Protocol):
right: Optional[Any] = None,
*other: Any,
**kwargs: Any,
) -> ColumnElement[Any]:
...
) -> ColumnElement[Any]: ...
@overload
def __call__(
@@ -87,8 +86,7 @@ class OperatorType(Protocol):
right: Optional[Any] = None,
*other: Any,
**kwargs: Any,
) -> Operators:
...
) -> Operators: ...
def __call__(
self,
@@ -96,8 +94,7 @@ class OperatorType(Protocol):
right: Optional[Any] = None,
*other: Any,
**kwargs: Any,
) -> Operators:
...
) -> Operators: ...
add = cast(OperatorType, _uncast_add)
@@ -151,6 +148,7 @@ class Operators:
is equivalent to::
from sqlalchemy import and_
and_(a, b)
Care should be taken when using ``&`` regarding
@@ -175,6 +173,7 @@ class Operators:
is equivalent to::
from sqlalchemy import or_
or_(a, b)
Care should be taken when using ``|`` regarding
@@ -199,6 +198,7 @@ class Operators:
is equivalent to::
from sqlalchemy import not_
not_(a)
"""
@@ -227,7 +227,7 @@ class Operators:
This function can also be used to make bitwise operators explicit. For
example::
somecolumn.op('&')(0xff)
somecolumn.op("&")(0xFF)
is a bitwise AND of the value in ``somecolumn``.
@@ -278,7 +278,7 @@ class Operators:
e.g.::
>>> expr = column('x').op('+', python_impl=lambda a, b: a + b)('y')
>>> expr = column("x").op("+", python_impl=lambda a, b: a + b)("y")
The operator for the above expression will also work for non-SQL
left and right objects::
@@ -392,10 +392,9 @@ class custom_op(OperatorType, Generic[_T]):
from sqlalchemy.sql import operators
from sqlalchemy import Numeric
unary = UnaryExpression(table.c.somecolumn,
modifier=operators.custom_op("!"),
type_=Numeric)
unary = UnaryExpression(
table.c.somecolumn, modifier=operators.custom_op("!"), type_=Numeric
)
.. seealso::
@@ -403,7 +402,7 @@ class custom_op(OperatorType, Generic[_T]):
:meth:`.Operators.bool_op`
"""
""" # noqa: E501
__name__ = "custom_op"
@@ -466,8 +465,7 @@ class custom_op(OperatorType, Generic[_T]):
right: Optional[Any] = None,
*other: Any,
**kwargs: Any,
) -> ColumnElement[Any]:
...
) -> ColumnElement[Any]: ...
@overload
def __call__(
@@ -476,8 +474,7 @@ class custom_op(OperatorType, Generic[_T]):
right: Optional[Any] = None,
*other: Any,
**kwargs: Any,
) -> Operators:
...
) -> Operators: ...
def __call__(
self,
@@ -545,13 +542,11 @@ class ColumnOperators(Operators):
def operate(
self, op: OperatorType, *other: Any, **kwargs: Any
) -> ColumnOperators:
...
) -> ColumnOperators: ...
def reverse_operate(
self, op: OperatorType, other: Any, **kwargs: Any
) -> ColumnOperators:
...
) -> ColumnOperators: ...
def __lt__(self, other: Any) -> ColumnOperators:
"""Implement the ``<`` operator.
@@ -574,8 +569,7 @@ class ColumnOperators(Operators):
# https://docs.python.org/3/reference/datamodel.html#object.__hash__
if TYPE_CHECKING:
def __hash__(self) -> int:
...
def __hash__(self) -> int: ...
else:
__hash__ = Operators.__hash__
@@ -623,8 +617,7 @@ class ColumnOperators(Operators):
# deprecated 1.4; see #5435
if TYPE_CHECKING:
def isnot_distinct_from(self, other: Any) -> ColumnOperators:
...
def isnot_distinct_from(self, other: Any) -> ColumnOperators: ...
else:
isnot_distinct_from = is_not_distinct_from
@@ -707,14 +700,15 @@ class ColumnOperators(Operators):
) -> ColumnOperators:
r"""Implement the ``like`` operator.
In a column context, produces the expression::
In a column context, produces the expression:
.. sourcecode:: sql
a LIKE other
E.g.::
stmt = select(sometable).\
where(sometable.c.column.like("%foobar%"))
stmt = select(sometable).where(sometable.c.column.like("%foobar%"))
:param other: expression to be compared
:param escape: optional escape character, renders the ``ESCAPE``
@@ -734,18 +728,21 @@ class ColumnOperators(Operators):
) -> ColumnOperators:
r"""Implement the ``ilike`` operator, e.g. case insensitive LIKE.
In a column context, produces an expression either of the form::
In a column context, produces an expression either of the form:
.. sourcecode:: sql
lower(a) LIKE lower(other)
Or on backends that support the ILIKE operator::
Or on backends that support the ILIKE operator:
.. sourcecode:: sql
a ILIKE other
E.g.::
stmt = select(sometable).\
where(sometable.c.column.ilike("%foobar%"))
stmt = select(sometable).where(sometable.c.column.ilike("%foobar%"))
:param other: expression to be compared
:param escape: optional escape character, renders the ``ESCAPE``
@@ -757,7 +754,7 @@ class ColumnOperators(Operators):
:meth:`.ColumnOperators.like`
"""
""" # noqa: E501
return self.operate(ilike_op, other, escape=escape)
def bitwise_xor(self, other: Any) -> ColumnOperators:
@@ -851,12 +848,15 @@ class ColumnOperators(Operators):
The given parameter ``other`` may be:
* A list of literal values, e.g.::
* A list of literal values,
e.g.::
stmt.where(column.in_([1, 2, 3]))
In this calling form, the list of items is converted to a set of
bound parameters the same length as the list given::
bound parameters the same length as the list given:
.. sourcecode:: sql
WHERE COL IN (?, ?, ?)
@@ -864,16 +864,20 @@ class ColumnOperators(Operators):
:func:`.tuple_` containing multiple expressions::
from sqlalchemy import tuple_
stmt.where(tuple_(col1, col2).in_([(1, 10), (2, 20), (3, 30)]))
* An empty list, e.g.::
* An empty list,
e.g.::
stmt.where(column.in_([]))
In this calling form, the expression renders an "empty set"
expression. These expressions are tailored to individual backends
and are generally trying to get an empty SELECT statement as a
subquery. Such as on SQLite, the expression is::
subquery. Such as on SQLite, the expression is:
.. sourcecode:: sql
WHERE col IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1)
@@ -883,10 +887,12 @@ class ColumnOperators(Operators):
* A bound parameter, e.g. :func:`.bindparam`, may be used if it
includes the :paramref:`.bindparam.expanding` flag::
stmt.where(column.in_(bindparam('value', expanding=True)))
stmt.where(column.in_(bindparam("value", expanding=True)))
In this calling form, the expression renders a special non-SQL
placeholder expression that looks like::
placeholder expression that looks like:
.. sourcecode:: sql
WHERE COL IN ([EXPANDING_value])
@@ -896,7 +902,9 @@ class ColumnOperators(Operators):
connection.execute(stmt, {"value": [1, 2, 3]})
The database would be passed a bound parameter for each value::
The database would be passed a bound parameter for each value:
.. sourcecode:: sql
WHERE COL IN (?, ?, ?)
@@ -904,7 +912,9 @@ class ColumnOperators(Operators):
If an empty list is passed, a special "empty list" expression,
which is specific to the database in use, is rendered. On
SQLite this would be::
SQLite this would be:
.. sourcecode:: sql
WHERE COL IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1)
@@ -915,13 +925,12 @@ class ColumnOperators(Operators):
correlated scalar select::
stmt.where(
column.in_(
select(othertable.c.y).
where(table.c.x == othertable.c.x)
)
column.in_(select(othertable.c.y).where(table.c.x == othertable.c.x))
)
In this calling form, :meth:`.ColumnOperators.in_` renders as given::
In this calling form, :meth:`.ColumnOperators.in_` renders as given:
.. sourcecode:: sql
WHERE COL IN (SELECT othertable.y
FROM othertable WHERE othertable.x = table.x)
@@ -930,7 +939,7 @@ class ColumnOperators(Operators):
construct, or a :func:`.bindparam` construct that includes the
:paramref:`.bindparam.expanding` flag set to True.
"""
""" # noqa: E501
return self.operate(in_op, other)
def not_in(self, other: Any) -> ColumnOperators:
@@ -964,8 +973,7 @@ class ColumnOperators(Operators):
# deprecated 1.4; see #5429
if TYPE_CHECKING:
def notin_(self, other: Any) -> ColumnOperators:
...
def notin_(self, other: Any) -> ColumnOperators: ...
else:
notin_ = not_in
@@ -994,8 +1002,7 @@ class ColumnOperators(Operators):
def notlike(
self, other: Any, escape: Optional[str] = None
) -> ColumnOperators:
...
) -> ColumnOperators: ...
else:
notlike = not_like
@@ -1024,8 +1031,7 @@ class ColumnOperators(Operators):
def notilike(
self, other: Any, escape: Optional[str] = None
) -> ColumnOperators:
...
) -> ColumnOperators: ...
else:
notilike = not_ilike
@@ -1063,8 +1069,7 @@ class ColumnOperators(Operators):
# deprecated 1.4; see #5429
if TYPE_CHECKING:
def isnot(self, other: Any) -> ColumnOperators:
...
def isnot(self, other: Any) -> ColumnOperators: ...
else:
isnot = is_not
@@ -1078,14 +1083,15 @@ class ColumnOperators(Operators):
r"""Implement the ``startswith`` operator.
Produces a LIKE expression that tests against a match for the start
of a string value::
of a string value:
.. sourcecode:: sql
column LIKE <other> || '%'
E.g.::
stmt = select(sometable).\
where(sometable.c.column.startswith("foobar"))
stmt = select(sometable).where(sometable.c.column.startswith("foobar"))
Since the operator uses ``LIKE``, wildcard characters
``"%"`` and ``"_"`` that are present inside the <other> expression
@@ -1114,7 +1120,9 @@ class ColumnOperators(Operators):
somecolumn.startswith("foo%bar", autoescape=True)
Will render as::
Will render as:
.. sourcecode:: sql
somecolumn LIKE :param || '%' ESCAPE '/'
@@ -1130,7 +1138,9 @@ class ColumnOperators(Operators):
somecolumn.startswith("foo/%bar", escape="^")
Will render as::
Will render as:
.. sourcecode:: sql
somecolumn LIKE :param || '%' ESCAPE '^'
@@ -1150,7 +1160,7 @@ class ColumnOperators(Operators):
:meth:`.ColumnOperators.like`
"""
""" # noqa: E501
return self.operate(
startswith_op, other, escape=escape, autoescape=autoescape
)
@@ -1165,14 +1175,15 @@ class ColumnOperators(Operators):
version of :meth:`.ColumnOperators.startswith`.
Produces a LIKE expression that tests against an insensitive
match for the start of a string value::
match for the start of a string value:
.. sourcecode:: sql
lower(column) LIKE lower(<other>) || '%'
E.g.::
stmt = select(sometable).\
where(sometable.c.column.istartswith("foobar"))
stmt = select(sometable).where(sometable.c.column.istartswith("foobar"))
Since the operator uses ``LIKE``, wildcard characters
``"%"`` and ``"_"`` that are present inside the <other> expression
@@ -1201,7 +1212,9 @@ class ColumnOperators(Operators):
somecolumn.istartswith("foo%bar", autoescape=True)
Will render as::
Will render as:
.. sourcecode:: sql
lower(somecolumn) LIKE lower(:param) || '%' ESCAPE '/'
@@ -1217,7 +1230,9 @@ class ColumnOperators(Operators):
somecolumn.istartswith("foo/%bar", escape="^")
Will render as::
Will render as:
.. sourcecode:: sql
lower(somecolumn) LIKE lower(:param) || '%' ESCAPE '^'
@@ -1232,7 +1247,7 @@ class ColumnOperators(Operators):
.. seealso::
:meth:`.ColumnOperators.startswith`
"""
""" # noqa: E501
return self.operate(
istartswith_op, other, escape=escape, autoescape=autoescape
)
@@ -1246,14 +1261,15 @@ class ColumnOperators(Operators):
r"""Implement the 'endswith' operator.
Produces a LIKE expression that tests against a match for the end
of a string value::
of a string value:
.. sourcecode:: sql
column LIKE '%' || <other>
E.g.::
stmt = select(sometable).\
where(sometable.c.column.endswith("foobar"))
stmt = select(sometable).where(sometable.c.column.endswith("foobar"))
Since the operator uses ``LIKE``, wildcard characters
``"%"`` and ``"_"`` that are present inside the <other> expression
@@ -1282,7 +1298,9 @@ class ColumnOperators(Operators):
somecolumn.endswith("foo%bar", autoescape=True)
Will render as::
Will render as:
.. sourcecode:: sql
somecolumn LIKE '%' || :param ESCAPE '/'
@@ -1298,7 +1316,9 @@ class ColumnOperators(Operators):
somecolumn.endswith("foo/%bar", escape="^")
Will render as::
Will render as:
.. sourcecode:: sql
somecolumn LIKE '%' || :param ESCAPE '^'
@@ -1318,7 +1338,7 @@ class ColumnOperators(Operators):
:meth:`.ColumnOperators.like`
"""
""" # noqa: E501
return self.operate(
endswith_op, other, escape=escape, autoescape=autoescape
)
@@ -1333,14 +1353,15 @@ class ColumnOperators(Operators):
version of :meth:`.ColumnOperators.endswith`.
Produces a LIKE expression that tests against an insensitive match
for the end of a string value::
for the end of a string value:
.. sourcecode:: sql
lower(column) LIKE '%' || lower(<other>)
E.g.::
stmt = select(sometable).\
where(sometable.c.column.iendswith("foobar"))
stmt = select(sometable).where(sometable.c.column.iendswith("foobar"))
Since the operator uses ``LIKE``, wildcard characters
``"%"`` and ``"_"`` that are present inside the <other> expression
@@ -1369,7 +1390,9 @@ class ColumnOperators(Operators):
somecolumn.iendswith("foo%bar", autoescape=True)
Will render as::
Will render as:
.. sourcecode:: sql
lower(somecolumn) LIKE '%' || lower(:param) ESCAPE '/'
@@ -1385,7 +1408,9 @@ class ColumnOperators(Operators):
somecolumn.iendswith("foo/%bar", escape="^")
Will render as::
Will render as:
.. sourcecode:: sql
lower(somecolumn) LIKE '%' || lower(:param) ESCAPE '^'
@@ -1400,7 +1425,7 @@ class ColumnOperators(Operators):
.. seealso::
:meth:`.ColumnOperators.endswith`
"""
""" # noqa: E501
return self.operate(
iendswith_op, other, escape=escape, autoescape=autoescape
)
@@ -1409,14 +1434,15 @@ class ColumnOperators(Operators):
r"""Implement the 'contains' operator.
Produces a LIKE expression that tests against a match for the middle
of a string value::
of a string value:
.. sourcecode:: sql
column LIKE '%' || <other> || '%'
E.g.::
stmt = select(sometable).\
where(sometable.c.column.contains("foobar"))
stmt = select(sometable).where(sometable.c.column.contains("foobar"))
Since the operator uses ``LIKE``, wildcard characters
``"%"`` and ``"_"`` that are present inside the <other> expression
@@ -1445,7 +1471,9 @@ class ColumnOperators(Operators):
somecolumn.contains("foo%bar", autoescape=True)
Will render as::
Will render as:
.. sourcecode:: sql
somecolumn LIKE '%' || :param || '%' ESCAPE '/'
@@ -1461,7 +1489,9 @@ class ColumnOperators(Operators):
somecolumn.contains("foo/%bar", escape="^")
Will render as::
Will render as:
.. sourcecode:: sql
somecolumn LIKE '%' || :param || '%' ESCAPE '^'
@@ -1482,7 +1512,7 @@ class ColumnOperators(Operators):
:meth:`.ColumnOperators.like`
"""
""" # noqa: E501
return self.operate(contains_op, other, **kw)
def icontains(self, other: Any, **kw: Any) -> ColumnOperators:
@@ -1490,14 +1520,15 @@ class ColumnOperators(Operators):
version of :meth:`.ColumnOperators.contains`.
Produces a LIKE expression that tests against an insensitive match
for the middle of a string value::
for the middle of a string value:
.. sourcecode:: sql
lower(column) LIKE '%' || lower(<other>) || '%'
E.g.::
stmt = select(sometable).\
where(sometable.c.column.icontains("foobar"))
stmt = select(sometable).where(sometable.c.column.icontains("foobar"))
Since the operator uses ``LIKE``, wildcard characters
``"%"`` and ``"_"`` that are present inside the <other> expression
@@ -1526,7 +1557,9 @@ class ColumnOperators(Operators):
somecolumn.icontains("foo%bar", autoescape=True)
Will render as::
Will render as:
.. sourcecode:: sql
lower(somecolumn) LIKE '%' || lower(:param) || '%' ESCAPE '/'
@@ -1542,7 +1575,9 @@ class ColumnOperators(Operators):
somecolumn.icontains("foo/%bar", escape="^")
Will render as::
Will render as:
.. sourcecode:: sql
lower(somecolumn) LIKE '%' || lower(:param) || '%' ESCAPE '^'
@@ -1558,7 +1593,7 @@ class ColumnOperators(Operators):
:meth:`.ColumnOperators.contains`
"""
""" # noqa: E501
return self.operate(icontains_op, other, **kw)
def match(self, other: Any, **kwargs: Any) -> ColumnOperators:
@@ -1582,7 +1617,7 @@ class ColumnOperators(Operators):
:class:`_mysql.match` - MySQL specific construct with
additional features.
* Oracle - renders ``CONTAINS(x, y)``
* Oracle Database - renders ``CONTAINS(x, y)``
* other backends may provide special implementations.
* Backends without any special implementation will emit
the operator as "MATCH". This is compatible with SQLite, for
@@ -1599,7 +1634,7 @@ class ColumnOperators(Operators):
E.g.::
stmt = select(table.c.some_column).where(
table.c.some_column.regexp_match('^(b|c)')
table.c.some_column.regexp_match("^(b|c)")
)
:meth:`_sql.ColumnOperators.regexp_match` attempts to resolve to
@@ -1610,7 +1645,7 @@ class ColumnOperators(Operators):
Examples include:
* PostgreSQL - renders ``x ~ y`` or ``x !~ y`` when negated.
* Oracle - renders ``REGEXP_LIKE(x, y)``
* Oracle Database - renders ``REGEXP_LIKE(x, y)``
* SQLite - uses SQLite's ``REGEXP`` placeholder operator and calls into
the Python ``re.match()`` builtin.
* other backends may provide special implementations.
@@ -1618,9 +1653,9 @@ class ColumnOperators(Operators):
the operator as "REGEXP" or "NOT REGEXP". This is compatible with
SQLite and MySQL, for example.
Regular expression support is currently implemented for Oracle,
PostgreSQL, MySQL and MariaDB. Partial support is available for
SQLite. Support among third-party dialects may vary.
Regular expression support is currently implemented for Oracle
Database, PostgreSQL, MySQL and MariaDB. Partial support is available
for SQLite. Support among third-party dialects may vary.
:param pattern: The regular expression pattern string or column
clause.
@@ -1657,11 +1692,7 @@ class ColumnOperators(Operators):
E.g.::
stmt = select(
table.c.some_column.regexp_replace(
'b(..)',
'X\1Y',
flags='g'
)
table.c.some_column.regexp_replace("b(..)", "X\1Y", flags="g")
)
:meth:`_sql.ColumnOperators.regexp_replace` attempts to resolve to
@@ -1671,8 +1702,8 @@ class ColumnOperators(Operators):
**not backend agnostic**.
Regular expression replacement support is currently implemented for
Oracle, PostgreSQL, MySQL 8 or greater and MariaDB. Support among
third-party dialects may vary.
Oracle Database, PostgreSQL, MySQL 8 or greater and MariaDB. Support
among third-party dialects may vary.
:param pattern: The regular expression pattern string or column
clause.
@@ -1728,8 +1759,7 @@ class ColumnOperators(Operators):
# deprecated 1.4; see #5435
if TYPE_CHECKING:
def nullsfirst(self) -> ColumnOperators:
...
def nullsfirst(self) -> ColumnOperators: ...
else:
nullsfirst = nulls_first
@@ -1747,8 +1777,7 @@ class ColumnOperators(Operators):
# deprecated 1.4; see #5429
if TYPE_CHECKING:
def nullslast(self) -> ColumnOperators:
...
def nullslast(self) -> ColumnOperators: ...
else:
nullslast = nulls_last
@@ -1819,10 +1848,10 @@ class ColumnOperators(Operators):
See the documentation for :func:`_sql.any_` for examples.
.. note:: be sure to not confuse the newer
:meth:`_sql.ColumnOperators.any_` method with its older
:class:`_types.ARRAY`-specific counterpart, the
:meth:`_types.ARRAY.Comparator.any` method, which a different
calling syntax and usage pattern.
:meth:`_sql.ColumnOperators.any_` method with the **legacy**
version of this method, the :meth:`_types.ARRAY.Comparator.any`
method that's specific to :class:`_types.ARRAY`, which uses a
different calling style.
"""
return self.operate(any_op)
@@ -1834,10 +1863,10 @@ class ColumnOperators(Operators):
See the documentation for :func:`_sql.all_` for examples.
.. note:: be sure to not confuse the newer
:meth:`_sql.ColumnOperators.all_` method with its older
:class:`_types.ARRAY`-specific counterpart, the
:meth:`_types.ARRAY.Comparator.all` method, which a different
calling syntax and usage pattern.
:meth:`_sql.ColumnOperators.all_` method with the **legacy**
version of this method, the :meth:`_types.ARRAY.Comparator.all`
method that's specific to :class:`_types.ARRAY`, which uses a
different calling style.
"""
return self.operate(all_op)
@@ -1968,8 +1997,7 @@ def is_true(a: Any) -> Any:
if TYPE_CHECKING:
@_operator_fn
def istrue(a: Any) -> Any:
...
def istrue(a: Any) -> Any: ...
else:
istrue = is_true
@@ -1984,8 +2012,7 @@ def is_false(a: Any) -> Any:
if TYPE_CHECKING:
@_operator_fn
def isfalse(a: Any) -> Any:
...
def isfalse(a: Any) -> Any: ...
else:
isfalse = is_false
@@ -2007,8 +2034,7 @@ def is_not_distinct_from(a: Any, b: Any) -> Any:
if TYPE_CHECKING:
@_operator_fn
def isnot_distinct_from(a: Any, b: Any) -> Any:
...
def isnot_distinct_from(a: Any, b: Any) -> Any: ...
else:
isnot_distinct_from = is_not_distinct_from
@@ -2030,8 +2056,7 @@ def is_not(a: Any, b: Any) -> Any:
if TYPE_CHECKING:
@_operator_fn
def isnot(a: Any, b: Any) -> Any:
...
def isnot(a: Any, b: Any) -> Any: ...
else:
isnot = is_not
@@ -2063,8 +2088,7 @@ def not_like_op(a: Any, b: Any, escape: Optional[str] = None) -> Any:
if TYPE_CHECKING:
@_operator_fn
def notlike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any:
...
def notlike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: ...
else:
notlike_op = not_like_op
@@ -2086,8 +2110,7 @@ def not_ilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any:
if TYPE_CHECKING:
@_operator_fn
def notilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any:
...
def notilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: ...
else:
notilike_op = not_ilike_op
@@ -2109,8 +2132,9 @@ def not_between_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any:
if TYPE_CHECKING:
@_operator_fn
def notbetween_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any:
...
def notbetween_op(
a: Any, b: Any, c: Any, symmetric: bool = False
) -> Any: ...
else:
notbetween_op = not_between_op
@@ -2132,8 +2156,7 @@ def not_in_op(a: Any, b: Any) -> Any:
if TYPE_CHECKING:
@_operator_fn
def notin_op(a: Any, b: Any) -> Any:
...
def notin_op(a: Any, b: Any) -> Any: ...
else:
notin_op = not_in_op
@@ -2198,8 +2221,7 @@ if TYPE_CHECKING:
@_operator_fn
def notstartswith_op(
a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False
) -> Any:
...
) -> Any: ...
else:
notstartswith_op = not_startswith_op
@@ -2243,8 +2265,7 @@ if TYPE_CHECKING:
@_operator_fn
def notendswith_op(
a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False
) -> Any:
...
) -> Any: ...
else:
notendswith_op = not_endswith_op
@@ -2288,8 +2309,7 @@ if TYPE_CHECKING:
@_operator_fn
def notcontains_op(
a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False
) -> Any:
...
) -> Any: ...
else:
notcontains_op = not_contains_op
@@ -2346,8 +2366,7 @@ def not_match_op(a: Any, b: Any, **kw: Any) -> Any:
if TYPE_CHECKING:
@_operator_fn
def notmatch_op(a: Any, b: Any, **kw: Any) -> Any:
...
def notmatch_op(a: Any, b: Any, **kw: Any) -> Any: ...
else:
notmatch_op = not_match_op
@@ -2392,8 +2411,7 @@ def nulls_first_op(a: Any) -> Any:
if TYPE_CHECKING:
@_operator_fn
def nullsfirst_op(a: Any) -> Any:
...
def nullsfirst_op(a: Any) -> Any: ...
else:
nullsfirst_op = nulls_first_op
@@ -2408,8 +2426,7 @@ def nulls_last_op(a: Any) -> Any:
if TYPE_CHECKING:
@_operator_fn
def nullslast_op(a: Any) -> Any:
...
def nullslast_op(a: Any) -> Any: ...
else:
nullslast_op = nulls_last_op
@@ -2501,6 +2518,12 @@ def is_associative(op: OperatorType) -> bool:
return op in _associative
def is_order_by_modifier(op: Optional[OperatorType]) -> bool:
return op in _order_by_modifier
_order_by_modifier = {desc_op, asc_op, nulls_first_op, nulls_last_op}
_natural_self_precedent = _associative.union(
[getitem, json_getitem_op, json_path_getitem_op]
)
@@ -2582,9 +2605,13 @@ _PRECEDENCE: Dict[OperatorType, int] = {
}
def is_precedent(operator: OperatorType, against: OperatorType) -> bool:
def is_precedent(
operator: OperatorType, against: Optional[OperatorType]
) -> bool:
if operator is against and is_natural_self_precedent(operator):
return False
elif against is None:
return True
else:
return bool(
_PRECEDENCE.get(

View File

@@ -1,5 +1,5 @@
# sql/roles.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -227,8 +227,7 @@ class AnonymizedFromClauseRole(StrictFromClauseRole):
def _anonymous_fromclause(
self, *, name: Optional[str] = None, flat: bool = False
) -> FromClause:
...
) -> FromClause: ...
class ReturnsRowsRole(SQLRole):
@@ -246,8 +245,7 @@ class StatementRole(SQLRole):
if TYPE_CHECKING:
@util.memoized_property
def _propagate_attrs(self) -> _PropagateAttrsType:
...
def _propagate_attrs(self) -> _PropagateAttrsType: ...
else:
_propagate_attrs = util.EMPTY_DICT

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
# sql/traversals.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -80,16 +80,13 @@ class HasShallowCopy(HasTraverseInternals):
if typing.TYPE_CHECKING:
def _generated_shallow_copy_traversal(self, other: Self) -> None:
...
def _generated_shallow_copy_traversal(self, other: Self) -> None: ...
def _generated_shallow_from_dict_traversal(
self, d: Dict[str, Any]
) -> None:
...
) -> None: ...
def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]:
...
def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: ...
@classmethod
def _generate_shallow_copy(
@@ -312,9 +309,11 @@ class _CopyInternalsTraversal(HasTraversalDispatch):
# sequence of 2-tuples
return [
(
clone(key, **kw)
if hasattr(key, "__clause_element__")
else key,
(
clone(key, **kw)
if hasattr(key, "__clause_element__")
else key
),
clone(value, **kw),
)
for key, value in element
@@ -336,9 +335,11 @@ class _CopyInternalsTraversal(HasTraversalDispatch):
def copy(elem):
if isinstance(elem, (list, tuple)):
return [
clone(value, **kw)
if hasattr(value, "__clause_element__")
else value
(
clone(value, **kw)
if hasattr(value, "__clause_element__")
else value
)
for value in elem
]
elif isinstance(elem, dict):
@@ -561,6 +562,8 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
return False
else:
continue
elif right_child is None:
return False
comparison = dispatch(
left_attrname, left, left_child, right, right_child, **kw

View File

@@ -1,18 +1,15 @@
# sql/types_api.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# sql/type_api.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Base types API.
"""
"""Base types API."""
from __future__ import annotations
from enum import Enum
from types import ModuleType
import typing
from typing import Any
from typing import Callable
@@ -39,6 +36,7 @@ from .. import exc
from .. import util
from ..util.typing import Protocol
from ..util.typing import Self
from ..util.typing import TypeAliasType
from ..util.typing import TypedDict
from ..util.typing import TypeGuard
@@ -57,6 +55,7 @@ if typing.TYPE_CHECKING:
from .sqltypes import NUMERICTYPE as NUMERICTYPE # noqa: F401
from .sqltypes import STRINGTYPE as STRINGTYPE # noqa: F401
from .sqltypes import TABLEVALUE as TABLEVALUE # noqa: F401
from ..engine.interfaces import DBAPIModule
from ..engine.interfaces import Dialect
from ..util.typing import GenericProtocol
@@ -66,8 +65,11 @@ _T_con = TypeVar("_T_con", bound=Any, contravariant=True)
_O = TypeVar("_O", bound=object)
_TE = TypeVar("_TE", bound="TypeEngine[Any]")
_CT = TypeVar("_CT", bound=Any)
_RT = TypeVar("_RT", bound=Any)
_MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]]
_MatchedOnType = Union[
"GenericProtocol[Any]", TypeAliasType, NewType, Type[Any]
]
class _NoValueInList(Enum):
@@ -80,23 +82,19 @@ _NO_VALUE_IN_LIST = _NoValueInList.NO_VALUE_IN_LIST
class _LiteralProcessorType(Protocol[_T_co]):
def __call__(self, value: Any) -> str:
...
def __call__(self, value: Any) -> str: ...
class _BindProcessorType(Protocol[_T_con]):
def __call__(self, value: Optional[_T_con]) -> Any:
...
def __call__(self, value: Optional[_T_con]) -> Any: ...
class _ResultProcessorType(Protocol[_T_co]):
def __call__(self, value: Any) -> Optional[_T_co]:
...
def __call__(self, value: Any) -> Optional[_T_co]: ...
class _SentinelProcessorType(Protocol[_T_co]):
def __call__(self, value: Any) -> Optional[_T_co]:
...
def __call__(self, value: Any) -> Optional[_T_co]: ...
class _BaseTypeMemoDict(TypedDict):
@@ -112,8 +110,9 @@ class _TypeMemoDict(_BaseTypeMemoDict, total=False):
class _ComparatorFactory(Protocol[_T]):
def __call__(self, expr: ColumnElement[_T]) -> TypeEngine.Comparator[_T]:
...
def __call__(
self, expr: ColumnElement[_T]
) -> TypeEngine.Comparator[_T]: ...
class TypeEngine(Visitable, Generic[_T]):
@@ -183,10 +182,27 @@ class TypeEngine(Visitable, Generic[_T]):
self.expr = expr
self.type = expr.type
def __reduce__(self) -> Any:
return self.__class__, (self.expr,)
@overload
def operate(
self,
op: OperatorType,
*other: Any,
result_type: Type[TypeEngine[_RT]],
**kwargs: Any,
) -> ColumnElement[_RT]: ...
@overload
def operate(
self, op: OperatorType, *other: Any, **kwargs: Any
) -> ColumnElement[_CT]: ...
@util.preload_module("sqlalchemy.sql.default_comparator")
def operate(
self, op: OperatorType, *other: Any, **kwargs: Any
) -> ColumnElement[_CT]:
) -> ColumnElement[Any]:
default_comparator = util.preloaded.sql_default_comparator
op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__]
if kwargs:
@@ -297,9 +313,9 @@ class TypeEngine(Visitable, Generic[_T]):
"""
_variant_mapping: util.immutabledict[
str, TypeEngine[Any]
] = util.EMPTY_DICT
_variant_mapping: util.immutabledict[str, TypeEngine[Any]] = (
util.EMPTY_DICT
)
def evaluates_none(self) -> Self:
"""Return a copy of this type which has the
@@ -308,11 +324,13 @@ class TypeEngine(Visitable, Generic[_T]):
E.g.::
Table(
'some_table', metadata,
"some_table",
metadata,
Column(
String(50).evaluates_none(),
nullable=True,
server_default='no value')
server_default="no value",
),
)
The ORM uses this flag to indicate that a positive value of ``None``
@@ -371,7 +389,7 @@ class TypeEngine(Visitable, Generic[_T]):
as the sole positional argument and will return a string representation
to be rendered in a SQL statement.
.. note::
.. tip::
This method is only called relative to a **dialect specific type
object**, which is often **private to a dialect in use** and is not
@@ -405,7 +423,7 @@ class TypeEngine(Visitable, Generic[_T]):
If processing is not necessary, the method should return ``None``.
.. note::
.. tip::
This method is only called relative to a **dialect specific type
object**, which is often **private to a dialect in use** and is not
@@ -441,7 +459,7 @@ class TypeEngine(Visitable, Generic[_T]):
If processing is not necessary, the method should return ``None``.
.. note::
.. tip::
This method is only called relative to a **dialect specific type
object**, which is often **private to a dialect in use** and is not
@@ -480,11 +498,19 @@ class TypeEngine(Visitable, Generic[_T]):
It is the SQL analogue of the :meth:`.TypeEngine.result_processor`
method.
.. note:: The :func:`.TypeEngine.column_expression` method is applied
only to the **outermost columns clause** of a SELECT statement, that
is, the columns that are to be delivered directly into the returned
result rows. It does **not** apply to the columns clause inside
of subqueries. This necessarily avoids double conversions against
the column and only runs the conversion when ready to be returned
to the client.
This method is called during the **SQL compilation** phase of a
statement, when rendering a SQL string. It is **not** called
against specific values.
.. note::
.. tip::
This method is only called relative to a **dialect specific type
object**, which is often **private to a dialect in use** and is not
@@ -574,18 +600,6 @@ class TypeEngine(Visitable, Generic[_T]):
"""
return None
def _sentinel_value_resolver(
self, dialect: Dialect
) -> Optional[_SentinelProcessorType[_T]]:
"""Return an optional callable that will match parameter values
(post-bind processing) to result values
(pre-result-processing), for use in the "sentinel" feature.
.. versionadded:: 2.0.10
"""
return None
@util.memoized_property
def _has_bind_expression(self) -> bool:
"""memoized boolean, check if bind_expression is implemented.
@@ -606,7 +620,7 @@ class TypeEngine(Visitable, Generic[_T]):
return x == y # type: ignore[no-any-return]
def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
def get_dbapi_type(self, dbapi: DBAPIModule) -> Optional[Any]:
"""Return the corresponding type object from the underlying DB-API, if
any.
@@ -650,7 +664,7 @@ class TypeEngine(Visitable, Generic[_T]):
string_type = String()
string_type = string_type.with_variant(
mysql.VARCHAR(collation='foo'), 'mysql', 'mariadb'
mysql.VARCHAR(collation="foo"), "mysql", "mariadb"
)
The variant mapping indicates that when this type is
@@ -767,6 +781,10 @@ class TypeEngine(Visitable, Generic[_T]):
return self
def _with_collation(self, collation: str) -> Self:
"""set up error handling for the collate expression"""
raise NotImplementedError("this datatype does not support collation")
@util.ro_memoized_property
def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]:
"""Return a rudimental 'affinity' value expressing the general class
@@ -933,18 +951,6 @@ class TypeEngine(Visitable, Generic[_T]):
d["result"][coltype] = rp
return rp
def _cached_sentinel_value_processor(
self, dialect: Dialect
) -> Optional[_SentinelProcessorType[_T]]:
try:
return dialect._type_memos[self]["sentinel"]
except KeyError:
pass
d = self._dialect_info(dialect)
d["sentinel"] = bp = d["impl"]._sentinel_value_resolver(dialect)
return bp
def _cached_custom_processor(
self, dialect: Dialect, key: str, fn: Callable[[TypeEngine[_T]], _O]
) -> _O:
@@ -999,9 +1005,11 @@ class TypeEngine(Visitable, Generic[_T]):
return (self.__class__,) + tuple(
(
k,
self.__dict__[k]._static_cache_key
if isinstance(self.__dict__[k], TypeEngine)
else self.__dict__[k],
(
self.__dict__[k]._static_cache_key
if isinstance(self.__dict__[k], TypeEngine)
else self.__dict__[k]
),
)
for k in names
if k in self.__dict__
@@ -1010,12 +1018,12 @@ class TypeEngine(Visitable, Generic[_T]):
)
@overload
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
...
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ...
@overload
def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
...
def adapt(
self, cls: Type[TypeEngineMixin], **kw: Any
) -> TypeEngine[Any]: ...
def adapt(
self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
@@ -1027,9 +1035,11 @@ class TypeEngine(Visitable, Generic[_T]):
types with "implementation" types that are specific to a particular
dialect.
"""
return util.constructor_copy(
typ = util.constructor_copy(
self, cast(Type[TypeEngine[Any]], cls), **kw
)
typ._variant_mapping = self._variant_mapping
return typ
def coerce_compared_value(
self, op: Optional[OperatorType], value: Any
@@ -1108,26 +1118,21 @@ class TypeEngineMixin:
@util.memoized_property
def _static_cache_key(
self,
) -> Union[CacheConst, Tuple[Any, ...]]:
...
) -> Union[CacheConst, Tuple[Any, ...]]: ...
@overload
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
...
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ...
@overload
def adapt(
self, cls: Type[TypeEngineMixin], **kw: Any
) -> TypeEngine[Any]:
...
) -> TypeEngine[Any]: ...
def adapt(
self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
) -> TypeEngine[Any]:
...
) -> TypeEngine[Any]: ...
def dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
...
def dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: ...
class ExternalType(TypeEngineMixin):
@@ -1146,7 +1151,7 @@ class ExternalType(TypeEngineMixin):
"""
cache_ok: Optional[bool] = None
"""Indicate if statements using this :class:`.ExternalType` are "safe to
'''Indicate if statements using this :class:`.ExternalType` are "safe to
cache".
The default value ``None`` will emit a warning and then not allow caching
@@ -1187,12 +1192,12 @@ class ExternalType(TypeEngineMixin):
series of tuples. Given a previously un-cacheable type as::
class LookupType(UserDefinedType):
'''a custom type that accepts a dictionary as a parameter.
"""a custom type that accepts a dictionary as a parameter.
this is the non-cacheable version, as "self.lookup" is not
hashable.
'''
"""
def __init__(self, lookup):
self.lookup = lookup
@@ -1200,8 +1205,7 @@ class ExternalType(TypeEngineMixin):
def get_col_spec(self, **kw):
return "VARCHAR(255)"
def bind_processor(self, dialect):
# ... works with "self.lookup" ...
def bind_processor(self, dialect): ... # works with "self.lookup" ...
Where "lookup" is a dictionary. The type will not be able to generate
a cache key::
@@ -1237,7 +1241,7 @@ class ExternalType(TypeEngineMixin):
to the ".lookup" attribute::
class LookupType(UserDefinedType):
'''a custom type that accepts a dictionary as a parameter.
"""a custom type that accepts a dictionary as a parameter.
The dictionary is stored both as itself in a private variable,
and published in a public variable as a sorted tuple of tuples,
@@ -1245,7 +1249,7 @@ class ExternalType(TypeEngineMixin):
two equivalent dictionaries. Note it assumes the keys and
values of the dictionary are themselves hashable.
'''
"""
cache_ok = True
@@ -1254,15 +1258,12 @@ class ExternalType(TypeEngineMixin):
# assume keys/values of "lookup" are hashable; otherwise
# they would also need to be converted in some way here
self.lookup = tuple(
(key, lookup[key]) for key in sorted(lookup)
)
self.lookup = tuple((key, lookup[key]) for key in sorted(lookup))
def get_col_spec(self, **kw):
return "VARCHAR(255)"
def bind_processor(self, dialect):
# ... works with "self._lookup" ...
def bind_processor(self, dialect): ... # works with "self._lookup" ...
Where above, the cache key for ``LookupType({"a": 10, "b": 20})`` will be::
@@ -1280,7 +1281,7 @@ class ExternalType(TypeEngineMixin):
:ref:`sql_caching`
""" # noqa: E501
''' # noqa: E501
@util.non_memoized_property
def _static_cache_key(
@@ -1322,10 +1323,11 @@ class UserDefinedType(
import sqlalchemy.types as types
class MyType(types.UserDefinedType):
cache_ok = True
def __init__(self, precision = 8):
def __init__(self, precision=8):
self.precision = precision
def get_col_spec(self, **kw):
@@ -1334,19 +1336,23 @@ class UserDefinedType(
def bind_processor(self, dialect):
def process(value):
return value
return process
def result_processor(self, dialect, coltype):
def process(value):
return value
return process
Once the type is made, it's immediately usable::
table = Table('foo', metadata_obj,
Column('id', Integer, primary_key=True),
Column('data', MyType(16))
)
table = Table(
"foo",
metadata_obj,
Column("id", Integer, primary_key=True),
Column("data", MyType(16)),
)
The ``get_col_spec()`` method will in most cases receive a keyword
argument ``type_expression`` which refers to the owning expression
@@ -1391,6 +1397,10 @@ class UserDefinedType(
return self
if TYPE_CHECKING:
def get_col_spec(self, **kw: Any) -> str: ...
class Emulated(TypeEngineMixin):
"""Mixin for base types that emulate the behavior of a DB-native type.
@@ -1429,12 +1439,12 @@ class Emulated(TypeEngineMixin):
return super().adapt(impltype, **kw)
@overload
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
...
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ...
@overload
def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
...
def adapt(
self, cls: Type[TypeEngineMixin], **kw: Any
) -> TypeEngine[Any]: ...
def adapt(
self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any
@@ -1511,7 +1521,7 @@ class NativeForEmulated(TypeEngineMixin):
class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
"""Allows the creation of types which add additional functionality
'''Allows the creation of types which add additional functionality
to an existing type.
This method is preferred to direct subclassing of SQLAlchemy's
@@ -1522,10 +1532,11 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
import sqlalchemy.types as types
class MyType(types.TypeDecorator):
'''Prefixes Unicode values with "PREFIX:" on the way in and
"""Prefixes Unicode values with "PREFIX:" on the way in and
strips it off on the way out.
'''
"""
impl = types.Unicode
@@ -1578,6 +1589,8 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
class MyEpochType(types.TypeDecorator):
impl = types.Integer
cache_ok = True
epoch = datetime.date(1970, 1, 1)
def process_bind_param(self, value, dialect):
@@ -1615,6 +1628,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
from sqlalchemy import JSON
from sqlalchemy import TypeDecorator
class MyJsonType(TypeDecorator):
impl = JSON
@@ -1635,6 +1649,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
from sqlalchemy import ARRAY
from sqlalchemy import TypeDecorator
class MyArrayType(TypeDecorator):
impl = ARRAY
@@ -1643,8 +1658,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
def coerce_compared_value(self, op, value):
return self.impl.coerce_compared_value(op, value)
"""
'''
__visit_name__ = "type_decorator"
@@ -1740,20 +1754,48 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
return super().reverse_operate(op, other, **kwargs)
@staticmethod
def _reduce_td_comparator(
impl: TypeEngine[Any], expr: ColumnElement[_T]
) -> Any:
return TypeDecorator._create_td_comparator_type(impl)(expr)
@staticmethod
def _create_td_comparator_type(
impl: TypeEngine[Any],
) -> _ComparatorFactory[Any]:
def __reduce__(self: TypeDecorator.Comparator[Any]) -> Any:
return (TypeDecorator._reduce_td_comparator, (impl, self.expr))
return type(
"TDComparator",
(TypeDecorator.Comparator, impl.comparator_factory), # type: ignore # noqa: E501
{"__reduce__": __reduce__},
)
@property
def comparator_factory( # type: ignore # mypy properties bug
self,
) -> _ComparatorFactory[Any]:
if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: # type: ignore # noqa: E501
return self.impl.comparator_factory
return self.impl_instance.comparator_factory
else:
# reconcile the Comparator class on the impl with that
# of TypeDecorator
return type(
"TDComparator",
(TypeDecorator.Comparator, self.impl.comparator_factory), # type: ignore # noqa: E501
{},
# of TypeDecorator.
# the use of multiple staticmethods is to support repeated
# pickling of the Comparator itself
return TypeDecorator._create_td_comparator_type(self.impl_instance)
def _copy_with_check(self) -> Self:
tt = self.copy()
if not isinstance(tt, self.__class__):
raise AssertionError(
"Type object %s does not properly "
"implement the copy() method, it must "
"return an object of type %s" % (self, self.__class__)
)
return tt
def _gen_dialect_impl(self, dialect: Dialect) -> TypeEngine[_T]:
if dialect.name in self._variant_mapping:
@@ -1769,16 +1811,17 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
# to a copy of this TypeDecorator and return
# that.
typedesc = self.load_dialect_impl(dialect).dialect_impl(dialect)
tt = self.copy()
if not isinstance(tt, self.__class__):
raise AssertionError(
"Type object %s does not properly "
"implement the copy() method, it must "
"return an object of type %s" % (self, self.__class__)
)
tt = self._copy_with_check()
tt.impl = tt.impl_instance = typedesc
return tt
def _with_collation(self, collation: str) -> Self:
tt = self._copy_with_check()
tt.impl = tt.impl_instance = self.impl_instance._with_collation(
collation
)
return tt
@util.ro_non_memoized_property
def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
return self.impl_instance._type_affinity
@@ -2233,7 +2276,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
instance.__dict__.update(self.__dict__)
return instance
def get_dbapi_type(self, dbapi: ModuleType) -> Optional[Any]:
def get_dbapi_type(self, dbapi: DBAPIModule) -> Optional[Any]:
"""Return the DBAPI type object represented by this
:class:`.TypeDecorator`.
@@ -2280,13 +2323,13 @@ class Variant(TypeDecorator[_T]):
@overload
def to_instance(typeobj: Union[Type[_TE], _TE], *arg: Any, **kw: Any) -> _TE:
...
def to_instance(
typeobj: Union[Type[_TE], _TE], *arg: Any, **kw: Any
) -> _TE: ...
@overload
def to_instance(typeobj: None, *arg: Any, **kw: Any) -> TypeEngine[None]:
...
def to_instance(typeobj: None, *arg: Any, **kw: Any) -> TypeEngine[None]: ...
def to_instance(
@@ -2302,11 +2345,10 @@ def to_instance(
def adapt_type(
typeobj: TypeEngine[Any],
typeobj: _TypeEngineArgument[Any],
colspecs: Mapping[Type[Any], Type[TypeEngine[Any]]],
) -> TypeEngine[Any]:
if isinstance(typeobj, type):
typeobj = typeobj()
typeobj = to_instance(typeobj)
for t in typeobj.__class__.__mro__[0:-1]:
try:
impltype = colspecs[t]

View File

@@ -1,14 +1,12 @@
# sql/util.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""High level utilities which build upon other modules here.
"""
"""High level utilities which build upon other modules here."""
from __future__ import annotations
from collections import deque
@@ -106,7 +104,7 @@ def join_condition(
would produce an expression along the lines of::
tablea.c.id==tableb.c.tablea_id
tablea.c.id == tableb.c.tablea_id
The join is determined based on the foreign key relationships
between the two selectables. If there are multiple ways
@@ -268,7 +266,7 @@ def visit_binary_product(
The function is of the form::
def my_fn(binary, left, right)
def my_fn(binary, left, right): ...
For each binary expression located which has a
comparison operator, the product of "left" and
@@ -277,12 +275,11 @@ def visit_binary_product(
Hence an expression like::
and_(
(a + b) == q + func.sum(e + f),
j == r
)
and_((a + b) == q + func.sum(e + f), j == r)
would have the traversal::
would have the traversal:
.. sourcecode:: text
a <eq> q
a <eq> e
@@ -350,9 +347,9 @@ def find_tables(
] = _visitors["lateral"] = tables.append
if include_crud:
_visitors["insert"] = _visitors["update"] = _visitors[
"delete"
] = lambda ent: tables.append(ent.table)
_visitors["insert"] = _visitors["update"] = _visitors["delete"] = (
lambda ent: tables.append(ent.table)
)
if check_columns:
@@ -367,7 +364,7 @@ def find_tables(
return tables
def unwrap_order_by(clause):
def unwrap_order_by(clause: Any) -> Any:
"""Break up an 'order by' expression into individual column-expressions,
without DESC/ASC/NULLS FIRST/NULLS LAST"""
@@ -481,7 +478,7 @@ def surface_selectables(clause):
stack.append(elem.element)
def surface_selectables_only(clause):
def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]:
stack = [clause]
while stack:
elem = stack.pop()
@@ -528,9 +525,7 @@ def bind_values(clause):
E.g.::
>>> expr = and_(
... table.c.foo==5, table.c.foo==7
... )
>>> expr = and_(table.c.foo == 5, table.c.foo == 7)
>>> bind_values(expr)
[5, 7]
"""
@@ -878,8 +873,7 @@ def reduce_columns(
columns: Iterable[ColumnElement[Any]],
*clauses: Optional[ClauseElement],
**kw: bool,
) -> Sequence[ColumnElement[Any]]:
...
) -> Sequence[ColumnElement[Any]]: ...
@overload
@@ -887,8 +881,7 @@ def reduce_columns(
columns: _SelectIterable,
*clauses: Optional[ClauseElement],
**kw: bool,
) -> Sequence[Union[ColumnElement[Any], TextClause]]:
...
) -> Sequence[Union[ColumnElement[Any], TextClause]]: ...
def reduce_columns(
@@ -1043,20 +1036,24 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
E.g.::
table1 = Table('sometable', metadata,
Column('col1', Integer),
Column('col2', Integer)
)
table2 = Table('someothertable', metadata,
Column('col1', Integer),
Column('col2', Integer)
)
table1 = Table(
"sometable",
metadata,
Column("col1", Integer),
Column("col2", Integer),
)
table2 = Table(
"someothertable",
metadata,
Column("col1", Integer),
Column("col2", Integer),
)
condition = table1.c.col1 == table2.c.col1
make an alias of table1::
s = table1.alias('foo')
s = table1.alias("foo")
calling ``ClauseAdapter(s).traverse(condition)`` converts
condition to read::
@@ -1099,8 +1096,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
if TYPE_CHECKING:
@overload
def traverse(self, obj: Literal[None]) -> None:
...
def traverse(self, obj: Literal[None]) -> None: ...
# note this specializes the ReplacingExternalTraversal.traverse()
# method to state
@@ -1111,13 +1107,11 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
# FromClause but Mypy is not accepting those as compatible with
# the base ReplacingExternalTraversal
@overload
def traverse(self, obj: _ET) -> _ET:
...
def traverse(self, obj: _ET) -> _ET: ...
def traverse(
self, obj: Optional[ExternallyTraversible]
) -> Optional[ExternallyTraversible]:
...
) -> Optional[ExternallyTraversible]: ...
def _corresponding_column(
self, col, require_embedded, _seen=util.EMPTY_SET
@@ -1177,7 +1171,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
# we are an alias of a table and we are not derived from an
# alias of a table (which nonetheless may be the same table
# as ours) so, same thing
return col # type: ignore
return col
else:
# other cases where we are a selectable and the element
# is another join or selectable that contains a table which our
@@ -1219,23 +1213,18 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
class _ColumnLookup(Protocol):
@overload
def __getitem__(self, key: None) -> None:
...
def __getitem__(self, key: None) -> None: ...
@overload
def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]:
...
def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ...
@overload
def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]:
...
def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ...
@overload
def __getitem__(self, key: _ET) -> _ET:
...
def __getitem__(self, key: _ET) -> _ET: ...
def __getitem__(self, key: Any) -> Any:
...
def __getitem__(self, key: Any) -> Any: ...
class ColumnAdapter(ClauseAdapter):
@@ -1333,12 +1322,10 @@ class ColumnAdapter(ClauseAdapter):
return ac
@overload
def traverse(self, obj: Literal[None]) -> None:
...
def traverse(self, obj: Literal[None]) -> None: ...
@overload
def traverse(self, obj: _ET) -> _ET:
...
def traverse(self, obj: _ET) -> _ET: ...
def traverse(
self, obj: Optional[ExternallyTraversible]
@@ -1353,8 +1340,7 @@ class ColumnAdapter(ClauseAdapter):
if TYPE_CHECKING:
@property
def visitor_iterator(self) -> Iterator[ColumnAdapter]:
...
def visitor_iterator(self) -> Iterator[ColumnAdapter]: ...
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process

View File

@@ -1,14 +1,11 @@
# sql/visitors.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Visitor/traversal interface and library functions.
"""
"""Visitor/traversal interface and library functions."""
from __future__ import annotations
@@ -72,8 +69,7 @@ __all__ = [
class _CompilerDispatchType(Protocol):
def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any:
...
def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any: ...
class Visitable:
@@ -100,8 +96,7 @@ class Visitable:
if typing.TYPE_CHECKING:
def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str:
...
def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str: ...
def __init_subclass__(cls) -> None:
if "__visit_name__" in cls.__dict__:
@@ -493,8 +488,7 @@ class HasTraverseInternals:
class _InternalTraversalDispatchType(Protocol):
def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any:
...
def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any: ...
class HasTraversalDispatch:
@@ -602,13 +596,11 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
if typing.TYPE_CHECKING:
def _annotate(self, values: _AnnotationDict) -> Self:
...
def _annotate(self, values: _AnnotationDict) -> Self: ...
def get_children(
self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
) -> Iterable[ExternallyTraversible]:
...
) -> Iterable[ExternallyTraversible]: ...
def _clone(self, **kw: Any) -> Self:
"""clone this element"""
@@ -638,13 +630,11 @@ _TraverseCallableType = Callable[[_ET], None]
class _CloneCallableType(Protocol):
def __call__(self, element: _ET, **kw: Any) -> _ET:
...
def __call__(self, element: _ET, **kw: Any) -> _ET: ...
class _TraverseTransformCallableType(Protocol[_ET]):
def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]:
...
def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]: ...
_ExtT = TypeVar("_ExtT", bound="ExternalTraversal")
@@ -680,12 +670,12 @@ class ExternalTraversal(util.MemoizedSlots):
return iterate(obj, self.__traverse_options__)
@overload
def traverse(self, obj: Literal[None]) -> None:
...
def traverse(self, obj: Literal[None]) -> None: ...
@overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
...
def traverse(
self, obj: ExternallyTraversible
) -> ExternallyTraversible: ...
def traverse(
self, obj: Optional[ExternallyTraversible]
@@ -746,12 +736,12 @@ class CloningExternalTraversal(ExternalTraversal):
return [self.traverse(x) for x in list_]
@overload
def traverse(self, obj: Literal[None]) -> None:
...
def traverse(self, obj: Literal[None]) -> None: ...
@overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
...
def traverse(
self, obj: ExternallyTraversible
) -> ExternallyTraversible: ...
def traverse(
self, obj: Optional[ExternallyTraversible]
@@ -786,12 +776,12 @@ class ReplacingExternalTraversal(CloningExternalTraversal):
return None
@overload
def traverse(self, obj: Literal[None]) -> None:
...
def traverse(self, obj: Literal[None]) -> None: ...
@overload
def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible:
...
def traverse(
self, obj: ExternallyTraversible
) -> ExternallyTraversible: ...
def traverse(
self, obj: Optional[ExternallyTraversible]
@@ -866,8 +856,7 @@ def traverse_using(
iterator: Iterable[ExternallyTraversible],
obj: Literal[None],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> None:
...
) -> None: ...
@overload
@@ -875,8 +864,7 @@ def traverse_using(
iterator: Iterable[ExternallyTraversible],
obj: ExternallyTraversible,
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
...
) -> ExternallyTraversible: ...
def traverse_using(
@@ -920,8 +908,7 @@ def traverse(
obj: Literal[None],
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> None:
...
) -> None: ...
@overload
@@ -929,8 +916,7 @@ def traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> ExternallyTraversible:
...
) -> ExternallyTraversible: ...
def traverse(
@@ -945,11 +931,13 @@ def traverse(
from sqlalchemy.sql import visitors
stmt = select(some_table).where(some_table.c.foo == 'bar')
stmt = select(some_table).where(some_table.c.foo == "bar")
def visit_bindparam(bind_param):
print("found bound value: %s" % bind_param.value)
visitors.traverse(stmt, {}, {"bindparam": visit_bindparam})
The iteration of objects uses the :func:`.visitors.iterate` function,
@@ -975,8 +963,7 @@ def cloned_traverse(
obj: Literal[None],
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> None:
...
) -> None: ...
# a bit of controversy here, as the clone of the lead element
@@ -988,8 +975,7 @@ def cloned_traverse(
obj: _ET,
opts: Mapping[str, Any],
visitors: Mapping[str, _TraverseCallableType[Any]],
) -> _ET:
...
) -> _ET: ...
def cloned_traverse(
@@ -1088,8 +1074,7 @@ def replacement_traverse(
obj: Literal[None],
opts: Mapping[str, Any],
replace: _TraverseTransformCallableType[Any],
) -> None:
...
) -> None: ...
@overload
@@ -1097,8 +1082,7 @@ def replacement_traverse(
obj: _CE,
opts: Mapping[str, Any],
replace: _TraverseTransformCallableType[Any],
) -> _CE:
...
) -> _CE: ...
@overload
@@ -1106,8 +1090,7 @@ def replacement_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
replace: _TraverseTransformCallableType[Any],
) -> ExternallyTraversible:
...
) -> ExternallyTraversible: ...
def replacement_traverse(