This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# postgresql/__init__.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
from types import ModuleType
|
||||
|
||||
from . import array as arraylib # noqa # keep above base and other dialects
|
||||
from . import asyncpg # noqa
|
||||
from . import base
|
||||
from . import pg8000 # noqa
|
||||
@@ -56,12 +57,14 @@ from .named_types import ENUM
|
||||
from .named_types import NamedType
|
||||
from .ranges import AbstractMultiRange
|
||||
from .ranges import AbstractRange
|
||||
from .ranges import AbstractSingleRange
|
||||
from .ranges import DATEMULTIRANGE
|
||||
from .ranges import DATERANGE
|
||||
from .ranges import INT4MULTIRANGE
|
||||
from .ranges import INT4RANGE
|
||||
from .ranges import INT8MULTIRANGE
|
||||
from .ranges import INT8RANGE
|
||||
from .ranges import MultiRange
|
||||
from .ranges import NUMMULTIRANGE
|
||||
from .ranges import NUMRANGE
|
||||
from .ranges import Range
|
||||
@@ -86,6 +89,7 @@ from .types import TIMESTAMP
|
||||
from .types import TSQUERY
|
||||
from .types import TSVECTOR
|
||||
|
||||
|
||||
# Alias psycopg also as psycopg_async
|
||||
psycopg_async = type(
|
||||
"psycopg_async", (ModuleType,), {"dialect": psycopg.dialect_async}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/_psycopg_common.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -169,8 +170,10 @@ class _PGDialect_common_psycopg(PGDialect):
|
||||
def _do_autocommit(self, connection, value):
|
||||
connection.autocommit = value
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_connection):
|
||||
return bool(dbapi_connection.autocommit)
|
||||
|
||||
def do_ping(self, dbapi_connection):
|
||||
cursor = None
|
||||
before_autocommit = dbapi_connection.autocommit
|
||||
|
||||
if not before_autocommit:
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
# postgresql/array.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/array.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Any as typing_Any
|
||||
from typing import Iterable
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from .operators import CONTAINED_BY
|
||||
from .operators import CONTAINS
|
||||
@@ -21,32 +24,55 @@ from ... import types as sqltypes
|
||||
from ... import util
|
||||
from ...sql import expression
|
||||
from ...sql import operators
|
||||
from ...sql._typing import _TypeEngineArgument
|
||||
from ...sql.visitors import InternalTraversal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql._typing import _ColumnExpressionArgument
|
||||
from ...sql._typing import _TypeEngineArgument
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.elements import Grouping
|
||||
from ...sql.expression import BindParameter
|
||||
from ...sql.operators import OperatorType
|
||||
from ...sql.selectable import _SelectIterable
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _LiteralProcessorType
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
from ...sql.type_api import TypeEngine
|
||||
from ...sql.visitors import _TraverseInternalsType
|
||||
from ...util.typing import Self
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
_T = TypeVar("_T", bound=typing_Any)
|
||||
|
||||
|
||||
def Any(other, arrexpr, operator=operators.eq):
|
||||
def Any(
|
||||
other: typing_Any,
|
||||
arrexpr: _ColumnExpressionArgument[_T],
|
||||
operator: OperatorType = operators.eq,
|
||||
) -> ColumnElement[bool]:
|
||||
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
|
||||
See that method for details.
|
||||
|
||||
"""
|
||||
|
||||
return arrexpr.any(other, operator)
|
||||
return arrexpr.any(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
|
||||
|
||||
|
||||
def All(other, arrexpr, operator=operators.eq):
|
||||
def All(
|
||||
other: typing_Any,
|
||||
arrexpr: _ColumnExpressionArgument[_T],
|
||||
operator: OperatorType = operators.eq,
|
||||
) -> ColumnElement[bool]:
|
||||
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
|
||||
See that method for details.
|
||||
|
||||
"""
|
||||
|
||||
return arrexpr.all(other, operator)
|
||||
return arrexpr.all(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
|
||||
|
||||
|
||||
class array(expression.ExpressionClauseList[_T]):
|
||||
|
||||
"""A PostgreSQL ARRAY literal.
|
||||
|
||||
This is used to produce ARRAY literals in SQL expressions, e.g.::
|
||||
@@ -55,20 +81,43 @@ class array(expression.ExpressionClauseList[_T]):
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import select, func
|
||||
|
||||
stmt = select(array([1,2]) + array([3,4,5]))
|
||||
stmt = select(array([1, 2]) + array([3, 4, 5]))
|
||||
|
||||
print(stmt.compile(dialect=postgresql.dialect()))
|
||||
|
||||
Produces the SQL::
|
||||
Produces the SQL:
|
||||
|
||||
.. sourcecode:: sql
|
||||
|
||||
SELECT ARRAY[%(param_1)s, %(param_2)s] ||
|
||||
ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
|
||||
|
||||
An instance of :class:`.array` will always have the datatype
|
||||
:class:`_types.ARRAY`. The "inner" type of the array is inferred from
|
||||
the values present, unless the ``type_`` keyword argument is passed::
|
||||
:class:`_types.ARRAY`. The "inner" type of the array is inferred from the
|
||||
values present, unless the :paramref:`_postgresql.array.type_` keyword
|
||||
argument is passed::
|
||||
|
||||
array(['foo', 'bar'], type_=CHAR)
|
||||
array(["foo", "bar"], type_=CHAR)
|
||||
|
||||
When constructing an empty array, the :paramref:`_postgresql.array.type_`
|
||||
argument is particularly important as PostgreSQL server typically requires
|
||||
a cast to be rendered for the inner type in order to render an empty array.
|
||||
SQLAlchemy's compilation for the empty array will produce this cast so
|
||||
that::
|
||||
|
||||
stmt = array([], type_=Integer)
|
||||
print(stmt.compile(dialect=postgresql.dialect()))
|
||||
|
||||
Produces:
|
||||
|
||||
.. sourcecode:: sql
|
||||
|
||||
ARRAY[]::INTEGER[]
|
||||
|
||||
As required by PostgreSQL for empty arrays.
|
||||
|
||||
.. versionadded:: 2.0.40 added support to render empty PostgreSQL array
|
||||
literals with a required cast.
|
||||
|
||||
Multidimensional arrays are produced by nesting :class:`.array` constructs.
|
||||
The dimensionality of the final :class:`_types.ARRAY`
|
||||
@@ -77,16 +126,21 @@ class array(expression.ExpressionClauseList[_T]):
|
||||
type::
|
||||
|
||||
stmt = select(
|
||||
array([
|
||||
array([1, 2]), array([3, 4]), array([column('q'), column('x')])
|
||||
])
|
||||
array(
|
||||
[array([1, 2]), array([3, 4]), array([column("q"), column("x")])]
|
||||
)
|
||||
)
|
||||
print(stmt.compile(dialect=postgresql.dialect()))
|
||||
|
||||
Produces::
|
||||
Produces:
|
||||
|
||||
SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s],
|
||||
ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1
|
||||
.. sourcecode:: sql
|
||||
|
||||
SELECT ARRAY[
|
||||
ARRAY[%(param_1)s, %(param_2)s],
|
||||
ARRAY[%(param_3)s, %(param_4)s],
|
||||
ARRAY[q, x]
|
||||
] AS anon_1
|
||||
|
||||
.. versionadded:: 1.3.6 added support for multidimensional array literals
|
||||
|
||||
@@ -94,42 +148,63 @@ class array(expression.ExpressionClauseList[_T]):
|
||||
|
||||
:class:`_postgresql.ARRAY`
|
||||
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
__visit_name__ = "array"
|
||||
|
||||
stringify_dialect = "postgresql"
|
||||
inherit_cache = True
|
||||
|
||||
def __init__(self, clauses, **kw):
|
||||
type_arg = kw.pop("type_", None)
|
||||
_traverse_internals: _TraverseInternalsType = [
|
||||
("clauses", InternalTraversal.dp_clauseelement_tuple),
|
||||
("type", InternalTraversal.dp_type),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clauses: Iterable[_T],
|
||||
*,
|
||||
type_: Optional[_TypeEngineArgument[_T]] = None,
|
||||
**kw: typing_Any,
|
||||
):
|
||||
r"""Construct an ARRAY literal.
|
||||
|
||||
:param clauses: iterable, such as a list, containing elements to be
|
||||
rendered in the array
|
||||
:param type\_: optional type. If omitted, the type is inferred
|
||||
from the contents of the array.
|
||||
|
||||
"""
|
||||
super().__init__(operators.comma_op, *clauses, **kw)
|
||||
|
||||
self._type_tuple = [arg.type for arg in self.clauses]
|
||||
|
||||
main_type = (
|
||||
type_arg
|
||||
if type_arg is not None
|
||||
else self._type_tuple[0]
|
||||
if self._type_tuple
|
||||
else sqltypes.NULLTYPE
|
||||
type_
|
||||
if type_ is not None
|
||||
else self.clauses[0].type if self.clauses else sqltypes.NULLTYPE
|
||||
)
|
||||
|
||||
if isinstance(main_type, ARRAY):
|
||||
self.type = ARRAY(
|
||||
main_type.item_type,
|
||||
dimensions=main_type.dimensions + 1
|
||||
if main_type.dimensions is not None
|
||||
else 2,
|
||||
)
|
||||
dimensions=(
|
||||
main_type.dimensions + 1
|
||||
if main_type.dimensions is not None
|
||||
else 2
|
||||
),
|
||||
) # type: ignore[assignment]
|
||||
else:
|
||||
self.type = ARRAY(main_type)
|
||||
self.type = ARRAY(main_type) # type: ignore[assignment]
|
||||
|
||||
@property
|
||||
def _select_iterable(self):
|
||||
def _select_iterable(self) -> _SelectIterable:
|
||||
return (self,)
|
||||
|
||||
def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
|
||||
def _bind_param(
|
||||
self,
|
||||
operator: OperatorType,
|
||||
obj: typing_Any,
|
||||
type_: Optional[TypeEngine[_T]] = None,
|
||||
_assume_scalar: bool = False,
|
||||
) -> BindParameter[_T]:
|
||||
if _assume_scalar or operator is operators.getitem:
|
||||
return expression.BindParameter(
|
||||
None,
|
||||
@@ -148,16 +223,18 @@ class array(expression.ExpressionClauseList[_T]):
|
||||
)
|
||||
for o in obj
|
||||
]
|
||||
)
|
||||
) # type: ignore[return-value]
|
||||
|
||||
def self_group(self, against=None):
|
||||
def self_group(
|
||||
self, against: Optional[OperatorType] = None
|
||||
) -> Union[Self, Grouping[_T]]:
|
||||
if against in (operators.any_op, operators.all_op, operators.getitem):
|
||||
return expression.Grouping(self)
|
||||
else:
|
||||
return self
|
||||
|
||||
|
||||
class ARRAY(sqltypes.ARRAY):
|
||||
class ARRAY(sqltypes.ARRAY[_T]):
|
||||
"""PostgreSQL ARRAY type.
|
||||
|
||||
The :class:`_postgresql.ARRAY` type is constructed in the same way
|
||||
@@ -167,9 +244,11 @@ class ARRAY(sqltypes.ARRAY):
|
||||
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
mytable = Table("mytable", metadata,
|
||||
Column("data", postgresql.ARRAY(Integer, dimensions=2))
|
||||
)
|
||||
mytable = Table(
|
||||
"mytable",
|
||||
metadata,
|
||||
Column("data", postgresql.ARRAY(Integer, dimensions=2)),
|
||||
)
|
||||
|
||||
The :class:`_postgresql.ARRAY` type provides all operations defined on the
|
||||
core :class:`_types.ARRAY` type, including support for "dimensions",
|
||||
@@ -184,8 +263,9 @@ class ARRAY(sqltypes.ARRAY):
|
||||
|
||||
mytable.c.data.contains([1, 2])
|
||||
|
||||
The :class:`_postgresql.ARRAY` type may not be supported on all
|
||||
PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
|
||||
Indexed access is one-based by default, to match that of PostgreSQL;
|
||||
for zero-based indexed access, set
|
||||
:paramref:`_postgresql.ARRAY.zero_indexes`.
|
||||
|
||||
Additionally, the :class:`_postgresql.ARRAY`
|
||||
type does not work directly in
|
||||
@@ -204,6 +284,7 @@ class ARRAY(sqltypes.ARRAY):
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from sqlalchemy.ext.mutable import MutableList
|
||||
|
||||
|
||||
class SomeOrmClass(Base):
|
||||
# ...
|
||||
|
||||
@@ -225,45 +306,9 @@ class ARRAY(sqltypes.ARRAY):
|
||||
|
||||
"""
|
||||
|
||||
class Comparator(sqltypes.ARRAY.Comparator):
|
||||
|
||||
"""Define comparison operations for :class:`_types.ARRAY`.
|
||||
|
||||
Note that these operations are in addition to those provided
|
||||
by the base :class:`.types.ARRAY.Comparator` class, including
|
||||
:meth:`.types.ARRAY.Comparator.any` and
|
||||
:meth:`.types.ARRAY.Comparator.all`.
|
||||
|
||||
"""
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
"""Boolean expression. Test if elements are a superset of the
|
||||
elements of the argument array expression.
|
||||
|
||||
kwargs may be ignored by this operator but are required for API
|
||||
conformance.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Test if elements are a proper subset of the
|
||||
elements of the argument array expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def overlap(self, other):
|
||||
"""Boolean expression. Test if array has elements in common with
|
||||
an argument array expression.
|
||||
"""
|
||||
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item_type: _TypeEngineArgument[Any],
|
||||
item_type: _TypeEngineArgument[_T],
|
||||
as_tuple: bool = False,
|
||||
dimensions: Optional[int] = None,
|
||||
zero_indexes: bool = False,
|
||||
@@ -272,7 +317,7 @@ class ARRAY(sqltypes.ARRAY):
|
||||
|
||||
E.g.::
|
||||
|
||||
Column('myarray', ARRAY(Integer))
|
||||
Column("myarray", ARRAY(Integer))
|
||||
|
||||
Arguments are:
|
||||
|
||||
@@ -312,35 +357,63 @@ class ARRAY(sqltypes.ARRAY):
|
||||
self.dimensions = dimensions
|
||||
self.zero_indexes = zero_indexes
|
||||
|
||||
@property
|
||||
def hashable(self):
|
||||
return self.as_tuple
|
||||
class Comparator(sqltypes.ARRAY.Comparator[_T]):
|
||||
"""Define comparison operations for :class:`_types.ARRAY`.
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return list
|
||||
Note that these operations are in addition to those provided
|
||||
by the base :class:`.types.ARRAY.Comparator` class, including
|
||||
:meth:`.types.ARRAY.Comparator.any` and
|
||||
:meth:`.types.ARRAY.Comparator.all`.
|
||||
|
||||
def compare_values(self, x, y):
|
||||
return x == y
|
||||
"""
|
||||
|
||||
def contains(
|
||||
self, other: typing_Any, **kwargs: typing_Any
|
||||
) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test if elements are a superset of the
|
||||
elements of the argument array expression.
|
||||
|
||||
kwargs may be ignored by this operator but are required for API
|
||||
conformance.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other: typing_Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test if elements are a proper subset of the
|
||||
elements of the argument array expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def overlap(self, other: typing_Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test if array has elements in common with
|
||||
an argument array expression.
|
||||
"""
|
||||
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
@util.memoized_property
|
||||
def _against_native_enum(self):
|
||||
def _against_native_enum(self) -> bool:
|
||||
return (
|
||||
isinstance(self.item_type, sqltypes.Enum)
|
||||
and self.item_type.native_enum
|
||||
)
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
def literal_processor(
|
||||
self, dialect: Dialect
|
||||
) -> Optional[_LiteralProcessorType[_T]]:
|
||||
item_proc = self.item_type.dialect_impl(dialect).literal_processor(
|
||||
dialect
|
||||
)
|
||||
if item_proc is None:
|
||||
return None
|
||||
|
||||
def to_str(elements):
|
||||
def to_str(elements: Iterable[typing_Any]) -> str:
|
||||
return f"ARRAY[{', '.join(elements)}]"
|
||||
|
||||
def process(value):
|
||||
def process(value: Sequence[typing_Any]) -> str:
|
||||
inner = self._apply_item_processor(
|
||||
value, item_proc, self.dimensions, to_str
|
||||
)
|
||||
@@ -348,12 +421,16 @@ class ARRAY(sqltypes.ARRAY):
|
||||
|
||||
return process
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
def bind_processor(
|
||||
self, dialect: Dialect
|
||||
) -> Optional[_BindProcessorType[Sequence[typing_Any]]]:
|
||||
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
|
||||
dialect
|
||||
)
|
||||
|
||||
def process(value):
|
||||
def process(
|
||||
value: Optional[Sequence[typing_Any]],
|
||||
) -> Optional[list[typing_Any]]:
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
@@ -363,12 +440,16 @@ class ARRAY(sqltypes.ARRAY):
|
||||
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> _ResultProcessorType[Sequence[typing_Any]]:
|
||||
item_proc = self.item_type.dialect_impl(dialect).result_processor(
|
||||
dialect, coltype
|
||||
)
|
||||
|
||||
def process(value):
|
||||
def process(
|
||||
value: Sequence[typing_Any],
|
||||
) -> Optional[Sequence[typing_Any]]:
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
@@ -383,11 +464,13 @@ class ARRAY(sqltypes.ARRAY):
|
||||
super_rp = process
|
||||
pattern = re.compile(r"^{(.*)}$")
|
||||
|
||||
def handle_raw_string(value):
|
||||
inner = pattern.match(value).group(1)
|
||||
def handle_raw_string(value: str) -> list[str]:
|
||||
inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501
|
||||
return _split_enum_values(inner)
|
||||
|
||||
def process(value):
|
||||
def process(
|
||||
value: Sequence[typing_Any],
|
||||
) -> Optional[Sequence[typing_Any]]:
|
||||
if value is None:
|
||||
return value
|
||||
# isinstance(value, str) is required to handle
|
||||
@@ -402,7 +485,7 @@ class ARRAY(sqltypes.ARRAY):
|
||||
return process
|
||||
|
||||
|
||||
def _split_enum_values(array_string):
|
||||
def _split_enum_values(array_string: str) -> list[str]:
|
||||
if '"' not in array_string:
|
||||
# no escape char is present so it can just split on the comma
|
||||
return array_string.split(",") if array_string else []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# postgresql/asyncpg.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# dialects/postgresql/asyncpg.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -23,18 +23,10 @@ This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
|
||||
|
||||
The dialect can also be run as a "synchronous" dialect within the
|
||||
:func:`_sa.create_engine` function, which will pass "await" calls into
|
||||
an ad-hoc event loop. This mode of operation is of **limited use**
|
||||
and is for special testing scenarios only. The mode can be enabled by
|
||||
adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
|
||||
in conjunction with :func:`_sa.create_engine`::
|
||||
|
||||
# for testing purposes only; do not use in production!
|
||||
engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname"
|
||||
)
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
@@ -89,11 +81,15 @@ asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
|
||||
argument)::
|
||||
|
||||
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500")
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500"
|
||||
)
|
||||
|
||||
To disable the prepared statement cache, use a value of zero::
|
||||
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0")
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0"
|
||||
)
|
||||
|
||||
.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
|
||||
|
||||
@@ -123,8 +119,8 @@ To disable the prepared statement cache, use a value of zero::
|
||||
|
||||
.. _asyncpg_prepared_statement_name:
|
||||
|
||||
Prepared Statement Name
|
||||
-----------------------
|
||||
Prepared Statement Name with PGBouncer
|
||||
--------------------------------------
|
||||
|
||||
By default, asyncpg enumerates prepared statements in numeric order, which
|
||||
can lead to errors if a name has already been taken for another prepared
|
||||
@@ -139,10 +135,10 @@ a prepared statement is prepared::
|
||||
from uuid import uuid4
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname",
|
||||
"postgresql+asyncpg://user:pass@somepgbouncer/dbname",
|
||||
poolclass=NullPool,
|
||||
connect_args={
|
||||
'prepared_statement_name_func': lambda: f'__asyncpg_{uuid4()}__',
|
||||
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -152,7 +148,7 @@ a prepared statement is prepared::
|
||||
|
||||
https://github.com/sqlalchemy/sqlalchemy/issues/6467
|
||||
|
||||
.. warning:: To prevent a buildup of useless prepared statements in
|
||||
.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in
|
||||
your application, it's important to use the :class:`.NullPool` pool
|
||||
class, and to configure PgBouncer to use `DISCARD <https://www.postgresql.org/docs/current/sql-discard.html>`_
|
||||
when returning connections. The DISCARD command is used to release resources held by the db connection,
|
||||
@@ -182,13 +178,11 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections import deque
|
||||
import decimal
|
||||
import json as _py_json
|
||||
import re
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import json
|
||||
from . import ranges
|
||||
@@ -218,9 +212,6 @@ from ...util.concurrency import asyncio
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class AsyncpgARRAY(PGARRAY):
|
||||
render_bind_cast = True
|
||||
@@ -274,20 +265,20 @@ class AsyncpgInteger(sqltypes.Integer):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class AsyncpgSmallInteger(sqltypes.SmallInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class AsyncpgBigInteger(sqltypes.BigInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class AsyncpgJSON(json.JSON):
|
||||
render_bind_cast = True
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
|
||||
class AsyncpgJSONB(json.JSONB):
|
||||
render_bind_cast = True
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
@@ -372,7 +363,7 @@ class AsyncpgCHAR(sqltypes.CHAR):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _AsyncpgRange(ranges.AbstractRangeImpl):
|
||||
class _AsyncpgRange(ranges.AbstractSingleRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
asyncpg_Range = dialect.dbapi.asyncpg.Range
|
||||
|
||||
@@ -426,10 +417,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
)
|
||||
return value
|
||||
|
||||
return [
|
||||
to_range(element)
|
||||
for element in cast("Iterable[ranges.Range]", value)
|
||||
]
|
||||
return [to_range(element) for element in value]
|
||||
|
||||
return to_range
|
||||
|
||||
@@ -448,7 +436,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
return rvalue
|
||||
|
||||
if value is not None:
|
||||
value = [to_range(elem) for elem in value]
|
||||
value = ranges.MultiRange(to_range(elem) for elem in value)
|
||||
|
||||
return value
|
||||
|
||||
@@ -506,7 +494,7 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self._rows = []
|
||||
self._rows = deque()
|
||||
self._cursor = None
|
||||
self.description = None
|
||||
self.arraysize = 1
|
||||
@@ -514,7 +502,7 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
self._invalidate_schema_cache_asof = 0
|
||||
|
||||
def close(self):
|
||||
self._rows[:] = []
|
||||
self._rows.clear()
|
||||
|
||||
def _handle_exception(self, error):
|
||||
self._adapt_connection._handle_exception(error)
|
||||
@@ -554,11 +542,12 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
self._cursor = await prepared_stmt.cursor(*parameters)
|
||||
self.rowcount = -1
|
||||
else:
|
||||
self._rows = await prepared_stmt.fetch(*parameters)
|
||||
self._rows = deque(await prepared_stmt.fetch(*parameters))
|
||||
status = prepared_stmt.get_statusmsg()
|
||||
|
||||
reg = re.match(
|
||||
r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status
|
||||
r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)",
|
||||
status or "",
|
||||
)
|
||||
if reg:
|
||||
self.rowcount = int(reg.group(1))
|
||||
@@ -602,11 +591,11 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
yield self._rows.popleft()
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.pop(0)
|
||||
return self._rows.popleft()
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -614,13 +603,12 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
rr = self._rows
|
||||
return [rr.popleft() for _ in range(min(size, len(rr)))]
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
retval = list(self._rows)
|
||||
self._rows.clear()
|
||||
return retval
|
||||
|
||||
|
||||
@@ -630,23 +618,21 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
super().__init__(adapt_connection)
|
||||
self._rowbuffer = None
|
||||
self._rowbuffer = deque()
|
||||
|
||||
def close(self):
|
||||
self._cursor = None
|
||||
self._rowbuffer = None
|
||||
self._rowbuffer.clear()
|
||||
|
||||
def _buffer_rows(self):
|
||||
assert self._cursor is not None
|
||||
new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
|
||||
self._rowbuffer = collections.deque(new_rows)
|
||||
self._rowbuffer.extend(new_rows)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._rowbuffer:
|
||||
self._buffer_rows()
|
||||
|
||||
while True:
|
||||
while self._rowbuffer:
|
||||
yield self._rowbuffer.popleft()
|
||||
@@ -669,21 +655,19 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
|
||||
if not self._rowbuffer:
|
||||
self._buffer_rows()
|
||||
|
||||
buf = list(self._rowbuffer)
|
||||
lb = len(buf)
|
||||
assert self._cursor is not None
|
||||
rb = self._rowbuffer
|
||||
lb = len(rb)
|
||||
if size > lb:
|
||||
buf.extend(
|
||||
rb.extend(
|
||||
self._adapt_connection.await_(self._cursor.fetch(size - lb))
|
||||
)
|
||||
|
||||
result = buf[0:size]
|
||||
self._rowbuffer = collections.deque(buf[size:])
|
||||
return result
|
||||
return [rb.popleft() for _ in range(min(size, len(rb)))]
|
||||
|
||||
def fetchall(self):
|
||||
ret = list(self._rowbuffer) + list(
|
||||
self._adapt_connection.await_(self._all())
|
||||
)
|
||||
ret = list(self._rowbuffer)
|
||||
ret.extend(self._adapt_connection.await_(self._all()))
|
||||
self._rowbuffer.clear()
|
||||
return ret
|
||||
|
||||
@@ -733,7 +717,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
self.isolation_level = self._isolation_setting = "read_committed"
|
||||
self.isolation_level = self._isolation_setting = None
|
||||
self.readonly = False
|
||||
self.deferrable = False
|
||||
self._transaction = None
|
||||
@@ -802,9 +786,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
translated_error = exception_mapping[super_](
|
||||
"%s: %s" % (type(error), error)
|
||||
)
|
||||
translated_error.pgcode = (
|
||||
translated_error.sqlstate
|
||||
) = getattr(error, "sqlstate", None)
|
||||
translated_error.pgcode = translated_error.sqlstate = (
|
||||
getattr(error, "sqlstate", None)
|
||||
)
|
||||
raise translated_error from error
|
||||
else:
|
||||
raise error
|
||||
@@ -868,25 +852,45 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
else:
|
||||
return AsyncAdapt_asyncpg_cursor(self)
|
||||
|
||||
async def _rollback_and_discard(self):
|
||||
try:
|
||||
await self._transaction.rollback()
|
||||
finally:
|
||||
# if asyncpg .rollback() was actually called, then whether or
|
||||
# not it raised or succeeded, the transation is done, discard it
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
|
||||
async def _commit_and_discard(self):
|
||||
try:
|
||||
await self._transaction.commit()
|
||||
finally:
|
||||
# if asyncpg .commit() was actually called, then whether or
|
||||
# not it raised or succeeded, the transation is done, discard it
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
|
||||
def rollback(self):
|
||||
if self._started:
|
||||
try:
|
||||
self.await_(self._transaction.rollback())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
finally:
|
||||
self.await_(self._rollback_and_discard())
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
except Exception as error:
|
||||
# don't dereference asyncpg transaction if we didn't
|
||||
# actually try to call rollback() on it
|
||||
self._handle_exception(error)
|
||||
|
||||
def commit(self):
|
||||
if self._started:
|
||||
try:
|
||||
self.await_(self._transaction.commit())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
finally:
|
||||
self.await_(self._commit_and_discard())
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
except Exception as error:
|
||||
# don't dereference asyncpg transaction if we didn't
|
||||
# actually try to call commit() on it
|
||||
self._handle_exception(error)
|
||||
|
||||
def close(self):
|
||||
self.rollback()
|
||||
@@ -894,7 +898,31 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
self.await_(self._connection.close())
|
||||
|
||||
def terminate(self):
|
||||
self._connection.terminate()
|
||||
if util.concurrency.in_greenlet():
|
||||
# in a greenlet; this is the connection was invalidated
|
||||
# case.
|
||||
try:
|
||||
# try to gracefully close; see #10717
|
||||
# timeout added in asyncpg 0.14.0 December 2017
|
||||
self.await_(asyncio.shield(self._connection.close(timeout=2)))
|
||||
except (
|
||||
asyncio.TimeoutError,
|
||||
asyncio.CancelledError,
|
||||
OSError,
|
||||
self.dbapi.asyncpg.PostgresError,
|
||||
) as e:
|
||||
# in the case where we are recycling an old connection
|
||||
# that may have already been disconnected, close() will
|
||||
# fail with the above timeout. in this case, terminate
|
||||
# the connection without any further waiting.
|
||||
# see issue #8419
|
||||
self._connection.terminate()
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
# re-raise CancelledError if we were cancelled
|
||||
raise
|
||||
else:
|
||||
# not in a greenlet; this is the gc cleanup case
|
||||
self._connection.terminate()
|
||||
self._started = False
|
||||
|
||||
@staticmethod
|
||||
@@ -1031,6 +1059,7 @@ class PGDialect_asyncpg(PGDialect):
|
||||
INTERVAL: AsyncPgInterval,
|
||||
sqltypes.Boolean: AsyncpgBoolean,
|
||||
sqltypes.Integer: AsyncpgInteger,
|
||||
sqltypes.SmallInteger: AsyncpgSmallInteger,
|
||||
sqltypes.BigInteger: AsyncpgBigInteger,
|
||||
sqltypes.Numeric: AsyncpgNumeric,
|
||||
sqltypes.Float: AsyncpgFloat,
|
||||
@@ -1045,7 +1074,7 @@ class PGDialect_asyncpg(PGDialect):
|
||||
OID: AsyncpgOID,
|
||||
REGCLASS: AsyncpgREGCLASS,
|
||||
sqltypes.CHAR: AsyncpgCHAR,
|
||||
ranges.AbstractRange: _AsyncpgRange,
|
||||
ranges.AbstractSingleRange: _AsyncpgRange,
|
||||
ranges.AbstractMultiRange: _AsyncpgMultiRange,
|
||||
},
|
||||
)
|
||||
@@ -1088,6 +1117,9 @@ class PGDialect_asyncpg(PGDialect):
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
dbapi_connection.set_isolation_level(self._isolation_lookup[level])
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def set_readonly(self, connection, value):
|
||||
connection.readonly = value
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
# postgresql/dml.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/dml.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -7,7 +7,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from . import ext
|
||||
from .._typing import _OnConflictConstraintT
|
||||
@@ -26,7 +29,9 @@ from ...sql.base import ColumnCollection
|
||||
from ...sql.base import ReadOnlyColumnCollection
|
||||
from ...sql.dml import Insert as StandardInsert
|
||||
from ...sql.elements import ClauseElement
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.elements import KeyedColumnElement
|
||||
from ...sql.elements import TextClause
|
||||
from ...sql.expression import alias
|
||||
from ...util.typing import Self
|
||||
|
||||
@@ -153,11 +158,10 @@ class Insert(StandardInsert):
|
||||
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
|
||||
|
||||
:param where:
|
||||
Optional argument. If present, can be a literal SQL
|
||||
string or an acceptable expression for a ``WHERE`` clause
|
||||
that restricts the rows affected by ``DO UPDATE SET``. Rows
|
||||
not meeting the ``WHERE`` condition will not be updated
|
||||
(effectively a ``DO NOTHING`` for those rows).
|
||||
Optional argument. An expression object representing a ``WHERE``
|
||||
clause that restricts the rows affected by ``DO UPDATE SET``. Rows not
|
||||
meeting the ``WHERE`` condition will not be updated (effectively a
|
||||
``DO NOTHING`` for those rows).
|
||||
|
||||
|
||||
.. seealso::
|
||||
@@ -212,8 +216,10 @@ class OnConflictClause(ClauseElement):
|
||||
stringify_dialect = "postgresql"
|
||||
|
||||
constraint_target: Optional[str]
|
||||
inferred_target_elements: _OnConflictIndexElementsT
|
||||
inferred_target_whereclause: _OnConflictIndexWhereT
|
||||
inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]]
|
||||
inferred_target_whereclause: Optional[
|
||||
Union[ColumnElement[Any], TextClause]
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -254,12 +260,28 @@ class OnConflictClause(ClauseElement):
|
||||
|
||||
if index_elements is not None:
|
||||
self.constraint_target = None
|
||||
self.inferred_target_elements = index_elements
|
||||
self.inferred_target_whereclause = index_where
|
||||
self.inferred_target_elements = [
|
||||
coercions.expect(roles.DDLConstraintColumnRole, column)
|
||||
for column in index_elements
|
||||
]
|
||||
|
||||
self.inferred_target_whereclause = (
|
||||
coercions.expect(
|
||||
(
|
||||
roles.StatementOptionRole
|
||||
if isinstance(constraint, ext.ExcludeConstraint)
|
||||
else roles.WhereHavingRole
|
||||
),
|
||||
index_where,
|
||||
)
|
||||
if index_where is not None
|
||||
else None
|
||||
)
|
||||
|
||||
elif constraint is None:
|
||||
self.constraint_target = (
|
||||
self.inferred_target_elements
|
||||
) = self.inferred_target_whereclause = None
|
||||
self.constraint_target = self.inferred_target_elements = (
|
||||
self.inferred_target_whereclause
|
||||
) = None
|
||||
|
||||
|
||||
class OnConflictDoNothing(OnConflictClause):
|
||||
@@ -269,6 +291,9 @@ class OnConflictDoNothing(OnConflictClause):
|
||||
class OnConflictDoUpdate(OnConflictClause):
|
||||
__visit_name__ = "on_conflict_do_update"
|
||||
|
||||
update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]]
|
||||
update_whereclause: Optional[ColumnElement[Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constraint: _OnConflictConstraintT = None,
|
||||
@@ -307,4 +332,8 @@ class OnConflictDoUpdate(OnConflictClause):
|
||||
(coercions.expect(roles.DMLColumnRole, key), value)
|
||||
for key, value in set_.items()
|
||||
]
|
||||
self.update_whereclause = where
|
||||
self.update_whereclause = (
|
||||
coercions.expect(roles.WhereHavingRole, where)
|
||||
if where is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# postgresql/ext.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/ext.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -8,6 +8,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -23,34 +27,44 @@ from ...sql.schema import ColumnCollectionConstraint
|
||||
from ...sql.sqltypes import TEXT
|
||||
from ...sql.visitors import InternalTraversal
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...sql._typing import _ColumnExpressionArgument
|
||||
from ...sql.elements import ClauseElement
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.operators import OperatorType
|
||||
from ...sql.selectable import FromClause
|
||||
from ...sql.visitors import _CloneCallableType
|
||||
from ...sql.visitors import _TraverseInternalsType
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
|
||||
class aggregate_order_by(expression.ColumnElement):
|
||||
|
||||
class aggregate_order_by(expression.ColumnElement[_T]):
|
||||
"""Represent a PostgreSQL aggregate order by expression.
|
||||
|
||||
E.g.::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import aggregate_order_by
|
||||
|
||||
expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
|
||||
stmt = select(expr)
|
||||
|
||||
would represent the expression::
|
||||
would represent the expression:
|
||||
|
||||
.. sourcecode:: sql
|
||||
|
||||
SELECT array_agg(a ORDER BY b DESC) FROM table;
|
||||
|
||||
Similarly::
|
||||
|
||||
expr = func.string_agg(
|
||||
table.c.a,
|
||||
aggregate_order_by(literal_column("','"), table.c.a)
|
||||
table.c.a, aggregate_order_by(literal_column("','"), table.c.a)
|
||||
)
|
||||
stmt = select(expr)
|
||||
|
||||
Would represent::
|
||||
Would represent:
|
||||
|
||||
.. sourcecode:: sql
|
||||
|
||||
SELECT string_agg(a, ',' ORDER BY a) FROM table;
|
||||
|
||||
@@ -71,11 +85,32 @@ class aggregate_order_by(expression.ColumnElement):
|
||||
("order_by", InternalTraversal.dp_clauseelement),
|
||||
]
|
||||
|
||||
def __init__(self, target, *order_by):
|
||||
self.target = coercions.expect(roles.ExpressionElementRole, target)
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: ColumnElement[_T],
|
||||
*order_by: _ColumnExpressionArgument[Any],
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: _ColumnExpressionArgument[_T],
|
||||
*order_by: _ColumnExpressionArgument[Any],
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: _ColumnExpressionArgument[_T],
|
||||
*order_by: _ColumnExpressionArgument[Any],
|
||||
):
|
||||
self.target: ClauseElement = coercions.expect(
|
||||
roles.ExpressionElementRole, target
|
||||
)
|
||||
self.type = self.target.type
|
||||
|
||||
_lob = len(order_by)
|
||||
self.order_by: ClauseElement
|
||||
if _lob == 0:
|
||||
raise TypeError("at least one ORDER BY element is required")
|
||||
elif _lob == 1:
|
||||
@@ -87,18 +122,22 @@ class aggregate_order_by(expression.ColumnElement):
|
||||
*order_by, _literal_as_text_role=roles.ExpressionElementRole
|
||||
)
|
||||
|
||||
def self_group(self, against=None):
|
||||
def self_group(
|
||||
self, against: Optional[OperatorType] = None
|
||||
) -> ClauseElement:
|
||||
return self
|
||||
|
||||
def get_children(self, **kwargs):
|
||||
def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]:
|
||||
return self.target, self.order_by
|
||||
|
||||
def _copy_internals(self, clone=elements._clone, **kw):
|
||||
def _copy_internals(
|
||||
self, clone: _CloneCallableType = elements._clone, **kw: Any
|
||||
) -> None:
|
||||
self.target = clone(self.target, **kw)
|
||||
self.order_by = clone(self.order_by, **kw)
|
||||
|
||||
@property
|
||||
def _from_objects(self):
|
||||
def _from_objects(self) -> List[FromClause]:
|
||||
return self.target._from_objects + self.order_by._from_objects
|
||||
|
||||
|
||||
@@ -131,10 +170,10 @@ class ExcludeConstraint(ColumnCollectionConstraint):
|
||||
E.g.::
|
||||
|
||||
const = ExcludeConstraint(
|
||||
(Column('period'), '&&'),
|
||||
(Column('group'), '='),
|
||||
where=(Column('group') != 'some group'),
|
||||
ops={'group': 'my_operator_class'}
|
||||
(Column("period"), "&&"),
|
||||
(Column("group"), "="),
|
||||
where=(Column("group") != "some group"),
|
||||
ops={"group": "my_operator_class"},
|
||||
)
|
||||
|
||||
The constraint is normally embedded into the :class:`_schema.Table`
|
||||
@@ -142,19 +181,20 @@ class ExcludeConstraint(ColumnCollectionConstraint):
|
||||
directly, or added later using :meth:`.append_constraint`::
|
||||
|
||||
some_table = Table(
|
||||
'some_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('period', TSRANGE()),
|
||||
Column('group', String)
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("period", TSRANGE()),
|
||||
Column("group", String),
|
||||
)
|
||||
|
||||
some_table.append_constraint(
|
||||
ExcludeConstraint(
|
||||
(some_table.c.period, '&&'),
|
||||
(some_table.c.group, '='),
|
||||
where=some_table.c.group != 'some group',
|
||||
name='some_table_excl_const',
|
||||
ops={'group': 'my_operator_class'}
|
||||
(some_table.c.period, "&&"),
|
||||
(some_table.c.group, "="),
|
||||
where=some_table.c.group != "some group",
|
||||
name="some_table_excl_const",
|
||||
ops={"group": "my_operator_class"},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# postgresql/hstore.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/hstore.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -28,28 +28,29 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
|
||||
The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
|
||||
|
||||
data_table = Table('data_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', HSTORE)
|
||||
data_table = Table(
|
||||
"data_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", HSTORE),
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
data_table.insert(),
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
data_table.insert(), data={"key1": "value1", "key2": "value2"}
|
||||
)
|
||||
|
||||
:class:`.HSTORE` provides for a wide range of operations, including:
|
||||
|
||||
* Index operations::
|
||||
|
||||
data_table.c.data['some key'] == 'some value'
|
||||
data_table.c.data["some key"] == "some value"
|
||||
|
||||
* Containment operations::
|
||||
|
||||
data_table.c.data.has_key('some key')
|
||||
data_table.c.data.has_key("some key")
|
||||
|
||||
data_table.c.data.has_all(['one', 'two', 'three'])
|
||||
data_table.c.data.has_all(["one", "two", "three"])
|
||||
|
||||
* Concatenation::
|
||||
|
||||
@@ -72,17 +73,19 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = 'data_table'
|
||||
__tablename__ = "data_table"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(MutableDict.as_mutable(HSTORE))
|
||||
|
||||
|
||||
my_object = session.query(MyClass).one()
|
||||
|
||||
# in-place mutation, requires Mutable extension
|
||||
# in order for the ORM to detect
|
||||
my_object.data['some_key'] = 'some value'
|
||||
my_object.data["some_key"] = "some value"
|
||||
|
||||
session.commit()
|
||||
|
||||
@@ -96,7 +99,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
:class:`.hstore` - render the PostgreSQL ``hstore()`` function.
|
||||
|
||||
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
__visit_name__ = "HSTORE"
|
||||
hashable = False
|
||||
@@ -192,6 +195,9 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
comparator_factory = Comparator
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
# note that dialect-specific types like that of psycopg and
|
||||
# psycopg2 will override this method to allow driver-level conversion
|
||||
# instead, see _PsycopgHStore
|
||||
def process(value):
|
||||
if isinstance(value, dict):
|
||||
return _serialize_hstore(value)
|
||||
@@ -201,6 +207,9 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
# note that dialect-specific types like that of psycopg and
|
||||
# psycopg2 will override this method to allow driver-level conversion
|
||||
# instead, see _PsycopgHStore
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return _parse_hstore(value)
|
||||
@@ -221,12 +230,12 @@ class hstore(sqlfunc.GenericFunction):
|
||||
|
||||
from sqlalchemy.dialects.postgresql import array, hstore
|
||||
|
||||
select(hstore('key1', 'value1'))
|
||||
select(hstore("key1", "value1"))
|
||||
|
||||
select(
|
||||
hstore(
|
||||
array(['key1', 'key2', 'key3']),
|
||||
array(['value1', 'value2', 'value3'])
|
||||
array(["key1", "key2", "key3"]),
|
||||
array(["value1", "value2", "value3"]),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
# postgresql/json.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/json.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .array import ARRAY
|
||||
from .array import array as _pg_array
|
||||
@@ -21,13 +28,23 @@ from .operators import PATH_EXISTS
|
||||
from .operators import PATH_MATCH
|
||||
from ... import types as sqltypes
|
||||
from ...sql import cast
|
||||
from ...sql._typing import _T
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _LiteralProcessorType
|
||||
from ...sql.type_api import TypeEngine
|
||||
|
||||
__all__ = ("JSON", "JSONB")
|
||||
|
||||
|
||||
class JSONPathType(sqltypes.JSON.JSONPathType):
|
||||
def _processor(self, dialect, super_proc):
|
||||
def process(value):
|
||||
def _processor(
|
||||
self, dialect: Dialect, super_proc: Optional[Callable[[Any], Any]]
|
||||
) -> Callable[[Any], Any]:
|
||||
def process(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
# If it's already a string assume that it's in json path
|
||||
# format. This allows using cast with json paths literals
|
||||
@@ -44,11 +61,13 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
|
||||
|
||||
return process
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return self._processor(dialect, self.string_bind_processor(dialect))
|
||||
def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
|
||||
return self._processor(dialect, self.string_bind_processor(dialect)) # type: ignore[return-value] # noqa: E501
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
return self._processor(dialect, self.string_literal_processor(dialect))
|
||||
def literal_processor(
|
||||
self, dialect: Dialect
|
||||
) -> _LiteralProcessorType[Any]:
|
||||
return self._processor(dialect, self.string_literal_processor(dialect)) # type: ignore[return-value] # noqa: E501
|
||||
|
||||
|
||||
class JSONPATH(JSONPathType):
|
||||
@@ -90,14 +109,14 @@ class JSON(sqltypes.JSON):
|
||||
|
||||
* Index operations (the ``->`` operator)::
|
||||
|
||||
data_table.c.data['some key']
|
||||
data_table.c.data["some key"]
|
||||
|
||||
data_table.c.data[5]
|
||||
|
||||
* Index operations returning text
|
||||
(the ``->>`` operator)::
|
||||
|
||||
* Index operations returning text (the ``->>`` operator)::
|
||||
|
||||
data_table.c.data['some key'].astext == 'some value'
|
||||
data_table.c.data["some key"].astext == "some value"
|
||||
|
||||
Note that equivalent functionality is available via the
|
||||
:attr:`.JSON.Comparator.as_string` accessor.
|
||||
@@ -105,18 +124,20 @@ class JSON(sqltypes.JSON):
|
||||
* Index operations with CAST
|
||||
(equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
|
||||
|
||||
data_table.c.data['some key'].astext.cast(Integer) == 5
|
||||
data_table.c.data["some key"].astext.cast(Integer) == 5
|
||||
|
||||
Note that equivalent functionality is available via the
|
||||
:attr:`.JSON.Comparator.as_integer` and similar accessors.
|
||||
|
||||
* Path index operations (the ``#>`` operator)::
|
||||
|
||||
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
|
||||
data_table.c.data[("key_1", "key_2", 5, ..., "key_n")]
|
||||
|
||||
* Path index operations returning text (the ``#>>`` operator)::
|
||||
|
||||
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value'
|
||||
data_table.c.data[
|
||||
("key_1", "key_2", 5, ..., "key_n")
|
||||
].astext == "some value"
|
||||
|
||||
Index operations return an expression object whose type defaults to
|
||||
:class:`_types.JSON` by default,
|
||||
@@ -128,10 +149,11 @@ class JSON(sqltypes.JSON):
|
||||
using psycopg2, the DBAPI only allows serializers at the per-cursor
|
||||
or per-connection level. E.g.::
|
||||
|
||||
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
json_serializer=my_serialize_fn,
|
||||
json_deserializer=my_deserialize_fn
|
||||
)
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
json_serializer=my_serialize_fn,
|
||||
json_deserializer=my_deserialize_fn,
|
||||
)
|
||||
|
||||
When using the psycopg2 dialect, the json_deserializer is registered
|
||||
against the database using ``psycopg2.extras.register_default_json``.
|
||||
@@ -144,9 +166,14 @@ class JSON(sqltypes.JSON):
|
||||
|
||||
""" # noqa
|
||||
|
||||
astext_type = sqltypes.Text()
|
||||
render_bind_cast = True
|
||||
astext_type: TypeEngine[str] = sqltypes.Text()
|
||||
|
||||
def __init__(self, none_as_null=False, astext_type=None):
|
||||
def __init__(
|
||||
self,
|
||||
none_as_null: bool = False,
|
||||
astext_type: Optional[TypeEngine[str]] = None,
|
||||
):
|
||||
"""Construct a :class:`_types.JSON` type.
|
||||
|
||||
:param none_as_null: if True, persist the value ``None`` as a
|
||||
@@ -155,7 +182,8 @@ class JSON(sqltypes.JSON):
|
||||
be used to persist a NULL value::
|
||||
|
||||
from sqlalchemy import null
|
||||
conn.execute(table.insert(), data=null())
|
||||
|
||||
conn.execute(table.insert(), {"data": null()})
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -170,17 +198,19 @@ class JSON(sqltypes.JSON):
|
||||
if astext_type is not None:
|
||||
self.astext_type = astext_type
|
||||
|
||||
class Comparator(sqltypes.JSON.Comparator):
|
||||
class Comparator(sqltypes.JSON.Comparator[_T]):
|
||||
"""Define comparison operations for :class:`_types.JSON`."""
|
||||
|
||||
type: JSON
|
||||
|
||||
@property
|
||||
def astext(self):
|
||||
def astext(self) -> ColumnElement[str]:
|
||||
"""On an indexed expression, use the "astext" (e.g. "->>")
|
||||
conversion when rendered in SQL.
|
||||
|
||||
E.g.::
|
||||
|
||||
select(data_table.c.data['some key'].astext)
|
||||
select(data_table.c.data["some key"].astext)
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -188,13 +218,13 @@ class JSON(sqltypes.JSON):
|
||||
|
||||
"""
|
||||
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
|
||||
return self.expr.left.operate(
|
||||
return self.expr.left.operate( # type: ignore[no-any-return]
|
||||
JSONPATH_ASTEXT,
|
||||
self.expr.right,
|
||||
result_type=self.type.astext_type,
|
||||
)
|
||||
else:
|
||||
return self.expr.left.operate(
|
||||
return self.expr.left.operate( # type: ignore[no-any-return]
|
||||
ASTEXT, self.expr.right, result_type=self.type.astext_type
|
||||
)
|
||||
|
||||
@@ -207,15 +237,16 @@ class JSONB(JSON):
|
||||
The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
|
||||
e.g.::
|
||||
|
||||
data_table = Table('data_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', JSONB)
|
||||
data_table = Table(
|
||||
"data_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", JSONB),
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
data_table.insert(),
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
data_table.insert(), data={"key1": "value1", "key2": "value2"}
|
||||
)
|
||||
|
||||
The :class:`_postgresql.JSONB` type includes all operations provided by
|
||||
@@ -252,43 +283,53 @@ class JSONB(JSON):
|
||||
|
||||
__visit_name__ = "JSONB"
|
||||
|
||||
class Comparator(JSON.Comparator):
|
||||
class Comparator(JSON.Comparator[_T]):
|
||||
"""Define comparison operations for :class:`_types.JSON`."""
|
||||
|
||||
def has_key(self, other):
|
||||
"""Boolean expression. Test for presence of a key. Note that the
|
||||
key may be a SQLA expression.
|
||||
type: JSONB
|
||||
|
||||
def has_key(self, other: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test for presence of a key (equivalent of
|
||||
the ``?`` operator). Note that the key may be a SQLA expression.
|
||||
"""
|
||||
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_all(self, other):
|
||||
"""Boolean expression. Test for presence of all keys in jsonb"""
|
||||
def has_all(self, other: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test for presence of all keys in jsonb
|
||||
(equivalent of the ``?&`` operator)
|
||||
"""
|
||||
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_any(self, other):
|
||||
"""Boolean expression. Test for presence of any key in jsonb"""
|
||||
def has_any(self, other: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test for presence of any key in jsonb
|
||||
(equivalent of the ``?|`` operator)
|
||||
"""
|
||||
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
def contains(self, other: Any, **kwargs: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test if keys (or array) are a superset
|
||||
of/contained the keys of the argument jsonb expression.
|
||||
of/contained the keys of the argument jsonb expression
|
||||
(equivalent of the ``@>`` operator).
|
||||
|
||||
kwargs may be ignored by this operator but are required for API
|
||||
conformance.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other):
|
||||
def contained_by(self, other: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test if keys are a proper subset of the
|
||||
keys of the argument jsonb expression.
|
||||
keys of the argument jsonb expression
|
||||
(equivalent of the ``<@`` operator).
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def delete_path(self, array):
|
||||
def delete_path(
|
||||
self, array: Union[List[str], _pg_array[str]]
|
||||
) -> ColumnElement[JSONB]:
|
||||
"""JSONB expression. Deletes field or array element specified in
|
||||
the argument array.
|
||||
the argument array (equivalent of the ``#-`` operator).
|
||||
|
||||
The input may be a list of strings that will be coerced to an
|
||||
``ARRAY`` or an instance of :meth:`_postgres.array`.
|
||||
@@ -300,9 +341,9 @@ class JSONB(JSON):
|
||||
right_side = cast(array, ARRAY(sqltypes.TEXT))
|
||||
return self.operate(DELETE_PATH, right_side, result_type=JSONB)
|
||||
|
||||
def path_exists(self, other):
|
||||
def path_exists(self, other: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test for presence of item given by the
|
||||
argument JSONPath expression.
|
||||
argument JSONPath expression (equivalent of the ``@?`` operator).
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
@@ -310,9 +351,10 @@ class JSONB(JSON):
|
||||
PATH_EXISTS, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def path_match(self, other):
|
||||
def path_match(self, other: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test if JSONPath predicate given by the
|
||||
argument JSONPath expression matches.
|
||||
argument JSONPath expression matches
|
||||
(equivalent of the ``@@`` operator).
|
||||
|
||||
Only the first item of the result is taken into account.
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# postgresql/named_types.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/named_types.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -7,7 +7,9 @@
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -25,10 +27,11 @@ from ...sql.ddl import InvokeCreateDDLBase
|
||||
from ...sql.ddl import InvokeDropDDLBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...sql._typing import _CreateDropBind
|
||||
from ...sql._typing import _TypeEngineArgument
|
||||
|
||||
|
||||
class NamedType(sqltypes.TypeEngine):
|
||||
class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
|
||||
"""Base for named types."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -36,7 +39,9 @@ class NamedType(sqltypes.TypeEngine):
|
||||
DDLDropper: Type[NamedTypeDropper]
|
||||
create_type: bool
|
||||
|
||||
def create(self, bind, checkfirst=True, **kw):
|
||||
def create(
|
||||
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
|
||||
) -> None:
|
||||
"""Emit ``CREATE`` DDL for this type.
|
||||
|
||||
:param bind: a connectable :class:`_engine.Engine`,
|
||||
@@ -50,7 +55,9 @@ class NamedType(sqltypes.TypeEngine):
|
||||
"""
|
||||
bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
|
||||
|
||||
def drop(self, bind, checkfirst=True, **kw):
|
||||
def drop(
|
||||
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
|
||||
) -> None:
|
||||
"""Emit ``DROP`` DDL for this type.
|
||||
|
||||
:param bind: a connectable :class:`_engine.Engine`,
|
||||
@@ -63,7 +70,9 @@ class NamedType(sqltypes.TypeEngine):
|
||||
"""
|
||||
bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
|
||||
|
||||
def _check_for_name_in_memos(self, checkfirst, kw):
|
||||
def _check_for_name_in_memos(
|
||||
self, checkfirst: bool, kw: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""Look in the 'ddl runner' for 'memos', then
|
||||
note our name in that collection.
|
||||
|
||||
@@ -87,7 +96,13 @@ class NamedType(sqltypes.TypeEngine):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _on_table_create(self, target, bind, checkfirst=False, **kw):
|
||||
def _on_table_create(
|
||||
self,
|
||||
target: Any,
|
||||
bind: _CreateDropBind,
|
||||
checkfirst: bool = False,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
if (
|
||||
checkfirst
|
||||
or (
|
||||
@@ -97,7 +112,13 @@ class NamedType(sqltypes.TypeEngine):
|
||||
) and not self._check_for_name_in_memos(checkfirst, kw):
|
||||
self.create(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
|
||||
def _on_table_drop(
|
||||
self,
|
||||
target: Any,
|
||||
bind: _CreateDropBind,
|
||||
checkfirst: bool = False,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
if (
|
||||
not self.metadata
|
||||
and not kw.get("_is_metadata_operation", False)
|
||||
@@ -105,11 +126,23 @@ class NamedType(sqltypes.TypeEngine):
|
||||
):
|
||||
self.drop(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
|
||||
def _on_metadata_create(
|
||||
self,
|
||||
target: Any,
|
||||
bind: _CreateDropBind,
|
||||
checkfirst: bool = False,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
if not self._check_for_name_in_memos(checkfirst, kw):
|
||||
self.create(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
|
||||
def _on_metadata_drop(
|
||||
self,
|
||||
target: Any,
|
||||
bind: _CreateDropBind,
|
||||
checkfirst: bool = False,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
if not self._check_for_name_in_memos(checkfirst, kw):
|
||||
self.drop(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
@@ -163,7 +196,6 @@ class EnumDropper(NamedTypeDropper):
|
||||
|
||||
|
||||
class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
|
||||
"""PostgreSQL ENUM type.
|
||||
|
||||
This is a subclass of :class:`_types.Enum` which includes
|
||||
@@ -186,8 +218,10 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
:meth:`_schema.Table.drop`
|
||||
methods are called::
|
||||
|
||||
table = Table('sometable', metadata,
|
||||
Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
|
||||
table = Table(
|
||||
"sometable",
|
||||
metadata,
|
||||
Column("some_enum", ENUM("a", "b", "c", name="myenum")),
|
||||
)
|
||||
|
||||
table.create(engine) # will emit CREATE ENUM and CREATE TABLE
|
||||
@@ -198,21 +232,17 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
:class:`_postgresql.ENUM` independently, and associate it with the
|
||||
:class:`_schema.MetaData` object itself::
|
||||
|
||||
my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
|
||||
my_enum = ENUM("a", "b", "c", name="myenum", metadata=metadata)
|
||||
|
||||
t1 = Table('sometable_one', metadata,
|
||||
Column('some_enum', myenum)
|
||||
)
|
||||
t1 = Table("sometable_one", metadata, Column("some_enum", myenum))
|
||||
|
||||
t2 = Table('sometable_two', metadata,
|
||||
Column('some_enum', myenum)
|
||||
)
|
||||
t2 = Table("sometable_two", metadata, Column("some_enum", myenum))
|
||||
|
||||
When this pattern is used, care must still be taken at the level
|
||||
of individual table creates. Emitting CREATE TABLE without also
|
||||
specifying ``checkfirst=True`` will still cause issues::
|
||||
|
||||
t1.create(engine) # will fail: no such type 'myenum'
|
||||
t1.create(engine) # will fail: no such type 'myenum'
|
||||
|
||||
If we specify ``checkfirst=True``, the individual table-level create
|
||||
operation will check for the ``ENUM`` and create if not exists::
|
||||
@@ -317,7 +347,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
|
||||
return cls(**kw)
|
||||
|
||||
def create(self, bind=None, checkfirst=True):
|
||||
def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
|
||||
"""Emit ``CREATE TYPE`` for this
|
||||
:class:`_postgresql.ENUM`.
|
||||
|
||||
@@ -338,7 +368,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
|
||||
super().create(bind, checkfirst=checkfirst)
|
||||
|
||||
def drop(self, bind=None, checkfirst=True):
|
||||
def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
|
||||
"""Emit ``DROP TYPE`` for this
|
||||
:class:`_postgresql.ENUM`.
|
||||
|
||||
@@ -358,7 +388,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
|
||||
super().drop(bind, checkfirst=checkfirst)
|
||||
|
||||
def get_dbapi_type(self, dbapi):
|
||||
def get_dbapi_type(self, dbapi: ModuleType) -> None:
|
||||
"""dont return dbapi.STRING for ENUM in PostgreSQL, since that's
|
||||
a different type"""
|
||||
|
||||
@@ -388,14 +418,12 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
|
||||
A domain is essentially a data type with optional constraints
|
||||
that restrict the allowed set of values. E.g.::
|
||||
|
||||
PositiveInt = DOMAIN(
|
||||
"pos_int", Integer, check="VALUE > 0", not_null=True
|
||||
)
|
||||
PositiveInt = DOMAIN("pos_int", Integer, check="VALUE > 0", not_null=True)
|
||||
|
||||
UsPostalCode = DOMAIN(
|
||||
"us_postal_code",
|
||||
Text,
|
||||
check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'"
|
||||
check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'",
|
||||
)
|
||||
|
||||
See the `PostgreSQL documentation`__ for additional details
|
||||
@@ -404,7 +432,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
DDLGenerator = DomainGenerator
|
||||
DDLDropper = DomainDropper
|
||||
@@ -417,10 +445,10 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
|
||||
data_type: _TypeEngineArgument[Any],
|
||||
*,
|
||||
collation: Optional[str] = None,
|
||||
default: Optional[Union[str, elements.TextClause]] = None,
|
||||
default: Union[elements.TextClause, str, None] = None,
|
||||
constraint_name: Optional[str] = None,
|
||||
not_null: Optional[bool] = None,
|
||||
check: Optional[str] = None,
|
||||
check: Union[elements.TextClause, str, None] = None,
|
||||
create_type: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
@@ -464,7 +492,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
|
||||
self.default = default
|
||||
self.collation = collation
|
||||
self.constraint_name = constraint_name
|
||||
self.not_null = not_null
|
||||
self.not_null = bool(not_null)
|
||||
if check is not None:
|
||||
check = coercions.expect(roles.DDLExpressionRole, check)
|
||||
self.check = check
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# postgresql/operators.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/operators.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# postgresql/pg8000.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# dialects/postgresql/pg8000.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -27,19 +27,21 @@ PostgreSQL ``client_encoding`` parameter; by default this is the value in
|
||||
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
|
||||
Typically, this can be changed to ``utf-8``, as a more useful default::
|
||||
|
||||
#client_encoding = sql_ascii # actually, defaults to database
|
||||
# encoding
|
||||
# client_encoding = sql_ascii # actually, defaults to database encoding
|
||||
client_encoding = utf8
|
||||
|
||||
The ``client_encoding`` can be overridden for a session by executing the SQL:
|
||||
|
||||
SET CLIENT_ENCODING TO 'utf8';
|
||||
.. sourcecode:: sql
|
||||
|
||||
SET CLIENT_ENCODING TO 'utf8';
|
||||
|
||||
SQLAlchemy will execute this SQL on all new connections based on the value
|
||||
passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
|
||||
"postgresql+pg8000://user:pass@host/dbname", client_encoding="utf8"
|
||||
)
|
||||
|
||||
.. _pg8000_ssl:
|
||||
|
||||
@@ -50,6 +52,7 @@ pg8000 accepts a Python ``SSLContext`` object which may be specified using the
|
||||
:paramref:`_sa.create_engine.connect_args` dictionary::
|
||||
|
||||
import ssl
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
engine = sa.create_engine(
|
||||
"postgresql+pg8000://scott:tiger@192.168.0.199/test",
|
||||
@@ -61,6 +64,7 @@ or does not match the host name (as seen from the client), it may also be
|
||||
necessary to disable hostname checking::
|
||||
|
||||
import ssl
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
@@ -253,7 +257,7 @@ class _PGOIDVECTOR(_SpaceVector, OIDVECTOR):
|
||||
pass
|
||||
|
||||
|
||||
class _Pg8000Range(ranges.AbstractRangeImpl):
|
||||
class _Pg8000Range(ranges.AbstractSingleRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
pg8000_Range = dialect.dbapi.Range
|
||||
|
||||
@@ -304,15 +308,13 @@ class _Pg8000MultiRange(ranges.AbstractMultiRangeImpl):
|
||||
def to_multirange(value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
mr = []
|
||||
for v in value:
|
||||
mr.append(
|
||||
else:
|
||||
return ranges.MultiRange(
|
||||
ranges.Range(
|
||||
v.lower, v.upper, bounds=v.bounds, empty=v.is_empty
|
||||
)
|
||||
for v in value
|
||||
)
|
||||
return mr
|
||||
|
||||
return to_multirange
|
||||
|
||||
@@ -538,6 +540,9 @@ class PGDialect_pg8000(PGDialect):
|
||||
cursor.execute("COMMIT")
|
||||
cursor.close()
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def set_readonly(self, connection, value):
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
@@ -584,8 +589,8 @@ class PGDialect_pg8000(PGDialect):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute(
|
||||
f"""SET CLIENT_ENCODING TO '{
|
||||
client_encoding.replace("'", "''")
|
||||
}'"""
|
||||
client_encoding.replace("'", "''")
|
||||
}'"""
|
||||
)
|
||||
cursor.execute("COMMIT")
|
||||
cursor.close()
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
# postgresql/pg_catalog.py
|
||||
# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/pg_catalog.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .array import ARRAY
|
||||
from .types import OID
|
||||
@@ -23,31 +29,37 @@ from ...types import String
|
||||
from ...types import Text
|
||||
from ...types import TypeDecorator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
|
||||
|
||||
# types
|
||||
class NAME(TypeDecorator):
|
||||
class NAME(TypeDecorator[str]):
|
||||
impl = String(64, collation="C")
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class PG_NODE_TREE(TypeDecorator):
|
||||
class PG_NODE_TREE(TypeDecorator[str]):
|
||||
impl = Text(collation="C")
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class INT2VECTOR(TypeDecorator):
|
||||
class INT2VECTOR(TypeDecorator[Sequence[int]]):
|
||||
impl = ARRAY(SmallInteger)
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class OIDVECTOR(TypeDecorator):
|
||||
class OIDVECTOR(TypeDecorator[Sequence[int]]):
|
||||
impl = ARRAY(OID)
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class _SpaceVector:
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> _ResultProcessorType[list[int]]:
|
||||
def process(value: Any) -> Optional[list[int]]:
|
||||
if value is None:
|
||||
return value
|
||||
return [int(p) for p in value.split(" ")]
|
||||
@@ -77,7 +89,7 @@ RELKINDS_MAT_VIEW = ("m",)
|
||||
RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
|
||||
|
||||
# tables
|
||||
pg_catalog_meta = MetaData()
|
||||
pg_catalog_meta = MetaData(schema="pg_catalog")
|
||||
|
||||
pg_namespace = Table(
|
||||
"pg_namespace",
|
||||
@@ -85,7 +97,6 @@ pg_namespace = Table(
|
||||
Column("oid", OID),
|
||||
Column("nspname", NAME),
|
||||
Column("nspowner", OID),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_class = Table(
|
||||
@@ -120,7 +131,6 @@ pg_class = Table(
|
||||
Column("relispartition", Boolean, info={"server_version": (10,)}),
|
||||
Column("relrewrite", OID, info={"server_version": (11,)}),
|
||||
Column("reloptions", ARRAY(Text)),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_type = Table(
|
||||
@@ -155,7 +165,6 @@ pg_type = Table(
|
||||
Column("typndims", Integer),
|
||||
Column("typcollation", OID, info={"server_version": (9, 1)}),
|
||||
Column("typdefault", Text),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_index = Table(
|
||||
@@ -182,7 +191,6 @@ pg_index = Table(
|
||||
Column("indoption", INT2VECTOR),
|
||||
Column("indexprs", PG_NODE_TREE),
|
||||
Column("indpred", PG_NODE_TREE),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_attribute = Table(
|
||||
@@ -209,7 +217,6 @@ pg_attribute = Table(
|
||||
Column("attislocal", Boolean),
|
||||
Column("attinhcount", Integer),
|
||||
Column("attcollation", OID, info={"server_version": (9, 1)}),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_constraint = Table(
|
||||
@@ -235,7 +242,6 @@ pg_constraint = Table(
|
||||
Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
|
||||
Column("conkey", ARRAY(SmallInteger)),
|
||||
Column("confkey", ARRAY(SmallInteger)),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_sequence = Table(
|
||||
@@ -249,7 +255,6 @@ pg_sequence = Table(
|
||||
Column("seqmin", BigInteger),
|
||||
Column("seqcache", BigInteger),
|
||||
Column("seqcycle", Boolean),
|
||||
schema="pg_catalog",
|
||||
info={"server_version": (10,)},
|
||||
)
|
||||
|
||||
@@ -260,7 +265,6 @@ pg_attrdef = Table(
|
||||
Column("adrelid", OID),
|
||||
Column("adnum", SmallInteger),
|
||||
Column("adbin", PG_NODE_TREE),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_description = Table(
|
||||
@@ -270,7 +274,6 @@ pg_description = Table(
|
||||
Column("classoid", OID),
|
||||
Column("objsubid", Integer),
|
||||
Column("description", Text(collation="C")),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_enum = Table(
|
||||
@@ -280,7 +283,6 @@ pg_enum = Table(
|
||||
Column("enumtypid", OID),
|
||||
Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
|
||||
Column("enumlabel", NAME),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_am = Table(
|
||||
@@ -290,5 +292,35 @@ pg_am = Table(
|
||||
Column("amname", NAME),
|
||||
Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
|
||||
Column("amtype", CHAR, info={"server_version": (9, 6)}),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_collation = Table(
|
||||
"pg_collation",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID, info={"server_version": (9, 3)}),
|
||||
Column("collname", NAME),
|
||||
Column("collnamespace", OID),
|
||||
Column("collowner", OID),
|
||||
Column("collprovider", CHAR, info={"server_version": (10,)}),
|
||||
Column("collisdeterministic", Boolean, info={"server_version": (12,)}),
|
||||
Column("collencoding", Integer),
|
||||
Column("collcollate", Text),
|
||||
Column("collctype", Text),
|
||||
Column("colliculocale", Text),
|
||||
Column("collicurules", Text, info={"server_version": (16,)}),
|
||||
Column("collversion", Text, info={"server_version": (10,)}),
|
||||
)
|
||||
|
||||
pg_opclass = Table(
|
||||
"pg_opclass",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID, info={"server_version": (9, 3)}),
|
||||
Column("opcmethod", NAME),
|
||||
Column("opcname", NAME),
|
||||
Column("opsnamespace", OID),
|
||||
Column("opsowner", OID),
|
||||
Column("opcfamily", OID),
|
||||
Column("opcintype", OID),
|
||||
Column("opcdefault", Boolean),
|
||||
Column("opckeytype", OID),
|
||||
)
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
# dialects/postgresql/provision.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import time
|
||||
@@ -91,7 +97,7 @@ def drop_all_schema_objects_pre_tables(cfg, eng):
|
||||
for xid in conn.exec_driver_sql(
|
||||
"select gid from pg_prepared_xacts"
|
||||
).scalars():
|
||||
conn.execute("ROLLBACK PREPARED '%s'" % xid)
|
||||
conn.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
|
||||
|
||||
|
||||
@drop_all_schema_objects_post_tables.for_db("postgresql")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# postgresql/psycopg2.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/psycopg.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -29,20 +29,29 @@ selected depending on how the engine is created:
|
||||
automatically select the sync version, e.g.::
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test")
|
||||
|
||||
sync_engine = create_engine(
|
||||
"postgresql+psycopg://scott:tiger@localhost/test"
|
||||
)
|
||||
|
||||
* calling :func:`_asyncio.create_async_engine` with
|
||||
``postgresql+psycopg://...`` will automatically select the async version,
|
||||
e.g.::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test")
|
||||
|
||||
asyncio_engine = create_async_engine(
|
||||
"postgresql+psycopg://scott:tiger@localhost/test"
|
||||
)
|
||||
|
||||
The asyncio version of the dialect may also be specified explicitly using the
|
||||
``psycopg_async`` suffix, as::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test")
|
||||
|
||||
asyncio_engine = create_async_engine(
|
||||
"postgresql+psycopg_async://scott:tiger@localhost/test"
|
||||
)
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -50,9 +59,42 @@ The asyncio version of the dialect may also be specified explicitly using the
|
||||
dialect shares most of its behavior with the ``psycopg2`` dialect.
|
||||
Further documentation is available there.
|
||||
|
||||
Using a different Cursor class
|
||||
------------------------------
|
||||
|
||||
One of the differences between ``psycopg`` and the older ``psycopg2``
|
||||
is how bound parameters are handled: ``psycopg2`` would bind them
|
||||
client side, while ``psycopg`` by default will bind them server side.
|
||||
|
||||
It's possible to configure ``psycopg`` to do client side binding by
|
||||
specifying the ``cursor_factory`` to be ``ClientCursor`` when creating
|
||||
the engine::
|
||||
|
||||
from psycopg import ClientCursor
|
||||
|
||||
client_side_engine = create_engine(
|
||||
"postgresql+psycopg://...",
|
||||
connect_args={"cursor_factory": ClientCursor},
|
||||
)
|
||||
|
||||
Similarly when using an async engine the ``AsyncClientCursor`` can be
|
||||
specified::
|
||||
|
||||
from psycopg import AsyncClientCursor
|
||||
|
||||
client_side_engine = create_async_engine(
|
||||
"postgresql+psycopg://...",
|
||||
connect_args={"cursor_factory": AsyncClientCursor},
|
||||
)
|
||||
|
||||
.. seealso::
|
||||
|
||||
`Client-side-binding cursors <https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#client-side-binding-cursors>`_
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
import logging
|
||||
import re
|
||||
from typing import cast
|
||||
@@ -79,6 +121,8 @@ from ...util.concurrency import await_only
|
||||
if TYPE_CHECKING:
|
||||
from typing import Iterable
|
||||
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
|
||||
|
||||
|
||||
@@ -91,8 +135,6 @@ class _PGREGCONFIG(REGCONFIG):
|
||||
|
||||
|
||||
class _PGJSON(JSON):
|
||||
render_bind_cast = True
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return self._make_bind_processor(None, dialect._psycopg_Json)
|
||||
|
||||
@@ -101,8 +143,6 @@ class _PGJSON(JSON):
|
||||
|
||||
|
||||
class _PGJSONB(JSONB):
|
||||
render_bind_cast = True
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return self._make_bind_processor(None, dialect._psycopg_Jsonb)
|
||||
|
||||
@@ -162,7 +202,7 @@ class _PGBoolean(sqltypes.Boolean):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PsycopgRange(ranges.AbstractRangeImpl):
|
||||
class _PsycopgRange(ranges.AbstractSingleRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
|
||||
|
||||
@@ -218,8 +258,10 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def to_range(value):
|
||||
if value is not None:
|
||||
value = [
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return ranges.MultiRange(
|
||||
ranges.Range(
|
||||
elem._lower,
|
||||
elem._upper,
|
||||
@@ -227,9 +269,7 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
empty=not elem._bounds,
|
||||
)
|
||||
for elem in value
|
||||
]
|
||||
|
||||
return value
|
||||
)
|
||||
|
||||
return to_range
|
||||
|
||||
@@ -286,7 +326,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
|
||||
sqltypes.Integer: _PGInteger,
|
||||
sqltypes.SmallInteger: _PGSmallInteger,
|
||||
sqltypes.BigInteger: _PGBigInteger,
|
||||
ranges.AbstractRange: _PsycopgRange,
|
||||
ranges.AbstractSingleRange: _PsycopgRange,
|
||||
ranges.AbstractMultiRange: _PsycopgMultiRange,
|
||||
},
|
||||
)
|
||||
@@ -366,10 +406,12 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
|
||||
|
||||
# register the adapter for connections made subsequent to
|
||||
# this one
|
||||
assert self._psycopg_adapters_map
|
||||
register_hstore(info, self._psycopg_adapters_map)
|
||||
|
||||
# register the adapter for this connection
|
||||
register_hstore(info, connection.connection)
|
||||
assert connection.connection
|
||||
register_hstore(info, connection.connection.driver_connection)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
@@ -530,7 +572,7 @@ class AsyncAdapt_psycopg_cursor:
|
||||
def __init__(self, cursor, await_) -> None:
|
||||
self._cursor = cursor
|
||||
self.await_ = await_
|
||||
self._rows = []
|
||||
self._rows = deque()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._cursor, name)
|
||||
@@ -557,24 +599,19 @@ class AsyncAdapt_psycopg_cursor:
|
||||
# eq/ne
|
||||
if res and res.status == self._psycopg_ExecStatus.TUPLES_OK:
|
||||
rows = self.await_(self._cursor.fetchall())
|
||||
if not isinstance(rows, list):
|
||||
self._rows = list(rows)
|
||||
else:
|
||||
self._rows = rows
|
||||
self._rows = deque(rows)
|
||||
return result
|
||||
|
||||
def executemany(self, query, params_seq):
|
||||
return self.await_(self._cursor.executemany(query, params_seq))
|
||||
|
||||
def __iter__(self):
|
||||
# TODO: try to avoid pop(0) on a list
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
yield self._rows.popleft()
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
# TODO: try to avoid pop(0) on a list
|
||||
return self._rows.pop(0)
|
||||
return self._rows.popleft()
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -582,13 +619,12 @@ class AsyncAdapt_psycopg_cursor:
|
||||
if size is None:
|
||||
size = self._cursor.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows = self._rows[size:]
|
||||
return retval
|
||||
rr = self._rows
|
||||
return [rr.popleft() for _ in range(min(size, len(rr)))]
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows
|
||||
self._rows = []
|
||||
retval = list(self._rows)
|
||||
self._rows.clear()
|
||||
return retval
|
||||
|
||||
|
||||
@@ -619,6 +655,7 @@ class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor):
|
||||
|
||||
|
||||
class AsyncAdapt_psycopg_connection(AdaptedConnection):
|
||||
_connection: AsyncConnection
|
||||
__slots__ = ()
|
||||
await_ = staticmethod(await_only)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# postgresql/psycopg2.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/psycopg2.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -88,7 +88,6 @@ connection URI::
|
||||
"postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require"
|
||||
)
|
||||
|
||||
|
||||
Unix Domain Connections
|
||||
------------------------
|
||||
|
||||
@@ -103,13 +102,17 @@ in ``/tmp``, or whatever socket directory was specified when PostgreSQL
|
||||
was built. This value can be overridden by passing a pathname to psycopg2,
|
||||
using ``host`` as an additional keyword argument::
|
||||
|
||||
create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql")
|
||||
create_engine(
|
||||
"postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql"
|
||||
)
|
||||
|
||||
.. warning:: The format accepted here allows for a hostname in the main URL
|
||||
in addition to the "host" query string argument. **When using this URL
|
||||
format, the initial host is silently ignored**. That is, this URL::
|
||||
|
||||
engine = create_engine("postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2")
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2"
|
||||
)
|
||||
|
||||
Above, the hostname ``myhost1`` is **silently ignored and discarded.** The
|
||||
host which is connected is the ``myhost2`` host.
|
||||
@@ -190,7 +193,7 @@ any or all elements of the connection string.
|
||||
For this form, the URL can be passed without any elements other than the
|
||||
initial scheme::
|
||||
|
||||
engine = create_engine('postgresql+psycopg2://')
|
||||
engine = create_engine("postgresql+psycopg2://")
|
||||
|
||||
In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()``
|
||||
function which in turn represents an empty DSN passed to libpq.
|
||||
@@ -242,7 +245,7 @@ Psycopg2 Fast Execution Helpers
|
||||
|
||||
Modern versions of psycopg2 include a feature known as
|
||||
`Fast Execution Helpers \
|
||||
<https://initd.org/psycopg/docs/extras.html#fast-execution-helpers>`_, which
|
||||
<https://www.psycopg.org/docs/extras.html#fast-execution-helpers>`_, which
|
||||
have been shown in benchmarking to improve psycopg2's executemany()
|
||||
performance, primarily with INSERT statements, by at least
|
||||
an order of magnitude.
|
||||
@@ -264,8 +267,8 @@ used feature. The use of this extension may be enabled using the
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@host/dbname",
|
||||
executemany_mode='values_plus_batch')
|
||||
|
||||
executemany_mode="values_plus_batch",
|
||||
)
|
||||
|
||||
Possible options for ``executemany_mode`` include:
|
||||
|
||||
@@ -311,8 +314,10 @@ is below::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@host/dbname",
|
||||
executemany_mode='values_plus_batch',
|
||||
insertmanyvalues_page_size=5000, executemany_batch_page_size=500)
|
||||
executemany_mode="values_plus_batch",
|
||||
insertmanyvalues_page_size=5000,
|
||||
executemany_batch_page_size=500,
|
||||
)
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -338,7 +343,9 @@ in the following ways:
|
||||
passed in the database URL; this parameter is consumed by the underlying
|
||||
``libpq`` PostgreSQL client library::
|
||||
|
||||
engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8")
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8"
|
||||
)
|
||||
|
||||
Alternatively, the above ``client_encoding`` value may be passed using
|
||||
:paramref:`_sa.create_engine.connect_args` for programmatic establishment with
|
||||
@@ -346,7 +353,7 @@ in the following ways:
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:pass@host/dbname",
|
||||
connect_args={'client_encoding': 'utf8'}
|
||||
connect_args={"client_encoding": "utf8"},
|
||||
)
|
||||
|
||||
* For all PostgreSQL versions, psycopg2 supports a client-side encoding
|
||||
@@ -355,8 +362,7 @@ in the following ways:
|
||||
``client_encoding`` parameter passed to :func:`_sa.create_engine`::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:pass@host/dbname",
|
||||
client_encoding="utf8"
|
||||
"postgresql+psycopg2://user:pass@host/dbname", client_encoding="utf8"
|
||||
)
|
||||
|
||||
.. tip:: The above ``client_encoding`` parameter admittedly is very similar
|
||||
@@ -375,11 +381,9 @@ in the following ways:
|
||||
# postgresql.conf file
|
||||
|
||||
# client_encoding = sql_ascii # actually, defaults to database
|
||||
# encoding
|
||||
# encoding
|
||||
client_encoding = utf8
|
||||
|
||||
|
||||
|
||||
Transactions
|
||||
------------
|
||||
|
||||
@@ -426,15 +430,15 @@ is set to the ``logging.INFO`` level, notice messages will be logged::
|
||||
|
||||
import logging
|
||||
|
||||
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
|
||||
logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO)
|
||||
|
||||
Above, it is assumed that logging is configured externally. If this is not
|
||||
the case, configuration such as ``logging.basicConfig()`` must be utilized::
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig() # log messages to stdout
|
||||
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
|
||||
logging.basicConfig() # log messages to stdout
|
||||
logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO)
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -471,8 +475,10 @@ textual HSTORE expression. If this behavior is not desired, disable the
|
||||
use of the hstore extension by setting ``use_native_hstore`` to ``False`` as
|
||||
follows::
|
||||
|
||||
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
use_native_hstore=False)
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
use_native_hstore=False,
|
||||
)
|
||||
|
||||
The ``HSTORE`` type is **still supported** when the
|
||||
``psycopg2.extensions.register_hstore()`` extension is not used. It merely
|
||||
@@ -513,7 +519,7 @@ class _PGJSONB(JSONB):
|
||||
return None
|
||||
|
||||
|
||||
class _Psycopg2Range(ranges.AbstractRangeImpl):
|
||||
class _Psycopg2Range(ranges.AbstractSingleRangeImpl):
|
||||
_psycopg2_range_cls = "none"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
@@ -844,33 +850,43 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg):
|
||||
# checks based on strings. in the case that .closed
|
||||
# didn't cut it, fall back onto these.
|
||||
str_e = str(e).partition("\n")[0]
|
||||
for msg in [
|
||||
# these error messages from libpq: interfaces/libpq/fe-misc.c
|
||||
# and interfaces/libpq/fe-secure.c.
|
||||
"terminating connection",
|
||||
"closed the connection",
|
||||
"connection not open",
|
||||
"could not receive data from server",
|
||||
"could not send data to server",
|
||||
# psycopg2 client errors, psycopg2/connection.h,
|
||||
# psycopg2/cursor.h
|
||||
"connection already closed",
|
||||
"cursor already closed",
|
||||
# not sure where this path is originally from, it may
|
||||
# be obsolete. It really says "losed", not "closed".
|
||||
"losed the connection unexpectedly",
|
||||
# these can occur in newer SSL
|
||||
"connection has been closed unexpectedly",
|
||||
"SSL error: decryption failed or bad record mac",
|
||||
"SSL SYSCALL error: Bad file descriptor",
|
||||
"SSL SYSCALL error: EOF detected",
|
||||
"SSL SYSCALL error: Operation timed out",
|
||||
"SSL SYSCALL error: Bad address",
|
||||
]:
|
||||
for msg in self._is_disconnect_messages:
|
||||
idx = str_e.find(msg)
|
||||
if idx >= 0 and '"' not in str_e[:idx]:
|
||||
return True
|
||||
return False
|
||||
|
||||
@util.memoized_property
|
||||
def _is_disconnect_messages(self):
|
||||
return (
|
||||
# these error messages from libpq: interfaces/libpq/fe-misc.c
|
||||
# and interfaces/libpq/fe-secure.c.
|
||||
"terminating connection",
|
||||
"closed the connection",
|
||||
"connection not open",
|
||||
"could not receive data from server",
|
||||
"could not send data to server",
|
||||
# psycopg2 client errors, psycopg2/connection.h,
|
||||
# psycopg2/cursor.h
|
||||
"connection already closed",
|
||||
"cursor already closed",
|
||||
# not sure where this path is originally from, it may
|
||||
# be obsolete. It really says "losed", not "closed".
|
||||
"losed the connection unexpectedly",
|
||||
# these can occur in newer SSL
|
||||
"connection has been closed unexpectedly",
|
||||
"SSL error: decryption failed or bad record mac",
|
||||
"SSL SYSCALL error: Bad file descriptor",
|
||||
"SSL SYSCALL error: EOF detected",
|
||||
"SSL SYSCALL error: Operation timed out",
|
||||
"SSL SYSCALL error: Bad address",
|
||||
# This can occur in OpenSSL 1 when an unexpected EOF occurs.
|
||||
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html#BUGS
|
||||
# It may also occur in newer OpenSSL for a non-recoverable I/O
|
||||
# error as a result of a system call that does not set 'errno'
|
||||
# in libc.
|
||||
"SSL SYSCALL error: Success",
|
||||
)
|
||||
|
||||
|
||||
dialect = PGDialect_psycopg2
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# testing/engines.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/psycopg2cffi.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/ranges.py
|
||||
# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -14,8 +15,10 @@ from decimal import Decimal
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Generic
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -151,8 +154,8 @@ class Range(Generic[_T]):
|
||||
return not self.empty and self.upper is None
|
||||
|
||||
@property
|
||||
def __sa_type_engine__(self) -> AbstractRange[Range[_T]]:
|
||||
return AbstractRange()
|
||||
def __sa_type_engine__(self) -> AbstractSingleRange[_T]:
|
||||
return AbstractSingleRange()
|
||||
|
||||
def _contains_value(self, value: _T) -> bool:
|
||||
"""Return True if this range contains the given value."""
|
||||
@@ -268,9 +271,9 @@ class Range(Generic[_T]):
|
||||
value2 += step
|
||||
value2_inc = False
|
||||
|
||||
if value1 < value2: # type: ignore
|
||||
if value1 < value2:
|
||||
return -1
|
||||
elif value1 > value2: # type: ignore
|
||||
elif value1 > value2:
|
||||
return 1
|
||||
elif only_values:
|
||||
return 0
|
||||
@@ -357,6 +360,8 @@ class Range(Generic[_T]):
|
||||
else:
|
||||
return self._contains_value(value)
|
||||
|
||||
__contains__ = contains
|
||||
|
||||
def overlaps(self, other: Range[_T]) -> bool:
|
||||
"Determine whether this range overlaps with `other`."
|
||||
|
||||
@@ -707,27 +712,46 @@ class Range(Generic[_T]):
|
||||
return f"{b0}{l},{r}{b1}"
|
||||
|
||||
|
||||
class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
|
||||
"""
|
||||
Base for PostgreSQL RANGE types.
|
||||
class MultiRange(List[Range[_T]]):
|
||||
"""Represents a multirange sequence.
|
||||
|
||||
This list subclass is an utility to allow automatic type inference of
|
||||
the proper multi-range SQL type depending on the single range values.
|
||||
This is useful when operating on literal multi-ranges::
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import MultiRange, Range
|
||||
|
||||
value = literal(MultiRange([Range(2, 4)]))
|
||||
|
||||
select(tbl).where(tbl.c.value.op("@")(MultiRange([Range(-3, 7)])))
|
||||
|
||||
.. versionadded:: 2.0.26
|
||||
|
||||
.. seealso::
|
||||
|
||||
`PostgreSQL range functions <https://www.postgresql.org/docs/current/static/functions-range.html>`_
|
||||
- :ref:`postgresql_multirange_list_use`.
|
||||
"""
|
||||
|
||||
""" # noqa: E501
|
||||
@property
|
||||
def __sa_type_engine__(self) -> AbstractMultiRange[_T]:
|
||||
return AbstractMultiRange()
|
||||
|
||||
|
||||
class AbstractRange(sqltypes.TypeEngine[_T]):
|
||||
"""Base class for single and multi Range SQL types."""
|
||||
|
||||
render_bind_cast = True
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
@overload
|
||||
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
|
||||
...
|
||||
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ...
|
||||
|
||||
@overload
|
||||
def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
|
||||
...
|
||||
def adapt(
|
||||
self, cls: Type[TypeEngineMixin], **kw: Any
|
||||
) -> TypeEngine[Any]: ...
|
||||
|
||||
def adapt(
|
||||
self,
|
||||
@@ -741,7 +765,10 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
|
||||
and also render as ``INT4RANGE`` in SQL and DDL.
|
||||
|
||||
"""
|
||||
if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__:
|
||||
if (
|
||||
issubclass(cls, (AbstractSingleRangeImpl, AbstractMultiRangeImpl))
|
||||
and cls is not self.__class__
|
||||
):
|
||||
# two ways to do this are: 1. create a new type on the fly
|
||||
# or 2. have AbstractRangeImpl(visit_name) constructor and a
|
||||
# visit_abstract_range_impl() method in the PG compiler.
|
||||
@@ -760,21 +787,6 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
|
||||
else:
|
||||
return super().adapt(cls)
|
||||
|
||||
def _resolve_for_literal(self, value: Any) -> Any:
|
||||
spec = value.lower if value.lower is not None else value.upper
|
||||
|
||||
if isinstance(spec, int):
|
||||
return INT8RANGE()
|
||||
elif isinstance(spec, (Decimal, float)):
|
||||
return NUMRANGE()
|
||||
elif isinstance(spec, datetime):
|
||||
return TSRANGE() if not spec.tzinfo else TSTZRANGE()
|
||||
elif isinstance(spec, date):
|
||||
return DATERANGE()
|
||||
else:
|
||||
# empty Range, SQL datatype can't be determined here
|
||||
return sqltypes.NULLTYPE
|
||||
|
||||
class comparator_factory(TypeEngine.Comparator[Range[Any]]):
|
||||
"""Define comparison operations for range types."""
|
||||
|
||||
@@ -856,91 +868,164 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
|
||||
return self.expr.operate(operators.mul, other)
|
||||
|
||||
|
||||
class AbstractRangeImpl(AbstractRange[Range[_T]]):
|
||||
"""Marker for AbstractRange that will apply a subclass-specific
|
||||
adaptation"""
|
||||
class AbstractSingleRange(AbstractRange[Range[_T]]):
|
||||
"""Base for PostgreSQL RANGE types.
|
||||
|
||||
These are types that return a single :class:`_postgresql.Range` object.
|
||||
|
||||
class AbstractMultiRange(AbstractRange[Range[_T]]):
|
||||
"""base for PostgreSQL MULTIRANGE types"""
|
||||
.. seealso::
|
||||
|
||||
`PostgreSQL range functions <https://www.postgresql.org/docs/current/static/functions-range.html>`_
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
def _resolve_for_literal(self, value: Range[Any]) -> Any:
|
||||
spec = value.lower if value.lower is not None else value.upper
|
||||
|
||||
class AbstractMultiRangeImpl(
|
||||
AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]]
|
||||
):
|
||||
"""Marker for AbstractRange that will apply a subclass-specific
|
||||
if isinstance(spec, int):
|
||||
# pg is unreasonably picky here: the query
|
||||
# "select 1::INTEGER <@ '[1, 4)'::INT8RANGE" raises
|
||||
# "operator does not exist: integer <@ int8range" as of pg 16
|
||||
if _is_int32(value):
|
||||
return INT4RANGE()
|
||||
else:
|
||||
return INT8RANGE()
|
||||
elif isinstance(spec, (Decimal, float)):
|
||||
return NUMRANGE()
|
||||
elif isinstance(spec, datetime):
|
||||
return TSRANGE() if not spec.tzinfo else TSTZRANGE()
|
||||
elif isinstance(spec, date):
|
||||
return DATERANGE()
|
||||
else:
|
||||
# empty Range, SQL datatype can't be determined here
|
||||
return sqltypes.NULLTYPE
|
||||
|
||||
|
||||
class AbstractSingleRangeImpl(AbstractSingleRange[_T]):
|
||||
"""Marker for AbstractSingleRange that will apply a subclass-specific
|
||||
adaptation"""
|
||||
|
||||
|
||||
class INT4RANGE(AbstractRange[Range[int]]):
|
||||
class AbstractMultiRange(AbstractRange[Sequence[Range[_T]]]):
|
||||
"""Base for PostgreSQL MULTIRANGE types.
|
||||
|
||||
these are types that return a sequence of :class:`_postgresql.Range`
|
||||
objects.
|
||||
|
||||
"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
def _resolve_for_literal(self, value: Sequence[Range[Any]]) -> Any:
|
||||
if not value:
|
||||
# empty MultiRange, SQL datatype can't be determined here
|
||||
return sqltypes.NULLTYPE
|
||||
first = value[0]
|
||||
spec = first.lower if first.lower is not None else first.upper
|
||||
|
||||
if isinstance(spec, int):
|
||||
# pg is unreasonably picky here: the query
|
||||
# "select 1::INTEGER <@ '{[1, 4),[6,19)}'::INT8MULTIRANGE" raises
|
||||
# "operator does not exist: integer <@ int8multirange" as of pg 16
|
||||
if all(_is_int32(r) for r in value):
|
||||
return INT4MULTIRANGE()
|
||||
else:
|
||||
return INT8MULTIRANGE()
|
||||
elif isinstance(spec, (Decimal, float)):
|
||||
return NUMMULTIRANGE()
|
||||
elif isinstance(spec, datetime):
|
||||
return TSMULTIRANGE() if not spec.tzinfo else TSTZMULTIRANGE()
|
||||
elif isinstance(spec, date):
|
||||
return DATEMULTIRANGE()
|
||||
else:
|
||||
# empty Range, SQL datatype can't be determined here
|
||||
return sqltypes.NULLTYPE
|
||||
|
||||
|
||||
class AbstractMultiRangeImpl(AbstractMultiRange[_T]):
|
||||
"""Marker for AbstractMultiRange that will apply a subclass-specific
|
||||
adaptation"""
|
||||
|
||||
|
||||
class INT4RANGE(AbstractSingleRange[int]):
|
||||
"""Represent the PostgreSQL INT4RANGE type."""
|
||||
|
||||
__visit_name__ = "INT4RANGE"
|
||||
|
||||
|
||||
class INT8RANGE(AbstractRange[Range[int]]):
|
||||
class INT8RANGE(AbstractSingleRange[int]):
|
||||
"""Represent the PostgreSQL INT8RANGE type."""
|
||||
|
||||
__visit_name__ = "INT8RANGE"
|
||||
|
||||
|
||||
class NUMRANGE(AbstractRange[Range[Decimal]]):
|
||||
class NUMRANGE(AbstractSingleRange[Decimal]):
|
||||
"""Represent the PostgreSQL NUMRANGE type."""
|
||||
|
||||
__visit_name__ = "NUMRANGE"
|
||||
|
||||
|
||||
class DATERANGE(AbstractRange[Range[date]]):
|
||||
class DATERANGE(AbstractSingleRange[date]):
|
||||
"""Represent the PostgreSQL DATERANGE type."""
|
||||
|
||||
__visit_name__ = "DATERANGE"
|
||||
|
||||
|
||||
class TSRANGE(AbstractRange[Range[datetime]]):
|
||||
class TSRANGE(AbstractSingleRange[datetime]):
|
||||
"""Represent the PostgreSQL TSRANGE type."""
|
||||
|
||||
__visit_name__ = "TSRANGE"
|
||||
|
||||
|
||||
class TSTZRANGE(AbstractRange[Range[datetime]]):
|
||||
class TSTZRANGE(AbstractSingleRange[datetime]):
|
||||
"""Represent the PostgreSQL TSTZRANGE type."""
|
||||
|
||||
__visit_name__ = "TSTZRANGE"
|
||||
|
||||
|
||||
class INT4MULTIRANGE(AbstractMultiRange[Range[int]]):
|
||||
class INT4MULTIRANGE(AbstractMultiRange[int]):
|
||||
"""Represent the PostgreSQL INT4MULTIRANGE type."""
|
||||
|
||||
__visit_name__ = "INT4MULTIRANGE"
|
||||
|
||||
|
||||
class INT8MULTIRANGE(AbstractMultiRange[Range[int]]):
|
||||
class INT8MULTIRANGE(AbstractMultiRange[int]):
|
||||
"""Represent the PostgreSQL INT8MULTIRANGE type."""
|
||||
|
||||
__visit_name__ = "INT8MULTIRANGE"
|
||||
|
||||
|
||||
class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]):
|
||||
class NUMMULTIRANGE(AbstractMultiRange[Decimal]):
|
||||
"""Represent the PostgreSQL NUMMULTIRANGE type."""
|
||||
|
||||
__visit_name__ = "NUMMULTIRANGE"
|
||||
|
||||
|
||||
class DATEMULTIRANGE(AbstractMultiRange[Range[date]]):
|
||||
class DATEMULTIRANGE(AbstractMultiRange[date]):
|
||||
"""Represent the PostgreSQL DATEMULTIRANGE type."""
|
||||
|
||||
__visit_name__ = "DATEMULTIRANGE"
|
||||
|
||||
|
||||
class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]):
|
||||
class TSMULTIRANGE(AbstractMultiRange[datetime]):
|
||||
"""Represent the PostgreSQL TSRANGE type."""
|
||||
|
||||
__visit_name__ = "TSMULTIRANGE"
|
||||
|
||||
|
||||
class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]):
|
||||
class TSTZMULTIRANGE(AbstractMultiRange[datetime]):
|
||||
"""Represent the PostgreSQL TSTZRANGE type."""
|
||||
|
||||
__visit_name__ = "TSTZMULTIRANGE"
|
||||
|
||||
|
||||
_max_int_32 = 2**31 - 1
|
||||
_min_int_32 = -(2**31)
|
||||
|
||||
|
||||
def _is_int32(r: Range[int]) -> bool:
|
||||
return (r.lower is None or _min_int_32 <= r.lower <= _max_int_32) and (
|
||||
r.upper is None or _min_int_32 <= r.upper <= _max_int_32
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors
|
||||
# dialects/postgresql/types.py
|
||||
# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -37,43 +38,52 @@ class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
@overload
|
||||
def __init__(
|
||||
self: PGUuid[_python_UUID], as_uuid: Literal[True] = ...
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None:
|
||||
...
|
||||
def __init__(
|
||||
self: PGUuid[str], as_uuid: Literal[False] = ...
|
||||
) -> None: ...
|
||||
|
||||
def __init__(self, as_uuid: bool = True) -> None:
|
||||
...
|
||||
def __init__(self, as_uuid: bool = True) -> None: ...
|
||||
|
||||
|
||||
class BYTEA(sqltypes.LargeBinary):
|
||||
__visit_name__ = "BYTEA"
|
||||
|
||||
|
||||
class INET(sqltypes.TypeEngine[str]):
|
||||
class _NetworkAddressTypeMixin:
|
||||
|
||||
def coerce_compared_value(
|
||||
self, op: Optional[OperatorType], value: Any
|
||||
) -> TypeEngine[Any]:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(self, TypeEngine)
|
||||
return self
|
||||
|
||||
|
||||
class INET(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "INET"
|
||||
|
||||
|
||||
PGInet = INET
|
||||
|
||||
|
||||
class CIDR(sqltypes.TypeEngine[str]):
|
||||
class CIDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "CIDR"
|
||||
|
||||
|
||||
PGCidr = CIDR
|
||||
|
||||
|
||||
class MACADDR(sqltypes.TypeEngine[str]):
|
||||
class MACADDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "MACADDR"
|
||||
|
||||
|
||||
PGMacAddr = MACADDR
|
||||
|
||||
|
||||
class MACADDR8(sqltypes.TypeEngine[str]):
|
||||
class MACADDR8(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "MACADDR8"
|
||||
|
||||
|
||||
@@ -94,12 +104,11 @@ class MONEY(sqltypes.TypeEngine[str]):
|
||||
from sqlalchemy import Dialect
|
||||
from sqlalchemy import TypeDecorator
|
||||
|
||||
|
||||
class NumericMoney(TypeDecorator):
|
||||
impl = MONEY
|
||||
|
||||
def process_result_value(
|
||||
self, value: Any, dialect: Dialect
|
||||
) -> None:
|
||||
def process_result_value(self, value: Any, dialect: Dialect) -> None:
|
||||
if value is not None:
|
||||
# adjust this for the currency and numeric
|
||||
m = re.match(r"\$([\d.]+)", value)
|
||||
@@ -114,6 +123,7 @@ class MONEY(sqltypes.TypeEngine[str]):
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import TypeDecorator
|
||||
|
||||
|
||||
class NumericMoney(TypeDecorator):
|
||||
impl = MONEY
|
||||
|
||||
@@ -122,20 +132,18 @@ class MONEY(sqltypes.TypeEngine[str]):
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
__visit_name__ = "MONEY"
|
||||
|
||||
|
||||
class OID(sqltypes.TypeEngine[int]):
|
||||
|
||||
"""Provide the PostgreSQL OID type."""
|
||||
|
||||
__visit_name__ = "OID"
|
||||
|
||||
|
||||
class REGCONFIG(sqltypes.TypeEngine[str]):
|
||||
|
||||
"""Provide the PostgreSQL REGCONFIG type.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
@@ -146,7 +154,6 @@ class REGCONFIG(sqltypes.TypeEngine[str]):
|
||||
|
||||
|
||||
class TSQUERY(sqltypes.TypeEngine[str]):
|
||||
|
||||
"""Provide the PostgreSQL TSQUERY type.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
@@ -157,7 +164,6 @@ class TSQUERY(sqltypes.TypeEngine[str]):
|
||||
|
||||
|
||||
class REGCLASS(sqltypes.TypeEngine[str]):
|
||||
|
||||
"""Provide the PostgreSQL REGCLASS type.
|
||||
|
||||
.. versionadded:: 1.2.7
|
||||
@@ -168,7 +174,6 @@ class REGCLASS(sqltypes.TypeEngine[str]):
|
||||
|
||||
|
||||
class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
|
||||
"""Provide the PostgreSQL TIMESTAMP type."""
|
||||
|
||||
__visit_name__ = "TIMESTAMP"
|
||||
@@ -189,7 +194,6 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
|
||||
|
||||
class TIME(sqltypes.TIME):
|
||||
|
||||
"""PostgreSQL TIME type."""
|
||||
|
||||
__visit_name__ = "TIME"
|
||||
@@ -210,7 +214,6 @@ class TIME(sqltypes.TIME):
|
||||
|
||||
|
||||
class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
|
||||
|
||||
"""PostgreSQL INTERVAL type."""
|
||||
|
||||
__visit_name__ = "INTERVAL"
|
||||
@@ -280,7 +283,6 @@ PGBit = BIT
|
||||
|
||||
|
||||
class TSVECTOR(sqltypes.TypeEngine[str]):
|
||||
|
||||
"""The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
|
||||
text search type TSVECTOR.
|
||||
|
||||
@@ -297,7 +299,6 @@ class TSVECTOR(sqltypes.TypeEngine[str]):
|
||||
|
||||
|
||||
class CITEXT(sqltypes.TEXT):
|
||||
|
||||
"""Provide the PostgreSQL CITEXT type.
|
||||
|
||||
.. versionadded:: 2.0.7
|
||||
|
||||
Reference in New Issue
Block a user