main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -1,5 +1,5 @@
# dialects/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -7,7 +7,6 @@
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Optional
from typing import Type
@@ -40,7 +39,7 @@ def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]:
# hardcoded. if mysql / mariadb etc were third party dialects
# they would just publish all the entrypoints, which would actually
# look much nicer.
module: Any = __import__(
module = __import__(
"sqlalchemy.dialects.mysql.mariadb"
).dialects.mysql.mariadb
return module.loader(driver) # type: ignore

View File

@@ -1,9 +1,3 @@
# dialects/_typing.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
@@ -12,19 +6,14 @@ from typing import Mapping
from typing import Optional
from typing import Union
from ..sql import roles
from ..sql.base import ColumnCollection
from ..sql.schema import Column
from ..sql._typing import _DDLColumnArgument
from ..sql.elements import DQLDMLClauseElement
from ..sql.schema import ColumnCollectionConstraint
from ..sql.schema import Index
_OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None]
_OnConflictIndexElementsT = Optional[
Iterable[Union[Column[Any], str, roles.DDLConstraintColumnRole]]
]
_OnConflictIndexWhereT = Optional[roles.WhereHavingRole]
_OnConflictSetT = Optional[
Union[Mapping[Any, Any], ColumnCollection[Any, Any]]
]
_OnConflictWhereT = Optional[roles.WhereHavingRole]
_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]]
_OnConflictIndexWhereT = Optional[DQLDMLClauseElement]
_OnConflictSetT = Optional[Mapping[Any, Any]]
_OnConflictWhereT = Union[DQLDMLClauseElement, str, None]

View File

@@ -1,5 +1,5 @@
# dialects/mssql/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mssql/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# dialects/mssql/aioodbc.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mssql/aioodbc.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -32,12 +32,13 @@ This dialect should normally be used only with the
styles are otherwise equivalent to those documented in the pyodbc section::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine(
"mssql+aioodbc://scott:tiger@mssql2017:1433/test?"
"driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
)
"""
from __future__ import annotations

View File

@@ -1,5 +1,5 @@
# dialects/mssql/base.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mssql/base.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -9,6 +9,7 @@
"""
.. dialect:: mssql
:name: Microsoft SQL Server
:full_support: 2017
:normal_support: 2012+
:best_effort: 2005+
@@ -39,12 +40,9 @@ considered to be the identity column - unless it is associated with a
from sqlalchemy import Table, MetaData, Column, Integer
m = MetaData()
t = Table(
"t",
m,
Column("id", Integer, primary_key=True),
Column("x", Integer),
)
t = Table('t', m,
Column('id', Integer, primary_key=True),
Column('x', Integer))
m.create_all(engine)
The above example will generate DDL as:
@@ -62,12 +60,9 @@ specify ``False`` for the :paramref:`_schema.Column.autoincrement` flag,
on the first integer primary key column::
m = MetaData()
t = Table(
"t",
m,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", Integer),
)
t = Table('t', m,
Column('id', Integer, primary_key=True, autoincrement=False),
Column('x', Integer))
m.create_all(engine)
To add the ``IDENTITY`` keyword to a non-primary key column, specify
@@ -77,12 +72,9 @@ To add the ``IDENTITY`` keyword to a non-primary key column, specify
is set to ``False`` on any integer primary key column::
m = MetaData()
t = Table(
"t",
m,
Column("id", Integer, primary_key=True, autoincrement=False),
Column("x", Integer, autoincrement=True),
)
t = Table('t', m,
Column('id', Integer, primary_key=True, autoincrement=False),
Column('x', Integer, autoincrement=True))
m.create_all(engine)
.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct
@@ -145,12 +137,14 @@ parameters passed to the :class:`_schema.Identity` object::
from sqlalchemy import Table, Integer, Column, Identity
test = Table(
"test",
metadata,
'test', metadata,
Column(
"id", Integer, primary_key=True, Identity(start=100, increment=10)
'id',
Integer,
primary_key=True,
Identity(start=100, increment=10)
),
Column("name", String(20)),
Column('name', String(20))
)
The CREATE TABLE for the above :class:`_schema.Table` object would be:
@@ -160,7 +154,7 @@ The CREATE TABLE for the above :class:`_schema.Table` object would be:
CREATE TABLE test (
id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
name VARCHAR(20) NULL,
)
)
.. note::
@@ -193,7 +187,6 @@ type deployed to the SQL Server database can be specified as ``Numeric`` using
Base = declarative_base()
class TestTable(Base):
__tablename__ = "test"
id = Column(
@@ -219,9 +212,8 @@ integer values in Python 3), use :class:`_types.TypeDecorator` as follows::
from sqlalchemy import TypeDecorator
class NumericAsInteger(TypeDecorator):
"normalize floating point return values into ints"
'''normalize floating point return values into ints'''
impl = Numeric(10, 0, asdecimal=False)
cache_ok = True
@@ -231,7 +223,6 @@ integer values in Python 3), use :class:`_types.TypeDecorator` as follows::
value = int(value)
return value
class TestTable(Base):
__tablename__ = "test"
id = Column(
@@ -280,11 +271,11 @@ The process for fetching this value has several variants:
fetched in order to receive the value. Given a table as::
t = Table(
"t",
't',
metadata,
Column("id", Integer, primary_key=True),
Column("x", Integer),
implicit_returning=False,
Column('id', Integer, primary_key=True),
Column('x', Integer),
implicit_returning=False
)
an INSERT will look like:
@@ -310,13 +301,12 @@ statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the
execution. Given this example::
m = MetaData()
t = Table(
"t", m, Column("id", Integer, primary_key=True), Column("x", Integer)
)
t = Table('t', m, Column('id', Integer, primary_key=True),
Column('x', Integer))
m.create_all(engine)
with engine.begin() as conn:
conn.execute(t.insert(), {"id": 1, "x": 1}, {"id": 2, "x": 2})
conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2})
The above column will be created with IDENTITY, however the INSERT statement
we emit is specifying explicit values. In the echo output we can see
@@ -352,11 +342,7 @@ The :class:`.Sequence` object creates "real" sequences, i.e.,
>>> from sqlalchemy import Sequence
>>> from sqlalchemy.schema import CreateSequence
>>> from sqlalchemy.dialects import mssql
>>> print(
... CreateSequence(Sequence("my_seq", start=1)).compile(
... dialect=mssql.dialect()
... )
... )
>>> print(CreateSequence(Sequence("my_seq", start=1)).compile(dialect=mssql.dialect()))
{printsql}CREATE SEQUENCE my_seq START WITH 1
For integer primary key generation, SQL Server's ``IDENTITY`` construct should
@@ -390,12 +376,12 @@ more than one backend without using dialect-specific types.
To build a SQL Server VARCHAR or NVARCHAR with MAX length, use None::
my_table = Table(
"my_table",
metadata,
Column("my_data", VARCHAR(None)),
Column("my_n_data", NVARCHAR(None)),
'my_table', metadata,
Column('my_data', VARCHAR(None)),
Column('my_n_data', NVARCHAR(None))
)
Collation Support
-----------------
@@ -403,13 +389,10 @@ Character collations are supported by the base string types,
specified by the string argument "collation"::
from sqlalchemy import VARCHAR
Column("login", VARCHAR(32, collation="Latin1_General_CI_AS"))
Column('login', VARCHAR(32, collation='Latin1_General_CI_AS'))
When such a column is associated with a :class:`_schema.Table`, the
CREATE TABLE statement for this column will yield:
.. sourcecode:: sql
CREATE TABLE statement for this column will yield::
login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL
@@ -429,9 +412,7 @@ versions when no OFFSET clause is present. A statement such as::
select(some_table).limit(5)
will render similarly to:
.. sourcecode:: sql
will render similarly to::
SELECT TOP 5 col1, col2.. FROM table
@@ -441,9 +422,7 @@ LIMIT and OFFSET, or just OFFSET alone, will be rendered using the
select(some_table).order_by(some_table.c.col3).limit(5).offset(10)
will render similarly to:
.. sourcecode:: sql
will render similarly to::
SELECT anon_1.col1, anon_1.col2 FROM (SELECT col1, col2,
ROW_NUMBER() OVER (ORDER BY col3) AS
@@ -496,13 +475,16 @@ each new connection.
To set isolation level using :func:`_sa.create_engine`::
engine = create_engine(
"mssql+pyodbc://scott:tiger@ms_2008", isolation_level="REPEATABLE READ"
"mssql+pyodbc://scott:tiger@ms_2008",
isolation_level="REPEATABLE READ"
)
To set using per-connection execution options::
connection = engine.connect()
connection = connection.execution_options(isolation_level="READ COMMITTED")
connection = connection.execution_options(
isolation_level="READ COMMITTED"
)
Valid values for ``isolation_level`` include:
@@ -552,6 +534,7 @@ will remain consistent with the state of the transaction::
mssql_engine = create_engine(
"mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server",
# disable default reset-on-return scheme
pool_reset_on_return=None,
)
@@ -580,17 +563,13 @@ Nullability
-----------
MSSQL has support for three levels of column nullability. The default
nullability allows nulls and is explicit in the CREATE TABLE
construct:
.. sourcecode:: sql
construct::
name VARCHAR(20) NULL
If ``nullable=None`` is specified then no specification is made. In
other words the database's configured default is used. This will
render:
.. sourcecode:: sql
render::
name VARCHAR(20)
@@ -646,9 +625,8 @@ behavior of this flag is as follows:
* The flag can be set to either ``True`` or ``False`` when the dialect
is created, typically via :func:`_sa.create_engine`::
eng = create_engine(
"mssql+pymssql://user:pass@host/db", deprecate_large_types=True
)
eng = create_engine("mssql+pymssql://user:pass@host/db",
deprecate_large_types=True)
* Complete control over whether the "old" or "new" types are rendered is
available in all SQLAlchemy versions by using the UPPERCASE type objects
@@ -670,10 +648,9 @@ at once using the :paramref:`_schema.Table.schema` argument of
:class:`_schema.Table`::
Table(
"some_table",
metadata,
"some_table", metadata,
Column("q", String(50)),
schema="mydatabase.dbo",
schema="mydatabase.dbo"
)
When performing operations such as table or component reflection, a schema
@@ -685,10 +662,9 @@ components will be quoted separately for case sensitive names and other
special characters. Given an argument as below::
Table(
"some_table",
metadata,
"some_table", metadata,
Column("q", String(50)),
schema="MyDataBase.dbo",
schema="MyDataBase.dbo"
)
The above schema would be rendered as ``[MyDataBase].dbo``, and also in
@@ -701,22 +677,21 @@ Below, the "owner" will be considered as ``MyDataBase.dbo`` and the
"database" will be None::
Table(
"some_table",
metadata,
"some_table", metadata,
Column("q", String(50)),
schema="[MyDataBase.dbo]",
schema="[MyDataBase.dbo]"
)
To individually specify both database and owner name with special characters
or embedded dots, use two sets of brackets::
Table(
"some_table",
metadata,
"some_table", metadata,
Column("q", String(50)),
schema="[MyDataBase.Period].[MyOwner.Dot]",
schema="[MyDataBase.Period].[MyOwner.Dot]"
)
.. versionchanged:: 1.2 the SQL Server dialect now treats brackets as
identifier delimiters splitting the schema into separate database
and owner tokens, to allow dots within either name itself.
@@ -731,11 +706,10 @@ schema-qualified table would be auto-aliased when used in a
SELECT statement; given a table::
account_table = Table(
"account",
metadata,
Column("id", Integer, primary_key=True),
Column("info", String(100)),
schema="customer_schema",
'account', metadata,
Column('id', Integer, primary_key=True),
Column('info', String(100)),
schema="customer_schema"
)
this legacy mode of rendering would assume that "customer_schema.account"
@@ -778,55 +752,37 @@ which renders the index as ``CREATE CLUSTERED INDEX my_index ON table (x)``.
To generate a clustered primary key use::
Table(
"my_table",
metadata,
Column("x", ...),
Column("y", ...),
PrimaryKeyConstraint("x", "y", mssql_clustered=True),
)
Table('my_table', metadata,
Column('x', ...),
Column('y', ...),
PrimaryKeyConstraint("x", "y", mssql_clustered=True))
which will render the table, for example, as:
which will render the table, for example, as::
.. sourcecode:: sql
CREATE TABLE my_table (
x INTEGER NOT NULL,
y INTEGER NOT NULL,
PRIMARY KEY CLUSTERED (x, y)
)
CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL,
PRIMARY KEY CLUSTERED (x, y))
Similarly, we can generate a clustered unique constraint using::
Table(
"my_table",
metadata,
Column("x", ...),
Column("y", ...),
PrimaryKeyConstraint("x"),
UniqueConstraint("y", mssql_clustered=True),
)
Table('my_table', metadata,
Column('x', ...),
Column('y', ...),
PrimaryKeyConstraint("x"),
UniqueConstraint("y", mssql_clustered=True),
)
To explicitly request a non-clustered primary key (for example, when
a separate clustered index is desired), use::
Table(
"my_table",
metadata,
Column("x", ...),
Column("y", ...),
PrimaryKeyConstraint("x", "y", mssql_clustered=False),
)
Table('my_table', metadata,
Column('x', ...),
Column('y', ...),
PrimaryKeyConstraint("x", "y", mssql_clustered=False))
which will render the table, for example, as:
which will render the table, for example, as::
.. sourcecode:: sql
CREATE TABLE my_table (
x INTEGER NOT NULL,
y INTEGER NOT NULL,
PRIMARY KEY NONCLUSTERED (x, y)
)
CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL,
PRIMARY KEY NONCLUSTERED (x, y))
Columnstore Index Support
-------------------------
@@ -864,7 +820,7 @@ INCLUDE
The ``mssql_include`` option renders INCLUDE(colname) for the given string
names::
Index("my_index", table.c.x, mssql_include=["y"])
Index("my_index", table.c.x, mssql_include=['y'])
would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)``
@@ -919,19 +875,18 @@ To disable the usage of OUTPUT INSERTED on a per-table basis,
specify ``implicit_returning=False`` for each :class:`_schema.Table`
which has triggers::
Table(
"mytable",
metadata,
Column("id", Integer, primary_key=True),
Table('mytable', metadata,
Column('id', Integer, primary_key=True),
# ...,
implicit_returning=False,
implicit_returning=False
)
Declarative form::
class MyClass(Base):
# ...
__table_args__ = {"implicit_returning": False}
__table_args__ = {'implicit_returning':False}
.. _mssql_rowcount_versioning:
@@ -965,9 +920,7 @@ isolation mode that locks entire tables, and causes even mildly concurrent
applications to have long held locks and frequent deadlocks.
Enabling snapshot isolation for the database as a whole is recommended
for modern levels of concurrency support. This is accomplished via the
following ALTER DATABASE commands executed at the SQL prompt:
.. sourcecode:: sql
following ALTER DATABASE commands executed at the SQL prompt::
ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON
@@ -1473,6 +1426,7 @@ class ROWVERSION(TIMESTAMP):
class NTEXT(sqltypes.UnicodeText):
"""MSSQL NTEXT type, for variable-length unicode text up to 2^30
characters."""
@@ -1597,11 +1551,44 @@ class MSUUid(sqltypes.Uuid):
def process(value):
return f"""'{
value.replace("-", "").replace("'", "''")
}'"""
value.replace("-", "").replace("'", "''")
}'"""
return process
def _sentinel_value_resolver(self, dialect):
"""Return a callable that will receive the uuid object or string
as it is normally passed to the DB in the parameter set, after
bind_processor() is called. Convert this value to match
what it would be as coming back from an INSERT..OUTPUT inserted.
for the UUID type, there are four varieties of settings so here
we seek to convert to the string or UUID representation that comes
back from the driver.
"""
character_based_uuid = (
not dialect.supports_native_uuid or not self.native_uuid
)
if character_based_uuid:
if self.native_uuid:
# for pyodbc, uuid.uuid() objects are accepted for incoming
# data, as well as strings. but the driver will always return
# uppercase strings in result sets.
def process(value):
return str(value).upper()
else:
def process(value):
return str(value)
return process
else:
# for pymssql, we get uuid.uuid() objects back.
return None
class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]):
__visit_name__ = "UNIQUEIDENTIFIER"
@@ -1609,12 +1596,12 @@ class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]):
@overload
def __init__(
self: UNIQUEIDENTIFIER[_python_UUID], as_uuid: Literal[True] = ...
): ...
):
...
@overload
def __init__(
self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...
): ...
def __init__(self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...):
...
def __init__(self, as_uuid: bool = True):
"""Construct a :class:`_mssql.UNIQUEIDENTIFIER` type.
@@ -1865,6 +1852,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
_enable_identity_insert = False
_select_lastrowid = False
_lastrowid = None
_rowcount = None
dialect: MSDialect
@@ -1984,6 +1972,13 @@ class MSExecutionContext(default.DefaultExecutionContext):
def get_lastrowid(self):
return self._lastrowid
@property
def rowcount(self):
if self._rowcount is not None:
return self._rowcount
else:
return self.cursor.rowcount
def handle_dbapi_exception(self, e):
if self._enable_identity_insert:
try:
@@ -2035,10 +2030,6 @@ class MSSQLCompiler(compiler.SQLCompiler):
self.tablealiases = {}
super().__init__(*args, **kwargs)
def _format_frame_clause(self, range_, **kw):
kw["literal_execute"] = True
return super()._format_frame_clause(range_, **kw)
def _with_legacy_schema_aliasing(fn):
def decorate(self, *arg, **kw):
if self.dialect.legacy_schema_aliasing:
@@ -2492,12 +2483,10 @@ class MSSQLCompiler(compiler.SQLCompiler):
type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw),
(
"FLOAT"
if isinstance(binary.type, sqltypes.Float)
else "NUMERIC(%s, %s)"
% (binary.type.precision, binary.type.scale)
),
"FLOAT"
if isinstance(binary.type, sqltypes.Float)
else "NUMERIC(%s, %s)"
% (binary.type.precision, binary.type.scale),
)
elif binary.type._type_affinity is sqltypes.Boolean:
# the NULL handling is particularly weird with boolean, so
@@ -2533,6 +2522,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
class MSSQLStrictCompiler(MSSQLCompiler):
"""A subclass of MSSQLCompiler which disables the usage of bind
parameters where not allowed natively by MS-SQL.
@@ -3632,36 +3622,27 @@ where
@reflection.cache
@_db_plus_owner
def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
sys_columns = ischema.sys_columns
sys_types = ischema.sys_types
sys_default_constraints = ischema.sys_default_constraints
computed_cols = ischema.computed_columns
identity_cols = ischema.identity_columns
extended_properties = ischema.extended_properties
# to access sys tables, need an object_id.
# object_id() can normally match to the unquoted name even if it
# has special characters. however it also accepts quoted names,
# which means for the special case that the name itself has
# "quotes" (e.g. brackets for SQL Server) we need to "quote" (e.g.
# bracket) that name anyway. Fixed as part of #12654
is_temp_table = tablename.startswith("#")
if is_temp_table:
owner, tablename = self._get_internal_temp_table_name(
connection, tablename
)
object_id_tokens = [self.identifier_preparer.quote(tablename)]
columns = ischema.mssql_temp_table_columns
else:
columns = ischema.columns
computed_cols = ischema.computed_columns
identity_cols = ischema.identity_columns
if owner:
object_id_tokens.insert(0, self.identifier_preparer.quote(owner))
if is_temp_table:
object_id_tokens.insert(0, "tempdb")
object_id = func.object_id(".".join(object_id_tokens))
whereclause = sys_columns.c.object_id == object_id
whereclause = sql.and_(
columns.c.table_name == tablename,
columns.c.table_schema == owner,
)
full_name = columns.c.table_schema + "." + columns.c.table_name
else:
whereclause = columns.c.table_name == tablename
full_name = columns.c.table_name
if self._supports_nvarchar_max:
computed_definition = computed_cols.c.definition
@@ -3671,112 +3652,92 @@ where
computed_cols.c.definition, NVARCHAR(4000)
)
object_id = func.object_id(full_name)
s = (
sql.select(
sys_columns.c.name,
sys_types.c.name,
sys_columns.c.is_nullable,
sys_columns.c.max_length,
sys_columns.c.precision,
sys_columns.c.scale,
sys_default_constraints.c.definition,
sys_columns.c.collation_name,
columns.c.column_name,
columns.c.data_type,
columns.c.is_nullable,
columns.c.character_maximum_length,
columns.c.numeric_precision,
columns.c.numeric_scale,
columns.c.column_default,
columns.c.collation_name,
computed_definition,
computed_cols.c.is_persisted,
identity_cols.c.is_identity,
identity_cols.c.seed_value,
identity_cols.c.increment_value,
extended_properties.c.value.label("comment"),
)
.select_from(sys_columns)
.join(
sys_types,
onclause=sys_columns.c.user_type_id
== sys_types.c.user_type_id,
)
.outerjoin(
sys_default_constraints,
sql.and_(
sys_default_constraints.c.object_id
== sys_columns.c.default_object_id,
sys_default_constraints.c.parent_column_id
== sys_columns.c.column_id,
),
ischema.extended_properties.c.value.label("comment"),
)
.select_from(columns)
.outerjoin(
computed_cols,
onclause=sql.and_(
computed_cols.c.object_id == sys_columns.c.object_id,
computed_cols.c.column_id == sys_columns.c.column_id,
computed_cols.c.object_id == object_id,
computed_cols.c.name
== columns.c.column_name.collate("DATABASE_DEFAULT"),
),
)
.outerjoin(
identity_cols,
onclause=sql.and_(
identity_cols.c.object_id == sys_columns.c.object_id,
identity_cols.c.column_id == sys_columns.c.column_id,
identity_cols.c.object_id == object_id,
identity_cols.c.name
== columns.c.column_name.collate("DATABASE_DEFAULT"),
),
)
.outerjoin(
extended_properties,
ischema.extended_properties,
onclause=sql.and_(
extended_properties.c["class"] == 1,
extended_properties.c.name == "MS_Description",
sys_columns.c.object_id == extended_properties.c.major_id,
sys_columns.c.column_id == extended_properties.c.minor_id,
ischema.extended_properties.c["class"] == 1,
ischema.extended_properties.c.major_id == object_id,
ischema.extended_properties.c.minor_id
== columns.c.ordinal_position,
ischema.extended_properties.c.name == "MS_Description",
),
)
.where(whereclause)
.order_by(sys_columns.c.column_id)
.order_by(columns.c.ordinal_position)
)
if is_temp_table:
exec_opts = {"schema_translate_map": {"sys": "tempdb.sys"}}
else:
exec_opts = {"schema_translate_map": {}}
c = connection.execution_options(**exec_opts).execute(s)
c = connection.execution_options(future_result=True).execute(s)
cols = []
for row in c.mappings():
name = row[sys_columns.c.name]
type_ = row[sys_types.c.name]
nullable = row[sys_columns.c.is_nullable] == 1
maxlen = row[sys_columns.c.max_length]
numericprec = row[sys_columns.c.precision]
numericscale = row[sys_columns.c.scale]
default = row[sys_default_constraints.c.definition]
collation = row[sys_columns.c.collation_name]
name = row[columns.c.column_name]
type_ = row[columns.c.data_type]
nullable = row[columns.c.is_nullable] == "YES"
charlen = row[columns.c.character_maximum_length]
numericprec = row[columns.c.numeric_precision]
numericscale = row[columns.c.numeric_scale]
default = row[columns.c.column_default]
collation = row[columns.c.collation_name]
definition = row[computed_definition]
is_persisted = row[computed_cols.c.is_persisted]
is_identity = row[identity_cols.c.is_identity]
identity_start = row[identity_cols.c.seed_value]
identity_increment = row[identity_cols.c.increment_value]
comment = row[extended_properties.c.value]
comment = row[ischema.extended_properties.c.value]
coltype = self.ischema_names.get(type_, None)
kwargs = {}
if coltype in (
MSString,
MSChar,
MSNVarchar,
MSNChar,
MSText,
MSNText,
MSBinary,
MSVarBinary,
sqltypes.LargeBinary,
):
kwargs["length"] = maxlen if maxlen != -1 else None
elif coltype in (
MSString,
MSChar,
MSText,
):
kwargs["length"] = maxlen if maxlen != -1 else None
if collation:
kwargs["collation"] = collation
elif coltype in (
MSNVarchar,
MSNChar,
MSNText,
):
kwargs["length"] = maxlen // 2 if maxlen != -1 else None
if charlen == -1:
charlen = None
kwargs["length"] = charlen
if collation:
kwargs["collation"] = collation
@@ -4020,8 +3981,10 @@ index_info AS (
)
# group rows by constraint ID, to handle multi-column FKs
fkeys = util.defaultdict(
lambda: {
fkeys = []
def fkey_rec():
return {
"name": None,
"constrained_columns": [],
"referred_schema": None,
@@ -4029,7 +3992,8 @@ index_info AS (
"referred_columns": [],
"options": {},
}
)
fkeys = util.defaultdict(fkey_rec)
for r in connection.execute(s).all():
(

View File

@@ -1,5 +1,5 @@
# dialects/mssql/information_schema.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mssql/information_schema.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -88,41 +88,23 @@ columns = Table(
schema="INFORMATION_SCHEMA",
)
sys_columns = Table(
"columns",
mssql_temp_table_columns = Table(
"COLUMNS",
ischema,
Column("object_id", Integer),
Column("name", CoerceUnicode),
Column("column_id", Integer),
Column("default_object_id", Integer),
Column("user_type_id", Integer),
Column("is_nullable", Integer),
Column("ordinal_position", Integer),
Column("max_length", Integer),
Column("precision", Integer),
Column("scale", Integer),
Column("collation_name", String),
schema="sys",
)
sys_types = Table(
"types",
ischema,
Column("name", CoerceUnicode, key="name"),
Column("system_type_id", Integer, key="system_type_id"),
Column("user_type_id", Integer, key="user_type_id"),
Column("schema_id", Integer, key="schema_id"),
Column("max_length", Integer, key="max_length"),
Column("precision", Integer, key="precision"),
Column("scale", Integer, key="scale"),
Column("collation_name", CoerceUnicode, key="collation_name"),
Column("is_nullable", Boolean, key="is_nullable"),
Column("is_user_defined", Boolean, key="is_user_defined"),
Column("is_assembly_type", Boolean, key="is_assembly_type"),
Column("default_object_id", Integer, key="default_object_id"),
Column("rule_object_id", Integer, key="rule_object_id"),
Column("is_table_type", Boolean, key="is_table_type"),
schema="sys",
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("IS_NULLABLE", Integer, key="is_nullable"),
Column("DATA_TYPE", String, key="data_type"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
Column(
"CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
),
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
Column("COLUMN_DEFAULT", Integer, key="column_default"),
Column("COLLATION_NAME", String, key="collation_name"),
schema="tempdb.INFORMATION_SCHEMA",
)
constraints = Table(
@@ -135,17 +117,6 @@ constraints = Table(
schema="INFORMATION_SCHEMA",
)
sys_default_constraints = Table(
"default_constraints",
ischema,
Column("object_id", Integer),
Column("name", CoerceUnicode),
Column("schema_id", Integer),
Column("parent_column_id", Integer),
Column("definition", CoerceUnicode),
schema="sys",
)
column_constraints = Table(
"CONSTRAINT_COLUMN_USAGE",
ischema,
@@ -211,7 +182,6 @@ computed_columns = Table(
ischema,
Column("object_id", Integer),
Column("name", CoerceUnicode),
Column("column_id", Integer),
Column("is_computed", Boolean),
Column("is_persisted", Boolean),
Column("definition", CoerceUnicode),
@@ -237,7 +207,6 @@ class NumericSqlVariant(TypeDecorator):
int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the
correct value as string.
"""
impl = Unicode
cache_ok = True
@@ -250,7 +219,6 @@ identity_columns = Table(
ischema,
Column("object_id", Integer),
Column("name", CoerceUnicode),
Column("column_id", Integer),
Column("is_identity", Boolean),
Column("seed_value", NumericSqlVariant),
Column("increment_value", NumericSqlVariant),

View File

@@ -1,9 +1,3 @@
# dialects/mssql/json.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import types as sqltypes
@@ -54,7 +48,9 @@ class JSON(sqltypes.JSON):
dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor
should be used::
stmt = select(data_table.c.data["some key"].as_json()).where(
stmt = select(
data_table.c.data["some key"].as_json()
).where(
data_table.c.data["some key"].as_json() == {"sub": "structure"}
)
@@ -65,7 +61,9 @@ class JSON(sqltypes.JSON):
:meth:`_types.JSON.Comparator.as_integer`,
:meth:`_types.JSON.Comparator.as_float`::
stmt = select(data_table.c.data["some key"].as_string()).where(
stmt = select(
data_table.c.data["some key"].as_string()
).where(
data_table.c.data["some key"].as_string() == "some string"
)

View File

@@ -1,9 +1,3 @@
# dialects/mssql/provision.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from sqlalchemy import inspect
@@ -22,17 +16,10 @@ from ...testing.provision import generate_driver_url
from ...testing.provision import get_temp_table_name
from ...testing.provision import log
from ...testing.provision import normalize_sequence
from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
from ...testing.provision import temp_table_keyword_args
@post_configure_engine.for_db("mssql")
def post_configure_engine(url, engine, follower_ident):
if engine.driver == "pyodbc":
engine.dialect.dbapi.pooling = False
@generate_driver_url.for_db("mssql")
def generate_driver_url(url, driver, query_str):
backend = url.get_backend_name()
@@ -42,9 +29,6 @@ def generate_driver_url(url, driver, query_str):
if driver not in ("pyodbc", "aioodbc"):
new_url = new_url.set(query="")
if driver == "aioodbc":
new_url = new_url.update_query_dict({"MARS_Connection": "Yes"})
if query_str:
new_url = new_url.update_query_string(query_str)

View File

@@ -1,5 +1,5 @@
# dialects/mssql/pymssql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mssql/pymssql.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -103,7 +103,6 @@ class MSDialect_pymssql(MSDialect):
"message 20006", # Write to the server failed
"message 20017", # Unexpected EOF from the server
"message 20047", # DBPROCESS is dead or not enabled
"The server failed to resume the transaction",
):
if msg in str(e):
return True

View File

@@ -1,5 +1,5 @@
# dialects/mssql/pyodbc.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mssql/pyodbc.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -30,9 +30,7 @@ is configured on the client, a basic DSN-based connection looks like::
engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn")
Which above, will pass the following connection string to PyODBC:
.. sourcecode:: text
Which above, will pass the following connection string to PyODBC::
DSN=some_dsn;UID=scott;PWD=tiger
@@ -51,9 +49,7 @@ When using a hostname connection, the driver name must also be specified in the
query parameters of the URL. As these names usually have spaces in them, the
name must be URL encoded which means using plus signs for spaces::
engine = create_engine(
"mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server"
)
engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server")
The ``driver`` keyword is significant to the pyodbc dialect and must be
specified in lowercase.
@@ -73,7 +69,6 @@ internally::
The equivalent URL can be constructed using :class:`_sa.engine.URL`::
from sqlalchemy.engine import URL
connection_url = URL.create(
"mssql+pyodbc",
username="scott",
@@ -88,6 +83,7 @@ The equivalent URL can be constructed using :class:`_sa.engine.URL`::
},
)
Pass through exact Pyodbc string
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -98,11 +94,8 @@ using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object
can help make this easier::
from sqlalchemy.engine import URL
connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password"
connection_url = URL.create(
"mssql+pyodbc", query={"odbc_connect": connection_string}
)
connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string})
engine = create_engine(connection_url)
@@ -134,8 +127,7 @@ database using Azure credentials::
from sqlalchemy.engine.url import URL
from azure import identity
# Connection option for access tokens, as defined in msodbcsql.h
SQL_COPT_SS_ACCESS_TOKEN = 1256
SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h
TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database
connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"
@@ -144,19 +136,14 @@ database using Azure credentials::
azure_credentials = identity.DefaultAzureCredential()
@event.listens_for(engine, "do_connect")
def provide_token(dialect, conn_rec, cargs, cparams):
# remove the "Trusted_Connection" parameter that SQLAlchemy adds
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
# create token credential
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode(
"utf-16-le"
)
token_struct = struct.pack(
f"<I{len(raw_token)}s", len(raw_token), raw_token
)
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
# apply it to keyword arguments
cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
@@ -189,9 +176,7 @@ emit a ``.rollback()`` after an operation had a failure of some kind.
This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to
the SQL Server dialect via the :func:`_sa.create_engine` function as follows::
engine = create_engine(
connection_url, ignore_no_transaction_on_rollback=True
)
engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True)
Using the above parameter, the dialect will catch ``ProgrammingError``
exceptions raised during ``connection.rollback()`` and emit a warning
@@ -251,6 +236,7 @@ behavior and pass long strings as varchar(max)/nvarchar(max) using the
},
)
Pyodbc Pooling / connection close behavior
------------------------------------------
@@ -315,8 +301,7 @@ Server dialect supports this parameter by passing the
engine = create_engine(
"mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server",
fast_executemany=True,
)
fast_executemany=True)
.. versionchanged:: 2.0.9 - the ``fast_executemany`` parameter now has its
intended effect of this PyODBC feature taking effect for all INSERT
@@ -384,6 +369,7 @@ from ...engine import cursor as _cursor
class _ms_numeric_pyodbc:
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
The routines here are needed for older pyodbc versions

View File

@@ -1,5 +1,5 @@
# dialects/mysql/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -53,8 +53,7 @@ from .base import YEAR
from .dml import Insert
from .dml import insert
from .expression import match
from .mariadb import INET4
from .mariadb import INET6
from ...util import compat
# default dialect
base.dialect = dialect = mysqldb.dialect
@@ -72,8 +71,6 @@ __all__ = (
"DOUBLE",
"ENUM",
"FLOAT",
"INET4",
"INET6",
"INTEGER",
"INTEGER",
"JSON",

View File

@@ -1,9 +1,10 @@
# dialects/mysql/aiomysql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
# mysql/aiomysql.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+aiomysql
@@ -22,108 +23,207 @@ This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
engine = create_async_engine(
"mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4"
)
""" # noqa
from __future__ import annotations
from types import ModuleType
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
from ...connectors.asyncio import AsyncAdapt_dbapi_module
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
from ...engine import AdaptedConnection
from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
if TYPE_CHECKING:
from ...connectors.asyncio import AsyncIODBAPIConnection
from ...connectors.asyncio import AsyncIODBAPICursor
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import PoolProxiedConnection
from ...engine.url import URL
class AsyncAdapt_aiomysql_cursor:
# TODO: base on connectors/asyncio.py
# see #10415
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor):
__slots__ = ()
cursor = self._connection.cursor(adapt_connection.dbapi.Cursor)
def _make_new_cursor(
self, connection: AsyncIODBAPIConnection
) -> AsyncIODBAPICursor:
return connection.cursor(self._adapt_connection.dbapi.Cursor)
# see https://github.com/aio-libs/aiomysql/issues/543
self._cursor = self.await_(cursor.__aenter__())
self._rows = []
@property
def description(self):
return self._cursor.description
class AsyncAdapt_aiomysql_ss_cursor(
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_aiomysql_cursor
):
__slots__ = ()
@property
def rowcount(self):
return self._cursor.rowcount
def _make_new_cursor(
self, connection: AsyncIODBAPIConnection
) -> AsyncIODBAPICursor:
return connection.cursor(
self._adapt_connection.dbapi.aiomysql.cursors.SSCursor
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. to allow this to be async
# we would need the Result to change how it does "Safe close cursor".
# MySQL "cursors" don't actually have state to be "closed" besides
# exhausting rows, which we already have done for sync cursor.
# another option would be to emulate aiosqlite dialect and assign
# cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._execute_mutex:
result = await self._cursor.execute(operation, parameters)
class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
if not self.server_side:
# aiomysql has a "fake" async result, so we have to pull it out
# of that here since our default result is not async.
# we could just as easily grab "_rows" here and be done with it
# but this is safer.
self._rows = list(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._execute_mutex:
return await self._cursor.executemany(operation, seq_of_parameters)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
server_side = True
_cursor_cls = AsyncAdapt_aiomysql_cursor
_ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
def ping(self, reconnect: bool) -> None:
assert not reconnect
self.await_(self._connection.ping(reconnect))
cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor)
def character_set_name(self) -> Optional[str]:
return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
self._cursor = self.await_(cursor.__aenter__())
def autocommit(self, value: Any) -> None:
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_aiomysql_connection(AdaptedConnection):
# TODO: base on connectors/asyncio.py
# see #10415
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
def ping(self, reconnect):
return self.await_(self._connection.ping(reconnect))
def character_set_name(self):
return self._connection.character_set_name()
def autocommit(self, value):
self.await_(self._connection.autocommit(value))
def get_autocommit(self) -> bool:
return self._connection.get_autocommit() # type: ignore
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_aiomysql_ss_cursor(self)
else:
return AsyncAdapt_aiomysql_cursor(self)
def terminate(self) -> None:
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def close(self):
# it's not awaitable.
self._connection.close()
def close(self) -> None:
self.await_(self._connection.ensure_closed())
class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
await_ = staticmethod(await_fallback)
class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
def __init__(self, aiomysql: ModuleType, pymysql: ModuleType):
class AsyncAdapt_aiomysql_dbapi:
def __init__(self, aiomysql, pymysql):
self.aiomysql = aiomysql
self.pymysql = pymysql
self.paramstyle = "format"
self._init_dbapi_attributes()
self.Cursor, self.SSCursor = self._init_cursors_subclasses()
def _init_dbapi_attributes(self) -> None:
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
@@ -149,7 +249,7 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
):
setattr(self, name, getattr(self.pymysql, name))
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection:
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
@@ -164,23 +264,17 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
await_only(creator_fn(*arg, **kw)),
)
def _init_cursors_subclasses(
self,
) -> Tuple[AsyncIODBAPICursor, AsyncIODBAPICursor]:
def _init_cursors_subclasses(self):
# suppress unconditional warning emitted by aiomysql
class Cursor(self.aiomysql.Cursor): # type: ignore[misc, name-defined]
async def _show_warnings(
self, conn: AsyncIODBAPIConnection
) -> None:
class Cursor(self.aiomysql.Cursor):
async def _show_warnings(self, conn):
pass
class SSCursor(self.aiomysql.SSCursor): # type: ignore[misc, name-defined] # noqa: E501
async def _show_warnings(
self, conn: AsyncIODBAPIConnection
) -> None:
class SSCursor(self.aiomysql.SSCursor):
async def _show_warnings(self, conn):
pass
return Cursor, SSCursor # type: ignore[return-value]
return Cursor, SSCursor
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
@@ -191,16 +285,15 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
_sscursor = AsyncAdapt_aiomysql_ss_cursor
is_async = True
has_terminate = True
@classmethod
def import_dbapi(cls) -> AsyncAdapt_aiomysql_dbapi:
def import_dbapi(cls):
return AsyncAdapt_aiomysql_dbapi(
__import__("aiomysql"), __import__("pymysql")
)
@classmethod
def get_pool_class(cls, url: URL) -> type:
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
@@ -208,37 +301,25 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
else:
return pool.AsyncAdaptedQueuePool
def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
dbapi_connection.terminate()
def create_connect_args(
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
) -> ConnectArgsType:
def create_connect_args(self, url):
return super().create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
else:
str_e = str(e).lower()
return "not connected" in str_e
def _found_rows_client_flag(self) -> int:
from pymysql.constants import CLIENT # type: ignore
def _found_rows_client_flag(self):
from pymysql.constants import CLIENT
return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
return CLIENT.FOUND_ROWS
def get_driver_connection(
self, connection: DBAPIConnection
) -> AsyncIODBAPIConnection:
return connection._connection # type: ignore[no-any-return]
def get_driver_connection(self, connection):
return connection._connection
dialect = MySQLDialect_aiomysql

View File

@@ -1,9 +1,10 @@
# dialects/mysql/asyncmy.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
# mysql/asyncmy.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+asyncmy
@@ -20,100 +21,210 @@ This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
engine = create_async_engine(
"mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4"
)
""" # noqa
from __future__ import annotations
from types import ModuleType
from typing import Any
from typing import NoReturn
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from contextlib import asynccontextmanager
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
from ...connectors.asyncio import AsyncAdapt_dbapi_module
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
from ...engine import AdaptedConnection
from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
if TYPE_CHECKING:
from ...connectors.asyncio import AsyncIODBAPIConnection
from ...connectors.asyncio import AsyncIODBAPICursor
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import PoolProxiedConnection
from ...engine.url import URL
class AsyncAdapt_asyncmy_cursor:
# TODO: base on connectors/asyncio.py
# see #10415
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor):
__slots__ = ()
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor()
class AsyncAdapt_asyncmy_ss_cursor(
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncmy_cursor
):
__slots__ = ()
self._cursor = self.await_(cursor.__aenter__())
self._rows = []
def _make_new_cursor(
self, connection: AsyncIODBAPIConnection
) -> AsyncIODBAPICursor:
return connection.cursor(
self._adapt_connection.dbapi.asyncmy.cursors.SSCursor
@property
def description(self):
return self._cursor.description
@property
def rowcount(self):
return self._cursor.rowcount
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. to allow this to be async
# we would need the Result to change how it does "Safe close cursor".
# MySQL "cursors" don't actually have state to be "closed" besides
# exhausting rows, which we already have done for sync cursor.
# another option would be to emulate aiosqlite dialect and assign
# cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._mutex_and_adapt_errors():
if parameters is None:
result = await self._cursor.execute(operation)
else:
result = await self._cursor.execute(operation, parameters)
class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
if not self.server_side:
# asyncmy has a "fake" async result, so we have to pull it out
# of that here since our default result is not async.
# we could just as easily grab "_rows" here and be done with it
# but this is safer.
self._rows = list(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._mutex_and_adapt_errors():
return await self._cursor.executemany(operation, seq_of_parameters)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
server_side = True
_cursor_cls = AsyncAdapt_asyncmy_cursor
_ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
def _handle_exception(self, error: Exception) -> NoReturn:
if isinstance(error, AttributeError):
raise self.dbapi.InternalError(
"network operation failed due to asyncmy attribute error"
)
cursor = self._connection.cursor(
adapt_connection.dbapi.asyncmy.cursors.SSCursor
)
raise error
self._cursor = self.await_(cursor.__aenter__())
def ping(self, reconnect: bool) -> None:
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_asyncmy_connection(AdaptedConnection):
# TODO: base on connectors/asyncio.py
# see #10415
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
@asynccontextmanager
async def _mutex_and_adapt_errors(self):
async with self._execute_mutex:
try:
yield
except AttributeError:
raise self.dbapi.InternalError(
"network operation failed due to asyncmy attribute error"
)
def ping(self, reconnect):
assert not reconnect
return self.await_(self._do_ping())
async def _do_ping(self) -> None:
try:
async with self._execute_mutex:
await self._connection.ping(False)
except Exception as error:
self._handle_exception(error)
async def _do_ping(self):
async with self._mutex_and_adapt_errors():
return await self._connection.ping(False)
def character_set_name(self) -> Optional[str]:
return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
def character_set_name(self):
return self._connection.character_set_name()
def autocommit(self, value: Any) -> None:
def autocommit(self, value):
self.await_(self._connection.autocommit(value))
def get_autocommit(self) -> bool:
return self._connection.get_autocommit() # type: ignore
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_asyncmy_ss_cursor(self)
else:
return AsyncAdapt_asyncmy_cursor(self)
def terminate(self) -> None:
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def close(self):
# it's not awaitable.
self._connection.close()
def close(self) -> None:
self.await_(self._connection.ensure_closed())
class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
__slots__ = ()
@@ -121,13 +232,18 @@ class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
await_ = staticmethod(await_fallback)
class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
def __init__(self, asyncmy: ModuleType):
def _Binary(x):
"""Return x as a binary type."""
return bytes(x)
class AsyncAdapt_asyncmy_dbapi:
def __init__(self, asyncmy):
self.asyncmy = asyncmy
self.paramstyle = "format"
self._init_dbapi_attributes()
def _init_dbapi_attributes(self) -> None:
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
@@ -148,9 +264,9 @@ class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
BINARY = util.symbol("BINARY")
DATETIME = util.symbol("DATETIME")
TIMESTAMP = util.symbol("TIMESTAMP")
Binary = staticmethod(bytes)
Binary = staticmethod(_Binary)
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection:
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
@@ -174,14 +290,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
_sscursor = AsyncAdapt_asyncmy_ss_cursor
is_async = True
has_terminate = True
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
@classmethod
def get_pool_class(cls, url: URL) -> type:
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
@@ -189,20 +304,12 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
else:
return pool.AsyncAdaptedQueuePool
def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
dbapi_connection.terminate()
def create_connect_args(self, url: URL) -> ConnectArgsType: # type: ignore[override] # noqa: E501
def create_connect_args(self, url):
return super().create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
else:
@@ -211,15 +318,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
"not connected" in str_e or "network operation failed" in str_e
)
def _found_rows_client_flag(self) -> int:
from asyncmy.constants import CLIENT # type: ignore
def _found_rows_client_flag(self):
from asyncmy.constants import CLIENT
return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
return CLIENT.FOUND_ROWS
def get_driver_connection(
self, connection: DBAPIConnection
) -> AsyncIODBAPIConnection:
return connection._connection # type: ignore[no-any-return]
def get_driver_connection(self, connection):
return connection._connection
dialect = MySQLDialect_asyncmy

View File

@@ -1,9 +1,10 @@
# dialects/mysql/cymysql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/cymysql.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
@@ -20,36 +21,18 @@ r"""
dialects are mysqlclient and PyMySQL.
""" # noqa
from __future__ import annotations
from typing import Any
from typing import Iterable
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from .base import BIT
from .base import MySQLDialect
from .mysqldb import MySQLDialect_mysqldb
from .types import BIT
from ... import util
if TYPE_CHECKING:
from ...engine.base import Connection
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import Dialect
from ...engine.interfaces import PoolProxiedConnection
from ...sql.type_api import _ResultProcessorType
class _cymysqlBIT(BIT):
def result_processor(
self, dialect: Dialect, coltype: object
) -> Optional[_ResultProcessorType[Any]]:
def result_processor(self, dialect, coltype):
"""Convert MySQL's 64 bit, variable length binary string to a long."""
def process(value: Optional[Iterable[int]]) -> Optional[int]:
def process(value):
if value is not None:
v = 0
for i in iter(value):
@@ -72,22 +55,17 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return __import__("cymysql")
def _detect_charset(self, connection: Connection) -> str:
return connection.connection.charset # type: ignore[no-any-return]
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
return exception.errno # type: ignore[no-any-return]
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
if isinstance(e, self.loaded_dbapi.OperationalError):
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.OperationalError):
return self._extract_error_code(e) in (
2006,
2013,
@@ -95,7 +73,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
2045,
2055,
)
elif isinstance(e, self.loaded_dbapi.InterfaceError):
elif isinstance(e, self.dbapi.InterfaceError):
# if underlying connection is closed,
# this is the error you get
return True

View File

@@ -1,5 +1,5 @@
# dialects/mysql/dml.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/dml.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -7,7 +7,6 @@
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
@@ -142,11 +141,7 @@ class Insert(StandardInsert):
in :ref:`tutorial_parameter_ordered_updates`::
insert().on_duplicate_key_update(
[
("name", "some name"),
("value", "some value"),
]
)
[("name", "some name"), ("value", "some value")])
.. versionchanged:: 1.3 parameters can be specified as a dictionary
or list of 2-tuples; the latter form provides for parameter
@@ -186,7 +181,6 @@ class OnDuplicateClause(ClauseElement):
_parameter_ordering: Optional[List[str]] = None
update: Dict[str, Any]
stringify_dialect = "mysql"
def __init__(

View File

@@ -1,51 +1,34 @@
# dialects/mysql/enumerated.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/enumerated.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import enum
import re
from typing import Any
from typing import Dict
from typing import Optional
from typing import Set
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from .types import _StringType
from ... import exc
from ... import sql
from ... import util
from ...sql import sqltypes
from ...sql import type_api
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.elements import ColumnElement
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _ResultProcessorType
from ...sql.type_api import TypeEngine
from ...sql.type_api import TypeEngineMixin
class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
"""MySQL ENUM type."""
__visit_name__ = "ENUM"
native_enum = True
def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None:
def __init__(self, *enums, **kw):
"""Construct an ENUM.
E.g.::
Column("myenum", ENUM("foo", "bar", "baz"))
Column('myenum', ENUM("foo", "bar", "baz"))
:param enums: The range of valid values for this ENUM. Values in
enums are not quoted, they will be escaped and surrounded by single
@@ -79,27 +62,21 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
"""
kw.pop("strict", None)
self._enum_init(enums, kw) # type: ignore[arg-type]
self._enum_init(enums, kw)
_StringType.__init__(self, length=self.length, **kw)
@classmethod
def adapt_emulated_to_native(
cls,
impl: Union[TypeEngine[Any], TypeEngineMixin],
**kw: Any,
) -> ENUM:
def adapt_emulated_to_native(cls, impl, **kw):
"""Produce a MySQL native :class:`.mysql.ENUM` from plain
:class:`.Enum`.
"""
if TYPE_CHECKING:
assert isinstance(impl, ENUM)
kw.setdefault("validate_strings", impl.validate_strings)
kw.setdefault("values_callable", impl.values_callable)
kw.setdefault("omit_aliases", impl._omit_aliases)
return cls(**kw)
def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
def _object_value_for_elem(self, elem):
# mysql sends back a blank string for any value that
# was persisted that was not in the enums; that is, it does no
# validation on the incoming data, it "truncates" it to be
@@ -109,27 +86,24 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
else:
return super()._object_value_for_elem(elem)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
)
# TODO: SET is a string as far as configuration but does not act like
# a string at the python level. We either need to make a py-type agnostic
# version of String as a base to be used for this, make this some kind of
# TypeDecorator, or just vendor it out as its own type.
class SET(_StringType):
"""MySQL SET type."""
__visit_name__ = "SET"
def __init__(self, *values: str, **kw: Any):
def __init__(self, *values, **kw):
"""Construct a SET.
E.g.::
Column("myset", SET("foo", "bar", "baz"))
Column('myset', SET("foo", "bar", "baz"))
The list of potential values is required in the case that this
set will be used to generate DDL for a table, or if the
@@ -177,19 +151,17 @@ class SET(_StringType):
"setting retrieve_as_bitwise=True"
)
if self.retrieve_as_bitwise:
self._inversed_bitmap: Dict[str, int] = {
self._bitmap = {
value: 2**idx for idx, value in enumerate(self.values)
}
self._bitmap: Dict[int, str] = {
2**idx: value for idx, value in enumerate(self.values)
}
self._bitmap.update(
(2**idx, value) for idx, value in enumerate(self.values)
)
length = max([len(v) for v in values] + [0])
kw.setdefault("length", length)
super().__init__(**kw)
def column_expression(
self, colexpr: ColumnElement[Any]
) -> ColumnElement[Any]:
def column_expression(self, colexpr):
if self.retrieve_as_bitwise:
return sql.type_coerce(
sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
@@ -197,12 +169,10 @@ class SET(_StringType):
else:
return colexpr
def result_processor(
self, dialect: Dialect, coltype: Any
) -> Optional[_ResultProcessorType[Any]]:
def result_processor(self, dialect, coltype):
if self.retrieve_as_bitwise:
def process(value: Union[str, int, None]) -> Optional[Set[str]]:
def process(value):
if value is not None:
value = int(value)
@@ -213,14 +183,11 @@ class SET(_StringType):
else:
super_convert = super().result_processor(dialect, coltype)
def process(value: Union[str, Set[str], None]) -> Optional[Set[str]]: # type: ignore[misc] # noqa: E501
def process(value):
if isinstance(value, str):
# MySQLdb returns a string, let's parse
if super_convert:
value = super_convert(value)
assert value is not None
if TYPE_CHECKING:
assert isinstance(value, str)
return set(re.findall(r"[^,]+", value))
else:
# mysql-connector-python does a naive
@@ -231,48 +198,43 @@ class SET(_StringType):
return process
def bind_processor(
self, dialect: Dialect
) -> _BindProcessorType[Union[str, int]]:
def bind_processor(self, dialect):
super_convert = super().bind_processor(dialect)
if self.retrieve_as_bitwise:
def process(
value: Union[str, int, set[str], None],
) -> Union[str, int, None]:
def process(value):
if value is None:
return None
elif isinstance(value, (int, str)):
if super_convert:
return super_convert(value) # type: ignore[arg-type, no-any-return] # noqa: E501
return super_convert(value)
else:
return value
else:
int_value = 0
for v in value:
int_value |= self._inversed_bitmap[v]
int_value |= self._bitmap[v]
return int_value
else:
def process(
value: Union[str, int, set[str], None],
) -> Union[str, int, None]:
def process(value):
# accept strings and int (actually bitflag) values directly
if value is not None and not isinstance(value, (int, str)):
value = ",".join(value)
if super_convert:
return super_convert(value) # type: ignore
return super_convert(value)
else:
return value
return process
def adapt(self, cls: type, **kw: Any) -> Any:
def adapt(self, impltype, **kw):
kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
return util.constructor_copy(self, cls, *self.values, **kw)
return util.constructor_copy(self, impltype, *self.values, **kw)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self,
to_inspect=[SET, _StringType],

View File

@@ -1,13 +1,10 @@
# dialects/mysql/expression.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import Any
from ... import exc
from ... import util
@@ -20,7 +17,7 @@ from ...sql.base import Generative
from ...util.typing import Self
class match(Generative, elements.BinaryExpression[Any]):
class match(Generative, elements.BinaryExpression):
"""Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
E.g.::
@@ -40,9 +37,7 @@ class match(Generative, elements.BinaryExpression[Any]):
.order_by(desc(match_expr))
)
Would produce SQL resembling:
.. sourcecode:: sql
Would produce SQL resembling::
SELECT id, firstname, lastname
FROM user
@@ -75,9 +70,8 @@ class match(Generative, elements.BinaryExpression[Any]):
__visit_name__ = "mysql_match"
inherit_cache = True
modifiers: util.immutabledict[str, Any]
def __init__(self, *cols: elements.ColumnElement[Any], **kw: Any):
def __init__(self, *cols, **kw):
if not cols:
raise exc.ArgumentError("columns are required")

View File

@@ -1,21 +1,13 @@
# dialects/mysql/json.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/json.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from typing import TYPE_CHECKING
# mypy: ignore-errors
from ... import types as sqltypes
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _LiteralProcessorType
class JSON(sqltypes.JSON):
"""MySQL JSON type.
@@ -42,13 +34,13 @@ class JSON(sqltypes.JSON):
class _FormatTypeMixin:
def _format_value(self, value: Any) -> str:
def _format_value(self, value):
raise NotImplementedError()
def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value: Any) -> Any:
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
@@ -56,31 +48,29 @@ class _FormatTypeMixin:
return process
def literal_processor(
self, dialect: Dialect
) -> _LiteralProcessorType[Any]:
super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value: Any) -> str:
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value # type: ignore[no-any-return]
return value
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value: Any) -> str:
def _format_value(self, value):
if isinstance(value, int):
formatted_value = "$[%s]" % value
value = "$[%s]" % value
else:
formatted_value = '$."%s"' % value
return formatted_value
value = '$."%s"' % value
return value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value: Any) -> str:
def _format_value(self, value):
return "$%s" % (
"".join(
[

View File

@@ -1,73 +1,32 @@
# dialects/mysql/mariadb.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/mariadb.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from typing import Callable
# mypy: ignore-errors
from .base import MariaDBIdentifierPreparer
from .base import MySQLDialect
from .base import MySQLIdentifierPreparer
from .base import MySQLTypeCompiler
from ...sql import sqltypes
class INET4(sqltypes.TypeEngine[str]):
"""INET4 column type for MariaDB
.. versionadded:: 2.0.37
"""
__visit_name__ = "INET4"
class INET6(sqltypes.TypeEngine[str]):
"""INET6 column type for MariaDB
.. versionadded:: 2.0.37
"""
__visit_name__ = "INET6"
class MariaDBTypeCompiler(MySQLTypeCompiler):
def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
return "INET4"
def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
return "INET6"
class MariaDBDialect(MySQLDialect):
is_mariadb = True
supports_statement_cache = True
name = "mariadb"
preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer
type_compiler_cls = MariaDBTypeCompiler
preparer = MariaDBIdentifierPreparer
def loader(driver: str) -> Callable[[], type[MariaDBDialect]]:
dialect_mod = __import__(
def loader(driver):
driver_mod = __import__(
"sqlalchemy.dialects.mysql.%s" % driver
).dialects.mysql
driver_cls = getattr(driver_mod, driver).dialect
driver_mod = getattr(dialect_mod, driver)
if hasattr(driver_mod, "mariadb_dialect"):
driver_cls = driver_mod.mariadb_dialect
return driver_cls # type: ignore[no-any-return]
else:
driver_cls = driver_mod.dialect
return type(
"MariaDBDialect_%s" % driver,
(
MariaDBDialect,
driver_cls,
),
{"supports_statement_cache": True},
)
return type(
"MariaDBDialect_%s" % driver,
(
MariaDBDialect,
driver_cls,
),
{"supports_statement_cache": True},
)

View File

@@ -1,9 +1,11 @@
# dialects/mysql/mariadbconnector.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/mariadbconnector.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
@@ -27,15 +29,7 @@ be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
""" # noqa
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from uuid import UUID as _python_UUID
from .base import MySQLCompiler
@@ -45,19 +39,6 @@ from ... import sql
from ... import util
from ...sql import sqltypes
if TYPE_CHECKING:
from ...engine.base import Connection
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import Dialect
from ...engine.interfaces import IsolationLevel
from ...engine.interfaces import PoolProxiedConnection
from ...engine.url import URL
from ...sql.compiler import SQLCompiler
from ...sql.type_api import _ResultProcessorType
mariadb_cpy_minimum_version = (1, 0, 1)
@@ -66,12 +47,10 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
# work around JIRA issue
# https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed,
# this type can be removed.
def result_processor(
self, dialect: Dialect, coltype: object
) -> Optional[_ResultProcessorType[Any]]:
def result_processor(self, dialect, coltype):
if self.as_uuid:
def process(value: Any) -> Any:
def process(value):
if value is not None:
if hasattr(value, "decode"):
value = value.decode("ascii")
@@ -81,7 +60,7 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
return process
else:
def process(value: Any) -> Any:
def process(value):
if value is not None:
if hasattr(value, "decode"):
value = value.decode("ascii")
@@ -92,27 +71,30 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
_lastrowid: Optional[int] = None
_lastrowid = None
def create_server_side_cursor(self) -> DBAPICursor:
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(buffered=False)
def create_default_cursor(self) -> DBAPICursor:
def create_default_cursor(self):
return self._dbapi_connection.cursor(buffered=True)
def post_exec(self) -> None:
def post_exec(self):
super().post_exec()
self._rowcount = self.cursor.rowcount
if TYPE_CHECKING:
assert isinstance(self.compiled, SQLCompiler)
if self.isinsert and self.compiled.postfetch_lastrowid:
self._lastrowid = self.cursor.lastrowid
def get_lastrowid(self) -> int:
if TYPE_CHECKING:
assert self._lastrowid is not None
@property
def rowcount(self):
if self._rowcount is not None:
return self._rowcount
else:
return self.cursor.rowcount
def get_lastrowid(self):
return self._lastrowid
@@ -151,7 +133,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
)
@util.memoized_property
def _dbapi_version(self) -> Tuple[int, ...]:
def _dbapi_version(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
return tuple(
[
@@ -164,7 +146,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
else:
return (99, 99, 99)
def __init__(self, **kwargs: Any) -> None:
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.paramstyle = "qmark"
if self.dbapi is not None:
@@ -176,26 +158,20 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
)
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return __import__("mariadb")
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
elif isinstance(e, self.loaded_dbapi.Error):
elif isinstance(e, self.dbapi.Error):
str_e = str(e).lower()
return "not connected" in str_e or "isn't valid" in str_e
else:
return False
def create_connect_args(self, url: URL) -> ConnectArgsType:
def create_connect_args(self, url):
opts = url.translate_connect_args()
opts.update(url.query)
int_params = [
"connect_timeout",
@@ -210,7 +186,6 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
"ssl_verify_cert",
"ssl",
"pool_reset_connection",
"compress",
]
for key in int_params:
@@ -230,21 +205,19 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
except (AttributeError, ImportError):
self.supports_sane_rowcount = False
opts["client_flag"] = client_flag
return [], opts
return [[], opts]
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
def _extract_error_code(self, exception):
try:
rc: int = exception.errno
rc = exception.errno
except:
rc = -1
return rc
def _detect_charset(self, connection: Connection) -> str:
def _detect_charset(self, connection):
return "utf8mb4"
def get_isolation_level_values(
self, dbapi_conn: DBAPIConnection
) -> Sequence[IsolationLevel]:
def get_isolation_level_values(self, dbapi_connection):
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
@@ -253,26 +226,21 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
"AUTOCOMMIT",
)
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
return bool(dbapi_conn.autocommit)
def set_isolation_level(
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
) -> None:
def set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
dbapi_connection.autocommit = True
connection.autocommit = True
else:
dbapi_connection.autocommit = False
super().set_isolation_level(dbapi_connection, level)
connection.autocommit = False
super().set_isolation_level(connection, level)
def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
def do_begin_twophase(self, connection, xid):
connection.execute(
sql.text("XA BEGIN :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
def do_prepare_twophase(self, connection, xid):
connection.execute(
sql.text("XA END :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
@@ -285,12 +253,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
)
def do_rollback_twophase(
self,
connection: Connection,
xid: Any,
is_prepared: bool = True,
recover: bool = False,
) -> None:
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
connection.execute(
sql.text("XA END :xid").bindparams(
@@ -304,12 +268,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
)
def do_commit_twophase(
self,
connection: Connection,
xid: Any,
is_prepared: bool = True,
recover: bool = False,
) -> None:
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
connection.execute(

View File

@@ -1,9 +1,10 @@
# dialects/mysql/mysqlconnector.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/mysqlconnector.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
@@ -13,85 +14,26 @@ r"""
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://pypi.org/project/mysql-connector-python/
Driver Status
-------------
MySQL Connector/Python is supported as of SQLAlchemy 2.0.39 to the
degree which the driver is functional. There are still ongoing issues
with features such as server side cursors which remain disabled until
upstream issues are repaired.
.. warning:: The MySQL Connector/Python driver published by Oracle is subject
to frequent, major regressions of essential functionality such as being able
to correctly persist simple binary strings which indicate it is not well
tested. The SQLAlchemy project is not able to maintain this dialect fully as
regressions in the driver prevent it from being included in continuous
integration.
.. versionchanged:: 2.0.39
The MySQL Connector/Python dialect has been updated to support the
latest version of this DBAPI. Previously, MySQL Connector/Python
was not fully supported. However, support remains limited due to ongoing
regressions introduced in this driver.
Connecting to MariaDB with MySQL Connector/Python
--------------------------------------------------
MySQL Connector/Python may attempt to pass an incompatible collation to the
database when connecting to MariaDB. Experimentation has shown that using
``?charset=utf8mb4&collation=utfmb4_general_ci`` or similar MariaDB-compatible
charset/collation will allow connectivity.
.. note::
The MySQL Connector/Python DBAPI has had many issues since its release,
some of which may remain unresolved, and the mysqlconnector dialect is
**not tested as part of SQLAlchemy's continuous integration**.
The recommended MySQL dialects are mysqlclient and PyMySQL.
""" # noqa
from __future__ import annotations
import re
from typing import Any
from typing import cast
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from .base import MariaDBIdentifierPreparer
from .base import BIT
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .base import MySQLIdentifierPreparer
from .mariadb import MariaDBDialect
from .types import BIT
from ... import util
if TYPE_CHECKING:
from ...engine.base import Connection
from ...engine.cursor import CursorResult
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import IsolationLevel
from ...engine.interfaces import PoolProxiedConnection
from ...engine.row import Row
from ...engine.url import URL
from ...sql.elements import BinaryExpression
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
def create_server_side_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor(buffered=False)
def create_default_cursor(self) -> DBAPICursor:
return self._dbapi_connection.cursor(buffered=True)
class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod_binary(
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
) -> str:
def visit_mod_binary(self, binary, operator, **kw):
return (
self.process(binary.left, **kw)
+ " % "
@@ -99,37 +41,22 @@ class MySQLCompiler_mysqlconnector(MySQLCompiler):
)
class IdentifierPreparerCommon_mysqlconnector:
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
@property
def _double_percents(self) -> bool:
def _double_percents(self):
return False
@_double_percents.setter
def _double_percents(self, value: Any) -> None:
def _double_percents(self, value):
pass
def _escape_identifier(self, value: str) -> str:
value = value.replace(
self.escape_quote, # type:ignore[attr-defined]
self.escape_to_quote, # type:ignore[attr-defined]
)
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value
class MySQLIdentifierPreparer_mysqlconnector(
IdentifierPreparerCommon_mysqlconnector, MySQLIdentifierPreparer
):
pass
class MariaDBIdentifierPreparer_mysqlconnector(
IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer
):
pass
class _myconnpyBIT(BIT):
def result_processor(self, dialect: Any, coltype: Any) -> None:
def result_processor(self, dialect, coltype):
"""MySQL-connector already converts mysql bits, so."""
return None
@@ -144,31 +71,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
supports_native_decimal = True
supports_native_bit = True
# not until https://bugs.mysql.com/bug.php?id=117548
supports_server_side_cursors = False
default_paramstyle = "format"
statement_compiler = MySQLCompiler_mysqlconnector
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
preparer: type[MySQLIdentifierPreparer] = (
MySQLIdentifierPreparer_mysqlconnector
)
preparer = MySQLIdentifierPreparer_mysqlconnector
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
@classmethod
def import_dbapi(cls) -> DBAPIModule:
return cast("DBAPIModule", __import__("mysql.connector").connector)
def import_dbapi(cls):
from mysql import connector
def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
return connector
def do_ping(self, dbapi_connection):
dbapi_connection.ping(False)
return True
def create_connect_args(self, url: URL) -> ConnectArgsType:
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
opts.update(url.query)
@@ -176,7 +96,6 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
util.coerce_kw_type(opts, "allow_local_infile", bool)
util.coerce_kw_type(opts, "autocommit", bool)
util.coerce_kw_type(opts, "buffered", bool)
util.coerce_kw_type(opts, "client_flag", int)
util.coerce_kw_type(opts, "compress", bool)
util.coerce_kw_type(opts, "connection_timeout", int)
util.coerce_kw_type(opts, "connect_timeout", int)
@@ -191,21 +110,15 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
util.coerce_kw_type(opts, "use_pure", bool)
util.coerce_kw_type(opts, "use_unicode", bool)
# note that "buffered" is set to False by default in MySQL/connector
# python. If you set it to True, then there is no way to get a server
# side cursor because the logic is written to disallow that.
# leaving this at True until
# https://bugs.mysql.com/bug.php?id=117548 can be fixed
opts["buffered"] = True
# unfortunately, MySQL/connector python refuses to release a
# cursor without reading fully, so non-buffered isn't an option
opts.setdefault("buffered", True)
# FOUND_ROWS must be set in ClientFlag to enable
# supports_sane_rowcount.
if self.dbapi is not None:
try:
from mysql.connector import constants # type: ignore
ClientFlag = constants.ClientFlag
from mysql.connector.constants import ClientFlag
client_flags = opts.get(
"client_flags", ClientFlag.get_default()
@@ -214,35 +127,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
opts["client_flags"] = client_flags
except Exception:
pass
return [], opts
return [[], opts]
@util.memoized_property
def _mysqlconnector_version_info(self) -> Optional[Tuple[int, ...]]:
def _mysqlconnector_version_info(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
return None
def _detect_charset(self, connection: Connection) -> str:
return connection.connection.charset # type: ignore
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception: BaseException) -> int:
return exception.errno # type: ignore
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(
self,
e: Exception,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (
self.loaded_dbapi.OperationalError, #
self.loaded_dbapi.InterfaceError,
self.loaded_dbapi.ProgrammingError,
)
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
if isinstance(e, exceptions):
return (
e.errno in errnos
@@ -252,51 +154,26 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
else:
return False
def _compat_fetchall(
self,
rp: CursorResult[Tuple[Any, ...]],
charset: Optional[str] = None,
) -> Sequence[Row[Tuple[Any, ...]]]:
def _compat_fetchall(self, rp, charset=None):
return rp.fetchall()
def _compat_fetchone(
self,
rp: CursorResult[Tuple[Any, ...]],
charset: Optional[str] = None,
) -> Optional[Row[Tuple[Any, ...]]]:
def _compat_fetchone(self, rp, charset=None):
return rp.fetchone()
def get_isolation_level_values(
self, dbapi_conn: DBAPIConnection
) -> Sequence[IsolationLevel]:
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
)
_isolation_lookup = {
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
}
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
return bool(dbapi_conn.autocommit)
def set_isolation_level(
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
) -> None:
def _set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
dbapi_connection.autocommit = True
connection.autocommit = True
else:
dbapi_connection.autocommit = False
super().set_isolation_level(dbapi_connection, level)
class MariaDBDialect_mysqlconnector(
MariaDBDialect, MySQLDialect_mysqlconnector
):
supports_statement_cache = True
_allows_uuid_binds = False
preparer = MariaDBIdentifierPreparer_mysqlconnector
connection.autocommit = False
super()._set_isolation_level(connection, level)
dialect = MySQLDialect_mysqlconnector
mariadb_dialect = MariaDBDialect_mysqlconnector

View File

@@ -1,9 +1,11 @@
# dialects/mysql/mysqldb.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/mysqldb.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
"""
@@ -46,9 +48,9 @@ key "ssl", which may be specified using the
"ssl": {
"ca": "/home/gord/client-ssl/ca.pem",
"cert": "/home/gord/client-ssl/client-cert.pem",
"key": "/home/gord/client-ssl/client-key.pem",
"key": "/home/gord/client-ssl/client-key.pem"
}
},
}
)
For convenience, the following keys may also be specified inline within the URL
@@ -72,9 +74,7 @@ Using MySQLdb with Google Cloud SQL
-----------------------------------
Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
using a URL like the following:
.. sourcecode:: text
using a URL like the following::
mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
@@ -84,39 +84,25 @@ Server Side Cursors
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
"""
from __future__ import annotations
import re
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .base import MySQLIdentifierPreparer
from .base import TEXT
from ... import sql
from ... import util
from ...util.typing import Literal
if TYPE_CHECKING:
from ...engine.base import Connection
from ...engine.interfaces import _DBAPIMultiExecuteParams
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import ExecutionContext
from ...engine.interfaces import IsolationLevel
from ...engine.url import URL
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
pass
@property
def rowcount(self):
if hasattr(self, "_rowcount"):
return self._rowcount
else:
return self.cursor.rowcount
class MySQLCompiler_mysqldb(MySQLCompiler):
@@ -136,9 +122,8 @@ class MySQLDialect_mysqldb(MySQLDialect):
execution_ctx_cls = MySQLExecutionContext_mysqldb
statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer
server_version_info: Tuple[int, ...]
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._mysql_dbapi_version = (
self._parse_dbapi_version(self.dbapi.__version__)
@@ -146,7 +131,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
else (0, 0, 0)
)
def _parse_dbapi_version(self, version: str) -> Tuple[int, ...]:
def _parse_dbapi_version(self, version):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
@@ -154,7 +139,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
return (0, 0, 0)
@util.langhelpers.memoized_property
def supports_server_side_cursors(self) -> bool:
def supports_server_side_cursors(self):
try:
cursors = __import__("MySQLdb.cursors").cursors
self._sscursor = cursors.SSCursor
@@ -163,13 +148,13 @@ class MySQLDialect_mysqldb(MySQLDialect):
return False
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return __import__("MySQLdb")
def on_connect(self) -> Callable[[DBAPIConnection], None]:
def on_connect(self):
super_ = super().on_connect()
def on_connect(conn: DBAPIConnection) -> None:
def on_connect(conn):
if super_ is not None:
super_(conn)
@@ -182,24 +167,43 @@ class MySQLDialect_mysqldb(MySQLDialect):
return on_connect
def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
def do_ping(self, dbapi_connection):
dbapi_connection.ping()
return True
def do_executemany(
self,
cursor: DBAPICursor,
statement: str,
parameters: _DBAPIMultiExecuteParams,
context: Optional[ExecutionContext] = None,
) -> None:
def do_executemany(self, cursor, statement, parameters, context=None):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
cast(MySQLExecutionContext, context)._rowcount = rowcount
context._rowcount = rowcount
def create_connect_args(
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
) -> ConnectArgsType:
def _check_unicode_returns(self, connection):
# work around issue fixed in
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
# specific issue w/ the utf8mb4_bin collation and unicode returns
collation = connection.exec_driver_sql(
"show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
% (
self.identifier_preparer.quote("Charset"),
self.identifier_preparer.quote("Collation"),
)
).scalar()
has_utf8mb4_bin = self.server_version_info > (5,) and collation
if has_utf8mb4_bin:
additional_tests = [
sql.collate(
sql.cast(
sql.literal_column("'test collated returns'"),
TEXT(charset="utf8mb4"),
),
"utf8mb4_bin",
)
]
else:
additional_tests = []
return super()._check_unicode_returns(connection, additional_tests)
def create_connect_args(self, url, _translate_args=None):
if _translate_args is None:
_translate_args = dict(
database="db", username="user", password="passwd"
@@ -213,7 +217,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
util.coerce_kw_type(opts, "read_timeout", int)
util.coerce_kw_type(opts, "write_timeout", int)
util.coerce_kw_type(opts, "client_flag", int)
util.coerce_kw_type(opts, "local_infile", bool)
util.coerce_kw_type(opts, "local_infile", int)
# Note: using either of the below will cause all strings to be
# returned as Unicode, both in raw SQL operations and with column
# types like String and MSString.
@@ -248,9 +252,9 @@ class MySQLDialect_mysqldb(MySQLDialect):
if client_flag_found_rows is not None:
client_flag |= client_flag_found_rows
opts["client_flag"] = client_flag
return [], opts
return [[], opts]
def _found_rows_client_flag(self) -> Optional[int]:
def _found_rows_client_flag(self):
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
@@ -259,23 +263,20 @@ class MySQLDialect_mysqldb(MySQLDialect):
except (AttributeError, ImportError):
return None
else:
return CLIENT_FLAGS.FOUND_ROWS # type: ignore
return CLIENT_FLAGS.FOUND_ROWS
else:
return None
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
return exception.args[0] # type: ignore[no-any-return]
def _extract_error_code(self, exception):
return exception.args[0]
def _detect_charset(self, connection: Connection) -> str:
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
try:
# note: the SQL here would be
# "SHOW VARIABLES LIKE 'character_set%%'"
cset_name: Callable[[], str] = (
connection.connection.character_set_name
)
cset_name = connection.connection.character_set_name
except AttributeError:
util.warn(
"No 'character_set_name' can be detected with "
@@ -287,9 +288,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
else:
return cset_name()
def get_isolation_level_values(
self, dbapi_conn: DBAPIConnection
) -> Tuple[IsolationLevel, ...]:
def get_isolation_level_values(self, dbapi_connection):
return (
"SERIALIZABLE",
"READ UNCOMMITTED",
@@ -298,12 +297,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
"AUTOCOMMIT",
)
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
return dbapi_conn.get_autocommit() # type: ignore[no-any-return]
def set_isolation_level(
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
) -> None:
def set_isolation_level(self, dbapi_connection, level):
if level == "AUTOCOMMIT":
dbapi_connection.autocommit(True)
else:

View File

@@ -1,10 +1,5 @@
# dialects/mysql/provision.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import exc
from ...testing.provision import configure_follower
from ...testing.provision import create_db
@@ -39,13 +34,6 @@ def generate_driver_url(url, driver, query_str):
drivername="%s+%s" % (backend, driver)
).update_query_string(query_str)
if driver == "mariadbconnector":
new_url = new_url.difference_update_query(["charset"])
elif driver == "mysqlconnector":
new_url = new_url.update_query_pairs(
[("collation", "utf8mb4_general_ci")]
)
try:
new_url.get_dialect()
except exc.NoSuchModuleError:

View File

@@ -1,9 +1,11 @@
# dialects/mysql/pymysql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/pymysql.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
@@ -39,6 +41,7 @@ necessary to indicate ``ssl_check_hostname=false`` in PyMySQL::
"&ssl_check_hostname=false"
)
MySQL-Python Compatibility
--------------------------
@@ -47,26 +50,9 @@ and targets 100% compatibility. Most behavioral notes for MySQL-python apply
to the pymysql driver as well.
""" # noqa
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from .mysqldb import MySQLDialect_mysqldb
from ...util import langhelpers
from ...util.typing import Literal
if TYPE_CHECKING:
from ...engine.interfaces import ConnectArgsType
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.interfaces import PoolProxiedConnection
from ...engine.url import URL
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
@@ -76,7 +62,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
description_encoding = None
@langhelpers.memoized_property
def supports_server_side_cursors(self) -> bool:
def supports_server_side_cursors(self):
try:
cursors = __import__("pymysql.cursors").cursors
self._sscursor = cursors.SSCursor
@@ -85,11 +71,11 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
return False
@classmethod
def import_dbapi(cls) -> DBAPIModule:
def import_dbapi(cls):
return __import__("pymysql")
@langhelpers.memoized_property
def _send_false_to_ping(self) -> bool:
def _send_false_to_ping(self):
"""determine if pymysql has deprecated, changed the default of,
or removed the 'reconnect' argument of connection.ping().
@@ -100,9 +86,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
""" # noqa: E501
try:
Connection = __import__(
"pymysql.connections"
).connections.Connection
Connection = __import__("pymysql.connections").Connection
except (ImportError, AttributeError):
return True
else:
@@ -116,7 +100,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
not insp.defaults or insp.defaults[0] is not False
)
def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
def do_ping(self, dbapi_connection):
if self._send_false_to_ping:
dbapi_connection.ping(False)
else:
@@ -124,24 +108,17 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
return True
def create_connect_args(
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
) -> ConnectArgsType:
def create_connect_args(self, url, _translate_args=None):
if _translate_args is None:
_translate_args = dict(username="user")
return super().create_connect_args(
url, _translate_args=_translate_args
)
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
elif isinstance(e, self.loaded_dbapi.Error):
elif isinstance(e, self.dbapi.Error):
str_e = str(e).lower()
return (
"already closed" in str_e or "connection was killed" in str_e
@@ -149,7 +126,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
else:
return False
def _extract_error_code(self, exception: BaseException) -> Any:
def _extract_error_code(self, exception):
if isinstance(exception.args[0], Exception):
exception = exception.args[0]
return exception.args[0]

View File

@@ -1,13 +1,15 @@
# dialects/mysql/pyodbc.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/pyodbc.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+pyodbc
:name: PyODBC
:dbapi: pyodbc
@@ -28,30 +30,21 @@ r"""
Pass through exact pyodbc connection string::
import urllib
connection_string = (
"DRIVER=MySQL ODBC 8.0 ANSI Driver;"
"SERVER=localhost;"
"PORT=3307;"
"DATABASE=mydb;"
"UID=root;"
"PWD=(whatever);"
"charset=utf8mb4;"
'DRIVER=MySQL ODBC 8.0 ANSI Driver;'
'SERVER=localhost;'
'PORT=3307;'
'DATABASE=mydb;'
'UID=root;'
'PWD=(whatever);'
'charset=utf8mb4;'
)
params = urllib.parse.quote_plus(connection_string)
connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
""" # noqa
from __future__ import annotations
import datetime
import re
from typing import Any
from typing import Callable
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from .base import MySQLDialect
from .base import MySQLExecutionContext
@@ -61,31 +54,23 @@ from ... import util
from ...connectors.pyodbc import PyODBCConnector
from ...sql.sqltypes import Time
if TYPE_CHECKING:
from ...engine import Connection
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import Dialect
from ...sql.type_api import _ResultProcessorType
class _pyodbcTIME(TIME):
def result_processor(
self, dialect: Dialect, coltype: object
) -> _ResultProcessorType[datetime.time]:
def process(value: Any) -> Union[datetime.time, None]:
def result_processor(self, dialect, coltype):
def process(value):
# pyodbc returns a datetime.time object; no need to convert
return value # type: ignore[no-any-return]
return value
return process
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
def get_lastrowid(self) -> int:
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0] # type: ignore[index]
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid # type: ignore[no-any-return]
return lastrowid
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
@@ -96,7 +81,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
pyodbc_driver_name = "MySQL"
def _detect_charset(self, connection: Connection) -> str:
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
@@ -121,25 +106,21 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
)
return "latin1"
def _get_server_version_info(
self, connection: Connection
) -> Tuple[int, ...]:
def _get_server_version_info(self, connection):
return MySQLDialect._get_server_version_info(self, connection)
def _extract_error_code(self, exception: BaseException) -> Optional[int]:
def _extract_error_code(self, exception):
m = re.compile(r"\((\d+)\)").search(str(exception.args))
if m is None:
return None
c: Optional[str] = m.group(1)
c = m.group(1)
if c:
return int(c)
else:
return None
def on_connect(self) -> Callable[[DBAPIConnection], None]:
def on_connect(self):
super_ = super().on_connect()
def on_connect(conn: DBAPIConnection) -> None:
def on_connect(conn):
if super_ is not None:
super_(conn)

View File

@@ -1,65 +1,46 @@
# dialects/mysql/reflection.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/reflection.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
# mypy: ignore-errors
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from .enumerated import ENUM
from .enumerated import SET
from .types import DATETIME
from .types import TIME
from .types import TIMESTAMP
from ... import log
from ... import types as sqltypes
from ... import util
from ...util.typing import Literal
if TYPE_CHECKING:
from .base import MySQLDialect
from .base import MySQLIdentifierPreparer
from ...engine.interfaces import ReflectedColumn
class ReflectedState:
"""Stores raw information about a SHOW CREATE TABLE statement."""
charset: Optional[str]
def __init__(self) -> None:
self.columns: List[ReflectedColumn] = []
self.table_options: Dict[str, str] = {}
self.table_name: Optional[str] = None
self.keys: List[Dict[str, Any]] = []
self.fk_constraints: List[Dict[str, Any]] = []
self.ck_constraints: List[Dict[str, Any]] = []
def __init__(self):
self.columns = []
self.table_options = {}
self.table_name = None
self.keys = []
self.fk_constraints = []
self.ck_constraints = []
@log.class_logger
class MySQLTableDefinitionParser:
"""Parses the results of a SHOW CREATE TABLE statement."""
def __init__(
self, dialect: MySQLDialect, preparer: MySQLIdentifierPreparer
):
def __init__(self, dialect, preparer):
self.dialect = dialect
self.preparer = preparer
self._prep_regexes()
def parse(
self, show_create: str, charset: Optional[str]
) -> ReflectedState:
def parse(self, show_create, charset):
state = ReflectedState()
state.charset = charset
for line in re.split(r"\r?\n", show_create):
@@ -84,11 +65,11 @@ class MySQLTableDefinitionParser:
if type_ is None:
util.warn("Unknown schema content: %r" % line)
elif type_ == "key":
state.keys.append(spec) # type: ignore[arg-type]
state.keys.append(spec)
elif type_ == "fk_constraint":
state.fk_constraints.append(spec) # type: ignore[arg-type]
state.fk_constraints.append(spec)
elif type_ == "ck_constraint":
state.ck_constraints.append(spec) # type: ignore[arg-type]
state.ck_constraints.append(spec)
else:
pass
return state
@@ -96,13 +77,7 @@ class MySQLTableDefinitionParser:
def _check_view(self, sql: str) -> bool:
return bool(self._re_is_view.match(sql))
def _parse_constraints(self, line: str) -> Union[
Tuple[None, str],
Tuple[Literal["partition"], str],
Tuple[
Literal["ck_constraint", "fk_constraint", "key"], Dict[str, str]
],
]:
def _parse_constraints(self, line):
"""Parse a KEY or CONSTRAINT line.
:param line: A line of SHOW CREATE TABLE output
@@ -152,7 +127,7 @@ class MySQLTableDefinitionParser:
# No match.
return (None, line)
def _parse_table_name(self, line: str, state: ReflectedState) -> None:
def _parse_table_name(self, line, state):
"""Extract the table name.
:param line: The first line of SHOW CREATE TABLE
@@ -163,7 +138,7 @@ class MySQLTableDefinitionParser:
if m:
state.table_name = cleanup(m.group("name"))
def _parse_table_options(self, line: str, state: ReflectedState) -> None:
def _parse_table_options(self, line, state):
"""Build a dictionary of all reflected table-level options.
:param line: The final line of SHOW CREATE TABLE output.
@@ -189,9 +164,7 @@ class MySQLTableDefinitionParser:
for opt, val in options.items():
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
def _parse_partition_options(
self, line: str, state: ReflectedState
) -> None:
def _parse_partition_options(self, line, state):
options = {}
new_line = line[:]
@@ -247,7 +220,7 @@ class MySQLTableDefinitionParser:
else:
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
def _parse_column(self, line: str, state: ReflectedState) -> None:
def _parse_column(self, line, state):
"""Extract column details.
Falls back to a 'minimal support' variant if full parse fails.
@@ -310,16 +283,13 @@ class MySQLTableDefinitionParser:
type_instance = col_type(*type_args, **type_kw)
col_kw: Dict[str, Any] = {}
col_kw = {}
# NOT NULL
col_kw["nullable"] = True
# this can be "NULL" in the case of TIMESTAMP
if spec.get("notnull", False) == "NOT NULL":
col_kw["nullable"] = False
# For generated columns, the nullability is marked in a different place
if spec.get("notnull_generated", False) == "NOT NULL":
col_kw["nullable"] = False
# AUTO_INCREMENT
if spec.get("autoincr", False):
@@ -351,13 +321,9 @@ class MySQLTableDefinitionParser:
name=name, type=type_instance, default=default, comment=comment
)
col_d.update(col_kw)
state.columns.append(col_d) # type: ignore[arg-type]
state.columns.append(col_d)
def _describe_to_create(
self,
table_name: str,
columns: Sequence[Tuple[str, str, str, str, str, str]],
) -> str:
def _describe_to_create(self, table_name, columns):
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
DESCRIBE is a much simpler reflection and is sufficient for
@@ -410,9 +376,7 @@ class MySQLTableDefinitionParser:
]
)
def _parse_keyexprs(
self, identifiers: str
) -> List[Tuple[str, Optional[int], str]]:
def _parse_keyexprs(self, identifiers):
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
return [
@@ -422,12 +386,11 @@ class MySQLTableDefinitionParser:
)
]
def _prep_regexes(self) -> None:
def _prep_regexes(self):
"""Pre-compile regular expressions."""
self._pr_options: List[
Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]
] = []
self._re_columns = []
self._pr_options = []
_final = self.preparer.final_quote
@@ -485,13 +448,11 @@ class MySQLTableDefinitionParser:
r"(?: +COLLATE +(?P<collate>[\w_]+))?"
r"(?: +(?P<notnull>(?:NOT )?NULL))?"
r"(?: +DEFAULT +(?P<default>"
r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+"
r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
r"))?"
r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?"
r"(?: +(?P<notnull_generated>(?:NOT )?NULL))?"
r")?"
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?)?"
r"(?: +(?P<autoincr>AUTO_INCREMENT))?"
r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?"
r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
@@ -539,7 +500,7 @@ class MySQLTableDefinitionParser:
#
# unique constraints come back as KEYs
kw = quotes.copy()
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT"
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
self._re_fk_constraint = _re_compile(
r" "
r"CONSTRAINT +"
@@ -616,21 +577,21 @@ class MySQLTableDefinitionParser:
_optional_equals = r"(?:\s*(?:=\s*)|\s+)"
def _add_option_string(self, directive: str) -> None:
def _add_option_string(self, directive):
regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(_pr_compile(regex, cleanup_text))
def _add_option_word(self, directive: str) -> None:
def _add_option_word(self, directive):
regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(_pr_compile(regex))
def _add_partition_option_word(self, directive: str) -> None:
def _add_partition_option_word(self, directive):
if directive == "PARTITION BY" or directive == "SUBPARTITION BY":
regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\w+.*)" % (
re.escape(directive),
@@ -645,7 +606,7 @@ class MySQLTableDefinitionParser:
regex = r"(?<!\S)(?P<directive>%s)(?!\S)" % (re.escape(directive),)
self._pr_options.append(_pr_compile(regex))
def _add_option_regex(self, directive: str, regex: str) -> None:
def _add_option_regex(self, directive, regex):
regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
re.escape(directive),
self._optional_equals,
@@ -663,35 +624,21 @@ _options_of_type_string = (
)
@overload
def _pr_compile(
regex: str, cleanup: Callable[[str], str]
) -> Tuple[re.Pattern[Any], Callable[[str], str]]: ...
@overload
def _pr_compile(
regex: str, cleanup: None = None
) -> Tuple[re.Pattern[Any], None]: ...
def _pr_compile(
regex: str, cleanup: Optional[Callable[[str], str]] = None
) -> Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]:
def _pr_compile(regex, cleanup=None):
"""Prepare a 2-tuple of compiled regex and callable."""
return (_re_compile(regex), cleanup)
def _re_compile(regex: str) -> re.Pattern[Any]:
def _re_compile(regex):
"""Compile a string to regex, I and UNICODE."""
return re.compile(regex, re.I | re.UNICODE)
def _strip_values(values: Sequence[str]) -> List[str]:
def _strip_values(values):
"Strip reflected values quotes"
strip_values: List[str] = []
strip_values = []
for a in values:
if a[0:1] == '"' or a[0:1] == "'":
# strip enclosing quotes and unquote interior
@@ -703,9 +650,7 @@ def _strip_values(values: Sequence[str]) -> List[str]:
def cleanup_text(raw_text: str) -> str:
if "\\" in raw_text:
raw_text = re.sub(
_control_char_regexp,
lambda s: _control_char_map[s[0]], # type: ignore[index]
raw_text,
_control_char_regexp, lambda s: _control_char_map[s[0]], raw_text
)
return raw_text.replace("''", "'")

View File

@@ -1,5 +1,5 @@
# dialects/mysql/reserved_words.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/reserved_words.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -11,6 +11,7 @@
# https://mariadb.com/kb/en/reserved-words/
# includes: Reserved Words, Oracle Mode (separate set unioned)
# excludes: Exceptions, Function Names
# mypy: ignore-errors
RESERVED_WORDS_MARIADB = {
"accessible",
@@ -281,7 +282,6 @@ RESERVED_WORDS_MARIADB = {
}
)
# https://dev.mysql.com/doc/refman/8.3/en/keywords.html
# https://dev.mysql.com/doc/refman/8.0/en/keywords.html
# https://dev.mysql.com/doc/refman/5.7/en/keywords.html
# https://dev.mysql.com/doc/refman/5.6/en/keywords.html
@@ -403,7 +403,6 @@ RESERVED_WORDS_MYSQL = {
"int4",
"int8",
"integer",
"intersect",
"interval",
"into",
"io_after_gtids",
@@ -469,7 +468,6 @@ RESERVED_WORDS_MYSQL = {
"outfile",
"over",
"parse_gcol_expr",
"parallel",
"partition",
"percent_rank",
"persist",
@@ -478,7 +476,6 @@ RESERVED_WORDS_MYSQL = {
"primary",
"procedure",
"purge",
"qualify",
"range",
"rank",
"read",

View File

@@ -1,30 +1,18 @@
# dialects/mysql/types.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# mysql/types.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
# mypy: ignore-errors
import datetime
import decimal
from typing import Any
from typing import Iterable
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from ... import exc
from ... import util
from ...sql import sqltypes
if TYPE_CHECKING:
from .base import MySQLDialect
from ...engine.interfaces import Dialect
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _ResultProcessorType
from ...sql.type_api import TypeEngine
class _NumericType:
"""Base for MySQL numeric types.
@@ -34,27 +22,19 @@ class _NumericType:
"""
def __init__(
self, unsigned: bool = False, zerofill: bool = False, **kw: Any
):
def __init__(self, unsigned=False, zerofill=False, **kw):
self.unsigned = unsigned
self.zerofill = zerofill
super().__init__(**kw)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_NumericType, sqltypes.Numeric]
)
class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]):
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
if isinstance(self, (REAL, DOUBLE)) and (
(precision is None and scale is not None)
or (precision is not None and scale is None)
@@ -66,18 +46,18 @@ class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]):
super().__init__(precision=precision, asdecimal=asdecimal, **kw)
self.scale = scale
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
)
class _IntegerType(_NumericType, sqltypes.Integer):
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
self.display_width = display_width
super().__init__(**kw)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
)
@@ -88,13 +68,13 @@ class _StringType(sqltypes.String):
def __init__(
self,
charset: Optional[str] = None,
collation: Optional[str] = None,
ascii: bool = False, # noqa
binary: bool = False,
unicode: bool = False,
national: bool = False,
**kw: Any,
charset=None,
collation=None,
ascii=False, # noqa
binary=False,
unicode=False,
national=False,
**kw,
):
self.charset = charset
@@ -107,33 +87,25 @@ class _StringType(sqltypes.String):
self.national = national
super().__init__(**kw)
def __repr__(self) -> str:
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_StringType, sqltypes.String]
)
class _MatchType(
sqltypes.Float[Union[decimal.Decimal, float]], sqltypes.MatchType
):
def __init__(self, **kw: Any):
class _MatchType(sqltypes.Float, sqltypes.MatchType):
def __init__(self, **kw):
# TODO: float arguments?
sqltypes.Float.__init__(self) # type: ignore[arg-type]
sqltypes.Float.__init__(self)
sqltypes.MatchType.__init__(self)
class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""MySQL NUMERIC type."""
__visit_name__ = "NUMERIC"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a NUMERIC.
:param precision: Total digits in this number. If scale and precision
@@ -154,18 +126,12 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
)
class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
class DECIMAL(_NumericType, sqltypes.DECIMAL):
"""MySQL DECIMAL type."""
__visit_name__ = "DECIMAL"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DECIMAL.
:param precision: Total digits in this number. If scale and precision
@@ -186,18 +152,12 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
)
class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
class DOUBLE(_FloatType, sqltypes.DOUBLE):
"""MySQL DOUBLE type."""
__visit_name__ = "DOUBLE"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DOUBLE.
.. note::
@@ -226,18 +186,12 @@ class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
)
class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
class REAL(_FloatType, sqltypes.REAL):
"""MySQL REAL type."""
__visit_name__ = "REAL"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = True,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a REAL.
.. note::
@@ -266,18 +220,12 @@ class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
)
class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
class FLOAT(_FloatType, sqltypes.FLOAT):
"""MySQL FLOAT type."""
__visit_name__ = "FLOAT"
def __init__(
self,
precision: Optional[int] = None,
scale: Optional[int] = None,
asdecimal: bool = False,
**kw: Any,
):
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
"""Construct a FLOAT.
:param precision: Total digits in this number. If scale and precision
@@ -297,9 +245,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
def bind_processor(
self, dialect: Dialect
) -> Optional[_BindProcessorType[Union[decimal.Decimal, float]]]:
def bind_processor(self, dialect):
return None
@@ -308,7 +254,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER):
__visit_name__ = "INTEGER"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct an INTEGER.
:param display_width: Optional, maximum display width for this number.
@@ -329,7 +275,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT):
__visit_name__ = "BIGINT"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct a BIGINTEGER.
:param display_width: Optional, maximum display width for this number.
@@ -350,7 +296,7 @@ class MEDIUMINT(_IntegerType):
__visit_name__ = "MEDIUMINT"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct a MEDIUMINTEGER
:param display_width: Optional, maximum display width for this number.
@@ -371,7 +317,7 @@ class TINYINT(_IntegerType):
__visit_name__ = "TINYINT"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct a TINYINT.
:param display_width: Optional, maximum display width for this number.
@@ -386,19 +332,13 @@ class TINYINT(_IntegerType):
"""
super().__init__(display_width=display_width, **kw)
def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool:
return (
self._type_affinity is other._type_affinity
or other._type_affinity is sqltypes.Boolean
)
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
"""MySQL SMALLINTEGER type."""
__visit_name__ = "SMALLINT"
def __init__(self, display_width: Optional[int] = None, **kw: Any):
def __init__(self, display_width=None, **kw):
"""Construct a SMALLINTEGER.
:param display_width: Optional, maximum display width for this number.
@@ -414,7 +354,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT):
super().__init__(display_width=display_width, **kw)
class BIT(sqltypes.TypeEngine[Any]):
class BIT(sqltypes.TypeEngine):
"""MySQL BIT type.
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
@@ -425,7 +365,7 @@ class BIT(sqltypes.TypeEngine[Any]):
__visit_name__ = "BIT"
def __init__(self, length: Optional[int] = None):
def __init__(self, length=None):
"""Construct a BIT.
:param length: Optional, number of bits.
@@ -433,19 +373,20 @@ class BIT(sqltypes.TypeEngine[Any]):
"""
self.length = length
def result_processor(
self, dialect: MySQLDialect, coltype: object # type: ignore[override]
) -> Optional[_ResultProcessorType[Any]]:
"""Convert a MySQL's 64 bit, variable length binary string to a
long."""
def result_processor(self, dialect, coltype):
"""Convert a MySQL's 64 bit, variable length binary string to a long.
if dialect.supports_native_bit:
return None
TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
already do this, so this logic should be moved to those dialects.
def process(value: Optional[Iterable[int]]) -> Optional[int]:
"""
def process(value):
if value is not None:
v = 0
for i in value:
if not isinstance(i, int):
i = ord(i) # convert byte to int on Python 2
v = v << 8 | i
return v
return value
@@ -458,7 +399,7 @@ class TIME(sqltypes.TIME):
__visit_name__ = "TIME"
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIME type.
:param timezone: not used by the MySQL dialect.
@@ -477,12 +418,10 @@ class TIME(sqltypes.TIME):
super().__init__(timezone=timezone)
self.fsp = fsp
def result_processor(
self, dialect: Dialect, coltype: object
) -> _ResultProcessorType[datetime.time]:
def result_processor(self, dialect, coltype):
time = datetime.time
def process(value: Any) -> Optional[datetime.time]:
def process(value):
# convert from a timedelta value
if value is not None:
microseconds = value.microseconds
@@ -505,7 +444,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
__visit_name__ = "TIMESTAMP"
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIMESTAMP type.
:param timezone: not used by the MySQL dialect.
@@ -530,7 +469,7 @@ class DATETIME(sqltypes.DATETIME):
__visit_name__ = "DATETIME"
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL DATETIME type.
:param timezone: not used by the MySQL dialect.
@@ -550,26 +489,26 @@ class DATETIME(sqltypes.DATETIME):
self.fsp = fsp
class YEAR(sqltypes.TypeEngine[Any]):
class YEAR(sqltypes.TypeEngine):
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
__visit_name__ = "YEAR"
def __init__(self, display_width: Optional[int] = None):
def __init__(self, display_width=None):
self.display_width = display_width
class TEXT(_StringType, sqltypes.TEXT):
"""MySQL TEXT type, for character storage encoded up to 2^16 bytes."""
"""MySQL TEXT type, for text up to 2^16 characters."""
__visit_name__ = "TEXT"
def __init__(self, length: Optional[int] = None, **kw: Any):
def __init__(self, length=None, **kw):
"""Construct a TEXT.
:param length: Optional, if provided the server may optimize storage
by substituting the smallest TEXT type sufficient to store
``length`` bytes of characters.
``length`` characters.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
@@ -596,11 +535,11 @@ class TEXT(_StringType, sqltypes.TEXT):
class TINYTEXT(_StringType):
"""MySQL TINYTEXT type, for character storage encoded up to 2^8 bytes."""
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
__visit_name__ = "TINYTEXT"
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs):
"""Construct a TINYTEXT.
:param charset: Optional, a column-level character set for this string
@@ -628,12 +567,11 @@ class TINYTEXT(_StringType):
class MEDIUMTEXT(_StringType):
"""MySQL MEDIUMTEXT type, for character storage encoded up
to 2^24 bytes."""
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
__visit_name__ = "MEDIUMTEXT"
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs):
"""Construct a MEDIUMTEXT.
:param charset: Optional, a column-level character set for this string
@@ -661,11 +599,11 @@ class MEDIUMTEXT(_StringType):
class LONGTEXT(_StringType):
"""MySQL LONGTEXT type, for character storage encoded up to 2^32 bytes."""
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
__visit_name__ = "LONGTEXT"
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs):
"""Construct a LONGTEXT.
:param charset: Optional, a column-level character set for this string
@@ -697,7 +635,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
__visit_name__ = "VARCHAR"
def __init__(self, length: Optional[int] = None, **kwargs: Any) -> None:
def __init__(self, length=None, **kwargs):
"""Construct a VARCHAR.
:param charset: Optional, a column-level character set for this string
@@ -729,7 +667,7 @@ class CHAR(_StringType, sqltypes.CHAR):
__visit_name__ = "CHAR"
def __init__(self, length: Optional[int] = None, **kwargs: Any):
def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
@@ -745,7 +683,7 @@ class CHAR(_StringType, sqltypes.CHAR):
super().__init__(length=length, **kwargs)
@classmethod
def _adapt_string_for_cast(cls, type_: sqltypes.String) -> sqltypes.CHAR:
def _adapt_string_for_cast(self, type_):
# copy the given string type into a CHAR
# for the purposes of rendering a CAST expression
type_ = sqltypes.to_instance(type_)
@@ -774,7 +712,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
__visit_name__ = "NVARCHAR"
def __init__(self, length: Optional[int] = None, **kwargs: Any):
def __init__(self, length=None, **kwargs):
"""Construct an NVARCHAR.
:param length: Maximum data length, in characters.
@@ -800,7 +738,7 @@ class NCHAR(_StringType, sqltypes.NCHAR):
__visit_name__ = "NCHAR"
def __init__(self, length: Optional[int] = None, **kwargs: Any):
def __init__(self, length=None, **kwargs):
"""Construct an NCHAR.
:param length: Maximum data length, in characters.

View File

@@ -1,11 +1,11 @@
# dialects/oracle/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# oracle/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from types import ModuleType
from . import base # noqa
from . import cx_oracle # noqa
@@ -32,18 +32,7 @@ from .base import ROWID
from .base import TIMESTAMP
from .base import VARCHAR
from .base import VARCHAR2
from .base import VECTOR
from .base import VectorIndexConfig
from .base import VectorIndexType
from .vector import SparseVector
from .vector import VectorDistanceType
from .vector import VectorStorageFormat
from .vector import VectorStorageType
# Alias oracledb also as oracledb_async
oracledb_async = type(
"oracledb_async", (ModuleType,), {"dialect": oracledb.dialect_async}
)
base.dialect = dialect = cx_oracle.dialect
@@ -71,11 +60,4 @@ __all__ = (
"NVARCHAR2",
"ROWID",
"REAL",
"VECTOR",
"VectorDistanceType",
"VectorIndexType",
"VectorIndexConfig",
"VectorStorageFormat",
"VectorStorageType",
"SparseVector",
)

View File

@@ -1,5 +1,4 @@
# dialects/oracle/cx_oracle.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -7,18 +6,13 @@
# mypy: ignore-errors
r""".. dialect:: oracle+cx_oracle
r"""
.. dialect:: oracle+cx_oracle
:name: cx-Oracle
:dbapi: cx_oracle
:connectstring: oracle+cx_oracle://user:pass@hostname:port[/dbname][?service_name=<service>[&key=value&key=value...]]
:url: https://oracle.github.io/python-cx_Oracle/
Description
-----------
cx_Oracle was the original driver for Oracle Database. It was superseded by
python-oracledb which should be used instead.
DSN vs. Hostname connections
-----------------------------
@@ -28,41 +22,27 @@ dialect translates from a series of different URL forms.
Hostname Connections with Easy Connect Syntax
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Given a hostname, port and service name of the target database, for example
from Oracle Database's Easy Connect syntax then connect in SQLAlchemy using the
``service_name`` query string parameter::
Given a hostname, port and service name of the target Oracle Database, for
example from Oracle's `Easy Connect syntax
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#easy-connect-syntax-for-connection-strings>`_,
then connect in SQLAlchemy using the ``service_name`` query string parameter::
engine = create_engine(
"oracle+cx_oracle://scott:tiger@hostname:port?service_name=myservice&encoding=UTF-8&nencoding=UTF-8"
)
engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:port/?service_name=myservice&encoding=UTF-8&nencoding=UTF-8")
Note that the default driver value for encoding and nencoding was changed to
“UTF-8” in cx_Oracle 8.0 so these parameters can be omitted when using that
version, or later.
The `full Easy Connect syntax
<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-B0437826-43C1-49EC-A94D-B650B6A4A6EE>`_
is not supported. Instead, use a ``tnsnames.ora`` file and connect using a
DSN.
To use a full Easy Connect string, pass it as the ``dsn`` key value in a
:paramref:`_sa.create_engine.connect_args` dictionary::
Connections with tnsnames.ora or Oracle Cloud
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import cx_Oracle
e = create_engine(
"oracle+cx_oracle://@",
connect_args={
"user": "scott",
"password": "tiger",
"dsn": "hostname:port/myservice?transport_connect_timeout=30&expire_time=60",
},
)
Connections with tnsnames.ora or to Oracle Autonomous Database
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Alternatively, if no port, database name, or service name is provided, the
dialect will use an Oracle Database DSN "connection string". This takes the
"hostname" portion of the URL as the data source name. For example, if the
``tnsnames.ora`` file contains a TNS Alias of ``myalias`` as below:
.. sourcecode:: text
Alternatively, if no port, database name, or ``service_name`` is provided, the
dialect will use an Oracle DSN "connection string". This takes the "hostname"
portion of the URL as the data source name. For example, if the
``tnsnames.ora`` file contains a `Net Service Name
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#net-service-names-for-connection-strings>`_
of ``myalias`` as below::
myalias =
(DESCRIPTION =
@@ -77,22 +57,19 @@ The cx_Oracle dialect connects to this database service when ``myalias`` is the
hostname portion of the URL, without specifying a port, database name or
``service_name``::
engine = create_engine("oracle+cx_oracle://scott:tiger@myalias")
engine = create_engine("oracle+cx_oracle://scott:tiger@myalias/?encoding=UTF-8&nencoding=UTF-8")
Users of Oracle Autonomous Database should use this syntax. If the database is
configured for mutural TLS ("mTLS"), then you must also configure the cloud
Users of Oracle Cloud should use this syntax and also configure the cloud
wallet as shown in cx_Oracle documentation `Connecting to Autononmous Databases
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#autonomousdb>`_.
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#connecting-to-autononmous-databases>`_.
SID Connections
^^^^^^^^^^^^^^^
To use Oracle Database's obsolete System Identifier connection syntax, the SID
can be passed in a "database name" portion of the URL::
To use Oracle's obsolete SID connection syntax, the SID can be passed in a
"database name" portion of the URL as below::
engine = create_engine(
"oracle+cx_oracle://scott:tiger@hostname:port/dbname"
)
engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:1521/dbname?encoding=UTF-8&nencoding=UTF-8")
Above, the DSN passed to cx_Oracle is created by ``cx_Oracle.makedsn()`` as
follows::
@@ -101,23 +78,17 @@ follows::
>>> cx_Oracle.makedsn("hostname", 1521, sid="dbname")
'(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=hostname)(PORT=1521))(CONNECT_DATA=(SID=dbname)))'
Note that although the SQLAlchemy syntax ``hostname:port/dbname`` looks like
Oracle's Easy Connect syntax it is different. It uses a SID in place of the
service name required by Easy Connect. The Easy Connect syntax does not
support SIDs.
Passing cx_Oracle connect arguments
-----------------------------------
Additional connection arguments can usually be passed via the URL query string;
particular symbols like ``SYSDBA`` are intercepted and converted to the correct
symbol::
Additional connection arguments can usually be passed via the URL
query string; particular symbols like ``cx_Oracle.SYSDBA`` are intercepted
and converted to the correct symbol::
e = create_engine(
"oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true"
)
"oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true")
.. versionchanged:: 1.3 the cx_Oracle dialect now accepts all argument names
.. versionchanged:: 1.3 the cx_oracle dialect now accepts all argument names
within the URL string itself, to be passed to the cx_Oracle DBAPI. As
was the case earlier but not correctly documented, the
:paramref:`_sa.create_engine.connect_args` parameter also accepts all
@@ -128,20 +99,19 @@ string, use the :paramref:`_sa.create_engine.connect_args` dictionary.
Any cx_Oracle parameter value and/or constant may be passed, such as::
import cx_Oracle
e = create_engine(
"oracle+cx_oracle://user:pass@dsn",
connect_args={
"encoding": "UTF-8",
"nencoding": "UTF-8",
"mode": cx_Oracle.SYSDBA,
"events": True,
},
"events": True
}
)
Note that the default driver value for ``encoding`` and ``nencoding`` was
changed to "UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when
using that version, or later.
Note that the default value for ``encoding`` and ``nencoding`` was changed to
"UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when using that
version, or later.
Options consumed by the SQLAlchemy cx_Oracle dialect outside of the driver
--------------------------------------------------------------------------
@@ -151,19 +121,14 @@ itself. These options are always passed directly to :func:`_sa.create_engine`
, such as::
e = create_engine(
"oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False
)
"oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False)
The parameters accepted by the cx_oracle dialect are as follows:
* ``arraysize`` - set the cx_oracle.arraysize value on cursors; defaults
to ``None``, indicating that the driver default should be used (typically
the value is 100). This setting controls how many rows are buffered when
fetching rows, and can have a significant effect on performance when
modified.
.. versionchanged:: 2.0.26 - changed the default value from 50 to None,
to use the default value of the driver itself.
* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted
to 50. This setting is significant with cx_Oracle as the contents of LOB
objects are only readable within a "live" row (e.g. within a batch of
50 rows).
* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`.
@@ -176,16 +141,10 @@ The parameters accepted by the cx_oracle dialect are as follows:
Using cx_Oracle SessionPool
---------------------------
The cx_Oracle driver provides its own connection pool implementation that may
be used in place of SQLAlchemy's pooling functionality. The driver pool
supports Oracle Database features such dead connection detection, connection
draining for planned database downtime, support for Oracle Application
Continuity and Transparent Application Continuity, and gives support for
Database Resident Connection Pooling (DRCP).
Using the driver pool can be achieved by using the
:paramref:`_sa.create_engine.creator` parameter to provide a function that
returns a new connection, along with setting
The cx_Oracle library provides its own connection pool implementation that may
be used in place of SQLAlchemy's pooling functionality. This can be achieved
by using the :paramref:`_sa.create_engine.creator` parameter to provide a
function that returns a new connection, along with setting
:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable
SQLAlchemy's pooling::
@@ -194,41 +153,32 @@ SQLAlchemy's pooling::
from sqlalchemy.pool import NullPool
pool = cx_Oracle.SessionPool(
user="scott",
password="tiger",
dsn="orclpdb",
min=1,
max=4,
increment=1,
threaded=True,
encoding="UTF-8",
nencoding="UTF-8",
user="scott", password="tiger", dsn="orclpdb",
min=2, max=5, increment=1, threaded=True,
encoding="UTF-8", nencoding="UTF-8"
)
engine = create_engine(
"oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool
)
engine = create_engine("oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool)
The above engine may then be used normally where cx_Oracle's pool handles
connection pooling::
with engine.connect() as conn:
print(conn.scalar("select 1 from dual"))
print(conn.scalar("select 1 FROM dual"))
As well as providing a scalable solution for multi-user applications, the
cx_Oracle session pool supports some Oracle features such as DRCP and
`Application Continuity
<https://cx-oracle.readthedocs.io/en/latest/user_guide/ha.html#application-continuity-ac>`_.
Note that the pool creation parameters ``threaded``, ``encoding`` and
``nencoding`` were deprecated in later cx_Oracle releases.
Using Oracle Database Resident Connection Pooling (DRCP)
--------------------------------------------------------
When using Oracle Database's DRCP, the best practice is to pass a connection
class and "purity" when acquiring a connection from the SessionPool. Refer to
the `cx_Oracle DRCP documentation
When using Oracle's `DRCP
<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-015CA8C1-2386-4626-855D-CC546DDC1086>`_,
the best practice is to pass a connection class and "purity" when acquiring a
connection from the SessionPool. Refer to the `cx_Oracle DRCP documentation
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#database-resident-connection-pooling-drcp>`_.
This can be achieved by wrapping ``pool.acquire()``::
@@ -238,33 +188,21 @@ This can be achieved by wrapping ``pool.acquire()``::
from sqlalchemy.pool import NullPool
pool = cx_Oracle.SessionPool(
user="scott",
password="tiger",
dsn="orclpdb",
min=2,
max=5,
increment=1,
threaded=True,
encoding="UTF-8",
nencoding="UTF-8",
user="scott", password="tiger", dsn="orclpdb",
min=2, max=5, increment=1, threaded=True,
encoding="UTF-8", nencoding="UTF-8"
)
def creator():
return pool.acquire(
cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF
)
return pool.acquire(cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF)
engine = create_engine(
"oracle+cx_oracle://", creator=creator, poolclass=NullPool
)
engine = create_engine("oracle+cx_oracle://", creator=creator, poolclass=NullPool)
The above engine may then be used normally where cx_Oracle handles session
pooling and Oracle Database additionally uses DRCP::
with engine.connect() as conn:
print(conn.scalar("select 1 from dual"))
print(conn.scalar("select 1 FROM dual"))
.. _cx_oracle_unicode:
@@ -272,28 +210,24 @@ Unicode
-------
As is the case for all DBAPIs under Python 3, all strings are inherently
Unicode strings. In all cases however, the driver requires an explicit
Unicode strings. In all cases however, the driver requires an explicit
encoding configuration.
Ensuring the Correct Client Encoding
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The long accepted standard for establishing client encoding for nearly all
Oracle Database related software is via the `NLS_LANG
<https://www.oracle.com/database/technologies/faq-nls-lang.html>`_ environment
variable. Older versions of cx_Oracle use this environment variable as the
source of its encoding configuration. The format of this variable is
Territory_Country.CharacterSet; a typical value would be
``AMERICAN_AMERICA.AL32UTF8``. cx_Oracle version 8 and later use the character
set "UTF-8" by default, and ignore the character set component of NLS_LANG.
Oracle related software is via the `NLS_LANG <https://www.oracle.com/database/technologies/faq-nls-lang.html>`_
environment variable. cx_Oracle like most other Oracle drivers will use
this environment variable as the source of its encoding configuration. The
format of this variable is idiosyncratic; a typical value would be
``AMERICAN_AMERICA.AL32UTF8``.
The cx_Oracle driver also supported a programmatic alternative which is to pass
the ``encoding`` and ``nencoding`` parameters directly to its ``.connect()``
function. These can be present in the URL as follows::
The cx_Oracle driver also supports a programmatic alternative which is to
pass the ``encoding`` and ``nencoding`` parameters directly to its
``.connect()`` function. These can be present in the URL as follows::
engine = create_engine(
"oracle+cx_oracle://scott:tiger@tnsalias?encoding=UTF-8&nencoding=UTF-8"
)
engine = create_engine("oracle+cx_oracle://scott:tiger@orclpdb/?encoding=UTF-8&nencoding=UTF-8")
For the meaning of the ``encoding`` and ``nencoding`` parameters, please
consult
@@ -308,24 +242,25 @@ consult
Unicode-specific Column datatypes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The Core expression language handles unicode data by use of the
:class:`.Unicode` and :class:`.UnicodeText` datatypes. These types correspond
to the VARCHAR2 and CLOB Oracle Database datatypes by default. When using
these datatypes with Unicode data, it is expected that the database is
configured with a Unicode-aware character set, as well as that the ``NLS_LANG``
environment variable is set appropriately (this applies to older versions of
cx_Oracle), so that the VARCHAR2 and CLOB datatypes can accommodate the data.
The Core expression language handles unicode data by use of the :class:`.Unicode`
and :class:`.UnicodeText`
datatypes. These types correspond to the VARCHAR2 and CLOB Oracle datatypes by
default. When using these datatypes with Unicode data, it is expected that
the Oracle database is configured with a Unicode-aware character set, as well
as that the ``NLS_LANG`` environment variable is set appropriately, so that
the VARCHAR2 and CLOB datatypes can accommodate the data.
In the case that Oracle Database is not configured with a Unicode character
In the case that the Oracle database is not configured with a Unicode character
set, the two options are to use the :class:`_types.NCHAR` and
:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag
``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, which will cause
the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`,
which will cause the
SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB.
.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText`
datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle Database
datatypes unless the ``use_nchar_for_unicode=True`` is passed to the dialect
.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText`
datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle datatypes
unless the ``use_nchar_for_unicode=True`` is passed to the dialect
when :func:`_sa.create_engine` is called.
@@ -334,7 +269,7 @@ the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
Encoding Errors
^^^^^^^^^^^^^^^
For the unusual case that data in Oracle Database is present with a broken
For the unusual case that data in the Oracle database is present with a broken
encoding, the dialect accepts a parameter ``encoding_errors`` which will be
passed to Unicode decoding functions in order to affect how decoding errors are
handled. The value is ultimately consumed by the Python `decode
@@ -352,13 +287,13 @@ Fine grained control over cx_Oracle data binding performance with setinputsizes
-------------------------------------------------------------------------------
The cx_Oracle DBAPI has a deep and fundamental reliance upon the usage of the
DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the
DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the
datatypes that are bound to a SQL statement for Python values being passed as
parameters. While virtually no other DBAPI assigns any use to the
``setinputsizes()`` call, the cx_Oracle DBAPI relies upon it heavily in its
interactions with the Oracle Database client interface, and in some scenarios
it is not possible for SQLAlchemy to know exactly how data should be bound, as
some settings can cause profoundly different performance characteristics, while
interactions with the Oracle client interface, and in some scenarios it is not
possible for SQLAlchemy to know exactly how data should be bound, as some
settings can cause profoundly different performance characteristics, while
altering the type coercion behavior at the same time.
Users of the cx_Oracle dialect are **strongly encouraged** to read through
@@ -387,16 +322,13 @@ objects which have a ``.key`` and a ``.type`` attribute::
engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe")
@event.listens_for(engine, "do_setinputsizes")
def _log_setinputsizes(inputsizes, cursor, statement, parameters, context):
for bindparam, dbapitype in inputsizes.items():
log.info(
"Bound parameter name: %s SQLAlchemy type: %r DBAPI object: %s",
bindparam.key,
bindparam.type,
dbapitype,
)
log.info(
"Bound parameter name: %s SQLAlchemy type: %r "
"DBAPI object: %s",
bindparam.key, bindparam.type, dbapitype)
Example 2 - remove all bindings to CLOB
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -410,28 +342,12 @@ series. This setting can be modified as follows::
engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe")
@event.listens_for(engine, "do_setinputsizes")
def _remove_clob(inputsizes, cursor, statement, parameters, context):
for bindparam, dbapitype in list(inputsizes.items()):
if dbapitype is CLOB:
del inputsizes[bindparam]
.. _cx_oracle_lob:
LOB Datatypes
--------------
LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and
BLOB. Modern versions of cx_Oracle is optimized for these datatypes to be
delivered as a single buffer. As such, SQLAlchemy makes use of these newer type
handlers by default.
To disable the use of newer type handlers and deliver LOB objects as classic
buffered objects with a ``read()`` method, the parameter
``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`,
which takes place only engine-wide.
.. _cx_oracle_returning:
RETURNING Support
@@ -440,12 +356,29 @@ RETURNING Support
The cx_Oracle dialect implements RETURNING using OUT parameters.
The dialect supports RETURNING fully.
Two Phase Transactions Not Supported
------------------------------------
.. _cx_oracle_lob:
Two phase transactions are **not supported** under cx_Oracle due to poor driver
support. The newer :ref:`oracledb` dialect however **does** support two phase
transactions.
LOB Datatypes
--------------
LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and
BLOB. Modern versions of cx_Oracle and oracledb are optimized for these
datatypes to be delivered as a single buffer. As such, SQLAlchemy makes use of
these newer type handlers by default.
To disable the use of newer type handlers and deliver LOB objects as classic
buffered objects with a ``read()`` method, the parameter
``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`,
which takes place only engine-wide.
Two Phase Transactions Not Supported
-------------------------------------
Two phase transactions are **not supported** under cx_Oracle due to poor
driver support. As of cx_Oracle 6.0b1, the interface for
two phase transactions has been changed to be more of a direct pass-through
to the underlying OCI layer with less automation. The additional logic
to support this system is not implemented in SQLAlchemy.
.. _cx_oracle_numeric:
@@ -456,21 +389,20 @@ SQLAlchemy's numeric types can handle receiving and returning values as Python
``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a
subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in
use, the :paramref:`.Numeric.asdecimal` flag determines if values should be
coerced to ``Decimal`` upon return, or returned as float objects. To make
matters more complicated under Oracle Database, the ``NUMBER`` type can also
represent integer values if the "scale" is zero, so the Oracle
Database-specific :class:`_oracle.NUMBER` type takes this into account as well.
coerced to ``Decimal`` upon return, or returned as float objects. To make
matters more complicated under Oracle, Oracle's ``NUMBER`` type can also
represent integer values if the "scale" is zero, so the Oracle-specific
:class:`_oracle.NUMBER` type takes this into account as well.
The cx_Oracle dialect makes extensive use of connection- and cursor-level
"outputtypehandler" callables in order to coerce numeric values as requested.
These callables are specific to the specific flavor of :class:`.Numeric` in
use, as well as if no SQLAlchemy typing objects are present. There are
observed scenarios where Oracle Database may send incomplete or ambiguous
information about the numeric types being returned, such as a query where the
numeric types are buried under multiple levels of subquery. The type handlers
do their best to make the right decision in all cases, deferring to the
underlying cx_Oracle DBAPI for all those cases where the driver can make the
best decision.
use, as well as if no SQLAlchemy typing objects are present. There are
observed scenarios where Oracle may sends incomplete or ambiguous information
about the numeric types being returned, such as a query where the numeric types
are buried under multiple levels of subquery. The type handlers do their best
to make the right decision in all cases, deferring to the underlying cx_Oracle
DBAPI for all those cases where the driver can make the best decision.
When no typing objects are present, as when executing plain SQL strings, a
default "outputtypehandler" is present which will generally return numeric
@@ -882,8 +814,6 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
out_parameters[name] = self.cursor.var(
dbtype,
# this is fine also in oracledb_async since
# the driver will await the read coroutine
outconverter=lambda value: value.read(),
arraysize=len_params,
)
@@ -902,9 +832,9 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
)
for param in self.parameters:
param[quoted_bind_names.get(name, name)] = (
out_parameters[name]
)
param[
quoted_bind_names.get(name, name)
] = out_parameters[name]
def _generate_cursor_outputtype_handler(self):
output_handlers = {}
@@ -1100,7 +1030,7 @@ class OracleDialect_cx_oracle(OracleDialect):
self,
auto_convert_lobs=True,
coerce_to_decimal=True,
arraysize=None,
arraysize=50,
encoding_errors=None,
threaded=None,
**kwargs,
@@ -1234,9 +1164,6 @@ class OracleDialect_cx_oracle(OracleDialect):
with dbapi_connection.cursor() as cursor:
cursor.execute(f"ALTER SESSION SET ISOLATION_LEVEL={level}")
def detect_autocommit_setting(self, dbapi_conn) -> bool:
return bool(dbapi_conn.autocommit)
def _detect_decimal_char(self, connection):
# we have the option to change this setting upon connect,
# or just look at what it is upon connect and convert.
@@ -1356,13 +1283,8 @@ class OracleDialect_cx_oracle(OracleDialect):
cx_Oracle.CLOB,
cx_Oracle.NCLOB,
):
typ = (
cx_Oracle.DB_TYPE_VARCHAR
if default_type is cx_Oracle.CLOB
else cx_Oracle.DB_TYPE_NVARCHAR
)
return cursor.var(
typ,
cx_Oracle.DB_TYPE_NVARCHAR,
_CX_ORACLE_MAGIC_LOB_SIZE,
cursor.arraysize,
**dialect._cursor_var_unicode_kwargs,
@@ -1493,6 +1415,13 @@ class OracleDialect_cx_oracle(OracleDialect):
return False
def create_xid(self):
"""create a two-phase transaction ID.
this id will be passed to do_begin_twophase(), do_rollback_twophase(),
do_commit_twophase(). its format is unspecified.
"""
id_ = random.randint(0, 2**128)
return (0x1234, "%032x" % id_, "%032x" % 9)

View File

@@ -1,5 +1,4 @@
# dialects/oracle/dictionary.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,639 +1,68 @@
# dialects/oracle/oracledb.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r""".. dialect:: oracle+oracledb
r"""
.. dialect:: oracle+oracledb
:name: python-oracledb
:dbapi: oracledb
:connectstring: oracle+oracledb://user:pass@hostname:port[/dbname][?service_name=<service>[&key=value&key=value...]]
:url: https://oracle.github.io/python-oracledb/
Description
-----------
python-oracledb is released by Oracle to supersede the cx_Oracle driver.
It is fully compatible with cx_Oracle and features both a "thin" client
mode that requires no dependencies, as well as a "thick" mode that uses
the Oracle Client Interface in the same way as cx_Oracle.
Python-oracledb is the Oracle Database driver for Python. It features a default
"thin" client mode that requires no dependencies, and an optional "thick" mode
that uses Oracle Client libraries. It supports SQLAlchemy features including
two phase transactions and Asyncio.
.. seealso::
Python-oracle is the renamed, updated cx_Oracle driver. Oracle is no longer
doing any releases in the cx_Oracle namespace.
The SQLAlchemy ``oracledb`` dialect provides both a sync and an async
implementation under the same dialect name. The proper version is
selected depending on how the engine is created:
* calling :func:`_sa.create_engine` with ``oracle+oracledb://...`` will
automatically select the sync version::
from sqlalchemy import create_engine
sync_engine = create_engine(
"oracle+oracledb://scott:tiger@localhost?service_name=FREEPDB1"
)
* calling :func:`_asyncio.create_async_engine` with ``oracle+oracledb://...``
will automatically select the async version::
from sqlalchemy.ext.asyncio import create_async_engine
asyncio_engine = create_async_engine(
"oracle+oracledb://scott:tiger@localhost?service_name=FREEPDB1"
)
The asyncio version of the dialect may also be specified explicitly using the
``oracledb_async`` suffix::
from sqlalchemy.ext.asyncio import create_async_engine
asyncio_engine = create_async_engine(
"oracle+oracledb_async://scott:tiger@localhost?service_name=FREEPDB1"
)
.. versionadded:: 2.0.25 added support for the async version of oracledb.
:ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver
as well.
Thick mode support
------------------
By default, the python-oracledb driver runs in a "thin" mode that does not
require Oracle Client libraries to be installed. The driver also supports a
"thick" mode that uses Oracle Client libraries to get functionality such as
Oracle Application Continuity.
By default the ``python-oracledb`` is started in thin mode, that does not
require oracle client libraries to be installed in the system. The
``python-oracledb`` driver also support a "thick" mode, that behaves
similarly to ``cx_oracle`` and requires that Oracle Client Interface (OCI)
is installed.
To enable thick mode, call `oracledb.init_oracle_client()
<https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client>`_
explicitly, or pass the parameter ``thick_mode=True`` to
:func:`_sa.create_engine`. To pass custom arguments to
``init_oracle_client()``, like the ``lib_dir`` path, a dict may be passed, for
example::
To enable this mode, the user may call ``oracledb.init_oracle_client``
manually, or by passing the parameter ``thick_mode=True`` to
:func:`_sa.create_engine`. To pass custom arguments to ``init_oracle_client``,
like the ``lib_dir`` path, a dict may be passed to this parameter, as in::
engine = sa.create_engine(
"oracle+oracledb://...",
thick_mode={
"lib_dir": "/path/to/oracle/client/lib",
"config_dir": "/path/to/network_config_file_directory",
"driver_name": "my-app : 1.0.0",
},
)
Note that passing a ``lib_dir`` path should only be done on macOS or
Windows. On Linux it does not behave as you might expect.
engine = sa.create_engine("oracle+oracledb://...", thick_mode={
"lib_dir": "/path/to/oracle/client/lib", "driver_name": "my-app"
})
.. seealso::
python-oracledb documentation `Enabling python-oracledb Thick mode
<https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html#enabling-python-oracledb-thick-mode>`_
https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client
Connecting to Oracle Database
-----------------------------
python-oracledb provides several methods of indicating the target database.
The dialect translates from a series of different URL forms.
Given the hostname, port and service name of the target database, you can
connect in SQLAlchemy using the ``service_name`` query string parameter::
engine = create_engine(
"oracle+oracledb://scott:tiger@hostname:port?service_name=myservice"
)
Connecting with Easy Connect strings
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
You can pass any valid python-oracledb connection string as the ``dsn`` key
value in a :paramref:`_sa.create_engine.connect_args` dictionary. See
python-oracledb documentation `Oracle Net Services Connection Strings
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#oracle-net-services-connection-strings>`_.
For example to use an `Easy Connect string
<https://download.oracle.com/ocomdocs/global/Oracle-Net-Easy-Connect-Plus.pdf>`_
with a timeout to prevent connection establishment from hanging if the network
transport to the database cannot be establishd in 30 seconds, and also setting
a keep-alive time of 60 seconds to stop idle network connections from being
terminated by a firewall::
e = create_engine(
"oracle+oracledb://@",
connect_args={
"user": "scott",
"password": "tiger",
"dsn": "hostname:port/myservice?transport_connect_timeout=30&expire_time=60",
},
)
The Easy Connect syntax has been enhanced during the life of Oracle Database.
Review the documentation for your database version. The current documentation
is at `Understanding the Easy Connect Naming Method
<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-B0437826-43C1-49EC-A94D-B650B6A4A6EE>`_.
The general syntax is similar to:
.. sourcecode:: text
[[protocol:]//]host[:port][/[service_name]][?parameter_name=value{&parameter_name=value}]
Note that although the SQLAlchemy URL syntax ``hostname:port/dbname`` looks
like Oracle's Easy Connect syntax, it is different. SQLAlchemy's URL requires a
system identifier (SID) for the ``dbname`` component::
engine = create_engine("oracle+oracledb://scott:tiger@hostname:port/sid")
Easy Connect syntax does not support SIDs. It uses services names, which are
the preferred choice for connecting to Oracle Database.
Passing python-oracledb connect arguments
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Other python-oracledb driver `connection options
<https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.connect>`_
can be passed in ``connect_args``. For example::
e = create_engine(
"oracle+oracledb://@",
connect_args={
"user": "scott",
"password": "tiger",
"dsn": "hostname:port/myservice",
"events": True,
"mode": oracledb.AUTH_MODE_SYSDBA,
},
)
Connecting with tnsnames.ora TNS aliases
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If no port, database name, or service name is provided, the dialect will use an
Oracle Database DSN "connection string". This takes the "hostname" portion of
the URL as the data source name. For example, if the ``tnsnames.ora`` file
contains a `TNS Alias
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#tns-aliases-for-connection-strings>`_
of ``myalias`` as below:
.. sourcecode:: text
myalias =
(DESCRIPTION =
(ADDRESS = (PROTOCOL = TCP)(HOST = mymachine.example.com)(PORT = 1521))
(CONNECT_DATA =
(SERVER = DEDICATED)
(SERVICE_NAME = orclpdb1)
)
)
The python-oracledb dialect connects to this database service when ``myalias`` is the
hostname portion of the URL, without specifying a port, database name or
``service_name``::
engine = create_engine("oracle+oracledb://scott:tiger@myalias")
Connecting to Oracle Autonomous Database
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Users of Oracle Autonomous Database should use either use the TNS Alias URL
shown above, or pass the TNS Alias as the ``dsn`` key value in a
:paramref:`_sa.create_engine.connect_args` dictionary.
If Oracle Autonomous Database is configured for mutual TLS ("mTLS")
connections, then additional configuration is required as shown in `Connecting
to Oracle Cloud Autonomous Databases
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#connecting-to-oracle-cloud-autonomous-databases>`_. In
summary, Thick mode users should configure file locations and set the wallet
path in ``sqlnet.ora`` appropriately::
e = create_engine(
"oracle+oracledb://@",
thick_mode={
# directory containing tnsnames.ora and cwallet.so
"config_dir": "/opt/oracle/wallet_dir",
},
connect_args={
"user": "scott",
"password": "tiger",
"dsn": "mydb_high",
},
)
Thin mode users of mTLS should pass the appropriate directories and PEM wallet
password when creating the engine, similar to::
e = create_engine(
"oracle+oracledb://@",
connect_args={
"user": "scott",
"password": "tiger",
"dsn": "mydb_high",
"config_dir": "/opt/oracle/wallet_dir", # directory containing tnsnames.ora
"wallet_location": "/opt/oracle/wallet_dir", # directory containing ewallet.pem
"wallet_password": "top secret", # password for the PEM file
},
)
Typically ``config_dir`` and ``wallet_location`` are the same directory, which
is where the Oracle Autonomous Database wallet zip file was extracted. Note
this directory should be protected.
Connection Pooling
------------------
Applications with multiple concurrent users should use connection pooling. A
minimal sized connection pool is also beneficial for long-running, single-user
applications that do not frequently use a connection.
The python-oracledb driver provides its own connection pool implementation that
may be used in place of SQLAlchemy's pooling functionality. The driver pool
gives support for high availability features such as dead connection detection,
connection draining for planned database downtime, support for Oracle
Application Continuity and Transparent Application Continuity, and gives
support for `Database Resident Connection Pooling (DRCP)
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#database-resident-connection-pooling-drcp>`_.
To take advantage of python-oracledb's pool, use the
:paramref:`_sa.create_engine.creator` parameter to provide a function that
returns a new connection, along with setting
:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable
SQLAlchemy's pooling::
import oracledb
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy.pool import NullPool
# Uncomment to use the optional python-oracledb Thick mode.
# Review the python-oracledb doc for the appropriate parameters
# oracledb.init_oracle_client(<your parameters>)
pool = oracledb.create_pool(
user="scott",
password="tiger",
dsn="localhost:1521/freepdb1",
min=1,
max=4,
increment=1,
)
engine = create_engine(
"oracle+oracledb://", creator=pool.acquire, poolclass=NullPool
)
The above engine may then be used normally. Internally, python-oracledb handles
connection pooling::
with engine.connect() as conn:
print(conn.scalar(text("select 1 from dual")))
Refer to the python-oracledb documentation for `oracledb.create_pool()
<https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.create_pool>`_
for the arguments that can be used when creating a connection pool.
.. _drcp:
Using Oracle Database Resident Connection Pooling (DRCP)
--------------------------------------------------------
When using Oracle Database's Database Resident Connection Pooling (DRCP), the
best practice is to specify a connection class and "purity". Refer to the
`python-oracledb documentation on DRCP
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#database-resident-connection-pooling-drcp>`_.
For example::
import oracledb
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy.pool import NullPool
# Uncomment to use the optional python-oracledb Thick mode.
# Review the python-oracledb doc for the appropriate parameters
# oracledb.init_oracle_client(<your parameters>)
pool = oracledb.create_pool(
user="scott",
password="tiger",
dsn="localhost:1521/freepdb1",
min=1,
max=4,
increment=1,
cclass="MYCLASS",
purity=oracledb.PURITY_SELF,
)
engine = create_engine(
"oracle+oracledb://", creator=pool.acquire, poolclass=NullPool
)
The above engine may then be used normally where python-oracledb handles
application connection pooling and Oracle Database additionally uses DRCP::
with engine.connect() as conn:
print(conn.scalar(text("select 1 from dual")))
If you wish to use different connection classes or purities for different
connections, then wrap ``pool.acquire()``::
import oracledb
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy.pool import NullPool
# Uncomment to use python-oracledb Thick mode.
# Review the python-oracledb doc for the appropriate parameters
# oracledb.init_oracle_client(<your parameters>)
pool = oracledb.create_pool(
user="scott",
password="tiger",
dsn="localhost:1521/freepdb1",
min=1,
max=4,
increment=1,
cclass="MYCLASS",
purity=oracledb.PURITY_SELF,
)
def creator():
return pool.acquire(cclass="MYOTHERCLASS", purity=oracledb.PURITY_NEW)
engine = create_engine(
"oracle+oracledb://", creator=creator, poolclass=NullPool
)
Engine Options consumed by the SQLAlchemy oracledb dialect outside of the driver
--------------------------------------------------------------------------------
There are also options that are consumed by the SQLAlchemy oracledb dialect
itself. These options are always passed directly to :func:`_sa.create_engine`,
such as::
e = create_engine("oracle+oracledb://user:pass@tnsalias", arraysize=500)
The parameters accepted by the oracledb dialect are as follows:
* ``arraysize`` - set the driver cursor.arraysize value. It defaults to
``None``, indicating that the driver default value of 100 should be used.
This setting controls how many rows are buffered when fetching rows, and can
have a significant effect on performance if increased for queries that return
large numbers of rows.
.. versionchanged:: 2.0.26 - changed the default value from 50 to None,
to use the default value of the driver itself.
* ``auto_convert_lobs`` - defaults to True; See :ref:`oracledb_lob`.
* ``coerce_to_decimal`` - see :ref:`oracledb_numeric` for detail.
* ``encoding_errors`` - see :ref:`oracledb_unicode_encoding_errors` for detail.
.. _oracledb_unicode:
Unicode
-------
As is the case for all DBAPIs under Python 3, all strings are inherently
Unicode strings.
Ensuring the Correct Client Encoding
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In python-oracledb, the encoding used for all character data is "UTF-8".
Unicode-specific Column datatypes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The Core expression language handles unicode data by use of the
:class:`.Unicode` and :class:`.UnicodeText` datatypes. These types correspond
to the VARCHAR2 and CLOB Oracle Database datatypes by default. When using
these datatypes with Unicode data, it is expected that the database is
configured with a Unicode-aware character set so that the VARCHAR2 and CLOB
datatypes can accommodate the data.
In the case that Oracle Database is not configured with a Unicode character
set, the two options are to use the :class:`_types.NCHAR` and
:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag
``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, which will cause
the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB.
.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText`
datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle Database
datatypes unless the ``use_nchar_for_unicode=True`` is passed to the dialect
when :func:`_sa.create_engine` is called.
.. _oracledb_unicode_encoding_errors:
Encoding Errors
^^^^^^^^^^^^^^^
For the unusual case that data in Oracle Database is present with a broken
encoding, the dialect accepts a parameter ``encoding_errors`` which will be
passed to Unicode decoding functions in order to affect how decoding errors are
handled. The value is ultimately consumed by the Python `decode
<https://docs.python.org/3/library/stdtypes.html#bytes.decode>`_ function, and
is passed both via python-oracledb's ``encodingErrors`` parameter consumed by
``Cursor.var()``, as well as SQLAlchemy's own decoding function, as the
python-oracledb dialect makes use of both under different circumstances.
.. versionadded:: 1.3.11
.. _oracledb_setinputsizes:
Fine grained control over python-oracledb data binding with setinputsizes
-------------------------------------------------------------------------
The python-oracle DBAPI has a deep and fundamental reliance upon the usage of
the DBAPI ``setinputsizes()`` call. The purpose of this call is to establish
the datatypes that are bound to a SQL statement for Python values being passed
as parameters. While virtually no other DBAPI assigns any use to the
``setinputsizes()`` call, the python-oracledb DBAPI relies upon it heavily in
its interactions with the Oracle Database, and in some scenarios it is not
possible for SQLAlchemy to know exactly how data should be bound, as some
settings can cause profoundly different performance characteristics, while
altering the type coercion behavior at the same time.
Users of the oracledb dialect are **strongly encouraged** to read through
python-oracledb's list of built-in datatype symbols at `Database Types
<https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#database-types>`_
Note that in some cases, significant performance degradation can occur when
using these types vs. not.
On the SQLAlchemy side, the :meth:`.DialectEvents.do_setinputsizes` event can
be used both for runtime visibility (e.g. logging) of the setinputsizes step as
well as to fully control how ``setinputsizes()`` is used on a per-statement
basis.
.. versionadded:: 1.2.9 Added :meth:`.DialectEvents.setinputsizes`
Example 1 - logging all setinputsizes calls
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The following example illustrates how to log the intermediary values from a
SQLAlchemy perspective before they are converted to the raw ``setinputsizes()``
parameter dictionary. The keys of the dictionary are :class:`.BindParameter`
objects which have a ``.key`` and a ``.type`` attribute::
from sqlalchemy import create_engine, event
engine = create_engine(
"oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1"
)
@event.listens_for(engine, "do_setinputsizes")
def _log_setinputsizes(inputsizes, cursor, statement, parameters, context):
for bindparam, dbapitype in inputsizes.items():
log.info(
"Bound parameter name: %s SQLAlchemy type: %r DBAPI object: %s",
bindparam.key,
bindparam.type,
dbapitype,
)
Example 2 - remove all bindings to CLOB
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
For performance, fetching LOB datatypes from Oracle Database is set by default
for the ``Text`` type within SQLAlchemy. This setting can be modified as
follows::
from sqlalchemy import create_engine, event
from oracledb import CLOB
engine = create_engine(
"oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1"
)
@event.listens_for(engine, "do_setinputsizes")
def _remove_clob(inputsizes, cursor, statement, parameters, context):
for bindparam, dbapitype in list(inputsizes.items()):
if dbapitype is CLOB:
del inputsizes[bindparam]
.. _oracledb_lob:
LOB Datatypes
--------------
LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and
BLOB. Oracle Database can efficiently return these datatypes as a single
buffer. SQLAlchemy makes use of type handlers to do this by default.
To disable the use of the type handlers and deliver LOB objects as classic
buffered objects with a ``read()`` method, the parameter
``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`.
.. _oracledb_returning:
RETURNING Support
-----------------
The oracledb dialect implements RETURNING using OUT parameters. The dialect
supports RETURNING fully.
Two Phase Transaction Support
-----------------------------
Two phase transactions are fully supported with python-oracledb. (Thin mode
requires python-oracledb 2.3). APIs for two phase transactions are provided at
the Core level via :meth:`_engine.Connection.begin_twophase` and
:paramref:`_orm.Session.twophase` for transparent ORM use.
.. versionchanged:: 2.0.32 added support for two phase transactions
.. _oracledb_numeric:
Precision Numerics
------------------
SQLAlchemy's numeric types can handle receiving and returning values as Python
``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a
subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in
use, the :paramref:`.Numeric.asdecimal` flag determines if values should be
coerced to ``Decimal`` upon return, or returned as float objects. To make
matters more complicated under Oracle Database, the ``NUMBER`` type can also
represent integer values if the "scale" is zero, so the Oracle
Database-specific :class:`_oracle.NUMBER` type takes this into account as well.
The oracledb dialect makes extensive use of connection- and cursor-level
"outputtypehandler" callables in order to coerce numeric values as requested.
These callables are specific to the specific flavor of :class:`.Numeric` in
use, as well as if no SQLAlchemy typing objects are present. There are
observed scenarios where Oracle Database may send incomplete or ambiguous
information about the numeric types being returned, such as a query where the
numeric types are buried under multiple levels of subquery. The type handlers
do their best to make the right decision in all cases, deferring to the
underlying python-oracledb DBAPI for all those cases where the driver can make
the best decision.
When no typing objects are present, as when executing plain SQL strings, a
default "outputtypehandler" is present which will generally return numeric
values which specify precision and scale as Python ``Decimal`` objects. To
disable this coercion to decimal for performance reasons, pass the flag
``coerce_to_decimal=False`` to :func:`_sa.create_engine`::
engine = create_engine(
"oracle+oracledb://scott:tiger@tnsalias", coerce_to_decimal=False
)
The ``coerce_to_decimal`` flag only impacts the results of plain string
SQL statements that are not otherwise associated with a :class:`.Numeric`
SQLAlchemy type (or a subclass of such).
.. versionchanged:: 1.2 The numeric handling system for the oracle dialects has
been reworked to take advantage of newer driver features as well as better
integration of outputtypehandlers.
.. versionadded:: 2.0.0 added support for the python-oracledb driver.
.. versionadded:: 2.0.0 added support for oracledb driver.
""" # noqa
from __future__ import annotations
import collections
import re
from typing import Any
from typing import TYPE_CHECKING
from . import cx_oracle as _cx_oracle
from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle
from ... import exc
from ... import pool
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
from ...engine import default
from ...util import asbool
from ...util import await_fallback
from ...util import await_only
if TYPE_CHECKING:
from oracledb import AsyncConnection
from oracledb import AsyncCursor
class OracleExecutionContext_oracledb(
_cx_oracle.OracleExecutionContext_cx_oracle
):
pass
class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
class OracleDialect_oracledb(_OracleDialect_cx_oracle):
supports_statement_cache = True
execution_ctx_cls = OracleExecutionContext_oracledb
driver = "oracledb"
_min_version = (1,)
def __init__(
self,
auto_convert_lobs=True,
coerce_to_decimal=True,
arraysize=None,
arraysize=50,
encoding_errors=None,
thick_mode=None,
**kwargs,
@@ -662,10 +91,6 @@ class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
def is_thin_mode(cls, connection):
return connection.connection.dbapi_connection.thin
@classmethod
def get_async_dialect_cls(cls, url):
return OracleDialectAsync_oracledb
def _load_version(self, dbapi_module):
version = (0, 0, 0)
if dbapi_module is not None:
@@ -675,273 +100,10 @@ class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
int(x) for x in m.group(1, 2, 3) if x is not None
)
self.oracledb_ver = version
if (
self.oracledb_ver > (0, 0, 0)
and self.oracledb_ver < self._min_version
):
if self.oracledb_ver < (1,) and self.oracledb_ver > (0, 0, 0):
raise exc.InvalidRequestError(
f"oracledb version {self._min_version} and above are supported"
"oracledb version 1 and above are supported"
)
def do_begin_twophase(self, connection, xid):
conn_xis = connection.connection.xid(*xid)
connection.connection.tpc_begin(conn_xis)
connection.connection.info["oracledb_xid"] = conn_xis
def do_prepare_twophase(self, connection, xid):
should_commit = connection.connection.tpc_prepare()
connection.info["oracledb_should_commit"] = should_commit
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False
):
if recover:
conn_xid = connection.connection.xid(*xid)
else:
conn_xid = None
connection.connection.tpc_rollback(conn_xid)
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False
):
conn_xid = None
if not is_prepared:
should_commit = connection.connection.tpc_prepare()
elif recover:
conn_xid = connection.connection.xid(*xid)
should_commit = True
else:
should_commit = connection.info["oracledb_should_commit"]
if should_commit:
connection.connection.tpc_commit(conn_xid)
def do_recover_twophase(self, connection):
return [
# oracledb seems to return bytes
(
fi,
gti.decode() if isinstance(gti, bytes) else gti,
bq.decode() if isinstance(bq, bytes) else bq,
)
for fi, gti, bq in connection.connection.tpc_recover()
]
def _check_max_identifier_length(self, connection):
if self.oracledb_ver >= (2, 5):
max_len = connection.connection.max_identifier_length
if max_len is not None:
return max_len
return super()._check_max_identifier_length(connection)
class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
_cursor: AsyncCursor
__slots__ = ()
@property
def outputtypehandler(self):
return self._cursor.outputtypehandler
@outputtypehandler.setter
def outputtypehandler(self, value):
self._cursor.outputtypehandler = value
def var(self, *args, **kwargs):
return self._cursor.var(*args, **kwargs)
def close(self):
self._rows.clear()
self._cursor.close()
def setinputsizes(self, *args: Any, **kwargs: Any) -> Any:
return self._cursor.setinputsizes(*args, **kwargs)
def _aenter_cursor(self, cursor: AsyncCursor) -> AsyncCursor:
try:
return cursor.__enter__()
except Exception as error:
self._adapt_connection._handle_exception(error)
async def _execute_async(self, operation, parameters):
# override to not use mutex, oracledb already has a mutex
if parameters is None:
result = await self._cursor.execute(operation)
else:
result = await self._cursor.execute(operation, parameters)
if self._cursor.description and not self.server_side:
self._rows = collections.deque(await self._cursor.fetchall())
return result
async def _executemany_async(
self,
operation,
seq_of_parameters,
):
# override to not use mutex, oracledb already has a mutex
return await self._cursor.executemany(operation, seq_of_parameters)
def __enter__(self):
return self
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
self.close()
class AsyncAdapt_oracledb_ss_cursor(
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_oracledb_cursor
):
__slots__ = ()
def close(self) -> None:
if self._cursor is not None:
self._cursor.close()
self._cursor = None # type: ignore
class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
_connection: AsyncConnection
__slots__ = ()
thin = True
_cursor_cls = AsyncAdapt_oracledb_cursor
_ss_cursor_cls = None
@property
def autocommit(self):
return self._connection.autocommit
@autocommit.setter
def autocommit(self, value):
self._connection.autocommit = value
@property
def outputtypehandler(self):
return self._connection.outputtypehandler
@outputtypehandler.setter
def outputtypehandler(self, value):
self._connection.outputtypehandler = value
@property
def version(self):
return self._connection.version
@property
def stmtcachesize(self):
return self._connection.stmtcachesize
@stmtcachesize.setter
def stmtcachesize(self, value):
self._connection.stmtcachesize = value
@property
def max_identifier_length(self):
return self._connection.max_identifier_length
def cursor(self):
return AsyncAdapt_oracledb_cursor(self)
def ss_cursor(self):
return AsyncAdapt_oracledb_ss_cursor(self)
def xid(self, *args: Any, **kwargs: Any) -> Any:
return self._connection.xid(*args, **kwargs)
def tpc_begin(self, *args: Any, **kwargs: Any) -> Any:
return self.await_(self._connection.tpc_begin(*args, **kwargs))
def tpc_commit(self, *args: Any, **kwargs: Any) -> Any:
return self.await_(self._connection.tpc_commit(*args, **kwargs))
def tpc_prepare(self, *args: Any, **kwargs: Any) -> Any:
return self.await_(self._connection.tpc_prepare(*args, **kwargs))
def tpc_recover(self, *args: Any, **kwargs: Any) -> Any:
return self.await_(self._connection.tpc_recover(*args, **kwargs))
def tpc_rollback(self, *args: Any, **kwargs: Any) -> Any:
return self.await_(self._connection.tpc_rollback(*args, **kwargs))
class AsyncAdaptFallback_oracledb_connection(
AsyncAdaptFallback_dbapi_connection, AsyncAdapt_oracledb_connection
):
__slots__ = ()
class OracledbAdaptDBAPI:
def __init__(self, oracledb) -> None:
self.oracledb = oracledb
for k, v in self.oracledb.__dict__.items():
if k != "connect":
self.__dict__[k] = v
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async)
if asbool(async_fallback):
return AsyncAdaptFallback_oracledb_connection(
self, await_fallback(creator_fn(*arg, **kw))
)
else:
return AsyncAdapt_oracledb_connection(
self, await_only(creator_fn(*arg, **kw))
)
class OracleExecutionContextAsync_oracledb(OracleExecutionContext_oracledb):
# restore default create cursor
create_cursor = default.DefaultExecutionContext.create_cursor
def create_default_cursor(self):
# copy of OracleExecutionContext_cx_oracle.create_cursor
c = self._dbapi_connection.cursor()
if self.dialect.arraysize:
c.arraysize = self.dialect.arraysize
return c
def create_server_side_cursor(self):
c = self._dbapi_connection.ss_cursor()
if self.dialect.arraysize:
c.arraysize = self.dialect.arraysize
return c
class OracleDialectAsync_oracledb(OracleDialect_oracledb):
is_async = True
supports_server_side_cursors = True
supports_statement_cache = True
execution_ctx_cls = OracleExecutionContextAsync_oracledb
_min_version = (2,)
# thick_mode mode is not supported by asyncio, oracledb will raise
@classmethod
def import_dbapi(cls):
import oracledb
return OracledbAdaptDBAPI(oracledb)
@classmethod
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if asbool(async_fallback):
return pool.FallbackAsyncAdaptedQueuePool
else:
return pool.AsyncAdaptedQueuePool
def get_driver_connection(self, connection):
return connection._connection
dialect = OracleDialect_oracledb
dialect_async = OracleDialectAsync_oracledb

View File

@@ -1,9 +1,3 @@
# dialects/oracle/provision.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import create_engine
@@ -89,7 +83,7 @@ def _oracle_drop_db(cfg, eng, ident):
# cx_Oracle seems to occasionally leak open connections when a large
# suite it run, even if we confirm we have zero references to
# connection objects.
# while there is a "kill session" command in Oracle Database,
# while there is a "kill session" command in Oracle,
# it unfortunately does not release the connection sufficiently.
_ora_drop_ignore(conn, ident)
_ora_drop_ignore(conn, "%s_ts1" % ident)

View File

@@ -1,5 +1,4 @@
# dialects/oracle/types.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -64,18 +63,17 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer):
class FLOAT(sqltypes.FLOAT):
"""Oracle Database FLOAT.
"""Oracle FLOAT.
This is the same as :class:`_sqltypes.FLOAT` except that
an Oracle Database -specific :paramref:`_oracle.FLOAT.binary_precision`
an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
parameter is accepted, and
the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
Oracle Database FLOAT types indicate precision in terms of "binary
precision", which defaults to 126. For a REAL type, the value is 63. This
parameter does not cleanly map to a specific number of decimal places but
is roughly equivalent to the desired number of decimal places divided by
0.3103.
Oracle FLOAT types indicate precision in terms of "binary precision", which
defaults to 126. For a REAL type, the value is 63. This parameter does not
cleanly map to a specific number of decimal places but is roughly
equivalent to the desired number of decimal places divided by 0.3103.
.. versionadded:: 2.0
@@ -92,11 +90,10 @@ class FLOAT(sqltypes.FLOAT):
r"""
Construct a FLOAT
:param binary_precision: Oracle Database binary precision value to be
rendered in DDL. This may be approximated to the number of decimal
characters using the formula "decimal precision = 0.30103 * binary
precision". The default value used by Oracle Database for FLOAT /
DOUBLE PRECISION is 126.
:param binary_precision: Oracle binary precision value to be rendered
in DDL. This may be approximated to the number of decimal characters
using the formula "decimal precision = 0.30103 * binary precision".
The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
:param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
@@ -111,36 +108,10 @@ class FLOAT(sqltypes.FLOAT):
class BINARY_DOUBLE(sqltypes.Double):
"""Implement the Oracle ``BINARY_DOUBLE`` datatype.
This datatype differs from the Oracle ``DOUBLE`` datatype in that it
delivers a true 8-byte FP value. The datatype may be combined with a
generic :class:`.Double` datatype using :meth:`.TypeEngine.with_variant`.
.. seealso::
:ref:`oracle_float_support`
"""
__visit_name__ = "BINARY_DOUBLE"
class BINARY_FLOAT(sqltypes.Float):
"""Implement the Oracle ``BINARY_FLOAT`` datatype.
This datatype differs from the Oracle ``FLOAT`` datatype in that it
delivers a true 4-byte FP value. The datatype may be combined with a
generic :class:`.Float` datatype using :meth:`.TypeEngine.with_variant`.
.. seealso::
:ref:`oracle_float_support`
"""
__visit_name__ = "BINARY_FLOAT"
@@ -191,10 +162,10 @@ class _OracleDateLiteralRender:
class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
"""Provide the Oracle Database DATE type.
"""Provide the oracle DATE type.
This type has no special Python behavior, except that it subclasses
:class:`_types.DateTime`; this is to suit the fact that the Oracle Database
:class:`_types.DateTime`; this is to suit the fact that the Oracle
``DATE`` type supports a time value.
"""
@@ -274,8 +245,8 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
class TIMESTAMP(sqltypes.TIMESTAMP):
"""Oracle Database implementation of ``TIMESTAMP``, which supports
additional Oracle Database-specific modes
"""Oracle implementation of ``TIMESTAMP``, which supports additional
Oracle-specific modes
.. versionadded:: 2.0
@@ -285,11 +256,10 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
"""Construct a new :class:`_oracle.TIMESTAMP`.
:param timezone: boolean. Indicates that the TIMESTAMP type should
use Oracle Database's ``TIMESTAMP WITH TIME ZONE`` datatype.
use Oracle's ``TIMESTAMP WITH TIME ZONE`` datatype.
:param local_timezone: boolean. Indicates that the TIMESTAMP type
should use Oracle Database's ``TIMESTAMP WITH LOCAL TIME ZONE``
datatype.
should use Oracle's ``TIMESTAMP WITH LOCAL TIME ZONE`` datatype.
"""
@@ -302,7 +272,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
class ROWID(sqltypes.TypeEngine):
"""Oracle Database ROWID type.
"""Oracle ROWID type.
When used in a cast() or similar, generates ROWID.

View File

@@ -1,364 +0,0 @@
# dialects/oracle/vector.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import array
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import Union
import sqlalchemy.types as types
from sqlalchemy.types import Float
class VectorIndexType(Enum):
"""Enum representing different types of VECTOR index structures.
See :ref:`oracle_vector_datatype` for background.
.. versionadded:: 2.0.41
"""
HNSW = "HNSW"
"""
The HNSW (Hierarchical Navigable Small World) index type.
"""
IVF = "IVF"
"""
The IVF (Inverted File Index) index type
"""
class VectorDistanceType(Enum):
"""Enum representing different types of vector distance metrics.
See :ref:`oracle_vector_datatype` for background.
.. versionadded:: 2.0.41
"""
EUCLIDEAN = "EUCLIDEAN"
"""Euclidean distance (L2 norm).
Measures the straight-line distance between two vectors in space.
"""
DOT = "DOT"
"""Dot product similarity.
Measures the algebraic similarity between two vectors.
"""
COSINE = "COSINE"
"""Cosine similarity.
Measures the cosine of the angle between two vectors.
"""
MANHATTAN = "MANHATTAN"
"""Manhattan distance (L1 norm).
Calculates the sum of absolute differences across dimensions.
"""
class VectorStorageFormat(Enum):
"""Enum representing the data format used to store vector components.
See :ref:`oracle_vector_datatype` for background.
.. versionadded:: 2.0.41
"""
INT8 = "INT8"
"""
8-bit integer format.
"""
BINARY = "BINARY"
"""
Binary format.
"""
FLOAT32 = "FLOAT32"
"""
32-bit floating-point format.
"""
FLOAT64 = "FLOAT64"
"""
64-bit floating-point format.
"""
class VectorStorageType(Enum):
"""Enum representing the vector type,
See :ref:`oracle_vector_datatype` for background.
.. versionadded:: 2.0.43
"""
SPARSE = "SPARSE"
"""
A Sparse vector is a vector which has zero value for
most of its dimensions.
"""
DENSE = "DENSE"
"""
A Dense vector is a vector where most, if not all, elements
hold meaningful values.
"""
@dataclass
class VectorIndexConfig:
"""Define the configuration for Oracle VECTOR Index.
See :ref:`oracle_vector_datatype` for background.
.. versionadded:: 2.0.41
:param index_type: Enum value from :class:`.VectorIndexType`
Specifies the indexing method. For HNSW, this must be
:attr:`.VectorIndexType.HNSW`.
:param distance: Enum value from :class:`.VectorDistanceType`
specifies the metric for calculating distance between VECTORS.
:param accuracy: interger. Should be in the range 0 to 100
Specifies the accuracy of the nearest neighbor search during
query execution.
:param parallel: integer. Specifies degree of parallelism.
:param hnsw_neighbors: interger. Should be in the range 0 to
2048. Specifies the number of nearest neighbors considered
during the search. The attribute :attr:`.VectorIndexConfig.hnsw_neighbors`
is HNSW index specific.
:param hnsw_efconstruction: integer. Should be in the range 0
to 65535. Controls the trade-off between indexing speed and
recall quality during index construction. The attribute
:attr:`.VectorIndexConfig.hnsw_efconstruction` is HNSW index
specific.
:param ivf_neighbor_partitions: integer. Should be in the range
0 to 10,000,000. Specifies the number of partitions used to
divide the dataset. The attribute
:attr:`.VectorIndexConfig.ivf_neighbor_partitions` is IVF index
specific.
:param ivf_sample_per_partition: integer. Should be between 1
and ``num_vectors / neighbor partitions``. Specifies the
number of samples used per partition. The attribute
:attr:`.VectorIndexConfig.ivf_sample_per_partition` is IVF index
specific.
:param ivf_min_vectors_per_partition: integer. From 0 (no trimming)
to the total number of vectors (results in 1 partition). Specifies
the minimum number of vectors per partition. The attribute
:attr:`.VectorIndexConfig.ivf_min_vectors_per_partition`
is IVF index specific.
"""
index_type: VectorIndexType = VectorIndexType.HNSW
distance: Optional[VectorDistanceType] = None
accuracy: Optional[int] = None
hnsw_neighbors: Optional[int] = None
hnsw_efconstruction: Optional[int] = None
ivf_neighbor_partitions: Optional[int] = None
ivf_sample_per_partition: Optional[int] = None
ivf_min_vectors_per_partition: Optional[int] = None
parallel: Optional[int] = None
def __post_init__(self):
self.index_type = VectorIndexType(self.index_type)
for field in [
"hnsw_neighbors",
"hnsw_efconstruction",
"ivf_neighbor_partitions",
"ivf_sample_per_partition",
"ivf_min_vectors_per_partition",
"parallel",
"accuracy",
]:
value = getattr(self, field)
if value is not None and not isinstance(value, int):
raise TypeError(
f"{field} must be an integer if"
f"provided, got {type(value).__name__}"
)
class SparseVector:
"""
Lightweight SQLAlchemy-side version of SparseVector.
This mimics oracledb.SparseVector.
.. versionadded:: 2.0.43
"""
def __init__(
self,
num_dimensions: int,
indices: Union[list, array.array],
values: Union[list, array.array],
):
if not isinstance(indices, array.array) or indices.typecode != "I":
indices = array.array("I", indices)
if not isinstance(values, array.array):
values = array.array("d", values)
if len(indices) != len(values):
raise TypeError("indices and values must be of the same length!")
self.num_dimensions = num_dimensions
self.indices = indices
self.values = values
def __str__(self):
return (
f"SparseVector(num_dimensions={self.num_dimensions}, "
f"size={len(self.indices)}, typecode={self.values.typecode})"
)
class VECTOR(types.TypeEngine):
"""Oracle VECTOR datatype.
For complete background on using this type, see
:ref:`oracle_vector_datatype`.
.. versionadded:: 2.0.41
"""
cache_ok = True
__visit_name__ = "VECTOR"
_typecode_map = {
VectorStorageFormat.INT8: "b", # Signed int
VectorStorageFormat.BINARY: "B", # Unsigned int
VectorStorageFormat.FLOAT32: "f", # Float
VectorStorageFormat.FLOAT64: "d", # Double
}
def __init__(self, dim=None, storage_format=None, storage_type=None):
"""Construct a VECTOR.
:param dim: integer. The dimension of the VECTOR datatype. This
should be an integer value.
:param storage_format: VectorStorageFormat. The VECTOR storage
type format. This should be Enum values form
:class:`.VectorStorageFormat` INT8, BINARY, FLOAT32, or FLOAT64.
:param storage_type: VectorStorageType. The Vector storage type. This
should be Enum values from :class:`.VectorStorageType` SPARSE or
DENSE.
"""
if dim is not None and not isinstance(dim, int):
raise TypeError("dim must be an interger")
if storage_format is not None and not isinstance(
storage_format, VectorStorageFormat
):
raise TypeError(
"storage_format must be an enum of type VectorStorageFormat"
)
if storage_type is not None and not isinstance(
storage_type, VectorStorageType
):
raise TypeError(
"storage_type must be an enum of type VectorStorageType"
)
self.dim = dim
self.storage_format = storage_format
self.storage_type = storage_type
def _cached_bind_processor(self, dialect):
"""
Converts a Python-side SparseVector instance into an
oracledb.SparseVectormor a compatible array format before
binding it to the database.
"""
def process(value):
if value is None or isinstance(value, array.array):
return value
# Convert list to a array.array
elif isinstance(value, list):
typecode = self._array_typecode(self.storage_format)
value = array.array(typecode, value)
return value
# Convert SqlAlchemy SparseVector to oracledb SparseVector object
elif isinstance(value, SparseVector):
return dialect.dbapi.SparseVector(
value.num_dimensions,
value.indices,
value.values,
)
else:
raise TypeError(
"""
Invalid input for VECTOR: expected a list, an array.array,
or a SparseVector object.
"""
)
return process
def _cached_result_processor(self, dialect, coltype):
"""
Converts database-returned values into Python-native representations.
If the value is an oracledb.SparseVector, it is converted into the
SQLAlchemy-side SparseVector class.
If the value is a array.array, it is converted to a plain Python list.
"""
def process(value):
if value is None:
return None
elif isinstance(value, array.array):
return list(value)
# Convert Oracledb SparseVector to SqlAlchemy SparseVector object
elif isinstance(value, dialect.dbapi.SparseVector):
return SparseVector(
num_dimensions=value.num_dimensions,
indices=value.indices,
values=value.values,
)
return process
def _array_typecode(self, typecode):
"""
Map storage format to array typecode.
"""
return self._typecode_map.get(typecode, "d")
class comparator_factory(types.TypeEngine.Comparator):
def l2_distance(self, other):
return self.op("<->", return_type=Float)(other)
def inner_product(self, other):
return self.op("<#>", return_type=Float)(other)
def cosine_distance(self, other):
return self.op("<=>", return_type=Float)(other)

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -8,7 +8,6 @@
from types import ModuleType
from . import array as arraylib # noqa # keep above base and other dialects
from . import asyncpg # noqa
from . import base
from . import pg8000 # noqa
@@ -57,14 +56,12 @@ from .named_types import ENUM
from .named_types import NamedType
from .ranges import AbstractMultiRange
from .ranges import AbstractRange
from .ranges import AbstractSingleRange
from .ranges import DATEMULTIRANGE
from .ranges import DATERANGE
from .ranges import INT4MULTIRANGE
from .ranges import INT4RANGE
from .ranges import INT8MULTIRANGE
from .ranges import INT8RANGE
from .ranges import MultiRange
from .ranges import NUMMULTIRANGE
from .ranges import NUMRANGE
from .ranges import Range
@@ -89,7 +86,6 @@ from .types import TIMESTAMP
from .types import TSQUERY
from .types import TSVECTOR
# Alias psycopg also as psycopg_async
psycopg_async = type(
"psycopg_async", (ModuleType,), {"dialect": psycopg.dialect_async}

View File

@@ -1,5 +1,4 @@
# dialects/postgresql/_psycopg_common.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -170,10 +169,8 @@ class _PGDialect_common_psycopg(PGDialect):
def _do_autocommit(self, connection, value):
connection.autocommit = value
def detect_autocommit_setting(self, dbapi_connection):
return bool(dbapi_connection.autocommit)
def do_ping(self, dbapi_connection):
cursor = None
before_autocommit = dbapi_connection.autocommit
if not before_autocommit:

View File

@@ -1,21 +1,18 @@
# dialects/postgresql/array.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/array.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import re
from typing import Any as typing_Any
from typing import Iterable
from typing import Any
from typing import Optional
from typing import Sequence
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .operators import CONTAINED_BY
from .operators import CONTAINS
@@ -24,55 +21,32 @@ from ... import types as sqltypes
from ... import util
from ...sql import expression
from ...sql import operators
from ...sql.visitors import InternalTraversal
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql._typing import _ColumnExpressionArgument
from ...sql._typing import _TypeEngineArgument
from ...sql.elements import ColumnElement
from ...sql.elements import Grouping
from ...sql.expression import BindParameter
from ...sql.operators import OperatorType
from ...sql.selectable import _SelectIterable
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _LiteralProcessorType
from ...sql.type_api import _ResultProcessorType
from ...sql.type_api import TypeEngine
from ...sql.visitors import _TraverseInternalsType
from ...util.typing import Self
from ...sql._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=typing_Any)
_T = TypeVar("_T", bound=Any)
def Any(
other: typing_Any,
arrexpr: _ColumnExpressionArgument[_T],
operator: OperatorType = operators.eq,
) -> ColumnElement[bool]:
def Any(other, arrexpr, operator=operators.eq):
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
See that method for details.
"""
return arrexpr.any(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
return arrexpr.any(other, operator)
def All(
other: typing_Any,
arrexpr: _ColumnExpressionArgument[_T],
operator: OperatorType = operators.eq,
) -> ColumnElement[bool]:
def All(other, arrexpr, operator=operators.eq):
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
See that method for details.
"""
return arrexpr.all(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
return arrexpr.all(other, operator)
class array(expression.ExpressionClauseList[_T]):
"""A PostgreSQL ARRAY literal.
This is used to produce ARRAY literals in SQL expressions, e.g.::
@@ -81,43 +55,20 @@ class array(expression.ExpressionClauseList[_T]):
from sqlalchemy.dialects import postgresql
from sqlalchemy import select, func
stmt = select(array([1, 2]) + array([3, 4, 5]))
stmt = select(array([1,2]) + array([3,4,5]))
print(stmt.compile(dialect=postgresql.dialect()))
Produces the SQL:
.. sourcecode:: sql
Produces the SQL::
SELECT ARRAY[%(param_1)s, %(param_2)s] ||
ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
An instance of :class:`.array` will always have the datatype
:class:`_types.ARRAY`. The "inner" type of the array is inferred from the
values present, unless the :paramref:`_postgresql.array.type_` keyword
argument is passed::
:class:`_types.ARRAY`. The "inner" type of the array is inferred from
the values present, unless the ``type_`` keyword argument is passed::
array(["foo", "bar"], type_=CHAR)
When constructing an empty array, the :paramref:`_postgresql.array.type_`
argument is particularly important as PostgreSQL server typically requires
a cast to be rendered for the inner type in order to render an empty array.
SQLAlchemy's compilation for the empty array will produce this cast so
that::
stmt = array([], type_=Integer)
print(stmt.compile(dialect=postgresql.dialect()))
Produces:
.. sourcecode:: sql
ARRAY[]::INTEGER[]
As required by PostgreSQL for empty arrays.
.. versionadded:: 2.0.40 added support to render empty PostgreSQL array
literals with a required cast.
array(['foo', 'bar'], type_=CHAR)
Multidimensional arrays are produced by nesting :class:`.array` constructs.
The dimensionality of the final :class:`_types.ARRAY`
@@ -126,21 +77,16 @@ class array(expression.ExpressionClauseList[_T]):
type::
stmt = select(
array(
[array([1, 2]), array([3, 4]), array([column("q"), column("x")])]
)
array([
array([1, 2]), array([3, 4]), array([column('q'), column('x')])
])
)
print(stmt.compile(dialect=postgresql.dialect()))
Produces:
Produces::
.. sourcecode:: sql
SELECT ARRAY[
ARRAY[%(param_1)s, %(param_2)s],
ARRAY[%(param_3)s, %(param_4)s],
ARRAY[q, x]
] AS anon_1
SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s],
ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1
.. versionadded:: 1.3.6 added support for multidimensional array literals
@@ -148,63 +94,42 @@ class array(expression.ExpressionClauseList[_T]):
:class:`_postgresql.ARRAY`
""" # noqa: E501
"""
__visit_name__ = "array"
stringify_dialect = "postgresql"
inherit_cache = True
_traverse_internals: _TraverseInternalsType = [
("clauses", InternalTraversal.dp_clauseelement_tuple),
("type", InternalTraversal.dp_type),
]
def __init__(
self,
clauses: Iterable[_T],
*,
type_: Optional[_TypeEngineArgument[_T]] = None,
**kw: typing_Any,
):
r"""Construct an ARRAY literal.
:param clauses: iterable, such as a list, containing elements to be
rendered in the array
:param type\_: optional type. If omitted, the type is inferred
from the contents of the array.
"""
def __init__(self, clauses, **kw):
type_arg = kw.pop("type_", None)
super().__init__(operators.comma_op, *clauses, **kw)
self._type_tuple = [arg.type for arg in self.clauses]
main_type = (
type_
if type_ is not None
else self.clauses[0].type if self.clauses else sqltypes.NULLTYPE
type_arg
if type_arg is not None
else self._type_tuple[0]
if self._type_tuple
else sqltypes.NULLTYPE
)
if isinstance(main_type, ARRAY):
self.type = ARRAY(
main_type.item_type,
dimensions=(
main_type.dimensions + 1
if main_type.dimensions is not None
else 2
),
) # type: ignore[assignment]
dimensions=main_type.dimensions + 1
if main_type.dimensions is not None
else 2,
)
else:
self.type = ARRAY(main_type) # type: ignore[assignment]
self.type = ARRAY(main_type)
@property
def _select_iterable(self) -> _SelectIterable:
def _select_iterable(self):
return (self,)
def _bind_param(
self,
operator: OperatorType,
obj: typing_Any,
type_: Optional[TypeEngine[_T]] = None,
_assume_scalar: bool = False,
) -> BindParameter[_T]:
def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
if _assume_scalar or operator is operators.getitem:
return expression.BindParameter(
None,
@@ -223,18 +148,16 @@ class array(expression.ExpressionClauseList[_T]):
)
for o in obj
]
) # type: ignore[return-value]
)
def self_group(
self, against: Optional[OperatorType] = None
) -> Union[Self, Grouping[_T]]:
def self_group(self, against=None):
if against in (operators.any_op, operators.all_op, operators.getitem):
return expression.Grouping(self)
else:
return self
class ARRAY(sqltypes.ARRAY[_T]):
class ARRAY(sqltypes.ARRAY):
"""PostgreSQL ARRAY type.
The :class:`_postgresql.ARRAY` type is constructed in the same way
@@ -244,11 +167,9 @@ class ARRAY(sqltypes.ARRAY[_T]):
from sqlalchemy.dialects import postgresql
mytable = Table(
"mytable",
metadata,
Column("data", postgresql.ARRAY(Integer, dimensions=2)),
)
mytable = Table("mytable", metadata,
Column("data", postgresql.ARRAY(Integer, dimensions=2))
)
The :class:`_postgresql.ARRAY` type provides all operations defined on the
core :class:`_types.ARRAY` type, including support for "dimensions",
@@ -263,9 +184,8 @@ class ARRAY(sqltypes.ARRAY[_T]):
mytable.c.data.contains([1, 2])
Indexed access is one-based by default, to match that of PostgreSQL;
for zero-based indexed access, set
:paramref:`_postgresql.ARRAY.zero_indexes`.
The :class:`_postgresql.ARRAY` type may not be supported on all
PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
Additionally, the :class:`_postgresql.ARRAY`
type does not work directly in
@@ -284,7 +204,6 @@ class ARRAY(sqltypes.ARRAY[_T]):
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.ext.mutable import MutableList
class SomeOrmClass(Base):
# ...
@@ -306,9 +225,45 @@ class ARRAY(sqltypes.ARRAY[_T]):
"""
class Comparator(sqltypes.ARRAY.Comparator):
"""Define comparison operations for :class:`_types.ARRAY`.
Note that these operations are in addition to those provided
by the base :class:`.types.ARRAY.Comparator` class, including
:meth:`.types.ARRAY.Comparator.any` and
:meth:`.types.ARRAY.Comparator.all`.
"""
def contains(self, other, **kwargs):
"""Boolean expression. Test if elements are a superset of the
elements of the argument array expression.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if elements are a proper subset of the
elements of the argument array expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def overlap(self, other):
"""Boolean expression. Test if array has elements in common with
an argument array expression.
"""
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
comparator_factory = Comparator
def __init__(
self,
item_type: _TypeEngineArgument[_T],
item_type: _TypeEngineArgument[Any],
as_tuple: bool = False,
dimensions: Optional[int] = None,
zero_indexes: bool = False,
@@ -317,7 +272,7 @@ class ARRAY(sqltypes.ARRAY[_T]):
E.g.::
Column("myarray", ARRAY(Integer))
Column('myarray', ARRAY(Integer))
Arguments are:
@@ -357,63 +312,35 @@ class ARRAY(sqltypes.ARRAY[_T]):
self.dimensions = dimensions
self.zero_indexes = zero_indexes
class Comparator(sqltypes.ARRAY.Comparator[_T]):
"""Define comparison operations for :class:`_types.ARRAY`.
@property
def hashable(self):
return self.as_tuple
Note that these operations are in addition to those provided
by the base :class:`.types.ARRAY.Comparator` class, including
:meth:`.types.ARRAY.Comparator.any` and
:meth:`.types.ARRAY.Comparator.all`.
@property
def python_type(self):
return list
"""
def contains(
self, other: typing_Any, **kwargs: typing_Any
) -> ColumnElement[bool]:
"""Boolean expression. Test if elements are a superset of the
elements of the argument array expression.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other: typing_Any) -> ColumnElement[bool]:
"""Boolean expression. Test if elements are a proper subset of the
elements of the argument array expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def overlap(self, other: typing_Any) -> ColumnElement[bool]:
"""Boolean expression. Test if array has elements in common with
an argument array expression.
"""
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
comparator_factory = Comparator
def compare_values(self, x, y):
return x == y
@util.memoized_property
def _against_native_enum(self) -> bool:
def _against_native_enum(self):
return (
isinstance(self.item_type, sqltypes.Enum)
and self.item_type.native_enum
)
def literal_processor(
self, dialect: Dialect
) -> Optional[_LiteralProcessorType[_T]]:
def literal_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).literal_processor(
dialect
)
if item_proc is None:
return None
def to_str(elements: Iterable[typing_Any]) -> str:
def to_str(elements):
return f"ARRAY[{', '.join(elements)}]"
def process(value: Sequence[typing_Any]) -> str:
def process(value):
inner = self._apply_item_processor(
value, item_proc, self.dimensions, to_str
)
@@ -421,16 +348,12 @@ class ARRAY(sqltypes.ARRAY[_T]):
return process
def bind_processor(
self, dialect: Dialect
) -> Optional[_BindProcessorType[Sequence[typing_Any]]]:
def bind_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
dialect
)
def process(
value: Optional[Sequence[typing_Any]],
) -> Optional[list[typing_Any]]:
def process(value):
if value is None:
return value
else:
@@ -440,16 +363,12 @@ class ARRAY(sqltypes.ARRAY[_T]):
return process
def result_processor(
self, dialect: Dialect, coltype: object
) -> _ResultProcessorType[Sequence[typing_Any]]:
def result_processor(self, dialect, coltype):
item_proc = self.item_type.dialect_impl(dialect).result_processor(
dialect, coltype
)
def process(
value: Sequence[typing_Any],
) -> Optional[Sequence[typing_Any]]:
def process(value):
if value is None:
return value
else:
@@ -464,13 +383,11 @@ class ARRAY(sqltypes.ARRAY[_T]):
super_rp = process
pattern = re.compile(r"^{(.*)}$")
def handle_raw_string(value: str) -> list[str]:
inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501
def handle_raw_string(value):
inner = pattern.match(value).group(1)
return _split_enum_values(inner)
def process(
value: Sequence[typing_Any],
) -> Optional[Sequence[typing_Any]]:
def process(value):
if value is None:
return value
# isinstance(value, str) is required to handle
@@ -485,7 +402,7 @@ class ARRAY(sqltypes.ARRAY[_T]):
return process
def _split_enum_values(array_string: str) -> list[str]:
def _split_enum_values(array_string):
if '"' not in array_string:
# no escape char is present so it can just split on the comma
return array_string.split(",") if array_string else []

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/asyncpg.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
# postgresql/asyncpg.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
@@ -23,10 +23,18 @@ This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
The dialect can also be run as a "synchronous" dialect within the
:func:`_sa.create_engine` function, which will pass "await" calls into
an ad-hoc event loop. This mode of operation is of **limited use**
and is for special testing scenarios only. The mode can be enabled by
adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
in conjunction with :func:`_sa.create_engine`::
# for testing purposes only; do not use in production!
engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
engine = create_async_engine(
"postgresql+asyncpg://user:pass@hostname/dbname"
)
.. versionadded:: 1.4
@@ -81,15 +89,11 @@ asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
argument)::
engine = create_async_engine(
"postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500"
)
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500")
To disable the prepared statement cache, use a value of zero::
engine = create_async_engine(
"postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0"
)
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0")
.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
@@ -119,8 +123,8 @@ To disable the prepared statement cache, use a value of zero::
.. _asyncpg_prepared_statement_name:
Prepared Statement Name with PGBouncer
--------------------------------------
Prepared Statement Name
-----------------------
By default, asyncpg enumerates prepared statements in numeric order, which
can lead to errors if a name has already been taken for another prepared
@@ -135,10 +139,10 @@ a prepared statement is prepared::
from uuid import uuid4
engine = create_async_engine(
"postgresql+asyncpg://user:pass@somepgbouncer/dbname",
"postgresql+asyncpg://user:pass@hostname/dbname",
poolclass=NullPool,
connect_args={
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__",
'prepared_statement_name_func': lambda: f'__asyncpg_{uuid4()}__',
},
)
@@ -148,7 +152,7 @@ a prepared statement is prepared::
https://github.com/sqlalchemy/sqlalchemy/issues/6467
.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in
.. warning:: To prevent a buildup of useless prepared statements in
your application, it's important to use the :class:`.NullPool` pool
class, and to configure PgBouncer to use `DISCARD <https://www.postgresql.org/docs/current/sql-discard.html>`_
when returning connections. The DISCARD command is used to release resources held by the db connection,
@@ -178,11 +182,13 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
from __future__ import annotations
from collections import deque
import collections
import decimal
import json as _py_json
import re
import time
from typing import cast
from typing import TYPE_CHECKING
from . import json
from . import ranges
@@ -212,6 +218,9 @@ from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
if TYPE_CHECKING:
from typing import Iterable
class AsyncpgARRAY(PGARRAY):
render_bind_cast = True
@@ -265,20 +274,20 @@ class AsyncpgInteger(sqltypes.Integer):
render_bind_cast = True
class AsyncpgSmallInteger(sqltypes.SmallInteger):
render_bind_cast = True
class AsyncpgBigInteger(sqltypes.BigInteger):
render_bind_cast = True
class AsyncpgJSON(json.JSON):
render_bind_cast = True
def result_processor(self, dialect, coltype):
return None
class AsyncpgJSONB(json.JSONB):
render_bind_cast = True
def result_processor(self, dialect, coltype):
return None
@@ -363,7 +372,7 @@ class AsyncpgCHAR(sqltypes.CHAR):
render_bind_cast = True
class _AsyncpgRange(ranges.AbstractSingleRangeImpl):
class _AsyncpgRange(ranges.AbstractRangeImpl):
def bind_processor(self, dialect):
asyncpg_Range = dialect.dbapi.asyncpg.Range
@@ -417,7 +426,10 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
)
return value
return [to_range(element) for element in value]
return [
to_range(element)
for element in cast("Iterable[ranges.Range]", value)
]
return to_range
@@ -436,7 +448,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
return rvalue
if value is not None:
value = ranges.MultiRange(to_range(elem) for elem in value)
value = [to_range(elem) for elem in value]
return value
@@ -494,7 +506,7 @@ class AsyncAdapt_asyncpg_cursor:
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self._rows = deque()
self._rows = []
self._cursor = None
self.description = None
self.arraysize = 1
@@ -502,7 +514,7 @@ class AsyncAdapt_asyncpg_cursor:
self._invalidate_schema_cache_asof = 0
def close(self):
self._rows.clear()
self._rows[:] = []
def _handle_exception(self, error):
self._adapt_connection._handle_exception(error)
@@ -542,12 +554,11 @@ class AsyncAdapt_asyncpg_cursor:
self._cursor = await prepared_stmt.cursor(*parameters)
self.rowcount = -1
else:
self._rows = deque(await prepared_stmt.fetch(*parameters))
self._rows = await prepared_stmt.fetch(*parameters)
status = prepared_stmt.get_statusmsg()
reg = re.match(
r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)",
status or "",
r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status
)
if reg:
self.rowcount = int(reg.group(1))
@@ -591,11 +602,11 @@ class AsyncAdapt_asyncpg_cursor:
def __iter__(self):
while self._rows:
yield self._rows.popleft()
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.popleft()
return self._rows.pop(0)
else:
return None
@@ -603,12 +614,13 @@ class AsyncAdapt_asyncpg_cursor:
if size is None:
size = self.arraysize
rr = self._rows
return [rr.popleft() for _ in range(min(size, len(rr)))]
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = list(self._rows)
self._rows.clear()
retval = self._rows[:]
self._rows[:] = []
return retval
@@ -618,21 +630,23 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
def __init__(self, adapt_connection):
super().__init__(adapt_connection)
self._rowbuffer = deque()
self._rowbuffer = None
def close(self):
self._cursor = None
self._rowbuffer.clear()
self._rowbuffer = None
def _buffer_rows(self):
assert self._cursor is not None
new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
self._rowbuffer.extend(new_rows)
self._rowbuffer = collections.deque(new_rows)
def __aiter__(self):
return self
async def __anext__(self):
if not self._rowbuffer:
self._buffer_rows()
while True:
while self._rowbuffer:
yield self._rowbuffer.popleft()
@@ -655,19 +669,21 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
if not self._rowbuffer:
self._buffer_rows()
assert self._cursor is not None
rb = self._rowbuffer
lb = len(rb)
buf = list(self._rowbuffer)
lb = len(buf)
if size > lb:
rb.extend(
buf.extend(
self._adapt_connection.await_(self._cursor.fetch(size - lb))
)
return [rb.popleft() for _ in range(min(size, len(rb)))]
result = buf[0:size]
self._rowbuffer = collections.deque(buf[size:])
return result
def fetchall(self):
ret = list(self._rowbuffer)
ret.extend(self._adapt_connection.await_(self._all()))
ret = list(self._rowbuffer) + list(
self._adapt_connection.await_(self._all())
)
self._rowbuffer.clear()
return ret
@@ -717,7 +733,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
):
self.dbapi = dbapi
self._connection = connection
self.isolation_level = self._isolation_setting = None
self.isolation_level = self._isolation_setting = "read_committed"
self.readonly = False
self.deferrable = False
self._transaction = None
@@ -786,9 +802,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
translated_error = exception_mapping[super_](
"%s: %s" % (type(error), error)
)
translated_error.pgcode = translated_error.sqlstate = (
getattr(error, "sqlstate", None)
)
translated_error.pgcode = (
translated_error.sqlstate
) = getattr(error, "sqlstate", None)
raise translated_error from error
else:
raise error
@@ -852,45 +868,25 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
else:
return AsyncAdapt_asyncpg_cursor(self)
async def _rollback_and_discard(self):
try:
await self._transaction.rollback()
finally:
# if asyncpg .rollback() was actually called, then whether or
# not it raised or succeeded, the transation is done, discard it
self._transaction = None
self._started = False
async def _commit_and_discard(self):
try:
await self._transaction.commit()
finally:
# if asyncpg .commit() was actually called, then whether or
# not it raised or succeeded, the transation is done, discard it
self._transaction = None
self._started = False
def rollback(self):
if self._started:
try:
self.await_(self._rollback_and_discard())
self.await_(self._transaction.rollback())
except Exception as error:
self._handle_exception(error)
finally:
self._transaction = None
self._started = False
except Exception as error:
# don't dereference asyncpg transaction if we didn't
# actually try to call rollback() on it
self._handle_exception(error)
def commit(self):
if self._started:
try:
self.await_(self._commit_and_discard())
self.await_(self._transaction.commit())
except Exception as error:
self._handle_exception(error)
finally:
self._transaction = None
self._started = False
except Exception as error:
# don't dereference asyncpg transaction if we didn't
# actually try to call commit() on it
self._handle_exception(error)
def close(self):
self.rollback()
@@ -898,31 +894,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
self.await_(self._connection.close())
def terminate(self):
if util.concurrency.in_greenlet():
# in a greenlet; this is the connection was invalidated
# case.
try:
# try to gracefully close; see #10717
# timeout added in asyncpg 0.14.0 December 2017
self.await_(asyncio.shield(self._connection.close(timeout=2)))
except (
asyncio.TimeoutError,
asyncio.CancelledError,
OSError,
self.dbapi.asyncpg.PostgresError,
) as e:
# in the case where we are recycling an old connection
# that may have already been disconnected, close() will
# fail with the above timeout. in this case, terminate
# the connection without any further waiting.
# see issue #8419
self._connection.terminate()
if isinstance(e, asyncio.CancelledError):
# re-raise CancelledError if we were cancelled
raise
else:
# not in a greenlet; this is the gc cleanup case
self._connection.terminate()
self._connection.terminate()
self._started = False
@staticmethod
@@ -1059,7 +1031,6 @@ class PGDialect_asyncpg(PGDialect):
INTERVAL: AsyncPgInterval,
sqltypes.Boolean: AsyncpgBoolean,
sqltypes.Integer: AsyncpgInteger,
sqltypes.SmallInteger: AsyncpgSmallInteger,
sqltypes.BigInteger: AsyncpgBigInteger,
sqltypes.Numeric: AsyncpgNumeric,
sqltypes.Float: AsyncpgFloat,
@@ -1074,7 +1045,7 @@ class PGDialect_asyncpg(PGDialect):
OID: AsyncpgOID,
REGCLASS: AsyncpgREGCLASS,
sqltypes.CHAR: AsyncpgCHAR,
ranges.AbstractSingleRange: _AsyncpgRange,
ranges.AbstractRange: _AsyncpgRange,
ranges.AbstractMultiRange: _AsyncpgMultiRange,
},
)
@@ -1117,9 +1088,6 @@ class PGDialect_asyncpg(PGDialect):
def set_isolation_level(self, dbapi_connection, level):
dbapi_connection.set_isolation_level(self._isolation_lookup[level])
def detect_autocommit_setting(self, dbapi_conn) -> bool:
return bool(dbapi_conn.autocommit)
def set_readonly(self, connection, value):
connection.readonly = value

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/dml.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/dml.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -7,10 +7,7 @@
from __future__ import annotations
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from . import ext
from .._typing import _OnConflictConstraintT
@@ -29,9 +26,7 @@ from ...sql.base import ColumnCollection
from ...sql.base import ReadOnlyColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.elements import ColumnElement
from ...sql.elements import KeyedColumnElement
from ...sql.elements import TextClause
from ...sql.expression import alias
from ...util.typing import Self
@@ -158,10 +153,11 @@ class Insert(StandardInsert):
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
:param where:
Optional argument. An expression object representing a ``WHERE``
clause that restricts the rows affected by ``DO UPDATE SET``. Rows not
meeting the ``WHERE`` condition will not be updated (effectively a
``DO NOTHING`` for those rows).
Optional argument. If present, can be a literal SQL
string or an acceptable expression for a ``WHERE`` clause
that restricts the rows affected by ``DO UPDATE SET``. Rows
not meeting the ``WHERE`` condition will not be updated
(effectively a ``DO NOTHING`` for those rows).
.. seealso::
@@ -216,10 +212,8 @@ class OnConflictClause(ClauseElement):
stringify_dialect = "postgresql"
constraint_target: Optional[str]
inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]]
inferred_target_whereclause: Optional[
Union[ColumnElement[Any], TextClause]
]
inferred_target_elements: _OnConflictIndexElementsT
inferred_target_whereclause: _OnConflictIndexWhereT
def __init__(
self,
@@ -260,28 +254,12 @@ class OnConflictClause(ClauseElement):
if index_elements is not None:
self.constraint_target = None
self.inferred_target_elements = [
coercions.expect(roles.DDLConstraintColumnRole, column)
for column in index_elements
]
self.inferred_target_whereclause = (
coercions.expect(
(
roles.StatementOptionRole
if isinstance(constraint, ext.ExcludeConstraint)
else roles.WhereHavingRole
),
index_where,
)
if index_where is not None
else None
)
self.inferred_target_elements = index_elements
self.inferred_target_whereclause = index_where
elif constraint is None:
self.constraint_target = self.inferred_target_elements = (
self.inferred_target_whereclause
) = None
self.constraint_target = (
self.inferred_target_elements
) = self.inferred_target_whereclause = None
class OnConflictDoNothing(OnConflictClause):
@@ -291,9 +269,6 @@ class OnConflictDoNothing(OnConflictClause):
class OnConflictDoUpdate(OnConflictClause):
__visit_name__ = "on_conflict_do_update"
update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]]
update_whereclause: Optional[ColumnElement[Any]]
def __init__(
self,
constraint: _OnConflictConstraintT = None,
@@ -332,8 +307,4 @@ class OnConflictDoUpdate(OnConflictClause):
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
]
self.update_whereclause = (
coercions.expect(roles.WhereHavingRole, where)
if where is not None
else None
)
self.update_whereclause = where

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/ext.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/ext.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -8,10 +8,6 @@
from __future__ import annotations
from typing import Any
from typing import Iterable
from typing import List
from typing import Optional
from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -27,44 +23,34 @@ from ...sql.schema import ColumnCollectionConstraint
from ...sql.sqltypes import TEXT
from ...sql.visitors import InternalTraversal
if TYPE_CHECKING:
from ...sql._typing import _ColumnExpressionArgument
from ...sql.elements import ClauseElement
from ...sql.elements import ColumnElement
from ...sql.operators import OperatorType
from ...sql.selectable import FromClause
from ...sql.visitors import _CloneCallableType
from ...sql.visitors import _TraverseInternalsType
_T = TypeVar("_T", bound=Any)
if TYPE_CHECKING:
from ...sql.visitors import _TraverseInternalsType
class aggregate_order_by(expression.ColumnElement[_T]):
class aggregate_order_by(expression.ColumnElement):
"""Represent a PostgreSQL aggregate order by expression.
E.g.::
from sqlalchemy.dialects.postgresql import aggregate_order_by
expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
stmt = select(expr)
would represent the expression:
.. sourcecode:: sql
would represent the expression::
SELECT array_agg(a ORDER BY b DESC) FROM table;
Similarly::
expr = func.string_agg(
table.c.a, aggregate_order_by(literal_column("','"), table.c.a)
table.c.a,
aggregate_order_by(literal_column("','"), table.c.a)
)
stmt = select(expr)
Would represent:
.. sourcecode:: sql
Would represent::
SELECT string_agg(a, ',' ORDER BY a) FROM table;
@@ -85,32 +71,11 @@ class aggregate_order_by(expression.ColumnElement[_T]):
("order_by", InternalTraversal.dp_clauseelement),
]
@overload
def __init__(
self,
target: ColumnElement[_T],
*order_by: _ColumnExpressionArgument[Any],
): ...
@overload
def __init__(
self,
target: _ColumnExpressionArgument[_T],
*order_by: _ColumnExpressionArgument[Any],
): ...
def __init__(
self,
target: _ColumnExpressionArgument[_T],
*order_by: _ColumnExpressionArgument[Any],
):
self.target: ClauseElement = coercions.expect(
roles.ExpressionElementRole, target
)
def __init__(self, target, *order_by):
self.target = coercions.expect(roles.ExpressionElementRole, target)
self.type = self.target.type
_lob = len(order_by)
self.order_by: ClauseElement
if _lob == 0:
raise TypeError("at least one ORDER BY element is required")
elif _lob == 1:
@@ -122,22 +87,18 @@ class aggregate_order_by(expression.ColumnElement[_T]):
*order_by, _literal_as_text_role=roles.ExpressionElementRole
)
def self_group(
self, against: Optional[OperatorType] = None
) -> ClauseElement:
def self_group(self, against=None):
return self
def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]:
def get_children(self, **kwargs):
return self.target, self.order_by
def _copy_internals(
self, clone: _CloneCallableType = elements._clone, **kw: Any
) -> None:
def _copy_internals(self, clone=elements._clone, **kw):
self.target = clone(self.target, **kw)
self.order_by = clone(self.order_by, **kw)
@property
def _from_objects(self) -> List[FromClause]:
def _from_objects(self):
return self.target._from_objects + self.order_by._from_objects
@@ -170,10 +131,10 @@ class ExcludeConstraint(ColumnCollectionConstraint):
E.g.::
const = ExcludeConstraint(
(Column("period"), "&&"),
(Column("group"), "="),
where=(Column("group") != "some group"),
ops={"group": "my_operator_class"},
(Column('period'), '&&'),
(Column('group'), '='),
where=(Column('group') != 'some group'),
ops={'group': 'my_operator_class'}
)
The constraint is normally embedded into the :class:`_schema.Table`
@@ -181,20 +142,19 @@ class ExcludeConstraint(ColumnCollectionConstraint):
directly, or added later using :meth:`.append_constraint`::
some_table = Table(
"some_table",
metadata,
Column("id", Integer, primary_key=True),
Column("period", TSRANGE()),
Column("group", String),
'some_table', metadata,
Column('id', Integer, primary_key=True),
Column('period', TSRANGE()),
Column('group', String)
)
some_table.append_constraint(
ExcludeConstraint(
(some_table.c.period, "&&"),
(some_table.c.group, "="),
where=some_table.c.group != "some group",
name="some_table_excl_const",
ops={"group": "my_operator_class"},
(some_table.c.period, '&&'),
(some_table.c.group, '='),
where=some_table.c.group != 'some group',
name='some_table_excl_const',
ops={'group': 'my_operator_class'}
)
)

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/hstore.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/hstore.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -28,29 +28,28 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
data_table = Table(
"data_table",
metadata,
Column("id", Integer, primary_key=True),
Column("data", HSTORE),
data_table = Table('data_table', metadata,
Column('id', Integer, primary_key=True),
Column('data', HSTORE)
)
with engine.connect() as conn:
conn.execute(
data_table.insert(), data={"key1": "value1", "key2": "value2"}
data_table.insert(),
data = {"key1": "value1", "key2": "value2"}
)
:class:`.HSTORE` provides for a wide range of operations, including:
* Index operations::
data_table.c.data["some key"] == "some value"
data_table.c.data['some key'] == 'some value'
* Containment operations::
data_table.c.data.has_key("some key")
data_table.c.data.has_key('some key')
data_table.c.data.has_all(["one", "two", "three"])
data_table.c.data.has_all(['one', 'two', 'three'])
* Concatenation::
@@ -73,19 +72,17 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
from sqlalchemy.ext.mutable import MutableDict
class MyClass(Base):
__tablename__ = "data_table"
__tablename__ = 'data_table'
id = Column(Integer, primary_key=True)
data = Column(MutableDict.as_mutable(HSTORE))
my_object = session.query(MyClass).one()
# in-place mutation, requires Mutable extension
# in order for the ORM to detect
my_object.data["some_key"] = "some value"
my_object.data['some_key'] = 'some value'
session.commit()
@@ -99,7 +96,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
:class:`.hstore` - render the PostgreSQL ``hstore()`` function.
""" # noqa: E501
"""
__visit_name__ = "HSTORE"
hashable = False
@@ -195,9 +192,6 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
comparator_factory = Comparator
def bind_processor(self, dialect):
# note that dialect-specific types like that of psycopg and
# psycopg2 will override this method to allow driver-level conversion
# instead, see _PsycopgHStore
def process(value):
if isinstance(value, dict):
return _serialize_hstore(value)
@@ -207,9 +201,6 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
return process
def result_processor(self, dialect, coltype):
# note that dialect-specific types like that of psycopg and
# psycopg2 will override this method to allow driver-level conversion
# instead, see _PsycopgHStore
def process(value):
if value is not None:
return _parse_hstore(value)
@@ -230,12 +221,12 @@ class hstore(sqlfunc.GenericFunction):
from sqlalchemy.dialects.postgresql import array, hstore
select(hstore("key1", "value1"))
select(hstore('key1', 'value1'))
select(
hstore(
array(["key1", "key2", "key3"]),
array(["value1", "value2", "value3"]),
array(['key1', 'key2', 'key3']),
array(['value1', 'value2', 'value3'])
)
)

View File

@@ -1,18 +1,11 @@
# dialects/postgresql/json.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/json.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from .array import ARRAY
from .array import array as _pg_array
@@ -28,23 +21,13 @@ from .operators import PATH_EXISTS
from .operators import PATH_MATCH
from ... import types as sqltypes
from ...sql import cast
from ...sql._typing import _T
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.elements import ColumnElement
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _LiteralProcessorType
from ...sql.type_api import TypeEngine
__all__ = ("JSON", "JSONB")
class JSONPathType(sqltypes.JSON.JSONPathType):
def _processor(
self, dialect: Dialect, super_proc: Optional[Callable[[Any], Any]]
) -> Callable[[Any], Any]:
def process(value: Any) -> Any:
def _processor(self, dialect, super_proc):
def process(value):
if isinstance(value, str):
# If it's already a string assume that it's in json path
# format. This allows using cast with json paths literals
@@ -61,13 +44,11 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
return process
def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
return self._processor(dialect, self.string_bind_processor(dialect)) # type: ignore[return-value] # noqa: E501
def bind_processor(self, dialect):
return self._processor(dialect, self.string_bind_processor(dialect))
def literal_processor(
self, dialect: Dialect
) -> _LiteralProcessorType[Any]:
return self._processor(dialect, self.string_literal_processor(dialect)) # type: ignore[return-value] # noqa: E501
def literal_processor(self, dialect):
return self._processor(dialect, self.string_literal_processor(dialect))
class JSONPATH(JSONPathType):
@@ -109,14 +90,14 @@ class JSON(sqltypes.JSON):
* Index operations (the ``->`` operator)::
data_table.c.data["some key"]
data_table.c.data['some key']
data_table.c.data[5]
* Index operations returning text
(the ``->>`` operator)::
data_table.c.data["some key"].astext == "some value"
* Index operations returning text (the ``->>`` operator)::
data_table.c.data['some key'].astext == 'some value'
Note that equivalent functionality is available via the
:attr:`.JSON.Comparator.as_string` accessor.
@@ -124,20 +105,18 @@ class JSON(sqltypes.JSON):
* Index operations with CAST
(equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
data_table.c.data["some key"].astext.cast(Integer) == 5
data_table.c.data['some key'].astext.cast(Integer) == 5
Note that equivalent functionality is available via the
:attr:`.JSON.Comparator.as_integer` and similar accessors.
* Path index operations (the ``#>`` operator)::
data_table.c.data[("key_1", "key_2", 5, ..., "key_n")]
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
* Path index operations returning text (the ``#>>`` operator)::
data_table.c.data[
("key_1", "key_2", 5, ..., "key_n")
].astext == "some value"
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value'
Index operations return an expression object whose type defaults to
:class:`_types.JSON` by default,
@@ -149,11 +128,10 @@ class JSON(sqltypes.JSON):
using psycopg2, the DBAPI only allows serializers at the per-cursor
or per-connection level. E.g.::
engine = create_engine(
"postgresql+psycopg2://scott:tiger@localhost/test",
json_serializer=my_serialize_fn,
json_deserializer=my_deserialize_fn,
)
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
json_serializer=my_serialize_fn,
json_deserializer=my_deserialize_fn
)
When using the psycopg2 dialect, the json_deserializer is registered
against the database using ``psycopg2.extras.register_default_json``.
@@ -166,14 +144,9 @@ class JSON(sqltypes.JSON):
""" # noqa
render_bind_cast = True
astext_type: TypeEngine[str] = sqltypes.Text()
astext_type = sqltypes.Text()
def __init__(
self,
none_as_null: bool = False,
astext_type: Optional[TypeEngine[str]] = None,
):
def __init__(self, none_as_null=False, astext_type=None):
"""Construct a :class:`_types.JSON` type.
:param none_as_null: if True, persist the value ``None`` as a
@@ -182,8 +155,7 @@ class JSON(sqltypes.JSON):
be used to persist a NULL value::
from sqlalchemy import null
conn.execute(table.insert(), {"data": null()})
conn.execute(table.insert(), data=null())
.. seealso::
@@ -198,19 +170,17 @@ class JSON(sqltypes.JSON):
if astext_type is not None:
self.astext_type = astext_type
class Comparator(sqltypes.JSON.Comparator[_T]):
class Comparator(sqltypes.JSON.Comparator):
"""Define comparison operations for :class:`_types.JSON`."""
type: JSON
@property
def astext(self) -> ColumnElement[str]:
def astext(self):
"""On an indexed expression, use the "astext" (e.g. "->>")
conversion when rendered in SQL.
E.g.::
select(data_table.c.data["some key"].astext)
select(data_table.c.data['some key'].astext)
.. seealso::
@@ -218,13 +188,13 @@ class JSON(sqltypes.JSON):
"""
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
return self.expr.left.operate( # type: ignore[no-any-return]
return self.expr.left.operate(
JSONPATH_ASTEXT,
self.expr.right,
result_type=self.type.astext_type,
)
else:
return self.expr.left.operate( # type: ignore[no-any-return]
return self.expr.left.operate(
ASTEXT, self.expr.right, result_type=self.type.astext_type
)
@@ -237,16 +207,15 @@ class JSONB(JSON):
The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
e.g.::
data_table = Table(
"data_table",
metadata,
Column("id", Integer, primary_key=True),
Column("data", JSONB),
data_table = Table('data_table', metadata,
Column('id', Integer, primary_key=True),
Column('data', JSONB)
)
with engine.connect() as conn:
conn.execute(
data_table.insert(), data={"key1": "value1", "key2": "value2"}
data_table.insert(),
data = {"key1": "value1", "key2": "value2"}
)
The :class:`_postgresql.JSONB` type includes all operations provided by
@@ -283,53 +252,43 @@ class JSONB(JSON):
__visit_name__ = "JSONB"
class Comparator(JSON.Comparator[_T]):
class Comparator(JSON.Comparator):
"""Define comparison operations for :class:`_types.JSON`."""
type: JSONB
def has_key(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test for presence of a key (equivalent of
the ``?`` operator). Note that the key may be a SQLA expression.
def has_key(self, other):
"""Boolean expression. Test for presence of a key. Note that the
key may be a SQLA expression.
"""
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
def has_all(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test for presence of all keys in jsonb
(equivalent of the ``?&`` operator)
"""
def has_all(self, other):
"""Boolean expression. Test for presence of all keys in jsonb"""
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
def has_any(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test for presence of any key in jsonb
(equivalent of the ``?|`` operator)
"""
def has_any(self, other):
"""Boolean expression. Test for presence of any key in jsonb"""
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
def contains(self, other: Any, **kwargs: Any) -> ColumnElement[bool]:
def contains(self, other, **kwargs):
"""Boolean expression. Test if keys (or array) are a superset
of/contained the keys of the argument jsonb expression
(equivalent of the ``@>`` operator).
of/contained the keys of the argument jsonb expression.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other: Any) -> ColumnElement[bool]:
def contained_by(self, other):
"""Boolean expression. Test if keys are a proper subset of the
keys of the argument jsonb expression
(equivalent of the ``<@`` operator).
keys of the argument jsonb expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def delete_path(
self, array: Union[List[str], _pg_array[str]]
) -> ColumnElement[JSONB]:
def delete_path(self, array):
"""JSONB expression. Deletes field or array element specified in
the argument array (equivalent of the ``#-`` operator).
the argument array.
The input may be a list of strings that will be coerced to an
``ARRAY`` or an instance of :meth:`_postgres.array`.
@@ -341,9 +300,9 @@ class JSONB(JSON):
right_side = cast(array, ARRAY(sqltypes.TEXT))
return self.operate(DELETE_PATH, right_side, result_type=JSONB)
def path_exists(self, other: Any) -> ColumnElement[bool]:
def path_exists(self, other):
"""Boolean expression. Test for presence of item given by the
argument JSONPath expression (equivalent of the ``@?`` operator).
argument JSONPath expression.
.. versionadded:: 2.0
"""
@@ -351,10 +310,9 @@ class JSONB(JSON):
PATH_EXISTS, other, result_type=sqltypes.Boolean
)
def path_match(self, other: Any) -> ColumnElement[bool]:
def path_match(self, other):
"""Boolean expression. Test if JSONPath predicate given by the
argument JSONPath expression matches
(equivalent of the ``@@`` operator).
argument JSONPath expression matches.
Only the first item of the result is taken into account.

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/named_types.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/named_types.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -7,9 +7,7 @@
# mypy: ignore-errors
from __future__ import annotations
from types import ModuleType
from typing import Any
from typing import Dict
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
@@ -27,11 +25,10 @@ from ...sql.ddl import InvokeCreateDDLBase
from ...sql.ddl import InvokeDropDDLBase
if TYPE_CHECKING:
from ...sql._typing import _CreateDropBind
from ...sql._typing import _TypeEngineArgument
class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
class NamedType(sqltypes.TypeEngine):
"""Base for named types."""
__abstract__ = True
@@ -39,9 +36,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
DDLDropper: Type[NamedTypeDropper]
create_type: bool
def create(
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
) -> None:
def create(self, bind, checkfirst=True, **kw):
"""Emit ``CREATE`` DDL for this type.
:param bind: a connectable :class:`_engine.Engine`,
@@ -55,9 +50,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
"""
bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
def drop(
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
) -> None:
def drop(self, bind, checkfirst=True, **kw):
"""Emit ``DROP`` DDL for this type.
:param bind: a connectable :class:`_engine.Engine`,
@@ -70,9 +63,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
"""
bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
def _check_for_name_in_memos(
self, checkfirst: bool, kw: Dict[str, Any]
) -> bool:
def _check_for_name_in_memos(self, checkfirst, kw):
"""Look in the 'ddl runner' for 'memos', then
note our name in that collection.
@@ -96,13 +87,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
else:
return False
def _on_table_create(
self,
target: Any,
bind: _CreateDropBind,
checkfirst: bool = False,
**kw: Any,
) -> None:
def _on_table_create(self, target, bind, checkfirst=False, **kw):
if (
checkfirst
or (
@@ -112,13 +97,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
) and not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_table_drop(
self,
target: Any,
bind: _CreateDropBind,
checkfirst: bool = False,
**kw: Any,
) -> None:
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
if (
not self.metadata
and not kw.get("_is_metadata_operation", False)
@@ -126,23 +105,11 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
):
self.drop(bind=bind, checkfirst=checkfirst)
def _on_metadata_create(
self,
target: Any,
bind: _CreateDropBind,
checkfirst: bool = False,
**kw: Any,
) -> None:
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
if not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_metadata_drop(
self,
target: Any,
bind: _CreateDropBind,
checkfirst: bool = False,
**kw: Any,
) -> None:
def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
if not self._check_for_name_in_memos(checkfirst, kw):
self.drop(bind=bind, checkfirst=checkfirst)
@@ -196,6 +163,7 @@ class EnumDropper(NamedTypeDropper):
class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
"""PostgreSQL ENUM type.
This is a subclass of :class:`_types.Enum` which includes
@@ -218,10 +186,8 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
:meth:`_schema.Table.drop`
methods are called::
table = Table(
"sometable",
metadata,
Column("some_enum", ENUM("a", "b", "c", name="myenum")),
table = Table('sometable', metadata,
Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
)
table.create(engine) # will emit CREATE ENUM and CREATE TABLE
@@ -232,17 +198,21 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
:class:`_postgresql.ENUM` independently, and associate it with the
:class:`_schema.MetaData` object itself::
my_enum = ENUM("a", "b", "c", name="myenum", metadata=metadata)
my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
t1 = Table("sometable_one", metadata, Column("some_enum", myenum))
t1 = Table('sometable_one', metadata,
Column('some_enum', myenum)
)
t2 = Table("sometable_two", metadata, Column("some_enum", myenum))
t2 = Table('sometable_two', metadata,
Column('some_enum', myenum)
)
When this pattern is used, care must still be taken at the level
of individual table creates. Emitting CREATE TABLE without also
specifying ``checkfirst=True`` will still cause issues::
t1.create(engine) # will fail: no such type 'myenum'
t1.create(engine) # will fail: no such type 'myenum'
If we specify ``checkfirst=True``, the individual table-level create
operation will check for the ``ENUM`` and create if not exists::
@@ -347,7 +317,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
return cls(**kw)
def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
def create(self, bind=None, checkfirst=True):
"""Emit ``CREATE TYPE`` for this
:class:`_postgresql.ENUM`.
@@ -368,7 +338,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
super().create(bind, checkfirst=checkfirst)
def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
def drop(self, bind=None, checkfirst=True):
"""Emit ``DROP TYPE`` for this
:class:`_postgresql.ENUM`.
@@ -388,7 +358,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
super().drop(bind, checkfirst=checkfirst)
def get_dbapi_type(self, dbapi: ModuleType) -> None:
def get_dbapi_type(self, dbapi):
"""dont return dbapi.STRING for ENUM in PostgreSQL, since that's
a different type"""
@@ -418,12 +388,14 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
A domain is essentially a data type with optional constraints
that restrict the allowed set of values. E.g.::
PositiveInt = DOMAIN("pos_int", Integer, check="VALUE > 0", not_null=True)
PositiveInt = DOMAIN(
"pos_int", Integer, check="VALUE > 0", not_null=True
)
UsPostalCode = DOMAIN(
"us_postal_code",
Text,
check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'",
check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'"
)
See the `PostgreSQL documentation`__ for additional details
@@ -432,7 +404,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
.. versionadded:: 2.0
""" # noqa: E501
"""
DDLGenerator = DomainGenerator
DDLDropper = DomainDropper
@@ -445,10 +417,10 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
data_type: _TypeEngineArgument[Any],
*,
collation: Optional[str] = None,
default: Union[elements.TextClause, str, None] = None,
default: Optional[Union[str, elements.TextClause]] = None,
constraint_name: Optional[str] = None,
not_null: Optional[bool] = None,
check: Union[elements.TextClause, str, None] = None,
check: Optional[str] = None,
create_type: bool = True,
**kw: Any,
):
@@ -492,7 +464,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
self.default = default
self.collation = collation
self.constraint_name = constraint_name
self.not_null = bool(not_null)
self.not_null = not_null
if check is not None:
check = coercions.expect(roles.DDLExpressionRole, check)
self.check = check

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/operators.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/operators.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/pg8000.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
# postgresql/pg8000.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
@@ -27,21 +27,19 @@ PostgreSQL ``client_encoding`` parameter; by default this is the value in
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
Typically, this can be changed to ``utf-8``, as a more useful default::
# client_encoding = sql_ascii # actually, defaults to database encoding
#client_encoding = sql_ascii # actually, defaults to database
# encoding
client_encoding = utf8
The ``client_encoding`` can be overridden for a session by executing the SQL:
.. sourcecode:: sql
SET CLIENT_ENCODING TO 'utf8';
SET CLIENT_ENCODING TO 'utf8';
SQLAlchemy will execute this SQL on all new connections based on the value
passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
engine = create_engine(
"postgresql+pg8000://user:pass@host/dbname", client_encoding="utf8"
)
"postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
.. _pg8000_ssl:
@@ -52,7 +50,6 @@ pg8000 accepts a Python ``SSLContext`` object which may be specified using the
:paramref:`_sa.create_engine.connect_args` dictionary::
import ssl
ssl_context = ssl.create_default_context()
engine = sa.create_engine(
"postgresql+pg8000://scott:tiger@192.168.0.199/test",
@@ -64,7 +61,6 @@ or does not match the host name (as seen from the client), it may also be
necessary to disable hostname checking::
import ssl
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
@@ -257,7 +253,7 @@ class _PGOIDVECTOR(_SpaceVector, OIDVECTOR):
pass
class _Pg8000Range(ranges.AbstractSingleRangeImpl):
class _Pg8000Range(ranges.AbstractRangeImpl):
def bind_processor(self, dialect):
pg8000_Range = dialect.dbapi.Range
@@ -308,13 +304,15 @@ class _Pg8000MultiRange(ranges.AbstractMultiRangeImpl):
def to_multirange(value):
if value is None:
return None
else:
return ranges.MultiRange(
mr = []
for v in value:
mr.append(
ranges.Range(
v.lower, v.upper, bounds=v.bounds, empty=v.is_empty
)
for v in value
)
return mr
return to_multirange
@@ -540,9 +538,6 @@ class PGDialect_pg8000(PGDialect):
cursor.execute("COMMIT")
cursor.close()
def detect_autocommit_setting(self, dbapi_conn) -> bool:
return bool(dbapi_conn.autocommit)
def set_readonly(self, connection, value):
cursor = connection.cursor()
try:
@@ -589,8 +584,8 @@ class PGDialect_pg8000(PGDialect):
cursor = dbapi_connection.cursor()
cursor.execute(
f"""SET CLIENT_ENCODING TO '{
client_encoding.replace("'", "''")
}'"""
client_encoding.replace("'", "''")
}'"""
)
cursor.execute("COMMIT")
cursor.close()

View File

@@ -1,16 +1,10 @@
# dialects/postgresql/pg_catalog.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/pg_catalog.py
# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from typing import Optional
from typing import Sequence
from typing import TYPE_CHECKING
# mypy: ignore-errors
from .array import ARRAY
from .types import OID
@@ -29,37 +23,31 @@ from ...types import String
from ...types import Text
from ...types import TypeDecorator
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.type_api import _ResultProcessorType
# types
class NAME(TypeDecorator[str]):
class NAME(TypeDecorator):
impl = String(64, collation="C")
cache_ok = True
class PG_NODE_TREE(TypeDecorator[str]):
class PG_NODE_TREE(TypeDecorator):
impl = Text(collation="C")
cache_ok = True
class INT2VECTOR(TypeDecorator[Sequence[int]]):
class INT2VECTOR(TypeDecorator):
impl = ARRAY(SmallInteger)
cache_ok = True
class OIDVECTOR(TypeDecorator[Sequence[int]]):
class OIDVECTOR(TypeDecorator):
impl = ARRAY(OID)
cache_ok = True
class _SpaceVector:
def result_processor(
self, dialect: Dialect, coltype: object
) -> _ResultProcessorType[list[int]]:
def process(value: Any) -> Optional[list[int]]:
def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return value
return [int(p) for p in value.split(" ")]
@@ -89,7 +77,7 @@ RELKINDS_MAT_VIEW = ("m",)
RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
# tables
pg_catalog_meta = MetaData(schema="pg_catalog")
pg_catalog_meta = MetaData()
pg_namespace = Table(
"pg_namespace",
@@ -97,6 +85,7 @@ pg_namespace = Table(
Column("oid", OID),
Column("nspname", NAME),
Column("nspowner", OID),
schema="pg_catalog",
)
pg_class = Table(
@@ -131,6 +120,7 @@ pg_class = Table(
Column("relispartition", Boolean, info={"server_version": (10,)}),
Column("relrewrite", OID, info={"server_version": (11,)}),
Column("reloptions", ARRAY(Text)),
schema="pg_catalog",
)
pg_type = Table(
@@ -165,6 +155,7 @@ pg_type = Table(
Column("typndims", Integer),
Column("typcollation", OID, info={"server_version": (9, 1)}),
Column("typdefault", Text),
schema="pg_catalog",
)
pg_index = Table(
@@ -191,6 +182,7 @@ pg_index = Table(
Column("indoption", INT2VECTOR),
Column("indexprs", PG_NODE_TREE),
Column("indpred", PG_NODE_TREE),
schema="pg_catalog",
)
pg_attribute = Table(
@@ -217,6 +209,7 @@ pg_attribute = Table(
Column("attislocal", Boolean),
Column("attinhcount", Integer),
Column("attcollation", OID, info={"server_version": (9, 1)}),
schema="pg_catalog",
)
pg_constraint = Table(
@@ -242,6 +235,7 @@ pg_constraint = Table(
Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
Column("conkey", ARRAY(SmallInteger)),
Column("confkey", ARRAY(SmallInteger)),
schema="pg_catalog",
)
pg_sequence = Table(
@@ -255,6 +249,7 @@ pg_sequence = Table(
Column("seqmin", BigInteger),
Column("seqcache", BigInteger),
Column("seqcycle", Boolean),
schema="pg_catalog",
info={"server_version": (10,)},
)
@@ -265,6 +260,7 @@ pg_attrdef = Table(
Column("adrelid", OID),
Column("adnum", SmallInteger),
Column("adbin", PG_NODE_TREE),
schema="pg_catalog",
)
pg_description = Table(
@@ -274,6 +270,7 @@ pg_description = Table(
Column("classoid", OID),
Column("objsubid", Integer),
Column("description", Text(collation="C")),
schema="pg_catalog",
)
pg_enum = Table(
@@ -283,6 +280,7 @@ pg_enum = Table(
Column("enumtypid", OID),
Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
Column("enumlabel", NAME),
schema="pg_catalog",
)
pg_am = Table(
@@ -292,35 +290,5 @@ pg_am = Table(
Column("amname", NAME),
Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
Column("amtype", CHAR, info={"server_version": (9, 6)}),
)
pg_collation = Table(
"pg_collation",
pg_catalog_meta,
Column("oid", OID, info={"server_version": (9, 3)}),
Column("collname", NAME),
Column("collnamespace", OID),
Column("collowner", OID),
Column("collprovider", CHAR, info={"server_version": (10,)}),
Column("collisdeterministic", Boolean, info={"server_version": (12,)}),
Column("collencoding", Integer),
Column("collcollate", Text),
Column("collctype", Text),
Column("colliculocale", Text),
Column("collicurules", Text, info={"server_version": (16,)}),
Column("collversion", Text, info={"server_version": (10,)}),
)
pg_opclass = Table(
"pg_opclass",
pg_catalog_meta,
Column("oid", OID, info={"server_version": (9, 3)}),
Column("opcmethod", NAME),
Column("opcname", NAME),
Column("opsnamespace", OID),
Column("opsowner", OID),
Column("opcfamily", OID),
Column("opcintype", OID),
Column("opcdefault", Boolean),
Column("opckeytype", OID),
schema="pg_catalog",
)

View File

@@ -1,9 +1,3 @@
# dialects/postgresql/provision.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import time
@@ -97,7 +91,7 @@ def drop_all_schema_objects_pre_tables(cfg, eng):
for xid in conn.exec_driver_sql(
"select gid from pg_prepared_xacts"
).scalars():
conn.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
conn.execute("ROLLBACK PREPARED '%s'" % xid)
@drop_all_schema_objects_post_tables.for_db("postgresql")

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/psycopg.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/psycopg2.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -29,29 +29,20 @@ selected depending on how the engine is created:
automatically select the sync version, e.g.::
from sqlalchemy import create_engine
sync_engine = create_engine(
"postgresql+psycopg://scott:tiger@localhost/test"
)
sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test")
* calling :func:`_asyncio.create_async_engine` with
``postgresql+psycopg://...`` will automatically select the async version,
e.g.::
from sqlalchemy.ext.asyncio import create_async_engine
asyncio_engine = create_async_engine(
"postgresql+psycopg://scott:tiger@localhost/test"
)
asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test")
The asyncio version of the dialect may also be specified explicitly using the
``psycopg_async`` suffix, as::
from sqlalchemy.ext.asyncio import create_async_engine
asyncio_engine = create_async_engine(
"postgresql+psycopg_async://scott:tiger@localhost/test"
)
asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test")
.. seealso::
@@ -59,42 +50,9 @@ The asyncio version of the dialect may also be specified explicitly using the
dialect shares most of its behavior with the ``psycopg2`` dialect.
Further documentation is available there.
Using a different Cursor class
------------------------------
One of the differences between ``psycopg`` and the older ``psycopg2``
is how bound parameters are handled: ``psycopg2`` would bind them
client side, while ``psycopg`` by default will bind them server side.
It's possible to configure ``psycopg`` to do client side binding by
specifying the ``cursor_factory`` to be ``ClientCursor`` when creating
the engine::
from psycopg import ClientCursor
client_side_engine = create_engine(
"postgresql+psycopg://...",
connect_args={"cursor_factory": ClientCursor},
)
Similarly when using an async engine the ``AsyncClientCursor`` can be
specified::
from psycopg import AsyncClientCursor
client_side_engine = create_async_engine(
"postgresql+psycopg://...",
connect_args={"cursor_factory": AsyncClientCursor},
)
.. seealso::
`Client-side-binding cursors <https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#client-side-binding-cursors>`_
""" # noqa
from __future__ import annotations
from collections import deque
import logging
import re
from typing import cast
@@ -121,8 +79,6 @@ from ...util.concurrency import await_only
if TYPE_CHECKING:
from typing import Iterable
from psycopg import AsyncConnection
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
@@ -135,6 +91,8 @@ class _PGREGCONFIG(REGCONFIG):
class _PGJSON(JSON):
render_bind_cast = True
def bind_processor(self, dialect):
return self._make_bind_processor(None, dialect._psycopg_Json)
@@ -143,6 +101,8 @@ class _PGJSON(JSON):
class _PGJSONB(JSONB):
render_bind_cast = True
def bind_processor(self, dialect):
return self._make_bind_processor(None, dialect._psycopg_Jsonb)
@@ -202,7 +162,7 @@ class _PGBoolean(sqltypes.Boolean):
render_bind_cast = True
class _PsycopgRange(ranges.AbstractSingleRangeImpl):
class _PsycopgRange(ranges.AbstractRangeImpl):
def bind_processor(self, dialect):
psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
@@ -258,10 +218,8 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
def result_processor(self, dialect, coltype):
def to_range(value):
if value is None:
return None
else:
return ranges.MultiRange(
if value is not None:
value = [
ranges.Range(
elem._lower,
elem._upper,
@@ -269,7 +227,9 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
empty=not elem._bounds,
)
for elem in value
)
]
return value
return to_range
@@ -326,7 +286,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
sqltypes.Integer: _PGInteger,
sqltypes.SmallInteger: _PGSmallInteger,
sqltypes.BigInteger: _PGBigInteger,
ranges.AbstractSingleRange: _PsycopgRange,
ranges.AbstractRange: _PsycopgRange,
ranges.AbstractMultiRange: _PsycopgMultiRange,
},
)
@@ -406,12 +366,10 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
# register the adapter for connections made subsequent to
# this one
assert self._psycopg_adapters_map
register_hstore(info, self._psycopg_adapters_map)
# register the adapter for this connection
assert connection.connection
register_hstore(info, connection.connection.driver_connection)
register_hstore(info, connection.connection)
@classmethod
def import_dbapi(cls):
@@ -572,7 +530,7 @@ class AsyncAdapt_psycopg_cursor:
def __init__(self, cursor, await_) -> None:
self._cursor = cursor
self.await_ = await_
self._rows = deque()
self._rows = []
def __getattr__(self, name):
return getattr(self._cursor, name)
@@ -599,19 +557,24 @@ class AsyncAdapt_psycopg_cursor:
# eq/ne
if res and res.status == self._psycopg_ExecStatus.TUPLES_OK:
rows = self.await_(self._cursor.fetchall())
self._rows = deque(rows)
if not isinstance(rows, list):
self._rows = list(rows)
else:
self._rows = rows
return result
def executemany(self, query, params_seq):
return self.await_(self._cursor.executemany(query, params_seq))
def __iter__(self):
# TODO: try to avoid pop(0) on a list
while self._rows:
yield self._rows.popleft()
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.popleft()
# TODO: try to avoid pop(0) on a list
return self._rows.pop(0)
else:
return None
@@ -619,12 +582,13 @@ class AsyncAdapt_psycopg_cursor:
if size is None:
size = self._cursor.arraysize
rr = self._rows
return [rr.popleft() for _ in range(min(size, len(rr)))]
retval = self._rows[0:size]
self._rows = self._rows[size:]
return retval
def fetchall(self):
retval = list(self._rows)
self._rows.clear()
retval = self._rows
self._rows = []
return retval
@@ -655,7 +619,6 @@ class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor):
class AsyncAdapt_psycopg_connection(AdaptedConnection):
_connection: AsyncConnection
__slots__ = ()
await_ = staticmethod(await_only)

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/psycopg2.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# postgresql/psycopg2.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -88,6 +88,7 @@ connection URI::
"postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require"
)
Unix Domain Connections
------------------------
@@ -102,17 +103,13 @@ in ``/tmp``, or whatever socket directory was specified when PostgreSQL
was built. This value can be overridden by passing a pathname to psycopg2,
using ``host`` as an additional keyword argument::
create_engine(
"postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql"
)
create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql")
.. warning:: The format accepted here allows for a hostname in the main URL
in addition to the "host" query string argument. **When using this URL
format, the initial host is silently ignored**. That is, this URL::
engine = create_engine(
"postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2"
)
engine = create_engine("postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2")
Above, the hostname ``myhost1`` is **silently ignored and discarded.** The
host which is connected is the ``myhost2`` host.
@@ -193,7 +190,7 @@ any or all elements of the connection string.
For this form, the URL can be passed without any elements other than the
initial scheme::
engine = create_engine("postgresql+psycopg2://")
engine = create_engine('postgresql+psycopg2://')
In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()``
function which in turn represents an empty DSN passed to libpq.
@@ -245,7 +242,7 @@ Psycopg2 Fast Execution Helpers
Modern versions of psycopg2 include a feature known as
`Fast Execution Helpers \
<https://www.psycopg.org/docs/extras.html#fast-execution-helpers>`_, which
<https://initd.org/psycopg/docs/extras.html#fast-execution-helpers>`_, which
have been shown in benchmarking to improve psycopg2's executemany()
performance, primarily with INSERT statements, by at least
an order of magnitude.
@@ -267,8 +264,8 @@ used feature. The use of this extension may be enabled using the
engine = create_engine(
"postgresql+psycopg2://scott:tiger@host/dbname",
executemany_mode="values_plus_batch",
)
executemany_mode='values_plus_batch')
Possible options for ``executemany_mode`` include:
@@ -314,10 +311,8 @@ is below::
engine = create_engine(
"postgresql+psycopg2://scott:tiger@host/dbname",
executemany_mode="values_plus_batch",
insertmanyvalues_page_size=5000,
executemany_batch_page_size=500,
)
executemany_mode='values_plus_batch',
insertmanyvalues_page_size=5000, executemany_batch_page_size=500)
.. seealso::
@@ -343,9 +338,7 @@ in the following ways:
passed in the database URL; this parameter is consumed by the underlying
``libpq`` PostgreSQL client library::
engine = create_engine(
"postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8"
)
engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8")
Alternatively, the above ``client_encoding`` value may be passed using
:paramref:`_sa.create_engine.connect_args` for programmatic establishment with
@@ -353,7 +346,7 @@ in the following ways:
engine = create_engine(
"postgresql+psycopg2://user:pass@host/dbname",
connect_args={"client_encoding": "utf8"},
connect_args={'client_encoding': 'utf8'}
)
* For all PostgreSQL versions, psycopg2 supports a client-side encoding
@@ -362,7 +355,8 @@ in the following ways:
``client_encoding`` parameter passed to :func:`_sa.create_engine`::
engine = create_engine(
"postgresql+psycopg2://user:pass@host/dbname", client_encoding="utf8"
"postgresql+psycopg2://user:pass@host/dbname",
client_encoding="utf8"
)
.. tip:: The above ``client_encoding`` parameter admittedly is very similar
@@ -381,9 +375,11 @@ in the following ways:
# postgresql.conf file
# client_encoding = sql_ascii # actually, defaults to database
# encoding
# encoding
client_encoding = utf8
Transactions
------------
@@ -430,15 +426,15 @@ is set to the ``logging.INFO`` level, notice messages will be logged::
import logging
logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO)
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
Above, it is assumed that logging is configured externally. If this is not
the case, configuration such as ``logging.basicConfig()`` must be utilized::
import logging
logging.basicConfig() # log messages to stdout
logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO)
logging.basicConfig() # log messages to stdout
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
.. seealso::
@@ -475,10 +471,8 @@ textual HSTORE expression. If this behavior is not desired, disable the
use of the hstore extension by setting ``use_native_hstore`` to ``False`` as
follows::
engine = create_engine(
"postgresql+psycopg2://scott:tiger@localhost/test",
use_native_hstore=False,
)
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
use_native_hstore=False)
The ``HSTORE`` type is **still supported** when the
``psycopg2.extensions.register_hstore()`` extension is not used. It merely
@@ -519,7 +513,7 @@ class _PGJSONB(JSONB):
return None
class _Psycopg2Range(ranges.AbstractSingleRangeImpl):
class _Psycopg2Range(ranges.AbstractRangeImpl):
_psycopg2_range_cls = "none"
def bind_processor(self, dialect):
@@ -850,43 +844,33 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg):
# checks based on strings. in the case that .closed
# didn't cut it, fall back onto these.
str_e = str(e).partition("\n")[0]
for msg in self._is_disconnect_messages:
for msg in [
# these error messages from libpq: interfaces/libpq/fe-misc.c
# and interfaces/libpq/fe-secure.c.
"terminating connection",
"closed the connection",
"connection not open",
"could not receive data from server",
"could not send data to server",
# psycopg2 client errors, psycopg2/connection.h,
# psycopg2/cursor.h
"connection already closed",
"cursor already closed",
# not sure where this path is originally from, it may
# be obsolete. It really says "losed", not "closed".
"losed the connection unexpectedly",
# these can occur in newer SSL
"connection has been closed unexpectedly",
"SSL error: decryption failed or bad record mac",
"SSL SYSCALL error: Bad file descriptor",
"SSL SYSCALL error: EOF detected",
"SSL SYSCALL error: Operation timed out",
"SSL SYSCALL error: Bad address",
]:
idx = str_e.find(msg)
if idx >= 0 and '"' not in str_e[:idx]:
return True
return False
@util.memoized_property
def _is_disconnect_messages(self):
return (
# these error messages from libpq: interfaces/libpq/fe-misc.c
# and interfaces/libpq/fe-secure.c.
"terminating connection",
"closed the connection",
"connection not open",
"could not receive data from server",
"could not send data to server",
# psycopg2 client errors, psycopg2/connection.h,
# psycopg2/cursor.h
"connection already closed",
"cursor already closed",
# not sure where this path is originally from, it may
# be obsolete. It really says "losed", not "closed".
"losed the connection unexpectedly",
# these can occur in newer SSL
"connection has been closed unexpectedly",
"SSL error: decryption failed or bad record mac",
"SSL SYSCALL error: Bad file descriptor",
"SSL SYSCALL error: EOF detected",
"SSL SYSCALL error: Operation timed out",
"SSL SYSCALL error: Bad address",
# This can occur in OpenSSL 1 when an unexpected EOF occurs.
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html#BUGS
# It may also occur in newer OpenSSL for a non-recoverable I/O
# error as a result of a system call that does not set 'errno'
# in libc.
"SSL SYSCALL error: Success",
)
dialect = PGDialect_psycopg2

View File

@@ -1,5 +1,5 @@
# dialects/postgresql/psycopg2cffi.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# testing/engines.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,5 +1,4 @@
# dialects/postgresql/ranges.py
# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -15,10 +14,8 @@ from decimal import Decimal
from typing import Any
from typing import cast
from typing import Generic
from typing import List
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
@@ -154,8 +151,8 @@ class Range(Generic[_T]):
return not self.empty and self.upper is None
@property
def __sa_type_engine__(self) -> AbstractSingleRange[_T]:
return AbstractSingleRange()
def __sa_type_engine__(self) -> AbstractRange[Range[_T]]:
return AbstractRange()
def _contains_value(self, value: _T) -> bool:
"""Return True if this range contains the given value."""
@@ -271,9 +268,9 @@ class Range(Generic[_T]):
value2 += step
value2_inc = False
if value1 < value2:
if value1 < value2: # type: ignore
return -1
elif value1 > value2:
elif value1 > value2: # type: ignore
return 1
elif only_values:
return 0
@@ -360,8 +357,6 @@ class Range(Generic[_T]):
else:
return self._contains_value(value)
__contains__ = contains
def overlaps(self, other: Range[_T]) -> bool:
"Determine whether this range overlaps with `other`."
@@ -712,46 +707,27 @@ class Range(Generic[_T]):
return f"{b0}{l},{r}{b1}"
class MultiRange(List[Range[_T]]):
"""Represents a multirange sequence.
This list subclass is an utility to allow automatic type inference of
the proper multi-range SQL type depending on the single range values.
This is useful when operating on literal multi-ranges::
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import MultiRange, Range
value = literal(MultiRange([Range(2, 4)]))
select(tbl).where(tbl.c.value.op("@")(MultiRange([Range(-3, 7)])))
.. versionadded:: 2.0.26
class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
"""
Base for PostgreSQL RANGE types.
.. seealso::
- :ref:`postgresql_multirange_list_use`.
"""
`PostgreSQL range functions <https://www.postgresql.org/docs/current/static/functions-range.html>`_
@property
def __sa_type_engine__(self) -> AbstractMultiRange[_T]:
return AbstractMultiRange()
class AbstractRange(sqltypes.TypeEngine[_T]):
"""Base class for single and multi Range SQL types."""
""" # noqa: E501
render_bind_cast = True
__abstract__ = True
@overload
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ...
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
...
@overload
def adapt(
self, cls: Type[TypeEngineMixin], **kw: Any
) -> TypeEngine[Any]: ...
def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
...
def adapt(
self,
@@ -765,10 +741,7 @@ class AbstractRange(sqltypes.TypeEngine[_T]):
and also render as ``INT4RANGE`` in SQL and DDL.
"""
if (
issubclass(cls, (AbstractSingleRangeImpl, AbstractMultiRangeImpl))
and cls is not self.__class__
):
if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__:
# two ways to do this are: 1. create a new type on the fly
# or 2. have AbstractRangeImpl(visit_name) constructor and a
# visit_abstract_range_impl() method in the PG compiler.
@@ -787,6 +760,21 @@ class AbstractRange(sqltypes.TypeEngine[_T]):
else:
return super().adapt(cls)
def _resolve_for_literal(self, value: Any) -> Any:
spec = value.lower if value.lower is not None else value.upper
if isinstance(spec, int):
return INT8RANGE()
elif isinstance(spec, (Decimal, float)):
return NUMRANGE()
elif isinstance(spec, datetime):
return TSRANGE() if not spec.tzinfo else TSTZRANGE()
elif isinstance(spec, date):
return DATERANGE()
else:
# empty Range, SQL datatype can't be determined here
return sqltypes.NULLTYPE
class comparator_factory(TypeEngine.Comparator[Range[Any]]):
"""Define comparison operations for range types."""
@@ -868,164 +856,91 @@ class AbstractRange(sqltypes.TypeEngine[_T]):
return self.expr.operate(operators.mul, other)
class AbstractSingleRange(AbstractRange[Range[_T]]):
"""Base for PostgreSQL RANGE types.
These are types that return a single :class:`_postgresql.Range` object.
.. seealso::
`PostgreSQL range functions <https://www.postgresql.org/docs/current/static/functions-range.html>`_
""" # noqa: E501
__abstract__ = True
def _resolve_for_literal(self, value: Range[Any]) -> Any:
spec = value.lower if value.lower is not None else value.upper
if isinstance(spec, int):
# pg is unreasonably picky here: the query
# "select 1::INTEGER <@ '[1, 4)'::INT8RANGE" raises
# "operator does not exist: integer <@ int8range" as of pg 16
if _is_int32(value):
return INT4RANGE()
else:
return INT8RANGE()
elif isinstance(spec, (Decimal, float)):
return NUMRANGE()
elif isinstance(spec, datetime):
return TSRANGE() if not spec.tzinfo else TSTZRANGE()
elif isinstance(spec, date):
return DATERANGE()
else:
# empty Range, SQL datatype can't be determined here
return sqltypes.NULLTYPE
class AbstractSingleRangeImpl(AbstractSingleRange[_T]):
"""Marker for AbstractSingleRange that will apply a subclass-specific
class AbstractRangeImpl(AbstractRange[Range[_T]]):
"""Marker for AbstractRange that will apply a subclass-specific
adaptation"""
class AbstractMultiRange(AbstractRange[Sequence[Range[_T]]]):
"""Base for PostgreSQL MULTIRANGE types.
these are types that return a sequence of :class:`_postgresql.Range`
objects.
"""
class AbstractMultiRange(AbstractRange[Range[_T]]):
"""base for PostgreSQL MULTIRANGE types"""
__abstract__ = True
def _resolve_for_literal(self, value: Sequence[Range[Any]]) -> Any:
if not value:
# empty MultiRange, SQL datatype can't be determined here
return sqltypes.NULLTYPE
first = value[0]
spec = first.lower if first.lower is not None else first.upper
if isinstance(spec, int):
# pg is unreasonably picky here: the query
# "select 1::INTEGER <@ '{[1, 4),[6,19)}'::INT8MULTIRANGE" raises
# "operator does not exist: integer <@ int8multirange" as of pg 16
if all(_is_int32(r) for r in value):
return INT4MULTIRANGE()
else:
return INT8MULTIRANGE()
elif isinstance(spec, (Decimal, float)):
return NUMMULTIRANGE()
elif isinstance(spec, datetime):
return TSMULTIRANGE() if not spec.tzinfo else TSTZMULTIRANGE()
elif isinstance(spec, date):
return DATEMULTIRANGE()
else:
# empty Range, SQL datatype can't be determined here
return sqltypes.NULLTYPE
class AbstractMultiRangeImpl(AbstractMultiRange[_T]):
"""Marker for AbstractMultiRange that will apply a subclass-specific
class AbstractMultiRangeImpl(
AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]]
):
"""Marker for AbstractRange that will apply a subclass-specific
adaptation"""
class INT4RANGE(AbstractSingleRange[int]):
class INT4RANGE(AbstractRange[Range[int]]):
"""Represent the PostgreSQL INT4RANGE type."""
__visit_name__ = "INT4RANGE"
class INT8RANGE(AbstractSingleRange[int]):
class INT8RANGE(AbstractRange[Range[int]]):
"""Represent the PostgreSQL INT8RANGE type."""
__visit_name__ = "INT8RANGE"
class NUMRANGE(AbstractSingleRange[Decimal]):
class NUMRANGE(AbstractRange[Range[Decimal]]):
"""Represent the PostgreSQL NUMRANGE type."""
__visit_name__ = "NUMRANGE"
class DATERANGE(AbstractSingleRange[date]):
class DATERANGE(AbstractRange[Range[date]]):
"""Represent the PostgreSQL DATERANGE type."""
__visit_name__ = "DATERANGE"
class TSRANGE(AbstractSingleRange[datetime]):
class TSRANGE(AbstractRange[Range[datetime]]):
"""Represent the PostgreSQL TSRANGE type."""
__visit_name__ = "TSRANGE"
class TSTZRANGE(AbstractSingleRange[datetime]):
class TSTZRANGE(AbstractRange[Range[datetime]]):
"""Represent the PostgreSQL TSTZRANGE type."""
__visit_name__ = "TSTZRANGE"
class INT4MULTIRANGE(AbstractMultiRange[int]):
class INT4MULTIRANGE(AbstractMultiRange[Range[int]]):
"""Represent the PostgreSQL INT4MULTIRANGE type."""
__visit_name__ = "INT4MULTIRANGE"
class INT8MULTIRANGE(AbstractMultiRange[int]):
class INT8MULTIRANGE(AbstractMultiRange[Range[int]]):
"""Represent the PostgreSQL INT8MULTIRANGE type."""
__visit_name__ = "INT8MULTIRANGE"
class NUMMULTIRANGE(AbstractMultiRange[Decimal]):
class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]):
"""Represent the PostgreSQL NUMMULTIRANGE type."""
__visit_name__ = "NUMMULTIRANGE"
class DATEMULTIRANGE(AbstractMultiRange[date]):
class DATEMULTIRANGE(AbstractMultiRange[Range[date]]):
"""Represent the PostgreSQL DATEMULTIRANGE type."""
__visit_name__ = "DATEMULTIRANGE"
class TSMULTIRANGE(AbstractMultiRange[datetime]):
class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]):
"""Represent the PostgreSQL TSRANGE type."""
__visit_name__ = "TSMULTIRANGE"
class TSTZMULTIRANGE(AbstractMultiRange[datetime]):
class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]):
"""Represent the PostgreSQL TSTZRANGE type."""
__visit_name__ = "TSTZMULTIRANGE"
_max_int_32 = 2**31 - 1
_min_int_32 = -(2**31)
def _is_int32(r: Range[int]) -> bool:
return (r.lower is None or _min_int_32 <= r.lower <= _max_int_32) and (
r.upper is None or _min_int_32 <= r.upper <= _max_int_32
)

View File

@@ -1,5 +1,4 @@
# dialects/postgresql/types.py
# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors
# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -38,52 +37,43 @@ class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]):
@overload
def __init__(
self: PGUuid[_python_UUID], as_uuid: Literal[True] = ...
) -> None: ...
) -> None:
...
@overload
def __init__(
self: PGUuid[str], as_uuid: Literal[False] = ...
) -> None: ...
def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None:
...
def __init__(self, as_uuid: bool = True) -> None: ...
def __init__(self, as_uuid: bool = True) -> None:
...
class BYTEA(sqltypes.LargeBinary):
__visit_name__ = "BYTEA"
class _NetworkAddressTypeMixin:
def coerce_compared_value(
self, op: Optional[OperatorType], value: Any
) -> TypeEngine[Any]:
if TYPE_CHECKING:
assert isinstance(self, TypeEngine)
return self
class INET(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
class INET(sqltypes.TypeEngine[str]):
__visit_name__ = "INET"
PGInet = INET
class CIDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
class CIDR(sqltypes.TypeEngine[str]):
__visit_name__ = "CIDR"
PGCidr = CIDR
class MACADDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
class MACADDR(sqltypes.TypeEngine[str]):
__visit_name__ = "MACADDR"
PGMacAddr = MACADDR
class MACADDR8(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
class MACADDR8(sqltypes.TypeEngine[str]):
__visit_name__ = "MACADDR8"
@@ -104,11 +94,12 @@ class MONEY(sqltypes.TypeEngine[str]):
from sqlalchemy import Dialect
from sqlalchemy import TypeDecorator
class NumericMoney(TypeDecorator):
impl = MONEY
def process_result_value(self, value: Any, dialect: Dialect) -> None:
def process_result_value(
self, value: Any, dialect: Dialect
) -> None:
if value is not None:
# adjust this for the currency and numeric
m = re.match(r"\$([\d.]+)", value)
@@ -123,7 +114,6 @@ class MONEY(sqltypes.TypeEngine[str]):
from sqlalchemy import cast
from sqlalchemy import TypeDecorator
class NumericMoney(TypeDecorator):
impl = MONEY
@@ -132,18 +122,20 @@ class MONEY(sqltypes.TypeEngine[str]):
.. versionadded:: 1.2
""" # noqa: E501
"""
__visit_name__ = "MONEY"
class OID(sqltypes.TypeEngine[int]):
"""Provide the PostgreSQL OID type."""
__visit_name__ = "OID"
class REGCONFIG(sqltypes.TypeEngine[str]):
"""Provide the PostgreSQL REGCONFIG type.
.. versionadded:: 2.0.0rc1
@@ -154,6 +146,7 @@ class REGCONFIG(sqltypes.TypeEngine[str]):
class TSQUERY(sqltypes.TypeEngine[str]):
"""Provide the PostgreSQL TSQUERY type.
.. versionadded:: 2.0.0rc1
@@ -164,6 +157,7 @@ class TSQUERY(sqltypes.TypeEngine[str]):
class REGCLASS(sqltypes.TypeEngine[str]):
"""Provide the PostgreSQL REGCLASS type.
.. versionadded:: 1.2.7
@@ -174,6 +168,7 @@ class REGCLASS(sqltypes.TypeEngine[str]):
class TIMESTAMP(sqltypes.TIMESTAMP):
"""Provide the PostgreSQL TIMESTAMP type."""
__visit_name__ = "TIMESTAMP"
@@ -194,6 +189,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
class TIME(sqltypes.TIME):
"""PostgreSQL TIME type."""
__visit_name__ = "TIME"
@@ -214,6 +210,7 @@ class TIME(sqltypes.TIME):
class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
"""PostgreSQL INTERVAL type."""
__visit_name__ = "INTERVAL"
@@ -283,6 +280,7 @@ PGBit = BIT
class TSVECTOR(sqltypes.TypeEngine[str]):
"""The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
text search type TSVECTOR.
@@ -299,6 +297,7 @@ class TSVECTOR(sqltypes.TypeEngine[str]):
class CITEXT(sqltypes.TEXT):
"""Provide the PostgreSQL CITEXT type.
.. versionadded:: 2.0.7

View File

@@ -1,5 +1,5 @@
# dialects/sqlite/__init__.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# sqlite/__init__.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under

View File

@@ -1,9 +1,10 @@
# dialects/sqlite/aiosqlite.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# sqlite/aiosqlite.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
@@ -30,7 +31,6 @@ This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("sqlite+aiosqlite:///filename")
The URL passes through all arguments to the ``pysqlite`` driver, so all
@@ -49,71 +49,45 @@ in Python and use them directly in SQLite queries as described here: :ref:`pysql
Serializable isolation / Savepoints / Transactional DDL (asyncio version)
-------------------------------------------------------------------------
A newly revised version of this important section is now available
at the top level of the SQLAlchemy SQLite documentation, in the section
:ref:`sqlite_transactions`.
Similarly to pysqlite, aiosqlite does not support SAVEPOINT feature.
The solution is similar to :ref:`pysqlite_serializable`. This is achieved by the event listeners in async::
.. _aiosqlite_pooling:
from sqlalchemy import create_engine, event
from sqlalchemy.ext.asyncio import create_async_engine
Pooling Behavior
----------------
engine = create_async_engine("sqlite+aiosqlite:///myfile.db")
The SQLAlchemy ``aiosqlite`` DBAPI establishes the connection pool differently
based on the kind of SQLite database that's requested:
@event.listens_for(engine.sync_engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable aiosqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None
* When a ``:memory:`` SQLite database is specified, the dialect by default
will use :class:`.StaticPool`. This pool maintains a single
connection, so that all access to the engine
use the same ``:memory:`` database.
* When a file-based database is specified, the dialect will use
:class:`.AsyncAdaptedQueuePool` as the source of connections.
@event.listens_for(engine.sync_engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.exec_driver_sql("BEGIN")
.. versionchanged:: 2.0.38
SQLite file database engines now use :class:`.AsyncAdaptedQueuePool` by default.
Previously, :class:`.NullPool` were used. The :class:`.NullPool` class
may be used by specifying it via the
:paramref:`_sa.create_engine.poolclass` parameter.
.. warning:: When using the above recipe, it is advised to not use the
:paramref:`.Connection.execution_options.isolation_level` setting on
:class:`_engine.Connection` and :func:`_sa.create_engine`
with the SQLite driver,
as this function necessarily will also alter the ".isolation_level" setting.
""" # noqa
from __future__ import annotations
import asyncio
from collections import deque
from functools import partial
from types import ModuleType
from typing import Any
from typing import cast
from typing import Deque
from typing import Iterator
from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import TYPE_CHECKING
from typing import Union
from .base import SQLiteExecutionContext
from .pysqlite import SQLiteDialect_pysqlite
from ... import pool
from ... import util
from ...connectors.asyncio import AsyncAdapt_dbapi_module
from ...engine import AdaptedConnection
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
if TYPE_CHECKING:
from ...connectors.asyncio import AsyncIODBAPIConnection
from ...connectors.asyncio import AsyncIODBAPICursor
from ...engine.interfaces import _DBAPICursorDescription
from ...engine.interfaces import _DBAPIMultiExecuteParams
from ...engine.interfaces import _DBAPISingleExecuteParams
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.url import URL
from ...pool.base import PoolProxiedConnection
class AsyncAdapt_aiosqlite_cursor:
# TODO: base on connectors/asyncio.py
@@ -132,26 +106,21 @@ class AsyncAdapt_aiosqlite_cursor:
server_side = False
def __init__(self, adapt_connection: AsyncAdapt_aiosqlite_connection):
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
self.arraysize = 1
self.rowcount = -1
self.description: Optional[_DBAPICursorDescription] = None
self._rows: Deque[Any] = deque()
self.description = None
self._rows = []
def close(self) -> None:
self._rows.clear()
def execute(
self,
operation: Any,
parameters: Optional[_DBAPISingleExecuteParams] = None,
) -> Any:
def close(self):
self._rows[:] = []
def execute(self, operation, parameters=None):
try:
_cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor()) # type: ignore[arg-type] # noqa: E501
_cursor = self.await_(self._connection.cursor())
if parameters is None:
self.await_(_cursor.execute(operation))
@@ -163,7 +132,7 @@ class AsyncAdapt_aiosqlite_cursor:
self.lastrowid = self.rowcount = -1
if not self.server_side:
self._rows = deque(self.await_(_cursor.fetchall()))
self._rows = self.await_(_cursor.fetchall())
else:
self.description = None
self.lastrowid = _cursor.lastrowid
@@ -172,17 +141,13 @@ class AsyncAdapt_aiosqlite_cursor:
if not self.server_side:
self.await_(_cursor.close())
else:
self._cursor = _cursor # type: ignore[misc]
self._cursor = _cursor
except Exception as error:
self._adapt_connection._handle_exception(error)
def executemany(
self,
operation: Any,
seq_of_parameters: _DBAPIMultiExecuteParams,
) -> Any:
def executemany(self, operation, seq_of_parameters):
try:
_cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor()) # type: ignore[arg-type] # noqa: E501
_cursor = self.await_(self._connection.cursor())
self.await_(_cursor.executemany(operation, seq_of_parameters))
self.description = None
self.lastrowid = _cursor.lastrowid
@@ -191,29 +156,30 @@ class AsyncAdapt_aiosqlite_cursor:
except Exception as error:
self._adapt_connection._handle_exception(error)
def setinputsizes(self, *inputsizes: Any) -> None:
def setinputsizes(self, *inputsizes):
pass
def __iter__(self) -> Iterator[Any]:
def __iter__(self):
while self._rows:
yield self._rows.popleft()
yield self._rows.pop(0)
def fetchone(self) -> Optional[Any]:
def fetchone(self):
if self._rows:
return self._rows.popleft()
return self._rows.pop(0)
else:
return None
def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
rr = self._rows
return [rr.popleft() for _ in range(min(size, len(rr)))]
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self) -> Sequence[Any]:
retval = list(self._rows)
self._rows.clear()
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
@@ -224,27 +190,24 @@ class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
server_side = True
def __init__(self, *arg: Any, **kw: Any) -> None:
def __init__(self, *arg, **kw):
super().__init__(*arg, **kw)
self._cursor: Optional[AsyncIODBAPICursor] = None
self._cursor = None
def close(self) -> None:
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self) -> Optional[Any]:
assert self._cursor is not None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
assert self._cursor is not None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self) -> Sequence[Any]:
assert self._cursor is not None
def fetchall(self):
return self.await_(self._cursor.fetchall())
@@ -252,24 +215,22 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
await_ = staticmethod(await_only)
__slots__ = ("dbapi",)
def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection) -> None:
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
@property
def isolation_level(self) -> Optional[str]:
return cast(str, self._connection.isolation_level)
def isolation_level(self):
return self._connection.isolation_level
@isolation_level.setter
def isolation_level(self, value: Optional[str]) -> None:
def isolation_level(self, value):
# aiosqlite's isolation_level setter works outside the Thread
# that it's supposed to, necessitating setting check_same_thread=False.
# for improved stability, we instead invent our own awaitable version
# using aiosqlite's async queue directly.
def set_iso(
connection: AsyncAdapt_aiosqlite_connection, value: Optional[str]
) -> None:
def set_iso(connection, value):
connection.isolation_level = value
function = partial(set_iso, self._connection._conn, value)
@@ -278,38 +239,38 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
self._connection._tx.put_nowait((future, function))
try:
self.await_(future)
return self.await_(future)
except Exception as error:
self._handle_exception(error)
def create_function(self, *args: Any, **kw: Any) -> None:
def create_function(self, *args, **kw):
try:
self.await_(self._connection.create_function(*args, **kw))
except Exception as error:
self._handle_exception(error)
def cursor(self, server_side: bool = False) -> AsyncAdapt_aiosqlite_cursor:
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_aiosqlite_ss_cursor(self)
else:
return AsyncAdapt_aiosqlite_cursor(self)
def execute(self, *args: Any, **kw: Any) -> Any:
def execute(self, *args, **kw):
return self.await_(self._connection.execute(*args, **kw))
def rollback(self) -> None:
def rollback(self):
try:
self.await_(self._connection.rollback())
except Exception as error:
self._handle_exception(error)
def commit(self) -> None:
def commit(self):
try:
self.await_(self._connection.commit())
except Exception as error:
self._handle_exception(error)
def close(self) -> None:
def close(self):
try:
self.await_(self._connection.close())
except ValueError:
@@ -325,7 +286,7 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
except Exception as error:
self._handle_exception(error)
def _handle_exception(self, error: Exception) -> NoReturn:
def _handle_exception(self, error):
if (
isinstance(error, ValueError)
and error.args[0] == "no active connection"
@@ -343,14 +304,14 @@ class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
await_ = staticmethod(await_fallback)
class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType):
class AsyncAdapt_aiosqlite_dbapi:
def __init__(self, aiosqlite, sqlite):
self.aiosqlite = aiosqlite
self.sqlite = sqlite
self.paramstyle = "qmark"
self._init_dbapi_attributes()
def _init_dbapi_attributes(self) -> None:
def _init_dbapi_attributes(self):
for name in (
"DatabaseError",
"Error",
@@ -369,7 +330,7 @@ class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
for name in ("Binary",):
setattr(self, name, getattr(self.sqlite, name))
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiosqlite_connection:
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", None)
@@ -393,7 +354,7 @@ class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
def create_server_side_cursor(self) -> DBAPICursor:
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(server_side=True)
@@ -408,25 +369,19 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
execution_ctx_cls = SQLiteExecutionContext_aiosqlite
@classmethod
def import_dbapi(cls) -> AsyncAdapt_aiosqlite_dbapi:
def import_dbapi(cls):
return AsyncAdapt_aiosqlite_dbapi(
__import__("aiosqlite"), __import__("sqlite3")
)
@classmethod
def get_pool_class(cls, url: URL) -> type[pool.Pool]:
def get_pool_class(cls, url):
if cls._is_url_file_db(url):
return pool.AsyncAdaptedQueuePool
return pool.NullPool
else:
return pool.StaticPool
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
self.dbapi = cast("DBAPIModule", self.dbapi)
def is_disconnect(self, e, connection, cursor):
if isinstance(
e, self.dbapi.OperationalError
) and "no active connection" in str(e):
@@ -434,10 +389,8 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
return super().is_disconnect(e, connection, cursor)
def get_driver_connection(
self, connection: DBAPIConnection
) -> AsyncIODBAPIConnection:
return connection._connection # type: ignore[no-any-return]
def get_driver_connection(self, connection):
return connection._connection
dialect = SQLiteDialect_aiosqlite

View File

@@ -1,5 +1,5 @@
# dialects/sqlite/dml.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# sqlite/dml.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -7,10 +7,6 @@
from __future__ import annotations
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from .._typing import _OnConflictIndexElementsT
from .._typing import _OnConflictIndexWhereT
@@ -19,7 +15,6 @@ from .._typing import _OnConflictWhereT
from ... import util
from ...sql import coercions
from ...sql import roles
from ...sql import schema
from ...sql._typing import _DMLTableArgument
from ...sql.base import _exclusive_against
from ...sql.base import _generative
@@ -27,9 +22,7 @@ from ...sql.base import ColumnCollection
from ...sql.base import ReadOnlyColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.elements import ColumnElement
from ...sql.elements import KeyedColumnElement
from ...sql.elements import TextClause
from ...sql.expression import alias
from ...util.typing import Self
@@ -148,10 +141,11 @@ class Insert(StandardInsert):
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
:param where:
Optional argument. An expression object representing a ``WHERE``
clause that restricts the rows affected by ``DO UPDATE SET``. Rows not
meeting the ``WHERE`` condition will not be updated (effectively a
``DO NOTHING`` for those rows).
Optional argument. If present, can be a literal SQL
string or an acceptable expression for a ``WHERE`` clause
that restricts the rows affected by ``DO UPDATE SET``. Rows
not meeting the ``WHERE`` condition will not be updated
(effectively a ``DO NOTHING`` for those rows).
"""
@@ -190,10 +184,9 @@ class Insert(StandardInsert):
class OnConflictClause(ClauseElement):
stringify_dialect = "sqlite"
inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]]
inferred_target_whereclause: Optional[
Union[ColumnElement[Any], TextClause]
]
constraint_target: None
inferred_target_elements: _OnConflictIndexElementsT
inferred_target_whereclause: _OnConflictIndexWhereT
def __init__(
self,
@@ -201,22 +194,13 @@ class OnConflictClause(ClauseElement):
index_where: _OnConflictIndexWhereT = None,
):
if index_elements is not None:
self.inferred_target_elements = [
coercions.expect(roles.DDLConstraintColumnRole, column)
for column in index_elements
]
self.inferred_target_whereclause = (
coercions.expect(
roles.WhereHavingRole,
index_where,
)
if index_where is not None
else None
)
self.constraint_target = None
self.inferred_target_elements = index_elements
self.inferred_target_whereclause = index_where
else:
self.inferred_target_elements = (
self.inferred_target_whereclause
) = None
self.constraint_target = (
self.inferred_target_elements
) = self.inferred_target_whereclause = None
class OnConflictDoNothing(OnConflictClause):
@@ -226,9 +210,6 @@ class OnConflictDoNothing(OnConflictClause):
class OnConflictDoUpdate(OnConflictClause):
__visit_name__ = "on_conflict_do_update"
update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]]
update_whereclause: Optional[ColumnElement[Any]]
def __init__(
self,
index_elements: _OnConflictIndexElementsT = None,
@@ -256,8 +237,4 @@ class OnConflictDoUpdate(OnConflictClause):
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
]
self.update_whereclause = (
coercions.expect(roles.WhereHavingRole, where)
if where is not None
else None
)
self.update_whereclause = where

View File

@@ -1,9 +1,3 @@
# dialects/sqlite/json.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from ... import types as sqltypes

View File

@@ -1,9 +1,3 @@
# dialects/sqlite/provision.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import os
@@ -52,6 +46,8 @@ def _format_url(url, driver, ident):
assert "test_schema" not in filename
tokens = re.split(r"[_\.]", filename)
new_filename = f"{driver}"
for token in tokens:
if token in _drivernames:
if driver is None:

View File

@@ -1,5 +1,5 @@
# dialects/sqlite/pysqlcipher.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# sqlite/pysqlcipher.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -39,7 +39,7 @@ Current dialect selection logic is:
e = create_engine(
"sqlite+pysqlcipher://:password@/dbname.db",
module=sqlcipher_compatible_driver,
module=sqlcipher_compatible_driver
)
These drivers make use of the SQLCipher engine. This system essentially
@@ -55,12 +55,12 @@ The format of the connect string is in every way the same as that
of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
"password" field is now accepted, which should contain a passphrase::
e = create_engine("sqlite+pysqlcipher://:testing@/foo.db")
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
For an absolute file path, two leading slashes should be used for the
database name::
e = create_engine("sqlite+pysqlcipher://:testing@//path/to/foo.db")
e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
A selection of additional encryption-related pragmas supported by SQLCipher
as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
@@ -68,9 +68,7 @@ in the query string, and will result in that PRAGMA being called for each
new connection. Currently, ``cipher``, ``kdf_iter``
``cipher_page_size`` and ``cipher_use_hmac`` are supported::
e = create_engine(
"sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000"
)
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
.. warning:: Previous versions of sqlalchemy did not take into consideration
the encryption-related pragmas passed in the url string, that were silently

View File

@@ -1,5 +1,5 @@
# dialects/sqlite/pysqlite.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# sqlite/pysqlite.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
@@ -28,9 +28,7 @@ Connect Strings
---------------
The file specification for the SQLite database is taken as the "database"
portion of the URL. Note that the format of a SQLAlchemy url is:
.. sourcecode:: text
portion of the URL. Note that the format of a SQLAlchemy url is::
driver://user:pass@host/database
@@ -39,28 +37,25 @@ the **right** of the third slash. So connecting to a relative filepath
looks like::
# relative path
e = create_engine("sqlite:///path/to/database.db")
e = create_engine('sqlite:///path/to/database.db')
An absolute path, which is denoted by starting with a slash, means you
need **four** slashes::
# absolute path
e = create_engine("sqlite:////path/to/database.db")
e = create_engine('sqlite:////path/to/database.db')
To use a Windows path, regular drive specifications and backslashes can be
used. Double backslashes are probably needed::
# absolute path on Windows
e = create_engine("sqlite:///C:\\path\\to\\database.db")
e = create_engine('sqlite:///C:\\path\\to\\database.db')
To use sqlite ``:memory:`` database specify it as the filename using
``sqlite:///:memory:``. It's also the default if no filepath is
present, specifying only ``sqlite://`` and nothing else::
The sqlite ``:memory:`` identifier is the default if no filepath is
present. Specify ``sqlite://`` and nothing else::
# in-memory database (note three slashes)
e = create_engine("sqlite:///:memory:")
# also in-memory database
e2 = create_engine("sqlite://")
# in-memory database
e = create_engine('sqlite://')
.. _pysqlite_uri_connections:
@@ -100,9 +95,7 @@ Above, the pysqlite / sqlite3 DBAPI would be passed arguments as::
sqlite3.connect(
"file:path/to/database?mode=ro&nolock=1",
check_same_thread=True,
timeout=10,
uri=True,
check_same_thread=True, timeout=10, uri=True
)
Regarding future parameters added to either the Python or native drivers. new
@@ -148,11 +141,8 @@ as follows::
def regexp(a, b):
return re.search(a, b) is not None
sqlite_connection.create_function(
"regexp",
2,
regexp,
"regexp", 2, regexp,
)
There is currently no support for regular expression flags as a separate
@@ -193,12 +183,10 @@ Keeping in mind that pysqlite's parsing option is not recommended,
nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES
can be forced if one configures "native_datetime=True" on create_engine()::
engine = create_engine(
"sqlite://",
connect_args={
"detect_types": sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
},
native_datetime=True,
engine = create_engine('sqlite://',
connect_args={'detect_types':
sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
native_datetime=True
)
With this flag enabled, the DATE and TIMESTAMP types (but note - not the
@@ -253,7 +241,6 @@ Pooling may be disabled for a file based database by specifying the
parameter::
from sqlalchemy import NullPool
engine = create_engine("sqlite:///myfile.db", poolclass=NullPool)
It's been observed that the :class:`.NullPool` implementation incurs an
@@ -273,12 +260,9 @@ globally, and the ``check_same_thread`` flag can be passed to Pysqlite
as ``False``::
from sqlalchemy.pool import StaticPool
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
engine = create_engine('sqlite://',
connect_args={'check_same_thread':False},
poolclass=StaticPool)
Note that using a ``:memory:`` database in multiple threads requires a recent
version of SQLite.
@@ -297,14 +281,14 @@ needed within multiple threads for this case::
# maintain the same connection per thread
from sqlalchemy.pool import SingletonThreadPool
engine = create_engine("sqlite:///mydb.db", poolclass=SingletonThreadPool)
engine = create_engine('sqlite:///mydb.db',
poolclass=SingletonThreadPool)
# maintain the same connection across all threads
from sqlalchemy.pool import StaticPool
engine = create_engine("sqlite:///mydb.db", poolclass=StaticPool)
engine = create_engine('sqlite:///mydb.db',
poolclass=StaticPool)
Note that :class:`.SingletonThreadPool` should be configured for the number
of threads that are to be used; beyond that number, connections will be
@@ -333,14 +317,13 @@ same column, use a custom type that will check each row individually::
from sqlalchemy import String
from sqlalchemy import TypeDecorator
class MixedBinary(TypeDecorator):
impl = String
cache_ok = True
def process_result_value(self, value, dialect):
if isinstance(value, str):
value = bytes(value, "utf-8")
value = bytes(value, 'utf-8')
elif value is not None:
value = bytes(value)
@@ -354,10 +337,74 @@ Then use the above ``MixedBinary`` datatype in the place where
Serializable isolation / Savepoints / Transactional DDL
-------------------------------------------------------
A newly revised version of this important section is now available
at the top level of the SQLAlchemy SQLite documentation, in the section
:ref:`sqlite_transactions`.
In the section :ref:`sqlite_concurrency`, we refer to the pysqlite
driver's assortment of issues that prevent several features of SQLite
from working correctly. The pysqlite DBAPI driver has several
long-standing bugs which impact the correctness of its transactional
behavior. In its default mode of operation, SQLite features such as
SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are
non-functional, and in order to use these features, workarounds must
be taken.
The issue is essentially that the driver attempts to second-guess the user's
intent, failing to start transactions and sometimes ending them prematurely, in
an effort to minimize the SQLite databases's file locking behavior, even
though SQLite itself uses "shared" locks for read-only activities.
SQLAlchemy chooses to not alter this behavior by default, as it is the
long-expected behavior of the pysqlite driver; if and when the pysqlite
driver attempts to repair these issues, that will be more of a driver towards
defaults for SQLAlchemy.
The good news is that with a few events, we can implement transactional
support fully, by disabling pysqlite's feature entirely and emitting BEGIN
ourselves. This is achieved using two event listeners::
from sqlalchemy import create_engine, event
engine = create_engine("sqlite:///myfile.db")
@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable pysqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None
@event.listens_for(engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.exec_driver_sql("BEGIN")
.. warning:: When using the above recipe, it is advised to not use the
:paramref:`.Connection.execution_options.isolation_level` setting on
:class:`_engine.Connection` and :func:`_sa.create_engine`
with the SQLite driver,
as this function necessarily will also alter the ".isolation_level" setting.
Above, we intercept a new pysqlite connection and disable any transactional
integration. Then, at the point at which SQLAlchemy knows that transaction
scope is to begin, we emit ``"BEGIN"`` ourselves.
When we take control of ``"BEGIN"``, we can also control directly SQLite's
locking modes, introduced at
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_,
by adding the desired locking mode to our ``"BEGIN"``::
@event.listens_for(engine, "begin")
def do_begin(conn):
conn.exec_driver_sql("BEGIN EXCLUSIVE")
.. seealso::
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_ -
on the SQLite site
`sqlite3 SELECT does not BEGIN a transaction <https://bugs.python.org/issue9924>`_ -
on the Python bug tracker
`sqlite3 module breaks transactions and potentially corrupts data <https://bugs.python.org/issue10740>`_ -
on the Python bug tracker
.. _pysqlite_udfs:
@@ -392,16 +439,12 @@ connection when it is created. That is accomplished with an event listener::
with engine.connect() as conn:
print(conn.scalar(text("SELECT UDF()")))
""" # noqa
from __future__ import annotations
import math
import os
import re
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from .base import DATE
from .base import DATETIME
@@ -411,13 +454,6 @@ from ... import pool
from ... import types as sqltypes
from ... import util
if TYPE_CHECKING:
from ...engine.interfaces import DBAPIConnection
from ...engine.interfaces import DBAPICursor
from ...engine.interfaces import DBAPIModule
from ...engine.url import URL
from ...pool.base import PoolProxiedConnection
class _SQLite_pysqliteTimeStamp(DATETIME):
def bind_processor(self, dialect):
@@ -471,7 +507,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
return sqlite
@classmethod
def _is_url_file_db(cls, url: URL):
def _is_url_file_db(cls, url):
if (url.database and url.database != ":memory:") and (
url.query.get("mode", None) != "memory"
):
@@ -502,9 +538,6 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
dbapi_connection.isolation_level = ""
return super().set_isolation_level(dbapi_connection, level)
def detect_autocommit_setting(self, dbapi_connection):
return dbapi_connection.isolation_level is None
def on_connect(self):
def regexp(a, b):
if b is None:
@@ -604,13 +637,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
return ([filename], pysqlite_opts)
def is_disconnect(
self,
e: DBAPIModule.Error,
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
cursor: Optional[DBAPICursor],
) -> bool:
self.dbapi = cast("DBAPIModule", self.dbapi)
def is_disconnect(self, e, connection, cursor):
return isinstance(
e, self.dbapi.ProgrammingError
) and "Cannot operate on a closed database." in str(e)