This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# dialects/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
@@ -40,7 +39,7 @@ def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]:
|
||||
# hardcoded. if mysql / mariadb etc were third party dialects
|
||||
# they would just publish all the entrypoints, which would actually
|
||||
# look much nicer.
|
||||
module: Any = __import__(
|
||||
module = __import__(
|
||||
"sqlalchemy.dialects.mysql.mariadb"
|
||||
).dialects.mysql.mariadb
|
||||
return module.loader(driver) # type: ignore
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
# dialects/_typing.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
@@ -12,19 +6,14 @@ from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from ..sql import roles
|
||||
from ..sql.base import ColumnCollection
|
||||
from ..sql.schema import Column
|
||||
from ..sql._typing import _DDLColumnArgument
|
||||
from ..sql.elements import DQLDMLClauseElement
|
||||
from ..sql.schema import ColumnCollectionConstraint
|
||||
from ..sql.schema import Index
|
||||
|
||||
|
||||
_OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None]
|
||||
_OnConflictIndexElementsT = Optional[
|
||||
Iterable[Union[Column[Any], str, roles.DDLConstraintColumnRole]]
|
||||
]
|
||||
_OnConflictIndexWhereT = Optional[roles.WhereHavingRole]
|
||||
_OnConflictSetT = Optional[
|
||||
Union[Mapping[Any, Any], ColumnCollection[Any, Any]]
|
||||
]
|
||||
_OnConflictWhereT = Optional[roles.WhereHavingRole]
|
||||
_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]]
|
||||
_OnConflictIndexWhereT = Optional[DQLDMLClauseElement]
|
||||
_OnConflictSetT = Optional[Mapping[Any, Any]]
|
||||
_OnConflictWhereT = Union[DQLDMLClauseElement, str, None]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mssql/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mssql/__init__.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mssql/aioodbc.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mssql/aioodbc.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -32,12 +32,13 @@ This dialect should normally be used only with the
|
||||
styles are otherwise equivalent to those documented in the pyodbc section::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
engine = create_async_engine(
|
||||
"mssql+aioodbc://scott:tiger@mssql2017:1433/test?"
|
||||
"driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
|
||||
)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mssql/base.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mssql/base.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -9,6 +9,7 @@
|
||||
"""
|
||||
.. dialect:: mssql
|
||||
:name: Microsoft SQL Server
|
||||
:full_support: 2017
|
||||
:normal_support: 2012+
|
||||
:best_effort: 2005+
|
||||
|
||||
@@ -39,12 +40,9 @@ considered to be the identity column - unless it is associated with a
|
||||
from sqlalchemy import Table, MetaData, Column, Integer
|
||||
|
||||
m = MetaData()
|
||||
t = Table(
|
||||
"t",
|
||||
m,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("x", Integer),
|
||||
)
|
||||
t = Table('t', m,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('x', Integer))
|
||||
m.create_all(engine)
|
||||
|
||||
The above example will generate DDL as:
|
||||
@@ -62,12 +60,9 @@ specify ``False`` for the :paramref:`_schema.Column.autoincrement` flag,
|
||||
on the first integer primary key column::
|
||||
|
||||
m = MetaData()
|
||||
t = Table(
|
||||
"t",
|
||||
m,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("x", Integer),
|
||||
)
|
||||
t = Table('t', m,
|
||||
Column('id', Integer, primary_key=True, autoincrement=False),
|
||||
Column('x', Integer))
|
||||
m.create_all(engine)
|
||||
|
||||
To add the ``IDENTITY`` keyword to a non-primary key column, specify
|
||||
@@ -77,12 +72,9 @@ To add the ``IDENTITY`` keyword to a non-primary key column, specify
|
||||
is set to ``False`` on any integer primary key column::
|
||||
|
||||
m = MetaData()
|
||||
t = Table(
|
||||
"t",
|
||||
m,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("x", Integer, autoincrement=True),
|
||||
)
|
||||
t = Table('t', m,
|
||||
Column('id', Integer, primary_key=True, autoincrement=False),
|
||||
Column('x', Integer, autoincrement=True))
|
||||
m.create_all(engine)
|
||||
|
||||
.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct
|
||||
@@ -145,12 +137,14 @@ parameters passed to the :class:`_schema.Identity` object::
|
||||
from sqlalchemy import Table, Integer, Column, Identity
|
||||
|
||||
test = Table(
|
||||
"test",
|
||||
metadata,
|
||||
'test', metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, Identity(start=100, increment=10)
|
||||
'id',
|
||||
Integer,
|
||||
primary_key=True,
|
||||
Identity(start=100, increment=10)
|
||||
),
|
||||
Column("name", String(20)),
|
||||
Column('name', String(20))
|
||||
)
|
||||
|
||||
The CREATE TABLE for the above :class:`_schema.Table` object would be:
|
||||
@@ -160,7 +154,7 @@ The CREATE TABLE for the above :class:`_schema.Table` object would be:
|
||||
CREATE TABLE test (
|
||||
id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
|
||||
name VARCHAR(20) NULL,
|
||||
)
|
||||
)
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -193,7 +187,6 @@ type deployed to the SQL Server database can be specified as ``Numeric`` using
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class TestTable(Base):
|
||||
__tablename__ = "test"
|
||||
id = Column(
|
||||
@@ -219,9 +212,8 @@ integer values in Python 3), use :class:`_types.TypeDecorator` as follows::
|
||||
|
||||
from sqlalchemy import TypeDecorator
|
||||
|
||||
|
||||
class NumericAsInteger(TypeDecorator):
|
||||
"normalize floating point return values into ints"
|
||||
'''normalize floating point return values into ints'''
|
||||
|
||||
impl = Numeric(10, 0, asdecimal=False)
|
||||
cache_ok = True
|
||||
@@ -231,7 +223,6 @@ integer values in Python 3), use :class:`_types.TypeDecorator` as follows::
|
||||
value = int(value)
|
||||
return value
|
||||
|
||||
|
||||
class TestTable(Base):
|
||||
__tablename__ = "test"
|
||||
id = Column(
|
||||
@@ -280,11 +271,11 @@ The process for fetching this value has several variants:
|
||||
fetched in order to receive the value. Given a table as::
|
||||
|
||||
t = Table(
|
||||
"t",
|
||||
't',
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("x", Integer),
|
||||
implicit_returning=False,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('x', Integer),
|
||||
implicit_returning=False
|
||||
)
|
||||
|
||||
an INSERT will look like:
|
||||
@@ -310,13 +301,12 @@ statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the
|
||||
execution. Given this example::
|
||||
|
||||
m = MetaData()
|
||||
t = Table(
|
||||
"t", m, Column("id", Integer, primary_key=True), Column("x", Integer)
|
||||
)
|
||||
t = Table('t', m, Column('id', Integer, primary_key=True),
|
||||
Column('x', Integer))
|
||||
m.create_all(engine)
|
||||
|
||||
with engine.begin() as conn:
|
||||
conn.execute(t.insert(), {"id": 1, "x": 1}, {"id": 2, "x": 2})
|
||||
conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2})
|
||||
|
||||
The above column will be created with IDENTITY, however the INSERT statement
|
||||
we emit is specifying explicit values. In the echo output we can see
|
||||
@@ -352,11 +342,7 @@ The :class:`.Sequence` object creates "real" sequences, i.e.,
|
||||
>>> from sqlalchemy import Sequence
|
||||
>>> from sqlalchemy.schema import CreateSequence
|
||||
>>> from sqlalchemy.dialects import mssql
|
||||
>>> print(
|
||||
... CreateSequence(Sequence("my_seq", start=1)).compile(
|
||||
... dialect=mssql.dialect()
|
||||
... )
|
||||
... )
|
||||
>>> print(CreateSequence(Sequence("my_seq", start=1)).compile(dialect=mssql.dialect()))
|
||||
{printsql}CREATE SEQUENCE my_seq START WITH 1
|
||||
|
||||
For integer primary key generation, SQL Server's ``IDENTITY`` construct should
|
||||
@@ -390,12 +376,12 @@ more than one backend without using dialect-specific types.
|
||||
To build a SQL Server VARCHAR or NVARCHAR with MAX length, use None::
|
||||
|
||||
my_table = Table(
|
||||
"my_table",
|
||||
metadata,
|
||||
Column("my_data", VARCHAR(None)),
|
||||
Column("my_n_data", NVARCHAR(None)),
|
||||
'my_table', metadata,
|
||||
Column('my_data', VARCHAR(None)),
|
||||
Column('my_n_data', NVARCHAR(None))
|
||||
)
|
||||
|
||||
|
||||
Collation Support
|
||||
-----------------
|
||||
|
||||
@@ -403,13 +389,10 @@ Character collations are supported by the base string types,
|
||||
specified by the string argument "collation"::
|
||||
|
||||
from sqlalchemy import VARCHAR
|
||||
|
||||
Column("login", VARCHAR(32, collation="Latin1_General_CI_AS"))
|
||||
Column('login', VARCHAR(32, collation='Latin1_General_CI_AS'))
|
||||
|
||||
When such a column is associated with a :class:`_schema.Table`, the
|
||||
CREATE TABLE statement for this column will yield:
|
||||
|
||||
.. sourcecode:: sql
|
||||
CREATE TABLE statement for this column will yield::
|
||||
|
||||
login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL
|
||||
|
||||
@@ -429,9 +412,7 @@ versions when no OFFSET clause is present. A statement such as::
|
||||
|
||||
select(some_table).limit(5)
|
||||
|
||||
will render similarly to:
|
||||
|
||||
.. sourcecode:: sql
|
||||
will render similarly to::
|
||||
|
||||
SELECT TOP 5 col1, col2.. FROM table
|
||||
|
||||
@@ -441,9 +422,7 @@ LIMIT and OFFSET, or just OFFSET alone, will be rendered using the
|
||||
|
||||
select(some_table).order_by(some_table.c.col3).limit(5).offset(10)
|
||||
|
||||
will render similarly to:
|
||||
|
||||
.. sourcecode:: sql
|
||||
will render similarly to::
|
||||
|
||||
SELECT anon_1.col1, anon_1.col2 FROM (SELECT col1, col2,
|
||||
ROW_NUMBER() OVER (ORDER BY col3) AS
|
||||
@@ -496,13 +475,16 @@ each new connection.
|
||||
To set isolation level using :func:`_sa.create_engine`::
|
||||
|
||||
engine = create_engine(
|
||||
"mssql+pyodbc://scott:tiger@ms_2008", isolation_level="REPEATABLE READ"
|
||||
"mssql+pyodbc://scott:tiger@ms_2008",
|
||||
isolation_level="REPEATABLE READ"
|
||||
)
|
||||
|
||||
To set using per-connection execution options::
|
||||
|
||||
connection = engine.connect()
|
||||
connection = connection.execution_options(isolation_level="READ COMMITTED")
|
||||
connection = connection.execution_options(
|
||||
isolation_level="READ COMMITTED"
|
||||
)
|
||||
|
||||
Valid values for ``isolation_level`` include:
|
||||
|
||||
@@ -552,6 +534,7 @@ will remain consistent with the state of the transaction::
|
||||
|
||||
mssql_engine = create_engine(
|
||||
"mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server",
|
||||
|
||||
# disable default reset-on-return scheme
|
||||
pool_reset_on_return=None,
|
||||
)
|
||||
@@ -580,17 +563,13 @@ Nullability
|
||||
-----------
|
||||
MSSQL has support for three levels of column nullability. The default
|
||||
nullability allows nulls and is explicit in the CREATE TABLE
|
||||
construct:
|
||||
|
||||
.. sourcecode:: sql
|
||||
construct::
|
||||
|
||||
name VARCHAR(20) NULL
|
||||
|
||||
If ``nullable=None`` is specified then no specification is made. In
|
||||
other words the database's configured default is used. This will
|
||||
render:
|
||||
|
||||
.. sourcecode:: sql
|
||||
render::
|
||||
|
||||
name VARCHAR(20)
|
||||
|
||||
@@ -646,9 +625,8 @@ behavior of this flag is as follows:
|
||||
* The flag can be set to either ``True`` or ``False`` when the dialect
|
||||
is created, typically via :func:`_sa.create_engine`::
|
||||
|
||||
eng = create_engine(
|
||||
"mssql+pymssql://user:pass@host/db", deprecate_large_types=True
|
||||
)
|
||||
eng = create_engine("mssql+pymssql://user:pass@host/db",
|
||||
deprecate_large_types=True)
|
||||
|
||||
* Complete control over whether the "old" or "new" types are rendered is
|
||||
available in all SQLAlchemy versions by using the UPPERCASE type objects
|
||||
@@ -670,10 +648,9 @@ at once using the :paramref:`_schema.Table.schema` argument of
|
||||
:class:`_schema.Table`::
|
||||
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
"some_table", metadata,
|
||||
Column("q", String(50)),
|
||||
schema="mydatabase.dbo",
|
||||
schema="mydatabase.dbo"
|
||||
)
|
||||
|
||||
When performing operations such as table or component reflection, a schema
|
||||
@@ -685,10 +662,9 @@ components will be quoted separately for case sensitive names and other
|
||||
special characters. Given an argument as below::
|
||||
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
"some_table", metadata,
|
||||
Column("q", String(50)),
|
||||
schema="MyDataBase.dbo",
|
||||
schema="MyDataBase.dbo"
|
||||
)
|
||||
|
||||
The above schema would be rendered as ``[MyDataBase].dbo``, and also in
|
||||
@@ -701,22 +677,21 @@ Below, the "owner" will be considered as ``MyDataBase.dbo`` and the
|
||||
"database" will be None::
|
||||
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
"some_table", metadata,
|
||||
Column("q", String(50)),
|
||||
schema="[MyDataBase.dbo]",
|
||||
schema="[MyDataBase.dbo]"
|
||||
)
|
||||
|
||||
To individually specify both database and owner name with special characters
|
||||
or embedded dots, use two sets of brackets::
|
||||
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
"some_table", metadata,
|
||||
Column("q", String(50)),
|
||||
schema="[MyDataBase.Period].[MyOwner.Dot]",
|
||||
schema="[MyDataBase.Period].[MyOwner.Dot]"
|
||||
)
|
||||
|
||||
|
||||
.. versionchanged:: 1.2 the SQL Server dialect now treats brackets as
|
||||
identifier delimiters splitting the schema into separate database
|
||||
and owner tokens, to allow dots within either name itself.
|
||||
@@ -731,11 +706,10 @@ schema-qualified table would be auto-aliased when used in a
|
||||
SELECT statement; given a table::
|
||||
|
||||
account_table = Table(
|
||||
"account",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("info", String(100)),
|
||||
schema="customer_schema",
|
||||
'account', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('info', String(100)),
|
||||
schema="customer_schema"
|
||||
)
|
||||
|
||||
this legacy mode of rendering would assume that "customer_schema.account"
|
||||
@@ -778,55 +752,37 @@ which renders the index as ``CREATE CLUSTERED INDEX my_index ON table (x)``.
|
||||
|
||||
To generate a clustered primary key use::
|
||||
|
||||
Table(
|
||||
"my_table",
|
||||
metadata,
|
||||
Column("x", ...),
|
||||
Column("y", ...),
|
||||
PrimaryKeyConstraint("x", "y", mssql_clustered=True),
|
||||
)
|
||||
Table('my_table', metadata,
|
||||
Column('x', ...),
|
||||
Column('y', ...),
|
||||
PrimaryKeyConstraint("x", "y", mssql_clustered=True))
|
||||
|
||||
which will render the table, for example, as:
|
||||
which will render the table, for example, as::
|
||||
|
||||
.. sourcecode:: sql
|
||||
|
||||
CREATE TABLE my_table (
|
||||
x INTEGER NOT NULL,
|
||||
y INTEGER NOT NULL,
|
||||
PRIMARY KEY CLUSTERED (x, y)
|
||||
)
|
||||
CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL,
|
||||
PRIMARY KEY CLUSTERED (x, y))
|
||||
|
||||
Similarly, we can generate a clustered unique constraint using::
|
||||
|
||||
Table(
|
||||
"my_table",
|
||||
metadata,
|
||||
Column("x", ...),
|
||||
Column("y", ...),
|
||||
PrimaryKeyConstraint("x"),
|
||||
UniqueConstraint("y", mssql_clustered=True),
|
||||
)
|
||||
Table('my_table', metadata,
|
||||
Column('x', ...),
|
||||
Column('y', ...),
|
||||
PrimaryKeyConstraint("x"),
|
||||
UniqueConstraint("y", mssql_clustered=True),
|
||||
)
|
||||
|
||||
To explicitly request a non-clustered primary key (for example, when
|
||||
a separate clustered index is desired), use::
|
||||
|
||||
Table(
|
||||
"my_table",
|
||||
metadata,
|
||||
Column("x", ...),
|
||||
Column("y", ...),
|
||||
PrimaryKeyConstraint("x", "y", mssql_clustered=False),
|
||||
)
|
||||
Table('my_table', metadata,
|
||||
Column('x', ...),
|
||||
Column('y', ...),
|
||||
PrimaryKeyConstraint("x", "y", mssql_clustered=False))
|
||||
|
||||
which will render the table, for example, as:
|
||||
which will render the table, for example, as::
|
||||
|
||||
.. sourcecode:: sql
|
||||
|
||||
CREATE TABLE my_table (
|
||||
x INTEGER NOT NULL,
|
||||
y INTEGER NOT NULL,
|
||||
PRIMARY KEY NONCLUSTERED (x, y)
|
||||
)
|
||||
CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL,
|
||||
PRIMARY KEY NONCLUSTERED (x, y))
|
||||
|
||||
Columnstore Index Support
|
||||
-------------------------
|
||||
@@ -864,7 +820,7 @@ INCLUDE
|
||||
The ``mssql_include`` option renders INCLUDE(colname) for the given string
|
||||
names::
|
||||
|
||||
Index("my_index", table.c.x, mssql_include=["y"])
|
||||
Index("my_index", table.c.x, mssql_include=['y'])
|
||||
|
||||
would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)``
|
||||
|
||||
@@ -919,19 +875,18 @@ To disable the usage of OUTPUT INSERTED on a per-table basis,
|
||||
specify ``implicit_returning=False`` for each :class:`_schema.Table`
|
||||
which has triggers::
|
||||
|
||||
Table(
|
||||
"mytable",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Table('mytable', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
# ...,
|
||||
implicit_returning=False,
|
||||
implicit_returning=False
|
||||
)
|
||||
|
||||
Declarative form::
|
||||
|
||||
class MyClass(Base):
|
||||
# ...
|
||||
__table_args__ = {"implicit_returning": False}
|
||||
__table_args__ = {'implicit_returning':False}
|
||||
|
||||
|
||||
.. _mssql_rowcount_versioning:
|
||||
|
||||
@@ -965,9 +920,7 @@ isolation mode that locks entire tables, and causes even mildly concurrent
|
||||
applications to have long held locks and frequent deadlocks.
|
||||
Enabling snapshot isolation for the database as a whole is recommended
|
||||
for modern levels of concurrency support. This is accomplished via the
|
||||
following ALTER DATABASE commands executed at the SQL prompt:
|
||||
|
||||
.. sourcecode:: sql
|
||||
following ALTER DATABASE commands executed at the SQL prompt::
|
||||
|
||||
ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON
|
||||
|
||||
@@ -1473,6 +1426,7 @@ class ROWVERSION(TIMESTAMP):
|
||||
|
||||
|
||||
class NTEXT(sqltypes.UnicodeText):
|
||||
|
||||
"""MSSQL NTEXT type, for variable-length unicode text up to 2^30
|
||||
characters."""
|
||||
|
||||
@@ -1597,11 +1551,44 @@ class MSUUid(sqltypes.Uuid):
|
||||
|
||||
def process(value):
|
||||
return f"""'{
|
||||
value.replace("-", "").replace("'", "''")
|
||||
}'"""
|
||||
value.replace("-", "").replace("'", "''")
|
||||
}'"""
|
||||
|
||||
return process
|
||||
|
||||
def _sentinel_value_resolver(self, dialect):
|
||||
"""Return a callable that will receive the uuid object or string
|
||||
as it is normally passed to the DB in the parameter set, after
|
||||
bind_processor() is called. Convert this value to match
|
||||
what it would be as coming back from an INSERT..OUTPUT inserted.
|
||||
|
||||
for the UUID type, there are four varieties of settings so here
|
||||
we seek to convert to the string or UUID representation that comes
|
||||
back from the driver.
|
||||
|
||||
"""
|
||||
character_based_uuid = (
|
||||
not dialect.supports_native_uuid or not self.native_uuid
|
||||
)
|
||||
|
||||
if character_based_uuid:
|
||||
if self.native_uuid:
|
||||
# for pyodbc, uuid.uuid() objects are accepted for incoming
|
||||
# data, as well as strings. but the driver will always return
|
||||
# uppercase strings in result sets.
|
||||
def process(value):
|
||||
return str(value).upper()
|
||||
|
||||
else:
|
||||
|
||||
def process(value):
|
||||
return str(value)
|
||||
|
||||
return process
|
||||
else:
|
||||
# for pymssql, we get uuid.uuid() objects back.
|
||||
return None
|
||||
|
||||
|
||||
class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]):
|
||||
__visit_name__ = "UNIQUEIDENTIFIER"
|
||||
@@ -1609,12 +1596,12 @@ class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]):
|
||||
@overload
|
||||
def __init__(
|
||||
self: UNIQUEIDENTIFIER[_python_UUID], as_uuid: Literal[True] = ...
|
||||
): ...
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...
|
||||
): ...
|
||||
def __init__(self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...):
|
||||
...
|
||||
|
||||
def __init__(self, as_uuid: bool = True):
|
||||
"""Construct a :class:`_mssql.UNIQUEIDENTIFIER` type.
|
||||
@@ -1865,6 +1852,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
|
||||
_enable_identity_insert = False
|
||||
_select_lastrowid = False
|
||||
_lastrowid = None
|
||||
_rowcount = None
|
||||
|
||||
dialect: MSDialect
|
||||
|
||||
@@ -1984,6 +1972,13 @@ class MSExecutionContext(default.DefaultExecutionContext):
|
||||
def get_lastrowid(self):
|
||||
return self._lastrowid
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
if self._rowcount is not None:
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
def handle_dbapi_exception(self, e):
|
||||
if self._enable_identity_insert:
|
||||
try:
|
||||
@@ -2035,10 +2030,6 @@ class MSSQLCompiler(compiler.SQLCompiler):
|
||||
self.tablealiases = {}
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _format_frame_clause(self, range_, **kw):
|
||||
kw["literal_execute"] = True
|
||||
return super()._format_frame_clause(range_, **kw)
|
||||
|
||||
def _with_legacy_schema_aliasing(fn):
|
||||
def decorate(self, *arg, **kw):
|
||||
if self.dialect.legacy_schema_aliasing:
|
||||
@@ -2492,12 +2483,10 @@ class MSSQLCompiler(compiler.SQLCompiler):
|
||||
type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % (
|
||||
self.process(binary.left, **kw),
|
||||
self.process(binary.right, **kw),
|
||||
(
|
||||
"FLOAT"
|
||||
if isinstance(binary.type, sqltypes.Float)
|
||||
else "NUMERIC(%s, %s)"
|
||||
% (binary.type.precision, binary.type.scale)
|
||||
),
|
||||
"FLOAT"
|
||||
if isinstance(binary.type, sqltypes.Float)
|
||||
else "NUMERIC(%s, %s)"
|
||||
% (binary.type.precision, binary.type.scale),
|
||||
)
|
||||
elif binary.type._type_affinity is sqltypes.Boolean:
|
||||
# the NULL handling is particularly weird with boolean, so
|
||||
@@ -2533,6 +2522,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
|
||||
|
||||
|
||||
class MSSQLStrictCompiler(MSSQLCompiler):
|
||||
|
||||
"""A subclass of MSSQLCompiler which disables the usage of bind
|
||||
parameters where not allowed natively by MS-SQL.
|
||||
|
||||
@@ -3632,36 +3622,27 @@ where
|
||||
@reflection.cache
|
||||
@_db_plus_owner
|
||||
def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
|
||||
sys_columns = ischema.sys_columns
|
||||
sys_types = ischema.sys_types
|
||||
sys_default_constraints = ischema.sys_default_constraints
|
||||
computed_cols = ischema.computed_columns
|
||||
identity_cols = ischema.identity_columns
|
||||
extended_properties = ischema.extended_properties
|
||||
|
||||
# to access sys tables, need an object_id.
|
||||
# object_id() can normally match to the unquoted name even if it
|
||||
# has special characters. however it also accepts quoted names,
|
||||
# which means for the special case that the name itself has
|
||||
# "quotes" (e.g. brackets for SQL Server) we need to "quote" (e.g.
|
||||
# bracket) that name anyway. Fixed as part of #12654
|
||||
|
||||
is_temp_table = tablename.startswith("#")
|
||||
if is_temp_table:
|
||||
owner, tablename = self._get_internal_temp_table_name(
|
||||
connection, tablename
|
||||
)
|
||||
|
||||
object_id_tokens = [self.identifier_preparer.quote(tablename)]
|
||||
columns = ischema.mssql_temp_table_columns
|
||||
else:
|
||||
columns = ischema.columns
|
||||
|
||||
computed_cols = ischema.computed_columns
|
||||
identity_cols = ischema.identity_columns
|
||||
if owner:
|
||||
object_id_tokens.insert(0, self.identifier_preparer.quote(owner))
|
||||
|
||||
if is_temp_table:
|
||||
object_id_tokens.insert(0, "tempdb")
|
||||
|
||||
object_id = func.object_id(".".join(object_id_tokens))
|
||||
|
||||
whereclause = sys_columns.c.object_id == object_id
|
||||
whereclause = sql.and_(
|
||||
columns.c.table_name == tablename,
|
||||
columns.c.table_schema == owner,
|
||||
)
|
||||
full_name = columns.c.table_schema + "." + columns.c.table_name
|
||||
else:
|
||||
whereclause = columns.c.table_name == tablename
|
||||
full_name = columns.c.table_name
|
||||
|
||||
if self._supports_nvarchar_max:
|
||||
computed_definition = computed_cols.c.definition
|
||||
@@ -3671,112 +3652,92 @@ where
|
||||
computed_cols.c.definition, NVARCHAR(4000)
|
||||
)
|
||||
|
||||
object_id = func.object_id(full_name)
|
||||
|
||||
s = (
|
||||
sql.select(
|
||||
sys_columns.c.name,
|
||||
sys_types.c.name,
|
||||
sys_columns.c.is_nullable,
|
||||
sys_columns.c.max_length,
|
||||
sys_columns.c.precision,
|
||||
sys_columns.c.scale,
|
||||
sys_default_constraints.c.definition,
|
||||
sys_columns.c.collation_name,
|
||||
columns.c.column_name,
|
||||
columns.c.data_type,
|
||||
columns.c.is_nullable,
|
||||
columns.c.character_maximum_length,
|
||||
columns.c.numeric_precision,
|
||||
columns.c.numeric_scale,
|
||||
columns.c.column_default,
|
||||
columns.c.collation_name,
|
||||
computed_definition,
|
||||
computed_cols.c.is_persisted,
|
||||
identity_cols.c.is_identity,
|
||||
identity_cols.c.seed_value,
|
||||
identity_cols.c.increment_value,
|
||||
extended_properties.c.value.label("comment"),
|
||||
)
|
||||
.select_from(sys_columns)
|
||||
.join(
|
||||
sys_types,
|
||||
onclause=sys_columns.c.user_type_id
|
||||
== sys_types.c.user_type_id,
|
||||
)
|
||||
.outerjoin(
|
||||
sys_default_constraints,
|
||||
sql.and_(
|
||||
sys_default_constraints.c.object_id
|
||||
== sys_columns.c.default_object_id,
|
||||
sys_default_constraints.c.parent_column_id
|
||||
== sys_columns.c.column_id,
|
||||
),
|
||||
ischema.extended_properties.c.value.label("comment"),
|
||||
)
|
||||
.select_from(columns)
|
||||
.outerjoin(
|
||||
computed_cols,
|
||||
onclause=sql.and_(
|
||||
computed_cols.c.object_id == sys_columns.c.object_id,
|
||||
computed_cols.c.column_id == sys_columns.c.column_id,
|
||||
computed_cols.c.object_id == object_id,
|
||||
computed_cols.c.name
|
||||
== columns.c.column_name.collate("DATABASE_DEFAULT"),
|
||||
),
|
||||
)
|
||||
.outerjoin(
|
||||
identity_cols,
|
||||
onclause=sql.and_(
|
||||
identity_cols.c.object_id == sys_columns.c.object_id,
|
||||
identity_cols.c.column_id == sys_columns.c.column_id,
|
||||
identity_cols.c.object_id == object_id,
|
||||
identity_cols.c.name
|
||||
== columns.c.column_name.collate("DATABASE_DEFAULT"),
|
||||
),
|
||||
)
|
||||
.outerjoin(
|
||||
extended_properties,
|
||||
ischema.extended_properties,
|
||||
onclause=sql.and_(
|
||||
extended_properties.c["class"] == 1,
|
||||
extended_properties.c.name == "MS_Description",
|
||||
sys_columns.c.object_id == extended_properties.c.major_id,
|
||||
sys_columns.c.column_id == extended_properties.c.minor_id,
|
||||
ischema.extended_properties.c["class"] == 1,
|
||||
ischema.extended_properties.c.major_id == object_id,
|
||||
ischema.extended_properties.c.minor_id
|
||||
== columns.c.ordinal_position,
|
||||
ischema.extended_properties.c.name == "MS_Description",
|
||||
),
|
||||
)
|
||||
.where(whereclause)
|
||||
.order_by(sys_columns.c.column_id)
|
||||
.order_by(columns.c.ordinal_position)
|
||||
)
|
||||
|
||||
if is_temp_table:
|
||||
exec_opts = {"schema_translate_map": {"sys": "tempdb.sys"}}
|
||||
else:
|
||||
exec_opts = {"schema_translate_map": {}}
|
||||
c = connection.execution_options(**exec_opts).execute(s)
|
||||
c = connection.execution_options(future_result=True).execute(s)
|
||||
|
||||
cols = []
|
||||
for row in c.mappings():
|
||||
name = row[sys_columns.c.name]
|
||||
type_ = row[sys_types.c.name]
|
||||
nullable = row[sys_columns.c.is_nullable] == 1
|
||||
maxlen = row[sys_columns.c.max_length]
|
||||
numericprec = row[sys_columns.c.precision]
|
||||
numericscale = row[sys_columns.c.scale]
|
||||
default = row[sys_default_constraints.c.definition]
|
||||
collation = row[sys_columns.c.collation_name]
|
||||
name = row[columns.c.column_name]
|
||||
type_ = row[columns.c.data_type]
|
||||
nullable = row[columns.c.is_nullable] == "YES"
|
||||
charlen = row[columns.c.character_maximum_length]
|
||||
numericprec = row[columns.c.numeric_precision]
|
||||
numericscale = row[columns.c.numeric_scale]
|
||||
default = row[columns.c.column_default]
|
||||
collation = row[columns.c.collation_name]
|
||||
definition = row[computed_definition]
|
||||
is_persisted = row[computed_cols.c.is_persisted]
|
||||
is_identity = row[identity_cols.c.is_identity]
|
||||
identity_start = row[identity_cols.c.seed_value]
|
||||
identity_increment = row[identity_cols.c.increment_value]
|
||||
comment = row[extended_properties.c.value]
|
||||
comment = row[ischema.extended_properties.c.value]
|
||||
|
||||
coltype = self.ischema_names.get(type_, None)
|
||||
|
||||
kwargs = {}
|
||||
|
||||
if coltype in (
|
||||
MSString,
|
||||
MSChar,
|
||||
MSNVarchar,
|
||||
MSNChar,
|
||||
MSText,
|
||||
MSNText,
|
||||
MSBinary,
|
||||
MSVarBinary,
|
||||
sqltypes.LargeBinary,
|
||||
):
|
||||
kwargs["length"] = maxlen if maxlen != -1 else None
|
||||
elif coltype in (
|
||||
MSString,
|
||||
MSChar,
|
||||
MSText,
|
||||
):
|
||||
kwargs["length"] = maxlen if maxlen != -1 else None
|
||||
if collation:
|
||||
kwargs["collation"] = collation
|
||||
elif coltype in (
|
||||
MSNVarchar,
|
||||
MSNChar,
|
||||
MSNText,
|
||||
):
|
||||
kwargs["length"] = maxlen // 2 if maxlen != -1 else None
|
||||
if charlen == -1:
|
||||
charlen = None
|
||||
kwargs["length"] = charlen
|
||||
if collation:
|
||||
kwargs["collation"] = collation
|
||||
|
||||
@@ -4020,8 +3981,10 @@ index_info AS (
|
||||
)
|
||||
|
||||
# group rows by constraint ID, to handle multi-column FKs
|
||||
fkeys = util.defaultdict(
|
||||
lambda: {
|
||||
fkeys = []
|
||||
|
||||
def fkey_rec():
|
||||
return {
|
||||
"name": None,
|
||||
"constrained_columns": [],
|
||||
"referred_schema": None,
|
||||
@@ -4029,7 +3992,8 @@ index_info AS (
|
||||
"referred_columns": [],
|
||||
"options": {},
|
||||
}
|
||||
)
|
||||
|
||||
fkeys = util.defaultdict(fkey_rec)
|
||||
|
||||
for r in connection.execute(s).all():
|
||||
(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mssql/information_schema.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mssql/information_schema.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -88,41 +88,23 @@ columns = Table(
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
sys_columns = Table(
|
||||
"columns",
|
||||
mssql_temp_table_columns = Table(
|
||||
"COLUMNS",
|
||||
ischema,
|
||||
Column("object_id", Integer),
|
||||
Column("name", CoerceUnicode),
|
||||
Column("column_id", Integer),
|
||||
Column("default_object_id", Integer),
|
||||
Column("user_type_id", Integer),
|
||||
Column("is_nullable", Integer),
|
||||
Column("ordinal_position", Integer),
|
||||
Column("max_length", Integer),
|
||||
Column("precision", Integer),
|
||||
Column("scale", Integer),
|
||||
Column("collation_name", String),
|
||||
schema="sys",
|
||||
)
|
||||
|
||||
sys_types = Table(
|
||||
"types",
|
||||
ischema,
|
||||
Column("name", CoerceUnicode, key="name"),
|
||||
Column("system_type_id", Integer, key="system_type_id"),
|
||||
Column("user_type_id", Integer, key="user_type_id"),
|
||||
Column("schema_id", Integer, key="schema_id"),
|
||||
Column("max_length", Integer, key="max_length"),
|
||||
Column("precision", Integer, key="precision"),
|
||||
Column("scale", Integer, key="scale"),
|
||||
Column("collation_name", CoerceUnicode, key="collation_name"),
|
||||
Column("is_nullable", Boolean, key="is_nullable"),
|
||||
Column("is_user_defined", Boolean, key="is_user_defined"),
|
||||
Column("is_assembly_type", Boolean, key="is_assembly_type"),
|
||||
Column("default_object_id", Integer, key="default_object_id"),
|
||||
Column("rule_object_id", Integer, key="rule_object_id"),
|
||||
Column("is_table_type", Boolean, key="is_table_type"),
|
||||
schema="sys",
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("IS_NULLABLE", Integer, key="is_nullable"),
|
||||
Column("DATA_TYPE", String, key="data_type"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
Column(
|
||||
"CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
|
||||
),
|
||||
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
|
||||
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
|
||||
Column("COLUMN_DEFAULT", Integer, key="column_default"),
|
||||
Column("COLLATION_NAME", String, key="collation_name"),
|
||||
schema="tempdb.INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
constraints = Table(
|
||||
@@ -135,17 +117,6 @@ constraints = Table(
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
sys_default_constraints = Table(
|
||||
"default_constraints",
|
||||
ischema,
|
||||
Column("object_id", Integer),
|
||||
Column("name", CoerceUnicode),
|
||||
Column("schema_id", Integer),
|
||||
Column("parent_column_id", Integer),
|
||||
Column("definition", CoerceUnicode),
|
||||
schema="sys",
|
||||
)
|
||||
|
||||
column_constraints = Table(
|
||||
"CONSTRAINT_COLUMN_USAGE",
|
||||
ischema,
|
||||
@@ -211,7 +182,6 @@ computed_columns = Table(
|
||||
ischema,
|
||||
Column("object_id", Integer),
|
||||
Column("name", CoerceUnicode),
|
||||
Column("column_id", Integer),
|
||||
Column("is_computed", Boolean),
|
||||
Column("is_persisted", Boolean),
|
||||
Column("definition", CoerceUnicode),
|
||||
@@ -237,7 +207,6 @@ class NumericSqlVariant(TypeDecorator):
|
||||
int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the
|
||||
correct value as string.
|
||||
"""
|
||||
|
||||
impl = Unicode
|
||||
cache_ok = True
|
||||
|
||||
@@ -250,7 +219,6 @@ identity_columns = Table(
|
||||
ischema,
|
||||
Column("object_id", Integer),
|
||||
Column("name", CoerceUnicode),
|
||||
Column("column_id", Integer),
|
||||
Column("is_identity", Boolean),
|
||||
Column("seed_value", NumericSqlVariant),
|
||||
Column("increment_value", NumericSqlVariant),
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
# dialects/mssql/json.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import types as sqltypes
|
||||
@@ -54,7 +48,9 @@ class JSON(sqltypes.JSON):
|
||||
dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor
|
||||
should be used::
|
||||
|
||||
stmt = select(data_table.c.data["some key"].as_json()).where(
|
||||
stmt = select(
|
||||
data_table.c.data["some key"].as_json()
|
||||
).where(
|
||||
data_table.c.data["some key"].as_json() == {"sub": "structure"}
|
||||
)
|
||||
|
||||
@@ -65,7 +61,9 @@ class JSON(sqltypes.JSON):
|
||||
:meth:`_types.JSON.Comparator.as_integer`,
|
||||
:meth:`_types.JSON.Comparator.as_float`::
|
||||
|
||||
stmt = select(data_table.c.data["some key"].as_string()).where(
|
||||
stmt = select(
|
||||
data_table.c.data["some key"].as_string()
|
||||
).where(
|
||||
data_table.c.data["some key"].as_string() == "some string"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
# dialects/mssql/provision.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from sqlalchemy import inspect
|
||||
@@ -22,17 +16,10 @@ from ...testing.provision import generate_driver_url
|
||||
from ...testing.provision import get_temp_table_name
|
||||
from ...testing.provision import log
|
||||
from ...testing.provision import normalize_sequence
|
||||
from ...testing.provision import post_configure_engine
|
||||
from ...testing.provision import run_reap_dbs
|
||||
from ...testing.provision import temp_table_keyword_args
|
||||
|
||||
|
||||
@post_configure_engine.for_db("mssql")
|
||||
def post_configure_engine(url, engine, follower_ident):
|
||||
if engine.driver == "pyodbc":
|
||||
engine.dialect.dbapi.pooling = False
|
||||
|
||||
|
||||
@generate_driver_url.for_db("mssql")
|
||||
def generate_driver_url(url, driver, query_str):
|
||||
backend = url.get_backend_name()
|
||||
@@ -42,9 +29,6 @@ def generate_driver_url(url, driver, query_str):
|
||||
if driver not in ("pyodbc", "aioodbc"):
|
||||
new_url = new_url.set(query="")
|
||||
|
||||
if driver == "aioodbc":
|
||||
new_url = new_url.update_query_dict({"MARS_Connection": "Yes"})
|
||||
|
||||
if query_str:
|
||||
new_url = new_url.update_query_string(query_str)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mssql/pymssql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mssql/pymssql.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -103,7 +103,6 @@ class MSDialect_pymssql(MSDialect):
|
||||
"message 20006", # Write to the server failed
|
||||
"message 20017", # Unexpected EOF from the server
|
||||
"message 20047", # DBPROCESS is dead or not enabled
|
||||
"The server failed to resume the transaction",
|
||||
):
|
||||
if msg in str(e):
|
||||
return True
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mssql/pyodbc.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mssql/pyodbc.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -30,9 +30,7 @@ is configured on the client, a basic DSN-based connection looks like::
|
||||
|
||||
engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn")
|
||||
|
||||
Which above, will pass the following connection string to PyODBC:
|
||||
|
||||
.. sourcecode:: text
|
||||
Which above, will pass the following connection string to PyODBC::
|
||||
|
||||
DSN=some_dsn;UID=scott;PWD=tiger
|
||||
|
||||
@@ -51,9 +49,7 @@ When using a hostname connection, the driver name must also be specified in the
|
||||
query parameters of the URL. As these names usually have spaces in them, the
|
||||
name must be URL encoded which means using plus signs for spaces::
|
||||
|
||||
engine = create_engine(
|
||||
"mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server"
|
||||
)
|
||||
engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server")
|
||||
|
||||
The ``driver`` keyword is significant to the pyodbc dialect and must be
|
||||
specified in lowercase.
|
||||
@@ -73,7 +69,6 @@ internally::
|
||||
The equivalent URL can be constructed using :class:`_sa.engine.URL`::
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
|
||||
connection_url = URL.create(
|
||||
"mssql+pyodbc",
|
||||
username="scott",
|
||||
@@ -88,6 +83,7 @@ The equivalent URL can be constructed using :class:`_sa.engine.URL`::
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Pass through exact Pyodbc string
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
@@ -98,11 +94,8 @@ using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object
|
||||
can help make this easier::
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
|
||||
connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password"
|
||||
connection_url = URL.create(
|
||||
"mssql+pyodbc", query={"odbc_connect": connection_string}
|
||||
)
|
||||
connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string})
|
||||
|
||||
engine = create_engine(connection_url)
|
||||
|
||||
@@ -134,8 +127,7 @@ database using Azure credentials::
|
||||
from sqlalchemy.engine.url import URL
|
||||
from azure import identity
|
||||
|
||||
# Connection option for access tokens, as defined in msodbcsql.h
|
||||
SQL_COPT_SS_ACCESS_TOKEN = 1256
|
||||
SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h
|
||||
TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database
|
||||
|
||||
connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"
|
||||
@@ -144,19 +136,14 @@ database using Azure credentials::
|
||||
|
||||
azure_credentials = identity.DefaultAzureCredential()
|
||||
|
||||
|
||||
@event.listens_for(engine, "do_connect")
|
||||
def provide_token(dialect, conn_rec, cargs, cparams):
|
||||
# remove the "Trusted_Connection" parameter that SQLAlchemy adds
|
||||
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
|
||||
|
||||
# create token credential
|
||||
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode(
|
||||
"utf-16-le"
|
||||
)
|
||||
token_struct = struct.pack(
|
||||
f"<I{len(raw_token)}s", len(raw_token), raw_token
|
||||
)
|
||||
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
|
||||
token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
|
||||
|
||||
# apply it to keyword arguments
|
||||
cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
|
||||
@@ -189,9 +176,7 @@ emit a ``.rollback()`` after an operation had a failure of some kind.
|
||||
This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to
|
||||
the SQL Server dialect via the :func:`_sa.create_engine` function as follows::
|
||||
|
||||
engine = create_engine(
|
||||
connection_url, ignore_no_transaction_on_rollback=True
|
||||
)
|
||||
engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True)
|
||||
|
||||
Using the above parameter, the dialect will catch ``ProgrammingError``
|
||||
exceptions raised during ``connection.rollback()`` and emit a warning
|
||||
@@ -251,6 +236,7 @@ behavior and pass long strings as varchar(max)/nvarchar(max) using the
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Pyodbc Pooling / connection close behavior
|
||||
------------------------------------------
|
||||
|
||||
@@ -315,8 +301,7 @@ Server dialect supports this parameter by passing the
|
||||
|
||||
engine = create_engine(
|
||||
"mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server",
|
||||
fast_executemany=True,
|
||||
)
|
||||
fast_executemany=True)
|
||||
|
||||
.. versionchanged:: 2.0.9 - the ``fast_executemany`` parameter now has its
|
||||
intended effect of this PyODBC feature taking effect for all INSERT
|
||||
@@ -384,6 +369,7 @@ from ...engine import cursor as _cursor
|
||||
|
||||
|
||||
class _ms_numeric_pyodbc:
|
||||
|
||||
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
|
||||
|
||||
The routines here are needed for older pyodbc versions
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mysql/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/__init__.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -53,8 +53,7 @@ from .base import YEAR
|
||||
from .dml import Insert
|
||||
from .dml import insert
|
||||
from .expression import match
|
||||
from .mariadb import INET4
|
||||
from .mariadb import INET6
|
||||
from ...util import compat
|
||||
|
||||
# default dialect
|
||||
base.dialect = dialect = mysqldb.dialect
|
||||
@@ -72,8 +71,6 @@ __all__ = (
|
||||
"DOUBLE",
|
||||
"ENUM",
|
||||
"FLOAT",
|
||||
"INET4",
|
||||
"INET6",
|
||||
"INTEGER",
|
||||
"INTEGER",
|
||||
"JSON",
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# dialects/mysql/aiomysql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# mysql/aiomysql.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: mysql+aiomysql
|
||||
@@ -22,108 +23,207 @@ This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
|
||||
|
||||
engine = create_async_engine(
|
||||
"mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4"
|
||||
)
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .pymysql import MySQLDialect_pymysql
|
||||
from ... import pool
|
||||
from ... import util
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_module
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
|
||||
from ...engine import AdaptedConnection
|
||||
from ...util.concurrency import asyncio
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ...connectors.asyncio import AsyncIODBAPIConnection
|
||||
from ...connectors.asyncio import AsyncIODBAPICursor
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.url import URL
|
||||
class AsyncAdapt_aiomysql_cursor:
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
server_side = False
|
||||
__slots__ = (
|
||||
"_adapt_connection",
|
||||
"_connection",
|
||||
"await_",
|
||||
"_cursor",
|
||||
"_rows",
|
||||
)
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor):
|
||||
__slots__ = ()
|
||||
cursor = self._connection.cursor(adapt_connection.dbapi.Cursor)
|
||||
|
||||
def _make_new_cursor(
|
||||
self, connection: AsyncIODBAPIConnection
|
||||
) -> AsyncIODBAPICursor:
|
||||
return connection.cursor(self._adapt_connection.dbapi.Cursor)
|
||||
# see https://github.com/aio-libs/aiomysql/issues/543
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
self._rows = []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self._cursor.description
|
||||
|
||||
class AsyncAdapt_aiomysql_ss_cursor(
|
||||
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_aiomysql_cursor
|
||||
):
|
||||
__slots__ = ()
|
||||
@property
|
||||
def rowcount(self):
|
||||
return self._cursor.rowcount
|
||||
|
||||
def _make_new_cursor(
|
||||
self, connection: AsyncIODBAPIConnection
|
||||
) -> AsyncIODBAPICursor:
|
||||
return connection.cursor(
|
||||
self._adapt_connection.dbapi.aiomysql.cursors.SSCursor
|
||||
@property
|
||||
def arraysize(self):
|
||||
return self._cursor.arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
self._cursor.arraysize = value
|
||||
|
||||
@property
|
||||
def lastrowid(self):
|
||||
return self._cursor.lastrowid
|
||||
|
||||
def close(self):
|
||||
# note we aren't actually closing the cursor here,
|
||||
# we are just letting GC do it. to allow this to be async
|
||||
# we would need the Result to change how it does "Safe close cursor".
|
||||
# MySQL "cursors" don't actually have state to be "closed" besides
|
||||
# exhausting rows, which we already have done for sync cursor.
|
||||
# another option would be to emulate aiosqlite dialect and assign
|
||||
# cursor only if we are doing server side cursor operation.
|
||||
self._rows[:] = []
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
return self.await_(self._execute_async(operation, parameters))
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
return self.await_(
|
||||
self._executemany_async(operation, seq_of_parameters)
|
||||
)
|
||||
|
||||
async def _execute_async(self, operation, parameters):
|
||||
async with self._adapt_connection._execute_mutex:
|
||||
result = await self._cursor.execute(operation, parameters)
|
||||
|
||||
class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
|
||||
if not self.server_side:
|
||||
# aiomysql has a "fake" async result, so we have to pull it out
|
||||
# of that here since our default result is not async.
|
||||
# we could just as easily grab "_rows" here and be done with it
|
||||
# but this is safer.
|
||||
self._rows = list(await self._cursor.fetchall())
|
||||
return result
|
||||
|
||||
async def _executemany_async(self, operation, seq_of_parameters):
|
||||
async with self._adapt_connection._execute_mutex:
|
||||
return await self._cursor.executemany(operation, seq_of_parameters)
|
||||
|
||||
def setinputsizes(self, *inputsizes):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = ()
|
||||
server_side = True
|
||||
|
||||
_cursor_cls = AsyncAdapt_aiomysql_cursor
|
||||
_ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
def ping(self, reconnect: bool) -> None:
|
||||
assert not reconnect
|
||||
self.await_(self._connection.ping(reconnect))
|
||||
cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor)
|
||||
|
||||
def character_set_name(self) -> Optional[str]:
|
||||
return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
|
||||
def autocommit(self, value: Any) -> None:
|
||||
def close(self):
|
||||
if self._cursor is not None:
|
||||
self.await_(self._cursor.close())
|
||||
self._cursor = None
|
||||
|
||||
def fetchone(self):
|
||||
return self.await_(self._cursor.fetchone())
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
return self.await_(self._cursor.fetchmany(size=size))
|
||||
|
||||
def fetchall(self):
|
||||
return self.await_(self._cursor.fetchall())
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_connection(AdaptedConnection):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
await_ = staticmethod(await_only)
|
||||
__slots__ = ("dbapi", "_execute_mutex")
|
||||
|
||||
def __init__(self, dbapi, connection):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
self._execute_mutex = asyncio.Lock()
|
||||
|
||||
def ping(self, reconnect):
|
||||
return self.await_(self._connection.ping(reconnect))
|
||||
|
||||
def character_set_name(self):
|
||||
return self._connection.character_set_name()
|
||||
|
||||
def autocommit(self, value):
|
||||
self.await_(self._connection.autocommit(value))
|
||||
|
||||
def get_autocommit(self) -> bool:
|
||||
return self._connection.get_autocommit() # type: ignore
|
||||
def cursor(self, server_side=False):
|
||||
if server_side:
|
||||
return AsyncAdapt_aiomysql_ss_cursor(self)
|
||||
else:
|
||||
return AsyncAdapt_aiomysql_cursor(self)
|
||||
|
||||
def terminate(self) -> None:
|
||||
def rollback(self):
|
||||
self.await_(self._connection.rollback())
|
||||
|
||||
def commit(self):
|
||||
self.await_(self._connection.commit())
|
||||
|
||||
def close(self):
|
||||
# it's not awaitable.
|
||||
self._connection.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self.await_(self._connection.ensure_closed())
|
||||
|
||||
|
||||
class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = ()
|
||||
|
||||
await_ = staticmethod(await_fallback)
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
|
||||
def __init__(self, aiomysql: ModuleType, pymysql: ModuleType):
|
||||
class AsyncAdapt_aiomysql_dbapi:
|
||||
def __init__(self, aiomysql, pymysql):
|
||||
self.aiomysql = aiomysql
|
||||
self.pymysql = pymysql
|
||||
self.paramstyle = "format"
|
||||
self._init_dbapi_attributes()
|
||||
self.Cursor, self.SSCursor = self._init_cursors_subclasses()
|
||||
|
||||
def _init_dbapi_attributes(self) -> None:
|
||||
def _init_dbapi_attributes(self):
|
||||
for name in (
|
||||
"Warning",
|
||||
"Error",
|
||||
@@ -149,7 +249,7 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
|
||||
):
|
||||
setattr(self, name, getattr(self.pymysql, name))
|
||||
|
||||
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection:
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
|
||||
|
||||
@@ -164,23 +264,17 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
|
||||
await_only(creator_fn(*arg, **kw)),
|
||||
)
|
||||
|
||||
def _init_cursors_subclasses(
|
||||
self,
|
||||
) -> Tuple[AsyncIODBAPICursor, AsyncIODBAPICursor]:
|
||||
def _init_cursors_subclasses(self):
|
||||
# suppress unconditional warning emitted by aiomysql
|
||||
class Cursor(self.aiomysql.Cursor): # type: ignore[misc, name-defined]
|
||||
async def _show_warnings(
|
||||
self, conn: AsyncIODBAPIConnection
|
||||
) -> None:
|
||||
class Cursor(self.aiomysql.Cursor):
|
||||
async def _show_warnings(self, conn):
|
||||
pass
|
||||
|
||||
class SSCursor(self.aiomysql.SSCursor): # type: ignore[misc, name-defined] # noqa: E501
|
||||
async def _show_warnings(
|
||||
self, conn: AsyncIODBAPIConnection
|
||||
) -> None:
|
||||
class SSCursor(self.aiomysql.SSCursor):
|
||||
async def _show_warnings(self, conn):
|
||||
pass
|
||||
|
||||
return Cursor, SSCursor # type: ignore[return-value]
|
||||
return Cursor, SSCursor
|
||||
|
||||
|
||||
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
|
||||
@@ -191,16 +285,15 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
|
||||
_sscursor = AsyncAdapt_aiomysql_ss_cursor
|
||||
|
||||
is_async = True
|
||||
has_terminate = True
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> AsyncAdapt_aiomysql_dbapi:
|
||||
def import_dbapi(cls):
|
||||
return AsyncAdapt_aiomysql_dbapi(
|
||||
__import__("aiomysql"), __import__("pymysql")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url: URL) -> type:
|
||||
def get_pool_class(cls, url):
|
||||
async_fallback = url.query.get("async_fallback", False)
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
@@ -208,37 +301,25 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql):
|
||||
else:
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
|
||||
def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
|
||||
dbapi_connection.terminate()
|
||||
|
||||
def create_connect_args(
|
||||
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
|
||||
) -> ConnectArgsType:
|
||||
def create_connect_args(self, url):
|
||||
return super().create_connect_args(
|
||||
url, _translate_args=dict(username="user", database="db")
|
||||
)
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
else:
|
||||
str_e = str(e).lower()
|
||||
return "not connected" in str_e
|
||||
|
||||
def _found_rows_client_flag(self) -> int:
|
||||
from pymysql.constants import CLIENT # type: ignore
|
||||
def _found_rows_client_flag(self):
|
||||
from pymysql.constants import CLIENT
|
||||
|
||||
return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
|
||||
return CLIENT.FOUND_ROWS
|
||||
|
||||
def get_driver_connection(
|
||||
self, connection: DBAPIConnection
|
||||
) -> AsyncIODBAPIConnection:
|
||||
return connection._connection # type: ignore[no-any-return]
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = MySQLDialect_aiomysql
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# dialects/mysql/asyncmy.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# mysql/asyncmy.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: mysql+asyncmy
|
||||
@@ -20,100 +21,210 @@ This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
|
||||
|
||||
engine = create_async_engine(
|
||||
"mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4"
|
||||
)
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from .pymysql import MySQLDialect_pymysql
|
||||
from ... import pool
|
||||
from ... import util
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_module
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
|
||||
from ...engine import AdaptedConnection
|
||||
from ...util.concurrency import asyncio
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...connectors.asyncio import AsyncIODBAPIConnection
|
||||
from ...connectors.asyncio import AsyncIODBAPICursor
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.url import URL
|
||||
|
||||
class AsyncAdapt_asyncmy_cursor:
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
server_side = False
|
||||
__slots__ = (
|
||||
"_adapt_connection",
|
||||
"_connection",
|
||||
"await_",
|
||||
"_cursor",
|
||||
"_rows",
|
||||
)
|
||||
|
||||
class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor):
|
||||
__slots__ = ()
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
cursor = self._connection.cursor()
|
||||
|
||||
class AsyncAdapt_asyncmy_ss_cursor(
|
||||
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncmy_cursor
|
||||
):
|
||||
__slots__ = ()
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
self._rows = []
|
||||
|
||||
def _make_new_cursor(
|
||||
self, connection: AsyncIODBAPIConnection
|
||||
) -> AsyncIODBAPICursor:
|
||||
return connection.cursor(
|
||||
self._adapt_connection.dbapi.asyncmy.cursors.SSCursor
|
||||
@property
|
||||
def description(self):
|
||||
return self._cursor.description
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
return self._cursor.rowcount
|
||||
|
||||
@property
|
||||
def arraysize(self):
|
||||
return self._cursor.arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
self._cursor.arraysize = value
|
||||
|
||||
@property
|
||||
def lastrowid(self):
|
||||
return self._cursor.lastrowid
|
||||
|
||||
def close(self):
|
||||
# note we aren't actually closing the cursor here,
|
||||
# we are just letting GC do it. to allow this to be async
|
||||
# we would need the Result to change how it does "Safe close cursor".
|
||||
# MySQL "cursors" don't actually have state to be "closed" besides
|
||||
# exhausting rows, which we already have done for sync cursor.
|
||||
# another option would be to emulate aiosqlite dialect and assign
|
||||
# cursor only if we are doing server side cursor operation.
|
||||
self._rows[:] = []
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
return self.await_(self._execute_async(operation, parameters))
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
return self.await_(
|
||||
self._executemany_async(operation, seq_of_parameters)
|
||||
)
|
||||
|
||||
async def _execute_async(self, operation, parameters):
|
||||
async with self._adapt_connection._mutex_and_adapt_errors():
|
||||
if parameters is None:
|
||||
result = await self._cursor.execute(operation)
|
||||
else:
|
||||
result = await self._cursor.execute(operation, parameters)
|
||||
|
||||
class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
|
||||
if not self.server_side:
|
||||
# asyncmy has a "fake" async result, so we have to pull it out
|
||||
# of that here since our default result is not async.
|
||||
# we could just as easily grab "_rows" here and be done with it
|
||||
# but this is safer.
|
||||
self._rows = list(await self._cursor.fetchall())
|
||||
return result
|
||||
|
||||
async def _executemany_async(self, operation, seq_of_parameters):
|
||||
async with self._adapt_connection._mutex_and_adapt_errors():
|
||||
return await self._cursor.executemany(operation, seq_of_parameters)
|
||||
|
||||
def setinputsizes(self, *inputsizes):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = ()
|
||||
server_side = True
|
||||
|
||||
_cursor_cls = AsyncAdapt_asyncmy_cursor
|
||||
_ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
def _handle_exception(self, error: Exception) -> NoReturn:
|
||||
if isinstance(error, AttributeError):
|
||||
raise self.dbapi.InternalError(
|
||||
"network operation failed due to asyncmy attribute error"
|
||||
)
|
||||
cursor = self._connection.cursor(
|
||||
adapt_connection.dbapi.asyncmy.cursors.SSCursor
|
||||
)
|
||||
|
||||
raise error
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
|
||||
def ping(self, reconnect: bool) -> None:
|
||||
def close(self):
|
||||
if self._cursor is not None:
|
||||
self.await_(self._cursor.close())
|
||||
self._cursor = None
|
||||
|
||||
def fetchone(self):
|
||||
return self.await_(self._cursor.fetchone())
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
return self.await_(self._cursor.fetchmany(size=size))
|
||||
|
||||
def fetchall(self):
|
||||
return self.await_(self._cursor.fetchall())
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_connection(AdaptedConnection):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
await_ = staticmethod(await_only)
|
||||
__slots__ = ("dbapi", "_execute_mutex")
|
||||
|
||||
def __init__(self, dbapi, connection):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
self._execute_mutex = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _mutex_and_adapt_errors(self):
|
||||
async with self._execute_mutex:
|
||||
try:
|
||||
yield
|
||||
except AttributeError:
|
||||
raise self.dbapi.InternalError(
|
||||
"network operation failed due to asyncmy attribute error"
|
||||
)
|
||||
|
||||
def ping(self, reconnect):
|
||||
assert not reconnect
|
||||
return self.await_(self._do_ping())
|
||||
|
||||
async def _do_ping(self) -> None:
|
||||
try:
|
||||
async with self._execute_mutex:
|
||||
await self._connection.ping(False)
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
async def _do_ping(self):
|
||||
async with self._mutex_and_adapt_errors():
|
||||
return await self._connection.ping(False)
|
||||
|
||||
def character_set_name(self) -> Optional[str]:
|
||||
return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501
|
||||
def character_set_name(self):
|
||||
return self._connection.character_set_name()
|
||||
|
||||
def autocommit(self, value: Any) -> None:
|
||||
def autocommit(self, value):
|
||||
self.await_(self._connection.autocommit(value))
|
||||
|
||||
def get_autocommit(self) -> bool:
|
||||
return self._connection.get_autocommit() # type: ignore
|
||||
def cursor(self, server_side=False):
|
||||
if server_side:
|
||||
return AsyncAdapt_asyncmy_ss_cursor(self)
|
||||
else:
|
||||
return AsyncAdapt_asyncmy_cursor(self)
|
||||
|
||||
def terminate(self) -> None:
|
||||
def rollback(self):
|
||||
self.await_(self._connection.rollback())
|
||||
|
||||
def commit(self):
|
||||
self.await_(self._connection.commit())
|
||||
|
||||
def close(self):
|
||||
# it's not awaitable.
|
||||
self._connection.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self.await_(self._connection.ensure_closed())
|
||||
|
||||
|
||||
class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
|
||||
__slots__ = ()
|
||||
@@ -121,13 +232,18 @@ class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
|
||||
await_ = staticmethod(await_fallback)
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
|
||||
def __init__(self, asyncmy: ModuleType):
|
||||
def _Binary(x):
|
||||
"""Return x as a binary type."""
|
||||
return bytes(x)
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_dbapi:
|
||||
def __init__(self, asyncmy):
|
||||
self.asyncmy = asyncmy
|
||||
self.paramstyle = "format"
|
||||
self._init_dbapi_attributes()
|
||||
|
||||
def _init_dbapi_attributes(self) -> None:
|
||||
def _init_dbapi_attributes(self):
|
||||
for name in (
|
||||
"Warning",
|
||||
"Error",
|
||||
@@ -148,9 +264,9 @@ class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
|
||||
BINARY = util.symbol("BINARY")
|
||||
DATETIME = util.symbol("DATETIME")
|
||||
TIMESTAMP = util.symbol("TIMESTAMP")
|
||||
Binary = staticmethod(bytes)
|
||||
Binary = staticmethod(_Binary)
|
||||
|
||||
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection:
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
|
||||
|
||||
@@ -174,14 +290,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
|
||||
_sscursor = AsyncAdapt_asyncmy_ss_cursor
|
||||
|
||||
is_async = True
|
||||
has_terminate = True
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url: URL) -> type:
|
||||
def get_pool_class(cls, url):
|
||||
async_fallback = url.query.get("async_fallback", False)
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
@@ -189,20 +304,12 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
|
||||
else:
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
|
||||
def do_terminate(self, dbapi_connection: DBAPIConnection) -> None:
|
||||
dbapi_connection.terminate()
|
||||
|
||||
def create_connect_args(self, url: URL) -> ConnectArgsType: # type: ignore[override] # noqa: E501
|
||||
def create_connect_args(self, url):
|
||||
return super().create_connect_args(
|
||||
url, _translate_args=dict(username="user", database="db")
|
||||
)
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
else:
|
||||
@@ -211,15 +318,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql):
|
||||
"not connected" in str_e or "network operation failed" in str_e
|
||||
)
|
||||
|
||||
def _found_rows_client_flag(self) -> int:
|
||||
from asyncmy.constants import CLIENT # type: ignore
|
||||
def _found_rows_client_flag(self):
|
||||
from asyncmy.constants import CLIENT
|
||||
|
||||
return CLIENT.FOUND_ROWS # type: ignore[no-any-return]
|
||||
return CLIENT.FOUND_ROWS
|
||||
|
||||
def get_driver_connection(
|
||||
self, connection: DBAPIConnection
|
||||
) -> AsyncIODBAPIConnection:
|
||||
return connection._connection # type: ignore[no-any-return]
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = MySQLDialect_asyncmy
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,10 @@
|
||||
# dialects/mysql/cymysql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/cymysql.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
|
||||
@@ -20,36 +21,18 @@ r"""
|
||||
dialects are mysqlclient and PyMySQL.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Iterable
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .base import BIT
|
||||
from .base import MySQLDialect
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from .types import BIT
|
||||
from ... import util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.base import Connection
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
|
||||
|
||||
class _cymysqlBIT(BIT):
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> Optional[_ResultProcessorType[Any]]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Convert MySQL's 64 bit, variable length binary string to a long."""
|
||||
|
||||
def process(value: Optional[Iterable[int]]) -> Optional[int]:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
v = 0
|
||||
for i in iter(value):
|
||||
@@ -72,22 +55,17 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
|
||||
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return __import__("cymysql")
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
return connection.connection.charset # type: ignore[no-any-return]
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
|
||||
return exception.errno # type: ignore[no-any-return]
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
if isinstance(e, self.loaded_dbapi.OperationalError):
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.OperationalError):
|
||||
return self._extract_error_code(e) in (
|
||||
2006,
|
||||
2013,
|
||||
@@ -95,7 +73,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb):
|
||||
2045,
|
||||
2055,
|
||||
)
|
||||
elif isinstance(e, self.loaded_dbapi.InterfaceError):
|
||||
elif isinstance(e, self.dbapi.InterfaceError):
|
||||
# if underlying connection is closed,
|
||||
# this is the error you get
|
||||
return True
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mysql/dml.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/dml.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -7,7 +7,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
@@ -142,11 +141,7 @@ class Insert(StandardInsert):
|
||||
in :ref:`tutorial_parameter_ordered_updates`::
|
||||
|
||||
insert().on_duplicate_key_update(
|
||||
[
|
||||
("name", "some name"),
|
||||
("value", "some value"),
|
||||
]
|
||||
)
|
||||
[("name", "some name"), ("value", "some value")])
|
||||
|
||||
.. versionchanged:: 1.3 parameters can be specified as a dictionary
|
||||
or list of 2-tuples; the latter form provides for parameter
|
||||
@@ -186,7 +181,6 @@ class OnDuplicateClause(ClauseElement):
|
||||
|
||||
_parameter_ordering: Optional[List[str]] = None
|
||||
|
||||
update: Dict[str, Any]
|
||||
stringify_dialect = "mysql"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,51 +1,34 @@
|
||||
# dialects/mysql/enumerated.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/enumerated.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Set
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .types import _StringType
|
||||
from ... import exc
|
||||
from ... import sql
|
||||
from ... import util
|
||||
from ...sql import sqltypes
|
||||
from ...sql import type_api
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
from ...sql.type_api import TypeEngine
|
||||
from ...sql.type_api import TypeEngineMixin
|
||||
|
||||
|
||||
class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
|
||||
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
|
||||
"""MySQL ENUM type."""
|
||||
|
||||
__visit_name__ = "ENUM"
|
||||
|
||||
native_enum = True
|
||||
|
||||
def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None:
|
||||
def __init__(self, *enums, **kw):
|
||||
"""Construct an ENUM.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column("myenum", ENUM("foo", "bar", "baz"))
|
||||
Column('myenum', ENUM("foo", "bar", "baz"))
|
||||
|
||||
:param enums: The range of valid values for this ENUM. Values in
|
||||
enums are not quoted, they will be escaped and surrounded by single
|
||||
@@ -79,27 +62,21 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
|
||||
|
||||
"""
|
||||
kw.pop("strict", None)
|
||||
self._enum_init(enums, kw) # type: ignore[arg-type]
|
||||
self._enum_init(enums, kw)
|
||||
_StringType.__init__(self, length=self.length, **kw)
|
||||
|
||||
@classmethod
|
||||
def adapt_emulated_to_native(
|
||||
cls,
|
||||
impl: Union[TypeEngine[Any], TypeEngineMixin],
|
||||
**kw: Any,
|
||||
) -> ENUM:
|
||||
def adapt_emulated_to_native(cls, impl, **kw):
|
||||
"""Produce a MySQL native :class:`.mysql.ENUM` from plain
|
||||
:class:`.Enum`.
|
||||
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(impl, ENUM)
|
||||
kw.setdefault("validate_strings", impl.validate_strings)
|
||||
kw.setdefault("values_callable", impl.values_callable)
|
||||
kw.setdefault("omit_aliases", impl._omit_aliases)
|
||||
return cls(**kw)
|
||||
|
||||
def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
|
||||
def _object_value_for_elem(self, elem):
|
||||
# mysql sends back a blank string for any value that
|
||||
# was persisted that was not in the enums; that is, it does no
|
||||
# validation on the incoming data, it "truncates" it to be
|
||||
@@ -109,27 +86,24 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType):
|
||||
else:
|
||||
return super()._object_value_for_elem(elem)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
|
||||
)
|
||||
|
||||
|
||||
# TODO: SET is a string as far as configuration but does not act like
|
||||
# a string at the python level. We either need to make a py-type agnostic
|
||||
# version of String as a base to be used for this, make this some kind of
|
||||
# TypeDecorator, or just vendor it out as its own type.
|
||||
class SET(_StringType):
|
||||
"""MySQL SET type."""
|
||||
|
||||
__visit_name__ = "SET"
|
||||
|
||||
def __init__(self, *values: str, **kw: Any):
|
||||
def __init__(self, *values, **kw):
|
||||
"""Construct a SET.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column("myset", SET("foo", "bar", "baz"))
|
||||
Column('myset', SET("foo", "bar", "baz"))
|
||||
|
||||
|
||||
The list of potential values is required in the case that this
|
||||
set will be used to generate DDL for a table, or if the
|
||||
@@ -177,19 +151,17 @@ class SET(_StringType):
|
||||
"setting retrieve_as_bitwise=True"
|
||||
)
|
||||
if self.retrieve_as_bitwise:
|
||||
self._inversed_bitmap: Dict[str, int] = {
|
||||
self._bitmap = {
|
||||
value: 2**idx for idx, value in enumerate(self.values)
|
||||
}
|
||||
self._bitmap: Dict[int, str] = {
|
||||
2**idx: value for idx, value in enumerate(self.values)
|
||||
}
|
||||
self._bitmap.update(
|
||||
(2**idx, value) for idx, value in enumerate(self.values)
|
||||
)
|
||||
length = max([len(v) for v in values] + [0])
|
||||
kw.setdefault("length", length)
|
||||
super().__init__(**kw)
|
||||
|
||||
def column_expression(
|
||||
self, colexpr: ColumnElement[Any]
|
||||
) -> ColumnElement[Any]:
|
||||
def column_expression(self, colexpr):
|
||||
if self.retrieve_as_bitwise:
|
||||
return sql.type_coerce(
|
||||
sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
|
||||
@@ -197,12 +169,10 @@ class SET(_StringType):
|
||||
else:
|
||||
return colexpr
|
||||
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: Any
|
||||
) -> Optional[_ResultProcessorType[Any]]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.retrieve_as_bitwise:
|
||||
|
||||
def process(value: Union[str, int, None]) -> Optional[Set[str]]:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
value = int(value)
|
||||
|
||||
@@ -213,14 +183,11 @@ class SET(_StringType):
|
||||
else:
|
||||
super_convert = super().result_processor(dialect, coltype)
|
||||
|
||||
def process(value: Union[str, Set[str], None]) -> Optional[Set[str]]: # type: ignore[misc] # noqa: E501
|
||||
def process(value):
|
||||
if isinstance(value, str):
|
||||
# MySQLdb returns a string, let's parse
|
||||
if super_convert:
|
||||
value = super_convert(value)
|
||||
assert value is not None
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(value, str)
|
||||
return set(re.findall(r"[^,]+", value))
|
||||
else:
|
||||
# mysql-connector-python does a naive
|
||||
@@ -231,48 +198,43 @@ class SET(_StringType):
|
||||
|
||||
return process
|
||||
|
||||
def bind_processor(
|
||||
self, dialect: Dialect
|
||||
) -> _BindProcessorType[Union[str, int]]:
|
||||
def bind_processor(self, dialect):
|
||||
super_convert = super().bind_processor(dialect)
|
||||
if self.retrieve_as_bitwise:
|
||||
|
||||
def process(
|
||||
value: Union[str, int, set[str], None],
|
||||
) -> Union[str, int, None]:
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, (int, str)):
|
||||
if super_convert:
|
||||
return super_convert(value) # type: ignore[arg-type, no-any-return] # noqa: E501
|
||||
return super_convert(value)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
int_value = 0
|
||||
for v in value:
|
||||
int_value |= self._inversed_bitmap[v]
|
||||
int_value |= self._bitmap[v]
|
||||
return int_value
|
||||
|
||||
else:
|
||||
|
||||
def process(
|
||||
value: Union[str, int, set[str], None],
|
||||
) -> Union[str, int, None]:
|
||||
def process(value):
|
||||
# accept strings and int (actually bitflag) values directly
|
||||
if value is not None and not isinstance(value, (int, str)):
|
||||
value = ",".join(value)
|
||||
|
||||
if super_convert:
|
||||
return super_convert(value) # type: ignore
|
||||
return super_convert(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def adapt(self, cls: type, **kw: Any) -> Any:
|
||||
def adapt(self, impltype, **kw):
|
||||
kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
|
||||
return util.constructor_copy(self, cls, *self.values, **kw)
|
||||
return util.constructor_copy(self, impltype, *self.values, **kw)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self,
|
||||
to_inspect=[SET, _StringType],
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
# dialects/mysql/expression.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ... import exc
|
||||
from ... import util
|
||||
@@ -20,7 +17,7 @@ from ...sql.base import Generative
|
||||
from ...util.typing import Self
|
||||
|
||||
|
||||
class match(Generative, elements.BinaryExpression[Any]):
|
||||
class match(Generative, elements.BinaryExpression):
|
||||
"""Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
|
||||
|
||||
E.g.::
|
||||
@@ -40,9 +37,7 @@ class match(Generative, elements.BinaryExpression[Any]):
|
||||
.order_by(desc(match_expr))
|
||||
)
|
||||
|
||||
Would produce SQL resembling:
|
||||
|
||||
.. sourcecode:: sql
|
||||
Would produce SQL resembling::
|
||||
|
||||
SELECT id, firstname, lastname
|
||||
FROM user
|
||||
@@ -75,9 +70,8 @@ class match(Generative, elements.BinaryExpression[Any]):
|
||||
__visit_name__ = "mysql_match"
|
||||
|
||||
inherit_cache = True
|
||||
modifiers: util.immutabledict[str, Any]
|
||||
|
||||
def __init__(self, *cols: elements.ColumnElement[Any], **kw: Any):
|
||||
def __init__(self, *cols, **kw):
|
||||
if not cols:
|
||||
raise exc.ArgumentError("columns are required")
|
||||
|
||||
|
||||
@@ -1,21 +1,13 @@
|
||||
# dialects/mysql/json.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/json.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import types as sqltypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _LiteralProcessorType
|
||||
|
||||
|
||||
class JSON(sqltypes.JSON):
|
||||
"""MySQL JSON type.
|
||||
@@ -42,13 +34,13 @@ class JSON(sqltypes.JSON):
|
||||
|
||||
|
||||
class _FormatTypeMixin:
|
||||
def _format_value(self, value: Any) -> str:
|
||||
def _format_value(self, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
|
||||
super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501
|
||||
def bind_processor(self, dialect):
|
||||
super_proc = self.string_bind_processor(dialect)
|
||||
|
||||
def process(value: Any) -> Any:
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
@@ -56,31 +48,29 @@ class _FormatTypeMixin:
|
||||
|
||||
return process
|
||||
|
||||
def literal_processor(
|
||||
self, dialect: Dialect
|
||||
) -> _LiteralProcessorType[Any]:
|
||||
super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501
|
||||
def literal_processor(self, dialect):
|
||||
super_proc = self.string_literal_processor(dialect)
|
||||
|
||||
def process(value: Any) -> str:
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value # type: ignore[no-any-return]
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
|
||||
def _format_value(self, value: Any) -> str:
|
||||
def _format_value(self, value):
|
||||
if isinstance(value, int):
|
||||
formatted_value = "$[%s]" % value
|
||||
value = "$[%s]" % value
|
||||
else:
|
||||
formatted_value = '$."%s"' % value
|
||||
return formatted_value
|
||||
value = '$."%s"' % value
|
||||
return value
|
||||
|
||||
|
||||
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
|
||||
def _format_value(self, value: Any) -> str:
|
||||
def _format_value(self, value):
|
||||
return "$%s" % (
|
||||
"".join(
|
||||
[
|
||||
|
||||
@@ -1,73 +1,32 @@
|
||||
# dialects/mysql/mariadb.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/mariadb.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
|
||||
# mypy: ignore-errors
|
||||
from .base import MariaDBIdentifierPreparer
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from .base import MySQLTypeCompiler
|
||||
from ...sql import sqltypes
|
||||
|
||||
|
||||
class INET4(sqltypes.TypeEngine[str]):
|
||||
"""INET4 column type for MariaDB
|
||||
|
||||
.. versionadded:: 2.0.37
|
||||
"""
|
||||
|
||||
__visit_name__ = "INET4"
|
||||
|
||||
|
||||
class INET6(sqltypes.TypeEngine[str]):
|
||||
"""INET6 column type for MariaDB
|
||||
|
||||
.. versionadded:: 2.0.37
|
||||
"""
|
||||
|
||||
__visit_name__ = "INET6"
|
||||
|
||||
|
||||
class MariaDBTypeCompiler(MySQLTypeCompiler):
|
||||
def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
|
||||
return "INET4"
|
||||
|
||||
def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
|
||||
return "INET6"
|
||||
|
||||
|
||||
class MariaDBDialect(MySQLDialect):
|
||||
is_mariadb = True
|
||||
supports_statement_cache = True
|
||||
name = "mariadb"
|
||||
preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer
|
||||
type_compiler_cls = MariaDBTypeCompiler
|
||||
preparer = MariaDBIdentifierPreparer
|
||||
|
||||
|
||||
def loader(driver: str) -> Callable[[], type[MariaDBDialect]]:
|
||||
dialect_mod = __import__(
|
||||
def loader(driver):
|
||||
driver_mod = __import__(
|
||||
"sqlalchemy.dialects.mysql.%s" % driver
|
||||
).dialects.mysql
|
||||
driver_cls = getattr(driver_mod, driver).dialect
|
||||
|
||||
driver_mod = getattr(dialect_mod, driver)
|
||||
if hasattr(driver_mod, "mariadb_dialect"):
|
||||
driver_cls = driver_mod.mariadb_dialect
|
||||
return driver_cls # type: ignore[no-any-return]
|
||||
else:
|
||||
driver_cls = driver_mod.dialect
|
||||
|
||||
return type(
|
||||
"MariaDBDialect_%s" % driver,
|
||||
(
|
||||
MariaDBDialect,
|
||||
driver_cls,
|
||||
),
|
||||
{"supports_statement_cache": True},
|
||||
)
|
||||
return type(
|
||||
"MariaDBDialect_%s" % driver,
|
||||
(
|
||||
MariaDBDialect,
|
||||
driver_cls,
|
||||
),
|
||||
{"supports_statement_cache": True},
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# dialects/mysql/mariadbconnector.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/mariadbconnector.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -27,15 +29,7 @@ be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
|
||||
.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
from uuid import UUID as _python_UUID
|
||||
|
||||
from .base import MySQLCompiler
|
||||
@@ -45,19 +39,6 @@ from ... import sql
|
||||
from ... import util
|
||||
from ...sql import sqltypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.base import Connection
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...engine.interfaces import IsolationLevel
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.url import URL
|
||||
from ...sql.compiler import SQLCompiler
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
|
||||
|
||||
mariadb_cpy_minimum_version = (1, 0, 1)
|
||||
|
||||
@@ -66,12 +47,10 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
# work around JIRA issue
|
||||
# https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed,
|
||||
# this type can be removed.
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> Optional[_ResultProcessorType[Any]]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.as_uuid:
|
||||
|
||||
def process(value: Any) -> Any:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
if hasattr(value, "decode"):
|
||||
value = value.decode("ascii")
|
||||
@@ -81,7 +60,7 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
return process
|
||||
else:
|
||||
|
||||
def process(value: Any) -> Any:
|
||||
def process(value):
|
||||
if value is not None:
|
||||
if hasattr(value, "decode"):
|
||||
value = value.decode("ascii")
|
||||
@@ -92,27 +71,30 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
|
||||
|
||||
class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
|
||||
_lastrowid: Optional[int] = None
|
||||
_lastrowid = None
|
||||
|
||||
def create_server_side_cursor(self) -> DBAPICursor:
|
||||
def create_server_side_cursor(self):
|
||||
return self._dbapi_connection.cursor(buffered=False)
|
||||
|
||||
def create_default_cursor(self) -> DBAPICursor:
|
||||
def create_default_cursor(self):
|
||||
return self._dbapi_connection.cursor(buffered=True)
|
||||
|
||||
def post_exec(self) -> None:
|
||||
def post_exec(self):
|
||||
super().post_exec()
|
||||
|
||||
self._rowcount = self.cursor.rowcount
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(self.compiled, SQLCompiler)
|
||||
if self.isinsert and self.compiled.postfetch_lastrowid:
|
||||
self._lastrowid = self.cursor.lastrowid
|
||||
|
||||
def get_lastrowid(self) -> int:
|
||||
if TYPE_CHECKING:
|
||||
assert self._lastrowid is not None
|
||||
@property
|
||||
def rowcount(self):
|
||||
if self._rowcount is not None:
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
def get_lastrowid(self):
|
||||
return self._lastrowid
|
||||
|
||||
|
||||
@@ -151,7 +133,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def _dbapi_version(self) -> Tuple[int, ...]:
|
||||
def _dbapi_version(self):
|
||||
if self.dbapi and hasattr(self.dbapi, "__version__"):
|
||||
return tuple(
|
||||
[
|
||||
@@ -164,7 +146,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
else:
|
||||
return (99, 99, 99)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.paramstyle = "qmark"
|
||||
if self.dbapi is not None:
|
||||
@@ -176,26 +158,20 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return __import__("mariadb")
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
elif isinstance(e, self.loaded_dbapi.Error):
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
str_e = str(e).lower()
|
||||
return "not connected" in str_e or "isn't valid" in str_e
|
||||
else:
|
||||
return False
|
||||
|
||||
def create_connect_args(self, url: URL) -> ConnectArgsType:
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args()
|
||||
opts.update(url.query)
|
||||
|
||||
int_params = [
|
||||
"connect_timeout",
|
||||
@@ -210,7 +186,6 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
"ssl_verify_cert",
|
||||
"ssl",
|
||||
"pool_reset_connection",
|
||||
"compress",
|
||||
]
|
||||
|
||||
for key in int_params:
|
||||
@@ -230,21 +205,19 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
except (AttributeError, ImportError):
|
||||
self.supports_sane_rowcount = False
|
||||
opts["client_flag"] = client_flag
|
||||
return [], opts
|
||||
return [[], opts]
|
||||
|
||||
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
|
||||
def _extract_error_code(self, exception):
|
||||
try:
|
||||
rc: int = exception.errno
|
||||
rc = exception.errno
|
||||
except:
|
||||
rc = -1
|
||||
return rc
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
def _detect_charset(self, connection):
|
||||
return "utf8mb4"
|
||||
|
||||
def get_isolation_level_values(
|
||||
self, dbapi_conn: DBAPIConnection
|
||||
) -> Sequence[IsolationLevel]:
|
||||
def get_isolation_level_values(self, dbapi_connection):
|
||||
return (
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
@@ -253,26 +226,21 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
"AUTOCOMMIT",
|
||||
)
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def set_isolation_level(
|
||||
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
|
||||
) -> None:
|
||||
def set_isolation_level(self, connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.autocommit = True
|
||||
connection.autocommit = True
|
||||
else:
|
||||
dbapi_connection.autocommit = False
|
||||
super().set_isolation_level(dbapi_connection, level)
|
||||
connection.autocommit = False
|
||||
super().set_isolation_level(connection, level)
|
||||
|
||||
def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
connection.execute(
|
||||
sql.text("XA BEGIN :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
)
|
||||
)
|
||||
|
||||
def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
connection.execute(
|
||||
sql.text("XA END :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
@@ -285,12 +253,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
)
|
||||
|
||||
def do_rollback_twophase(
|
||||
self,
|
||||
connection: Connection,
|
||||
xid: Any,
|
||||
is_prepared: bool = True,
|
||||
recover: bool = False,
|
||||
) -> None:
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
if not is_prepared:
|
||||
connection.execute(
|
||||
sql.text("XA END :xid").bindparams(
|
||||
@@ -304,12 +268,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
)
|
||||
|
||||
def do_commit_twophase(
|
||||
self,
|
||||
connection: Connection,
|
||||
xid: Any,
|
||||
is_prepared: bool = True,
|
||||
recover: bool = False,
|
||||
) -> None:
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
if not is_prepared:
|
||||
self.do_prepare_twophase(connection, xid)
|
||||
connection.execute(
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# dialects/mysql/mysqlconnector.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/mysqlconnector.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
@@ -13,85 +14,26 @@ r"""
|
||||
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: https://pypi.org/project/mysql-connector-python/
|
||||
|
||||
Driver Status
|
||||
-------------
|
||||
|
||||
MySQL Connector/Python is supported as of SQLAlchemy 2.0.39 to the
|
||||
degree which the driver is functional. There are still ongoing issues
|
||||
with features such as server side cursors which remain disabled until
|
||||
upstream issues are repaired.
|
||||
|
||||
.. warning:: The MySQL Connector/Python driver published by Oracle is subject
|
||||
to frequent, major regressions of essential functionality such as being able
|
||||
to correctly persist simple binary strings which indicate it is not well
|
||||
tested. The SQLAlchemy project is not able to maintain this dialect fully as
|
||||
regressions in the driver prevent it from being included in continuous
|
||||
integration.
|
||||
|
||||
.. versionchanged:: 2.0.39
|
||||
|
||||
The MySQL Connector/Python dialect has been updated to support the
|
||||
latest version of this DBAPI. Previously, MySQL Connector/Python
|
||||
was not fully supported. However, support remains limited due to ongoing
|
||||
regressions introduced in this driver.
|
||||
|
||||
Connecting to MariaDB with MySQL Connector/Python
|
||||
--------------------------------------------------
|
||||
|
||||
MySQL Connector/Python may attempt to pass an incompatible collation to the
|
||||
database when connecting to MariaDB. Experimentation has shown that using
|
||||
``?charset=utf8mb4&collation=utfmb4_general_ci`` or similar MariaDB-compatible
|
||||
charset/collation will allow connectivity.
|
||||
.. note::
|
||||
|
||||
The MySQL Connector/Python DBAPI has had many issues since its release,
|
||||
some of which may remain unresolved, and the mysqlconnector dialect is
|
||||
**not tested as part of SQLAlchemy's continuous integration**.
|
||||
The recommended MySQL dialects are mysqlclient and PyMySQL.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .base import MariaDBIdentifierPreparer
|
||||
from .base import BIT
|
||||
from .base import MySQLCompiler
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLExecutionContext
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from .mariadb import MariaDBDialect
|
||||
from .types import BIT
|
||||
from ... import util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ...engine.base import Connection
|
||||
from ...engine.cursor import CursorResult
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import IsolationLevel
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.row import Row
|
||||
from ...engine.url import URL
|
||||
from ...sql.elements import BinaryExpression
|
||||
|
||||
|
||||
class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext):
|
||||
def create_server_side_cursor(self) -> DBAPICursor:
|
||||
return self._dbapi_connection.cursor(buffered=False)
|
||||
|
||||
def create_default_cursor(self) -> DBAPICursor:
|
||||
return self._dbapi_connection.cursor(buffered=True)
|
||||
|
||||
|
||||
class MySQLCompiler_mysqlconnector(MySQLCompiler):
|
||||
def visit_mod_binary(
|
||||
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
|
||||
) -> str:
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return (
|
||||
self.process(binary.left, **kw)
|
||||
+ " % "
|
||||
@@ -99,37 +41,22 @@ class MySQLCompiler_mysqlconnector(MySQLCompiler):
|
||||
)
|
||||
|
||||
|
||||
class IdentifierPreparerCommon_mysqlconnector:
|
||||
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
|
||||
@property
|
||||
def _double_percents(self) -> bool:
|
||||
def _double_percents(self):
|
||||
return False
|
||||
|
||||
@_double_percents.setter
|
||||
def _double_percents(self, value: Any) -> None:
|
||||
def _double_percents(self, value):
|
||||
pass
|
||||
|
||||
def _escape_identifier(self, value: str) -> str:
|
||||
value = value.replace(
|
||||
self.escape_quote, # type:ignore[attr-defined]
|
||||
self.escape_to_quote, # type:ignore[attr-defined]
|
||||
)
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value
|
||||
|
||||
|
||||
class MySQLIdentifierPreparer_mysqlconnector(
|
||||
IdentifierPreparerCommon_mysqlconnector, MySQLIdentifierPreparer
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class MariaDBIdentifierPreparer_mysqlconnector(
|
||||
IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class _myconnpyBIT(BIT):
|
||||
def result_processor(self, dialect: Any, coltype: Any) -> None:
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""MySQL-connector already converts mysql bits, so."""
|
||||
|
||||
return None
|
||||
@@ -144,31 +71,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
supports_native_bit = True
|
||||
|
||||
# not until https://bugs.mysql.com/bug.php?id=117548
|
||||
supports_server_side_cursors = False
|
||||
|
||||
default_paramstyle = "format"
|
||||
statement_compiler = MySQLCompiler_mysqlconnector
|
||||
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqlconnector
|
||||
|
||||
preparer: type[MySQLIdentifierPreparer] = (
|
||||
MySQLIdentifierPreparer_mysqlconnector
|
||||
)
|
||||
preparer = MySQLIdentifierPreparer_mysqlconnector
|
||||
|
||||
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
return cast("DBAPIModule", __import__("mysql.connector").connector)
|
||||
def import_dbapi(cls):
|
||||
from mysql import connector
|
||||
|
||||
def do_ping(self, dbapi_connection: DBAPIConnection) -> bool:
|
||||
return connector
|
||||
|
||||
def do_ping(self, dbapi_connection):
|
||||
dbapi_connection.ping(False)
|
||||
return True
|
||||
|
||||
def create_connect_args(self, url: URL) -> ConnectArgsType:
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username="user")
|
||||
|
||||
opts.update(url.query)
|
||||
@@ -176,7 +96,6 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
util.coerce_kw_type(opts, "allow_local_infile", bool)
|
||||
util.coerce_kw_type(opts, "autocommit", bool)
|
||||
util.coerce_kw_type(opts, "buffered", bool)
|
||||
util.coerce_kw_type(opts, "client_flag", int)
|
||||
util.coerce_kw_type(opts, "compress", bool)
|
||||
util.coerce_kw_type(opts, "connection_timeout", int)
|
||||
util.coerce_kw_type(opts, "connect_timeout", int)
|
||||
@@ -191,21 +110,15 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
util.coerce_kw_type(opts, "use_pure", bool)
|
||||
util.coerce_kw_type(opts, "use_unicode", bool)
|
||||
|
||||
# note that "buffered" is set to False by default in MySQL/connector
|
||||
# python. If you set it to True, then there is no way to get a server
|
||||
# side cursor because the logic is written to disallow that.
|
||||
|
||||
# leaving this at True until
|
||||
# https://bugs.mysql.com/bug.php?id=117548 can be fixed
|
||||
opts["buffered"] = True
|
||||
# unfortunately, MySQL/connector python refuses to release a
|
||||
# cursor without reading fully, so non-buffered isn't an option
|
||||
opts.setdefault("buffered", True)
|
||||
|
||||
# FOUND_ROWS must be set in ClientFlag to enable
|
||||
# supports_sane_rowcount.
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
from mysql.connector import constants # type: ignore
|
||||
|
||||
ClientFlag = constants.ClientFlag
|
||||
from mysql.connector.constants import ClientFlag
|
||||
|
||||
client_flags = opts.get(
|
||||
"client_flags", ClientFlag.get_default()
|
||||
@@ -214,35 +127,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
opts["client_flags"] = client_flags
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return [], opts
|
||||
return [[], opts]
|
||||
|
||||
@util.memoized_property
|
||||
def _mysqlconnector_version_info(self) -> Optional[Tuple[int, ...]]:
|
||||
def _mysqlconnector_version_info(self):
|
||||
if self.dbapi and hasattr(self.dbapi, "__version__"):
|
||||
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
|
||||
return None
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
return connection.connection.charset # type: ignore
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception: BaseException) -> int:
|
||||
return exception.errno # type: ignore
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: Exception,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
|
||||
exceptions = (
|
||||
self.loaded_dbapi.OperationalError, #
|
||||
self.loaded_dbapi.InterfaceError,
|
||||
self.loaded_dbapi.ProgrammingError,
|
||||
)
|
||||
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
|
||||
if isinstance(e, exceptions):
|
||||
return (
|
||||
e.errno in errnos
|
||||
@@ -252,51 +154,26 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _compat_fetchall(
|
||||
self,
|
||||
rp: CursorResult[Tuple[Any, ...]],
|
||||
charset: Optional[str] = None,
|
||||
) -> Sequence[Row[Tuple[Any, ...]]]:
|
||||
def _compat_fetchall(self, rp, charset=None):
|
||||
return rp.fetchall()
|
||||
|
||||
def _compat_fetchone(
|
||||
self,
|
||||
rp: CursorResult[Tuple[Any, ...]],
|
||||
charset: Optional[str] = None,
|
||||
) -> Optional[Row[Tuple[Any, ...]]]:
|
||||
def _compat_fetchone(self, rp, charset=None):
|
||||
return rp.fetchone()
|
||||
|
||||
def get_isolation_level_values(
|
||||
self, dbapi_conn: DBAPIConnection
|
||||
) -> Sequence[IsolationLevel]:
|
||||
return (
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
"READ COMMITTED",
|
||||
"REPEATABLE READ",
|
||||
"AUTOCOMMIT",
|
||||
)
|
||||
_isolation_lookup = {
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
"READ COMMITTED",
|
||||
"REPEATABLE READ",
|
||||
"AUTOCOMMIT",
|
||||
}
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def set_isolation_level(
|
||||
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
|
||||
) -> None:
|
||||
def _set_isolation_level(self, connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.autocommit = True
|
||||
connection.autocommit = True
|
||||
else:
|
||||
dbapi_connection.autocommit = False
|
||||
super().set_isolation_level(dbapi_connection, level)
|
||||
|
||||
|
||||
class MariaDBDialect_mysqlconnector(
|
||||
MariaDBDialect, MySQLDialect_mysqlconnector
|
||||
):
|
||||
supports_statement_cache = True
|
||||
_allows_uuid_binds = False
|
||||
preparer = MariaDBIdentifierPreparer_mysqlconnector
|
||||
connection.autocommit = False
|
||||
super()._set_isolation_level(connection, level)
|
||||
|
||||
|
||||
dialect = MySQLDialect_mysqlconnector
|
||||
mariadb_dialect = MariaDBDialect_mysqlconnector
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# dialects/mysql/mysqldb.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/mysqldb.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -46,9 +48,9 @@ key "ssl", which may be specified using the
|
||||
"ssl": {
|
||||
"ca": "/home/gord/client-ssl/ca.pem",
|
||||
"cert": "/home/gord/client-ssl/client-cert.pem",
|
||||
"key": "/home/gord/client-ssl/client-key.pem",
|
||||
"key": "/home/gord/client-ssl/client-key.pem"
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
For convenience, the following keys may also be specified inline within the URL
|
||||
@@ -72,9 +74,7 @@ Using MySQLdb with Google Cloud SQL
|
||||
-----------------------------------
|
||||
|
||||
Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
|
||||
using a URL like the following:
|
||||
|
||||
.. sourcecode:: text
|
||||
using a URL like the following::
|
||||
|
||||
mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
|
||||
|
||||
@@ -84,39 +84,25 @@ Server Side Cursors
|
||||
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import MySQLCompiler
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLExecutionContext
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from .base import TEXT
|
||||
from ... import sql
|
||||
from ... import util
|
||||
from ...util.typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ...engine.base import Connection
|
||||
from ...engine.interfaces import _DBAPIMultiExecuteParams
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import ExecutionContext
|
||||
from ...engine.interfaces import IsolationLevel
|
||||
from ...engine.url import URL
|
||||
|
||||
|
||||
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
|
||||
pass
|
||||
@property
|
||||
def rowcount(self):
|
||||
if hasattr(self, "_rowcount"):
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
|
||||
class MySQLCompiler_mysqldb(MySQLCompiler):
|
||||
@@ -136,9 +122,8 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqldb
|
||||
statement_compiler = MySQLCompiler_mysqldb
|
||||
preparer = MySQLIdentifierPreparer
|
||||
server_version_info: Tuple[int, ...]
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._mysql_dbapi_version = (
|
||||
self._parse_dbapi_version(self.dbapi.__version__)
|
||||
@@ -146,7 +131,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
else (0, 0, 0)
|
||||
)
|
||||
|
||||
def _parse_dbapi_version(self, version: str) -> Tuple[int, ...]:
|
||||
def _parse_dbapi_version(self, version):
|
||||
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
|
||||
@@ -154,7 +139,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
return (0, 0, 0)
|
||||
|
||||
@util.langhelpers.memoized_property
|
||||
def supports_server_side_cursors(self) -> bool:
|
||||
def supports_server_side_cursors(self):
|
||||
try:
|
||||
cursors = __import__("MySQLdb.cursors").cursors
|
||||
self._sscursor = cursors.SSCursor
|
||||
@@ -163,13 +148,13 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return __import__("MySQLdb")
|
||||
|
||||
def on_connect(self) -> Callable[[DBAPIConnection], None]:
|
||||
def on_connect(self):
|
||||
super_ = super().on_connect()
|
||||
|
||||
def on_connect(conn: DBAPIConnection) -> None:
|
||||
def on_connect(conn):
|
||||
if super_ is not None:
|
||||
super_(conn)
|
||||
|
||||
@@ -182,24 +167,43 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
|
||||
return on_connect
|
||||
|
||||
def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
|
||||
def do_ping(self, dbapi_connection):
|
||||
dbapi_connection.ping()
|
||||
return True
|
||||
|
||||
def do_executemany(
|
||||
self,
|
||||
cursor: DBAPICursor,
|
||||
statement: str,
|
||||
parameters: _DBAPIMultiExecuteParams,
|
||||
context: Optional[ExecutionContext] = None,
|
||||
) -> None:
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
rowcount = cursor.executemany(statement, parameters)
|
||||
if context is not None:
|
||||
cast(MySQLExecutionContext, context)._rowcount = rowcount
|
||||
context._rowcount = rowcount
|
||||
|
||||
def create_connect_args(
|
||||
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
|
||||
) -> ConnectArgsType:
|
||||
def _check_unicode_returns(self, connection):
|
||||
# work around issue fixed in
|
||||
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
|
||||
# specific issue w/ the utf8mb4_bin collation and unicode returns
|
||||
|
||||
collation = connection.exec_driver_sql(
|
||||
"show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
|
||||
% (
|
||||
self.identifier_preparer.quote("Charset"),
|
||||
self.identifier_preparer.quote("Collation"),
|
||||
)
|
||||
).scalar()
|
||||
has_utf8mb4_bin = self.server_version_info > (5,) and collation
|
||||
if has_utf8mb4_bin:
|
||||
additional_tests = [
|
||||
sql.collate(
|
||||
sql.cast(
|
||||
sql.literal_column("'test collated returns'"),
|
||||
TEXT(charset="utf8mb4"),
|
||||
),
|
||||
"utf8mb4_bin",
|
||||
)
|
||||
]
|
||||
else:
|
||||
additional_tests = []
|
||||
return super()._check_unicode_returns(connection, additional_tests)
|
||||
|
||||
def create_connect_args(self, url, _translate_args=None):
|
||||
if _translate_args is None:
|
||||
_translate_args = dict(
|
||||
database="db", username="user", password="passwd"
|
||||
@@ -213,7 +217,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
util.coerce_kw_type(opts, "read_timeout", int)
|
||||
util.coerce_kw_type(opts, "write_timeout", int)
|
||||
util.coerce_kw_type(opts, "client_flag", int)
|
||||
util.coerce_kw_type(opts, "local_infile", bool)
|
||||
util.coerce_kw_type(opts, "local_infile", int)
|
||||
# Note: using either of the below will cause all strings to be
|
||||
# returned as Unicode, both in raw SQL operations and with column
|
||||
# types like String and MSString.
|
||||
@@ -248,9 +252,9 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
if client_flag_found_rows is not None:
|
||||
client_flag |= client_flag_found_rows
|
||||
opts["client_flag"] = client_flag
|
||||
return [], opts
|
||||
return [[], opts]
|
||||
|
||||
def _found_rows_client_flag(self) -> Optional[int]:
|
||||
def _found_rows_client_flag(self):
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
CLIENT_FLAGS = __import__(
|
||||
@@ -259,23 +263,20 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
except (AttributeError, ImportError):
|
||||
return None
|
||||
else:
|
||||
return CLIENT_FLAGS.FOUND_ROWS # type: ignore
|
||||
return CLIENT_FLAGS.FOUND_ROWS
|
||||
else:
|
||||
return None
|
||||
|
||||
def _extract_error_code(self, exception: DBAPIModule.Error) -> int:
|
||||
return exception.args[0] # type: ignore[no-any-return]
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.args[0]
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
try:
|
||||
# note: the SQL here would be
|
||||
# "SHOW VARIABLES LIKE 'character_set%%'"
|
||||
|
||||
cset_name: Callable[[], str] = (
|
||||
connection.connection.character_set_name
|
||||
)
|
||||
cset_name = connection.connection.character_set_name
|
||||
except AttributeError:
|
||||
util.warn(
|
||||
"No 'character_set_name' can be detected with "
|
||||
@@ -287,9 +288,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
else:
|
||||
return cset_name()
|
||||
|
||||
def get_isolation_level_values(
|
||||
self, dbapi_conn: DBAPIConnection
|
||||
) -> Tuple[IsolationLevel, ...]:
|
||||
def get_isolation_level_values(self, dbapi_connection):
|
||||
return (
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
@@ -298,12 +297,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
|
||||
"AUTOCOMMIT",
|
||||
)
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool:
|
||||
return dbapi_conn.get_autocommit() # type: ignore[no-any-return]
|
||||
|
||||
def set_isolation_level(
|
||||
self, dbapi_connection: DBAPIConnection, level: IsolationLevel
|
||||
) -> None:
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.autocommit(True)
|
||||
else:
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
# dialects/mysql/provision.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import exc
|
||||
from ...testing.provision import configure_follower
|
||||
from ...testing.provision import create_db
|
||||
@@ -39,13 +34,6 @@ def generate_driver_url(url, driver, query_str):
|
||||
drivername="%s+%s" % (backend, driver)
|
||||
).update_query_string(query_str)
|
||||
|
||||
if driver == "mariadbconnector":
|
||||
new_url = new_url.difference_update_query(["charset"])
|
||||
elif driver == "mysqlconnector":
|
||||
new_url = new_url.update_query_pairs(
|
||||
[("collation", "utf8mb4_general_ci")]
|
||||
)
|
||||
|
||||
try:
|
||||
new_url.get_dialect()
|
||||
except exc.NoSuchModuleError:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# dialects/mysql/pymysql.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/pymysql.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
|
||||
@@ -39,6 +41,7 @@ necessary to indicate ``ssl_check_hostname=false`` in PyMySQL::
|
||||
"&ssl_check_hostname=false"
|
||||
)
|
||||
|
||||
|
||||
MySQL-Python Compatibility
|
||||
--------------------------
|
||||
|
||||
@@ -47,26 +50,9 @@ and targets 100% compatibility. Most behavioral notes for MySQL-python apply
|
||||
to the pymysql driver as well.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from ...util import langhelpers
|
||||
from ...util.typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ...engine.interfaces import ConnectArgsType
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.interfaces import PoolProxiedConnection
|
||||
from ...engine.url import URL
|
||||
|
||||
|
||||
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
@@ -76,7 +62,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
description_encoding = None
|
||||
|
||||
@langhelpers.memoized_property
|
||||
def supports_server_side_cursors(self) -> bool:
|
||||
def supports_server_side_cursors(self):
|
||||
try:
|
||||
cursors = __import__("pymysql.cursors").cursors
|
||||
self._sscursor = cursors.SSCursor
|
||||
@@ -85,11 +71,11 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> DBAPIModule:
|
||||
def import_dbapi(cls):
|
||||
return __import__("pymysql")
|
||||
|
||||
@langhelpers.memoized_property
|
||||
def _send_false_to_ping(self) -> bool:
|
||||
def _send_false_to_ping(self):
|
||||
"""determine if pymysql has deprecated, changed the default of,
|
||||
or removed the 'reconnect' argument of connection.ping().
|
||||
|
||||
@@ -100,9 +86,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
""" # noqa: E501
|
||||
|
||||
try:
|
||||
Connection = __import__(
|
||||
"pymysql.connections"
|
||||
).connections.Connection
|
||||
Connection = __import__("pymysql.connections").Connection
|
||||
except (ImportError, AttributeError):
|
||||
return True
|
||||
else:
|
||||
@@ -116,7 +100,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
not insp.defaults or insp.defaults[0] is not False
|
||||
)
|
||||
|
||||
def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]:
|
||||
def do_ping(self, dbapi_connection):
|
||||
if self._send_false_to_ping:
|
||||
dbapi_connection.ping(False)
|
||||
else:
|
||||
@@ -124,24 +108,17 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
|
||||
return True
|
||||
|
||||
def create_connect_args(
|
||||
self, url: URL, _translate_args: Optional[Dict[str, Any]] = None
|
||||
) -> ConnectArgsType:
|
||||
def create_connect_args(self, url, _translate_args=None):
|
||||
if _translate_args is None:
|
||||
_translate_args = dict(username="user")
|
||||
return super().create_connect_args(
|
||||
url, _translate_args=_translate_args
|
||||
)
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
elif isinstance(e, self.loaded_dbapi.Error):
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
str_e = str(e).lower()
|
||||
return (
|
||||
"already closed" in str_e or "connection was killed" in str_e
|
||||
@@ -149,7 +126,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _extract_error_code(self, exception: BaseException) -> Any:
|
||||
def _extract_error_code(self, exception):
|
||||
if isinstance(exception.args[0], Exception):
|
||||
exception = exception.args[0]
|
||||
return exception.args[0]
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
# dialects/mysql/pyodbc.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/pyodbc.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
|
||||
|
||||
.. dialect:: mysql+pyodbc
|
||||
:name: PyODBC
|
||||
:dbapi: pyodbc
|
||||
@@ -28,30 +30,21 @@ r"""
|
||||
Pass through exact pyodbc connection string::
|
||||
|
||||
import urllib
|
||||
|
||||
connection_string = (
|
||||
"DRIVER=MySQL ODBC 8.0 ANSI Driver;"
|
||||
"SERVER=localhost;"
|
||||
"PORT=3307;"
|
||||
"DATABASE=mydb;"
|
||||
"UID=root;"
|
||||
"PWD=(whatever);"
|
||||
"charset=utf8mb4;"
|
||||
'DRIVER=MySQL ODBC 8.0 ANSI Driver;'
|
||||
'SERVER=localhost;'
|
||||
'PORT=3307;'
|
||||
'DATABASE=mydb;'
|
||||
'UID=root;'
|
||||
'PWD=(whatever);'
|
||||
'charset=utf8mb4;'
|
||||
)
|
||||
params = urllib.parse.quote_plus(connection_string)
|
||||
connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLExecutionContext
|
||||
@@ -61,31 +54,23 @@ from ... import util
|
||||
from ...connectors.pyodbc import PyODBCConnector
|
||||
from ...sql.sqltypes import Time
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine import Connection
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
|
||||
|
||||
class _pyodbcTIME(TIME):
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> _ResultProcessorType[datetime.time]:
|
||||
def process(value: Any) -> Union[datetime.time, None]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
# pyodbc returns a datetime.time object; no need to convert
|
||||
return value # type: ignore[no-any-return]
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
|
||||
def get_lastrowid(self) -> int:
|
||||
def get_lastrowid(self):
|
||||
cursor = self.create_cursor()
|
||||
cursor.execute("SELECT LAST_INSERT_ID()")
|
||||
lastrowid = cursor.fetchone()[0] # type: ignore[index]
|
||||
lastrowid = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return lastrowid # type: ignore[no-any-return]
|
||||
return lastrowid
|
||||
|
||||
|
||||
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
@@ -96,7 +81,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
|
||||
pyodbc_driver_name = "MySQL"
|
||||
|
||||
def _detect_charset(self, connection: Connection) -> str:
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
# Prefer 'character_set_results' for the current connection over the
|
||||
@@ -121,25 +106,21 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
)
|
||||
return "latin1"
|
||||
|
||||
def _get_server_version_info(
|
||||
self, connection: Connection
|
||||
) -> Tuple[int, ...]:
|
||||
def _get_server_version_info(self, connection):
|
||||
return MySQLDialect._get_server_version_info(self, connection)
|
||||
|
||||
def _extract_error_code(self, exception: BaseException) -> Optional[int]:
|
||||
def _extract_error_code(self, exception):
|
||||
m = re.compile(r"\((\d+)\)").search(str(exception.args))
|
||||
if m is None:
|
||||
return None
|
||||
c: Optional[str] = m.group(1)
|
||||
c = m.group(1)
|
||||
if c:
|
||||
return int(c)
|
||||
else:
|
||||
return None
|
||||
|
||||
def on_connect(self) -> Callable[[DBAPIConnection], None]:
|
||||
def on_connect(self):
|
||||
super_ = super().on_connect()
|
||||
|
||||
def on_connect(conn: DBAPIConnection) -> None:
|
||||
def on_connect(conn):
|
||||
if super_ is not None:
|
||||
super_(conn)
|
||||
|
||||
|
||||
@@ -1,65 +1,46 @@
|
||||
# dialects/mysql/reflection.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/reflection.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .enumerated import ENUM
|
||||
from .enumerated import SET
|
||||
from .types import DATETIME
|
||||
from .types import TIME
|
||||
from .types import TIMESTAMP
|
||||
from ... import log
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
from ...util.typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from ...engine.interfaces import ReflectedColumn
|
||||
|
||||
|
||||
class ReflectedState:
|
||||
"""Stores raw information about a SHOW CREATE TABLE statement."""
|
||||
|
||||
charset: Optional[str]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.columns: List[ReflectedColumn] = []
|
||||
self.table_options: Dict[str, str] = {}
|
||||
self.table_name: Optional[str] = None
|
||||
self.keys: List[Dict[str, Any]] = []
|
||||
self.fk_constraints: List[Dict[str, Any]] = []
|
||||
self.ck_constraints: List[Dict[str, Any]] = []
|
||||
def __init__(self):
|
||||
self.columns = []
|
||||
self.table_options = {}
|
||||
self.table_name = None
|
||||
self.keys = []
|
||||
self.fk_constraints = []
|
||||
self.ck_constraints = []
|
||||
|
||||
|
||||
@log.class_logger
|
||||
class MySQLTableDefinitionParser:
|
||||
"""Parses the results of a SHOW CREATE TABLE statement."""
|
||||
|
||||
def __init__(
|
||||
self, dialect: MySQLDialect, preparer: MySQLIdentifierPreparer
|
||||
):
|
||||
def __init__(self, dialect, preparer):
|
||||
self.dialect = dialect
|
||||
self.preparer = preparer
|
||||
self._prep_regexes()
|
||||
|
||||
def parse(
|
||||
self, show_create: str, charset: Optional[str]
|
||||
) -> ReflectedState:
|
||||
def parse(self, show_create, charset):
|
||||
state = ReflectedState()
|
||||
state.charset = charset
|
||||
for line in re.split(r"\r?\n", show_create):
|
||||
@@ -84,11 +65,11 @@ class MySQLTableDefinitionParser:
|
||||
if type_ is None:
|
||||
util.warn("Unknown schema content: %r" % line)
|
||||
elif type_ == "key":
|
||||
state.keys.append(spec) # type: ignore[arg-type]
|
||||
state.keys.append(spec)
|
||||
elif type_ == "fk_constraint":
|
||||
state.fk_constraints.append(spec) # type: ignore[arg-type]
|
||||
state.fk_constraints.append(spec)
|
||||
elif type_ == "ck_constraint":
|
||||
state.ck_constraints.append(spec) # type: ignore[arg-type]
|
||||
state.ck_constraints.append(spec)
|
||||
else:
|
||||
pass
|
||||
return state
|
||||
@@ -96,13 +77,7 @@ class MySQLTableDefinitionParser:
|
||||
def _check_view(self, sql: str) -> bool:
|
||||
return bool(self._re_is_view.match(sql))
|
||||
|
||||
def _parse_constraints(self, line: str) -> Union[
|
||||
Tuple[None, str],
|
||||
Tuple[Literal["partition"], str],
|
||||
Tuple[
|
||||
Literal["ck_constraint", "fk_constraint", "key"], Dict[str, str]
|
||||
],
|
||||
]:
|
||||
def _parse_constraints(self, line):
|
||||
"""Parse a KEY or CONSTRAINT line.
|
||||
|
||||
:param line: A line of SHOW CREATE TABLE output
|
||||
@@ -152,7 +127,7 @@ class MySQLTableDefinitionParser:
|
||||
# No match.
|
||||
return (None, line)
|
||||
|
||||
def _parse_table_name(self, line: str, state: ReflectedState) -> None:
|
||||
def _parse_table_name(self, line, state):
|
||||
"""Extract the table name.
|
||||
|
||||
:param line: The first line of SHOW CREATE TABLE
|
||||
@@ -163,7 +138,7 @@ class MySQLTableDefinitionParser:
|
||||
if m:
|
||||
state.table_name = cleanup(m.group("name"))
|
||||
|
||||
def _parse_table_options(self, line: str, state: ReflectedState) -> None:
|
||||
def _parse_table_options(self, line, state):
|
||||
"""Build a dictionary of all reflected table-level options.
|
||||
|
||||
:param line: The final line of SHOW CREATE TABLE output.
|
||||
@@ -189,9 +164,7 @@ class MySQLTableDefinitionParser:
|
||||
for opt, val in options.items():
|
||||
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
|
||||
|
||||
def _parse_partition_options(
|
||||
self, line: str, state: ReflectedState
|
||||
) -> None:
|
||||
def _parse_partition_options(self, line, state):
|
||||
options = {}
|
||||
new_line = line[:]
|
||||
|
||||
@@ -247,7 +220,7 @@ class MySQLTableDefinitionParser:
|
||||
else:
|
||||
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
|
||||
|
||||
def _parse_column(self, line: str, state: ReflectedState) -> None:
|
||||
def _parse_column(self, line, state):
|
||||
"""Extract column details.
|
||||
|
||||
Falls back to a 'minimal support' variant if full parse fails.
|
||||
@@ -310,16 +283,13 @@ class MySQLTableDefinitionParser:
|
||||
|
||||
type_instance = col_type(*type_args, **type_kw)
|
||||
|
||||
col_kw: Dict[str, Any] = {}
|
||||
col_kw = {}
|
||||
|
||||
# NOT NULL
|
||||
col_kw["nullable"] = True
|
||||
# this can be "NULL" in the case of TIMESTAMP
|
||||
if spec.get("notnull", False) == "NOT NULL":
|
||||
col_kw["nullable"] = False
|
||||
# For generated columns, the nullability is marked in a different place
|
||||
if spec.get("notnull_generated", False) == "NOT NULL":
|
||||
col_kw["nullable"] = False
|
||||
|
||||
# AUTO_INCREMENT
|
||||
if spec.get("autoincr", False):
|
||||
@@ -351,13 +321,9 @@ class MySQLTableDefinitionParser:
|
||||
name=name, type=type_instance, default=default, comment=comment
|
||||
)
|
||||
col_d.update(col_kw)
|
||||
state.columns.append(col_d) # type: ignore[arg-type]
|
||||
state.columns.append(col_d)
|
||||
|
||||
def _describe_to_create(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: Sequence[Tuple[str, str, str, str, str, str]],
|
||||
) -> str:
|
||||
def _describe_to_create(self, table_name, columns):
|
||||
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
|
||||
|
||||
DESCRIBE is a much simpler reflection and is sufficient for
|
||||
@@ -410,9 +376,7 @@ class MySQLTableDefinitionParser:
|
||||
]
|
||||
)
|
||||
|
||||
def _parse_keyexprs(
|
||||
self, identifiers: str
|
||||
) -> List[Tuple[str, Optional[int], str]]:
|
||||
def _parse_keyexprs(self, identifiers):
|
||||
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
|
||||
|
||||
return [
|
||||
@@ -422,12 +386,11 @@ class MySQLTableDefinitionParser:
|
||||
)
|
||||
]
|
||||
|
||||
def _prep_regexes(self) -> None:
|
||||
def _prep_regexes(self):
|
||||
"""Pre-compile regular expressions."""
|
||||
|
||||
self._pr_options: List[
|
||||
Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]
|
||||
] = []
|
||||
self._re_columns = []
|
||||
self._pr_options = []
|
||||
|
||||
_final = self.preparer.final_quote
|
||||
|
||||
@@ -485,13 +448,11 @@ class MySQLTableDefinitionParser:
|
||||
r"(?: +COLLATE +(?P<collate>[\w_]+))?"
|
||||
r"(?: +(?P<notnull>(?:NOT )?NULL))?"
|
||||
r"(?: +DEFAULT +(?P<default>"
|
||||
r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+"
|
||||
r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
|
||||
r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
|
||||
r"))?"
|
||||
r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
|
||||
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?"
|
||||
r"(?: +(?P<notnull_generated>(?:NOT )?NULL))?"
|
||||
r")?"
|
||||
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?)?"
|
||||
r"(?: +(?P<autoincr>AUTO_INCREMENT))?"
|
||||
r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?"
|
||||
r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
|
||||
@@ -539,7 +500,7 @@ class MySQLTableDefinitionParser:
|
||||
#
|
||||
# unique constraints come back as KEYs
|
||||
kw = quotes.copy()
|
||||
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT"
|
||||
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
|
||||
self._re_fk_constraint = _re_compile(
|
||||
r" "
|
||||
r"CONSTRAINT +"
|
||||
@@ -616,21 +577,21 @@ class MySQLTableDefinitionParser:
|
||||
|
||||
_optional_equals = r"(?:\s*(?:=\s*)|\s+)"
|
||||
|
||||
def _add_option_string(self, directive: str) -> None:
|
||||
def _add_option_string(self, directive):
|
||||
regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
)
|
||||
self._pr_options.append(_pr_compile(regex, cleanup_text))
|
||||
|
||||
def _add_option_word(self, directive: str) -> None:
|
||||
def _add_option_word(self, directive):
|
||||
regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
)
|
||||
self._pr_options.append(_pr_compile(regex))
|
||||
|
||||
def _add_partition_option_word(self, directive: str) -> None:
|
||||
def _add_partition_option_word(self, directive):
|
||||
if directive == "PARTITION BY" or directive == "SUBPARTITION BY":
|
||||
regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\w+.*)" % (
|
||||
re.escape(directive),
|
||||
@@ -645,7 +606,7 @@ class MySQLTableDefinitionParser:
|
||||
regex = r"(?<!\S)(?P<directive>%s)(?!\S)" % (re.escape(directive),)
|
||||
self._pr_options.append(_pr_compile(regex))
|
||||
|
||||
def _add_option_regex(self, directive: str, regex: str) -> None:
|
||||
def _add_option_regex(self, directive, regex):
|
||||
regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
@@ -663,35 +624,21 @@ _options_of_type_string = (
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def _pr_compile(
|
||||
regex: str, cleanup: Callable[[str], str]
|
||||
) -> Tuple[re.Pattern[Any], Callable[[str], str]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _pr_compile(
|
||||
regex: str, cleanup: None = None
|
||||
) -> Tuple[re.Pattern[Any], None]: ...
|
||||
|
||||
|
||||
def _pr_compile(
|
||||
regex: str, cleanup: Optional[Callable[[str], str]] = None
|
||||
) -> Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]:
|
||||
def _pr_compile(regex, cleanup=None):
|
||||
"""Prepare a 2-tuple of compiled regex and callable."""
|
||||
|
||||
return (_re_compile(regex), cleanup)
|
||||
|
||||
|
||||
def _re_compile(regex: str) -> re.Pattern[Any]:
|
||||
def _re_compile(regex):
|
||||
"""Compile a string to regex, I and UNICODE."""
|
||||
|
||||
return re.compile(regex, re.I | re.UNICODE)
|
||||
|
||||
|
||||
def _strip_values(values: Sequence[str]) -> List[str]:
|
||||
def _strip_values(values):
|
||||
"Strip reflected values quotes"
|
||||
strip_values: List[str] = []
|
||||
strip_values = []
|
||||
for a in values:
|
||||
if a[0:1] == '"' or a[0:1] == "'":
|
||||
# strip enclosing quotes and unquote interior
|
||||
@@ -703,9 +650,7 @@ def _strip_values(values: Sequence[str]) -> List[str]:
|
||||
def cleanup_text(raw_text: str) -> str:
|
||||
if "\\" in raw_text:
|
||||
raw_text = re.sub(
|
||||
_control_char_regexp,
|
||||
lambda s: _control_char_map[s[0]], # type: ignore[index]
|
||||
raw_text,
|
||||
_control_char_regexp, lambda s: _control_char_map[s[0]], raw_text
|
||||
)
|
||||
return raw_text.replace("''", "'")
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/mysql/reserved_words.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/reserved_words.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -11,6 +11,7 @@
|
||||
# https://mariadb.com/kb/en/reserved-words/
|
||||
# includes: Reserved Words, Oracle Mode (separate set unioned)
|
||||
# excludes: Exceptions, Function Names
|
||||
# mypy: ignore-errors
|
||||
|
||||
RESERVED_WORDS_MARIADB = {
|
||||
"accessible",
|
||||
@@ -281,7 +282,6 @@ RESERVED_WORDS_MARIADB = {
|
||||
}
|
||||
)
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.3/en/keywords.html
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/keywords.html
|
||||
# https://dev.mysql.com/doc/refman/5.7/en/keywords.html
|
||||
# https://dev.mysql.com/doc/refman/5.6/en/keywords.html
|
||||
@@ -403,7 +403,6 @@ RESERVED_WORDS_MYSQL = {
|
||||
"int4",
|
||||
"int8",
|
||||
"integer",
|
||||
"intersect",
|
||||
"interval",
|
||||
"into",
|
||||
"io_after_gtids",
|
||||
@@ -469,7 +468,6 @@ RESERVED_WORDS_MYSQL = {
|
||||
"outfile",
|
||||
"over",
|
||||
"parse_gcol_expr",
|
||||
"parallel",
|
||||
"partition",
|
||||
"percent_rank",
|
||||
"persist",
|
||||
@@ -478,7 +476,6 @@ RESERVED_WORDS_MYSQL = {
|
||||
"primary",
|
||||
"procedure",
|
||||
"purge",
|
||||
"qualify",
|
||||
"range",
|
||||
"rank",
|
||||
"read",
|
||||
|
||||
@@ -1,30 +1,18 @@
|
||||
# dialects/mysql/types.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# mysql/types.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
from typing import Any
|
||||
from typing import Iterable
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from ... import exc
|
||||
from ... import util
|
||||
from ...sql import sqltypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base import MySQLDialect
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
from ...sql.type_api import TypeEngine
|
||||
|
||||
|
||||
class _NumericType:
|
||||
"""Base for MySQL numeric types.
|
||||
@@ -34,27 +22,19 @@ class _NumericType:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, unsigned: bool = False, zerofill: bool = False, **kw: Any
|
||||
):
|
||||
def __init__(self, unsigned=False, zerofill=False, **kw):
|
||||
self.unsigned = unsigned
|
||||
self.zerofill = zerofill
|
||||
super().__init__(**kw)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_NumericType, sqltypes.Numeric]
|
||||
)
|
||||
|
||||
|
||||
class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]):
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
class _FloatType(_NumericType, sqltypes.Float):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
if isinstance(self, (REAL, DOUBLE)) and (
|
||||
(precision is None and scale is not None)
|
||||
or (precision is not None and scale is None)
|
||||
@@ -66,18 +46,18 @@ class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]):
|
||||
super().__init__(precision=precision, asdecimal=asdecimal, **kw)
|
||||
self.scale = scale
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
|
||||
)
|
||||
|
||||
|
||||
class _IntegerType(_NumericType, sqltypes.Integer):
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
self.display_width = display_width
|
||||
super().__init__(**kw)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
|
||||
)
|
||||
@@ -88,13 +68,13 @@ class _StringType(sqltypes.String):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
charset: Optional[str] = None,
|
||||
collation: Optional[str] = None,
|
||||
ascii: bool = False, # noqa
|
||||
binary: bool = False,
|
||||
unicode: bool = False,
|
||||
national: bool = False,
|
||||
**kw: Any,
|
||||
charset=None,
|
||||
collation=None,
|
||||
ascii=False, # noqa
|
||||
binary=False,
|
||||
unicode=False,
|
||||
national=False,
|
||||
**kw,
|
||||
):
|
||||
self.charset = charset
|
||||
|
||||
@@ -107,33 +87,25 @@ class _StringType(sqltypes.String):
|
||||
self.national = national
|
||||
super().__init__(**kw)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_StringType, sqltypes.String]
|
||||
)
|
||||
|
||||
|
||||
class _MatchType(
|
||||
sqltypes.Float[Union[decimal.Decimal, float]], sqltypes.MatchType
|
||||
):
|
||||
def __init__(self, **kw: Any):
|
||||
class _MatchType(sqltypes.Float, sqltypes.MatchType):
|
||||
def __init__(self, **kw):
|
||||
# TODO: float arguments?
|
||||
sqltypes.Float.__init__(self) # type: ignore[arg-type]
|
||||
sqltypes.Float.__init__(self)
|
||||
sqltypes.MatchType.__init__(self)
|
||||
|
||||
|
||||
class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
|
||||
class NUMERIC(_NumericType, sqltypes.NUMERIC):
|
||||
"""MySQL NUMERIC type."""
|
||||
|
||||
__visit_name__ = "NUMERIC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a NUMERIC.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
@@ -154,18 +126,12 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]):
|
||||
)
|
||||
|
||||
|
||||
class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
|
||||
class DECIMAL(_NumericType, sqltypes.DECIMAL):
|
||||
"""MySQL DECIMAL type."""
|
||||
|
||||
__visit_name__ = "DECIMAL"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DECIMAL.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
@@ -186,18 +152,12 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]):
|
||||
)
|
||||
|
||||
|
||||
class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
|
||||
class DOUBLE(_FloatType, sqltypes.DOUBLE):
|
||||
"""MySQL DOUBLE type."""
|
||||
|
||||
__visit_name__ = "DOUBLE"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DOUBLE.
|
||||
|
||||
.. note::
|
||||
@@ -226,18 +186,12 @@ class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]):
|
||||
)
|
||||
|
||||
|
||||
class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
|
||||
class REAL(_FloatType, sqltypes.REAL):
|
||||
"""MySQL REAL type."""
|
||||
|
||||
__visit_name__ = "REAL"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a REAL.
|
||||
|
||||
.. note::
|
||||
@@ -266,18 +220,12 @@ class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]):
|
||||
)
|
||||
|
||||
|
||||
class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
|
||||
class FLOAT(_FloatType, sqltypes.FLOAT):
|
||||
"""MySQL FLOAT type."""
|
||||
|
||||
__visit_name__ = "FLOAT"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision: Optional[int] = None,
|
||||
scale: Optional[int] = None,
|
||||
asdecimal: bool = False,
|
||||
**kw: Any,
|
||||
):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
|
||||
"""Construct a FLOAT.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
@@ -297,9 +245,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]):
|
||||
precision=precision, scale=scale, asdecimal=asdecimal, **kw
|
||||
)
|
||||
|
||||
def bind_processor(
|
||||
self, dialect: Dialect
|
||||
) -> Optional[_BindProcessorType[Union[decimal.Decimal, float]]]:
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
|
||||
@@ -308,7 +254,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER):
|
||||
|
||||
__visit_name__ = "INTEGER"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct an INTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -329,7 +275,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT):
|
||||
|
||||
__visit_name__ = "BIGINT"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a BIGINTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -350,7 +296,7 @@ class MEDIUMINT(_IntegerType):
|
||||
|
||||
__visit_name__ = "MEDIUMINT"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a MEDIUMINTEGER
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -371,7 +317,7 @@ class TINYINT(_IntegerType):
|
||||
|
||||
__visit_name__ = "TINYINT"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a TINYINT.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -386,19 +332,13 @@ class TINYINT(_IntegerType):
|
||||
"""
|
||||
super().__init__(display_width=display_width, **kw)
|
||||
|
||||
def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool:
|
||||
return (
|
||||
self._type_affinity is other._type_affinity
|
||||
or other._type_affinity is sqltypes.Boolean
|
||||
)
|
||||
|
||||
|
||||
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
|
||||
"""MySQL SMALLINTEGER type."""
|
||||
|
||||
__visit_name__ = "SMALLINT"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a SMALLINTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
@@ -414,7 +354,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT):
|
||||
super().__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class BIT(sqltypes.TypeEngine[Any]):
|
||||
class BIT(sqltypes.TypeEngine):
|
||||
"""MySQL BIT type.
|
||||
|
||||
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
|
||||
@@ -425,7 +365,7 @@ class BIT(sqltypes.TypeEngine[Any]):
|
||||
|
||||
__visit_name__ = "BIT"
|
||||
|
||||
def __init__(self, length: Optional[int] = None):
|
||||
def __init__(self, length=None):
|
||||
"""Construct a BIT.
|
||||
|
||||
:param length: Optional, number of bits.
|
||||
@@ -433,19 +373,20 @@ class BIT(sqltypes.TypeEngine[Any]):
|
||||
"""
|
||||
self.length = length
|
||||
|
||||
def result_processor(
|
||||
self, dialect: MySQLDialect, coltype: object # type: ignore[override]
|
||||
) -> Optional[_ResultProcessorType[Any]]:
|
||||
"""Convert a MySQL's 64 bit, variable length binary string to a
|
||||
long."""
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Convert a MySQL's 64 bit, variable length binary string to a long.
|
||||
|
||||
if dialect.supports_native_bit:
|
||||
return None
|
||||
TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
|
||||
already do this, so this logic should be moved to those dialects.
|
||||
|
||||
def process(value: Optional[Iterable[int]]) -> Optional[int]:
|
||||
"""
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
v = 0
|
||||
for i in value:
|
||||
if not isinstance(i, int):
|
||||
i = ord(i) # convert byte to int on Python 2
|
||||
v = v << 8 | i
|
||||
return v
|
||||
return value
|
||||
@@ -458,7 +399,7 @@ class TIME(sqltypes.TIME):
|
||||
|
||||
__visit_name__ = "TIME"
|
||||
|
||||
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL TIME type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
@@ -477,12 +418,10 @@ class TIME(sqltypes.TIME):
|
||||
super().__init__(timezone=timezone)
|
||||
self.fsp = fsp
|
||||
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> _ResultProcessorType[datetime.time]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
time = datetime.time
|
||||
|
||||
def process(value: Any) -> Optional[datetime.time]:
|
||||
def process(value):
|
||||
# convert from a timedelta value
|
||||
if value is not None:
|
||||
microseconds = value.microseconds
|
||||
@@ -505,7 +444,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
|
||||
__visit_name__ = "TIMESTAMP"
|
||||
|
||||
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL TIMESTAMP type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
@@ -530,7 +469,7 @@ class DATETIME(sqltypes.DATETIME):
|
||||
|
||||
__visit_name__ = "DATETIME"
|
||||
|
||||
def __init__(self, timezone: bool = False, fsp: Optional[int] = None):
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL DATETIME type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
@@ -550,26 +489,26 @@ class DATETIME(sqltypes.DATETIME):
|
||||
self.fsp = fsp
|
||||
|
||||
|
||||
class YEAR(sqltypes.TypeEngine[Any]):
|
||||
class YEAR(sqltypes.TypeEngine):
|
||||
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
|
||||
|
||||
__visit_name__ = "YEAR"
|
||||
|
||||
def __init__(self, display_width: Optional[int] = None):
|
||||
def __init__(self, display_width=None):
|
||||
self.display_width = display_width
|
||||
|
||||
|
||||
class TEXT(_StringType, sqltypes.TEXT):
|
||||
"""MySQL TEXT type, for character storage encoded up to 2^16 bytes."""
|
||||
"""MySQL TEXT type, for text up to 2^16 characters."""
|
||||
|
||||
__visit_name__ = "TEXT"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kw: Any):
|
||||
def __init__(self, length=None, **kw):
|
||||
"""Construct a TEXT.
|
||||
|
||||
:param length: Optional, if provided the server may optimize storage
|
||||
by substituting the smallest TEXT type sufficient to store
|
||||
``length`` bytes of characters.
|
||||
``length`` characters.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
@@ -596,11 +535,11 @@ class TEXT(_StringType, sqltypes.TEXT):
|
||||
|
||||
|
||||
class TINYTEXT(_StringType):
|
||||
"""MySQL TINYTEXT type, for character storage encoded up to 2^8 bytes."""
|
||||
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
|
||||
|
||||
__visit_name__ = "TINYTEXT"
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a TINYTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
@@ -628,12 +567,11 @@ class TINYTEXT(_StringType):
|
||||
|
||||
|
||||
class MEDIUMTEXT(_StringType):
|
||||
"""MySQL MEDIUMTEXT type, for character storage encoded up
|
||||
to 2^24 bytes."""
|
||||
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
|
||||
|
||||
__visit_name__ = "MEDIUMTEXT"
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a MEDIUMTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
@@ -661,11 +599,11 @@ class MEDIUMTEXT(_StringType):
|
||||
|
||||
|
||||
class LONGTEXT(_StringType):
|
||||
"""MySQL LONGTEXT type, for character storage encoded up to 2^32 bytes."""
|
||||
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
|
||||
|
||||
__visit_name__ = "LONGTEXT"
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a LONGTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
@@ -697,7 +635,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
|
||||
|
||||
__visit_name__ = "VARCHAR"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kwargs: Any) -> None:
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a VARCHAR.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
@@ -729,7 +667,7 @@ class CHAR(_StringType, sqltypes.CHAR):
|
||||
|
||||
__visit_name__ = "CHAR"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kwargs: Any):
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a CHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
@@ -745,7 +683,7 @@ class CHAR(_StringType, sqltypes.CHAR):
|
||||
super().__init__(length=length, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _adapt_string_for_cast(cls, type_: sqltypes.String) -> sqltypes.CHAR:
|
||||
def _adapt_string_for_cast(self, type_):
|
||||
# copy the given string type into a CHAR
|
||||
# for the purposes of rendering a CAST expression
|
||||
type_ = sqltypes.to_instance(type_)
|
||||
@@ -774,7 +712,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
|
||||
|
||||
__visit_name__ = "NVARCHAR"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kwargs: Any):
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct an NVARCHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
@@ -800,7 +738,7 @@ class NCHAR(_StringType, sqltypes.NCHAR):
|
||||
|
||||
__visit_name__ = "NCHAR"
|
||||
|
||||
def __init__(self, length: Optional[int] = None, **kwargs: Any):
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct an NCHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# dialects/oracle/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# oracle/__init__.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
from . import base # noqa
|
||||
from . import cx_oracle # noqa
|
||||
@@ -32,18 +32,7 @@ from .base import ROWID
|
||||
from .base import TIMESTAMP
|
||||
from .base import VARCHAR
|
||||
from .base import VARCHAR2
|
||||
from .base import VECTOR
|
||||
from .base import VectorIndexConfig
|
||||
from .base import VectorIndexType
|
||||
from .vector import SparseVector
|
||||
from .vector import VectorDistanceType
|
||||
from .vector import VectorStorageFormat
|
||||
from .vector import VectorStorageType
|
||||
|
||||
# Alias oracledb also as oracledb_async
|
||||
oracledb_async = type(
|
||||
"oracledb_async", (ModuleType,), {"dialect": oracledb.dialect_async}
|
||||
)
|
||||
|
||||
base.dialect = dialect = cx_oracle.dialect
|
||||
|
||||
@@ -71,11 +60,4 @@ __all__ = (
|
||||
"NVARCHAR2",
|
||||
"ROWID",
|
||||
"REAL",
|
||||
"VECTOR",
|
||||
"VectorDistanceType",
|
||||
"VectorIndexType",
|
||||
"VectorIndexConfig",
|
||||
"VectorStorageFormat",
|
||||
"VectorStorageType",
|
||||
"SparseVector",
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,4 @@
|
||||
# dialects/oracle/cx_oracle.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -7,18 +6,13 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r""".. dialect:: oracle+cx_oracle
|
||||
r"""
|
||||
.. dialect:: oracle+cx_oracle
|
||||
:name: cx-Oracle
|
||||
:dbapi: cx_oracle
|
||||
:connectstring: oracle+cx_oracle://user:pass@hostname:port[/dbname][?service_name=<service>[&key=value&key=value...]]
|
||||
:url: https://oracle.github.io/python-cx_Oracle/
|
||||
|
||||
Description
|
||||
-----------
|
||||
|
||||
cx_Oracle was the original driver for Oracle Database. It was superseded by
|
||||
python-oracledb which should be used instead.
|
||||
|
||||
DSN vs. Hostname connections
|
||||
-----------------------------
|
||||
|
||||
@@ -28,41 +22,27 @@ dialect translates from a series of different URL forms.
|
||||
Hostname Connections with Easy Connect Syntax
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Given a hostname, port and service name of the target database, for example
|
||||
from Oracle Database's Easy Connect syntax then connect in SQLAlchemy using the
|
||||
``service_name`` query string parameter::
|
||||
Given a hostname, port and service name of the target Oracle Database, for
|
||||
example from Oracle's `Easy Connect syntax
|
||||
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#easy-connect-syntax-for-connection-strings>`_,
|
||||
then connect in SQLAlchemy using the ``service_name`` query string parameter::
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+cx_oracle://scott:tiger@hostname:port?service_name=myservice&encoding=UTF-8&nencoding=UTF-8"
|
||||
)
|
||||
engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:port/?service_name=myservice&encoding=UTF-8&nencoding=UTF-8")
|
||||
|
||||
Note that the default driver value for encoding and nencoding was changed to
|
||||
“UTF-8” in cx_Oracle 8.0 so these parameters can be omitted when using that
|
||||
version, or later.
|
||||
The `full Easy Connect syntax
|
||||
<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-B0437826-43C1-49EC-A94D-B650B6A4A6EE>`_
|
||||
is not supported. Instead, use a ``tnsnames.ora`` file and connect using a
|
||||
DSN.
|
||||
|
||||
To use a full Easy Connect string, pass it as the ``dsn`` key value in a
|
||||
:paramref:`_sa.create_engine.connect_args` dictionary::
|
||||
Connections with tnsnames.ora or Oracle Cloud
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
import cx_Oracle
|
||||
|
||||
e = create_engine(
|
||||
"oracle+cx_oracle://@",
|
||||
connect_args={
|
||||
"user": "scott",
|
||||
"password": "tiger",
|
||||
"dsn": "hostname:port/myservice?transport_connect_timeout=30&expire_time=60",
|
||||
},
|
||||
)
|
||||
|
||||
Connections with tnsnames.ora or to Oracle Autonomous Database
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Alternatively, if no port, database name, or service name is provided, the
|
||||
dialect will use an Oracle Database DSN "connection string". This takes the
|
||||
"hostname" portion of the URL as the data source name. For example, if the
|
||||
``tnsnames.ora`` file contains a TNS Alias of ``myalias`` as below:
|
||||
|
||||
.. sourcecode:: text
|
||||
Alternatively, if no port, database name, or ``service_name`` is provided, the
|
||||
dialect will use an Oracle DSN "connection string". This takes the "hostname"
|
||||
portion of the URL as the data source name. For example, if the
|
||||
``tnsnames.ora`` file contains a `Net Service Name
|
||||
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#net-service-names-for-connection-strings>`_
|
||||
of ``myalias`` as below::
|
||||
|
||||
myalias =
|
||||
(DESCRIPTION =
|
||||
@@ -77,22 +57,19 @@ The cx_Oracle dialect connects to this database service when ``myalias`` is the
|
||||
hostname portion of the URL, without specifying a port, database name or
|
||||
``service_name``::
|
||||
|
||||
engine = create_engine("oracle+cx_oracle://scott:tiger@myalias")
|
||||
engine = create_engine("oracle+cx_oracle://scott:tiger@myalias/?encoding=UTF-8&nencoding=UTF-8")
|
||||
|
||||
Users of Oracle Autonomous Database should use this syntax. If the database is
|
||||
configured for mutural TLS ("mTLS"), then you must also configure the cloud
|
||||
Users of Oracle Cloud should use this syntax and also configure the cloud
|
||||
wallet as shown in cx_Oracle documentation `Connecting to Autononmous Databases
|
||||
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#autonomousdb>`_.
|
||||
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#connecting-to-autononmous-databases>`_.
|
||||
|
||||
SID Connections
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
To use Oracle Database's obsolete System Identifier connection syntax, the SID
|
||||
can be passed in a "database name" portion of the URL::
|
||||
To use Oracle's obsolete SID connection syntax, the SID can be passed in a
|
||||
"database name" portion of the URL as below::
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+cx_oracle://scott:tiger@hostname:port/dbname"
|
||||
)
|
||||
engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:1521/dbname?encoding=UTF-8&nencoding=UTF-8")
|
||||
|
||||
Above, the DSN passed to cx_Oracle is created by ``cx_Oracle.makedsn()`` as
|
||||
follows::
|
||||
@@ -101,23 +78,17 @@ follows::
|
||||
>>> cx_Oracle.makedsn("hostname", 1521, sid="dbname")
|
||||
'(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=hostname)(PORT=1521))(CONNECT_DATA=(SID=dbname)))'
|
||||
|
||||
Note that although the SQLAlchemy syntax ``hostname:port/dbname`` looks like
|
||||
Oracle's Easy Connect syntax it is different. It uses a SID in place of the
|
||||
service name required by Easy Connect. The Easy Connect syntax does not
|
||||
support SIDs.
|
||||
|
||||
Passing cx_Oracle connect arguments
|
||||
-----------------------------------
|
||||
|
||||
Additional connection arguments can usually be passed via the URL query string;
|
||||
particular symbols like ``SYSDBA`` are intercepted and converted to the correct
|
||||
symbol::
|
||||
Additional connection arguments can usually be passed via the URL
|
||||
query string; particular symbols like ``cx_Oracle.SYSDBA`` are intercepted
|
||||
and converted to the correct symbol::
|
||||
|
||||
e = create_engine(
|
||||
"oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true"
|
||||
)
|
||||
"oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true")
|
||||
|
||||
.. versionchanged:: 1.3 the cx_Oracle dialect now accepts all argument names
|
||||
.. versionchanged:: 1.3 the cx_oracle dialect now accepts all argument names
|
||||
within the URL string itself, to be passed to the cx_Oracle DBAPI. As
|
||||
was the case earlier but not correctly documented, the
|
||||
:paramref:`_sa.create_engine.connect_args` parameter also accepts all
|
||||
@@ -128,20 +99,19 @@ string, use the :paramref:`_sa.create_engine.connect_args` dictionary.
|
||||
Any cx_Oracle parameter value and/or constant may be passed, such as::
|
||||
|
||||
import cx_Oracle
|
||||
|
||||
e = create_engine(
|
||||
"oracle+cx_oracle://user:pass@dsn",
|
||||
connect_args={
|
||||
"encoding": "UTF-8",
|
||||
"nencoding": "UTF-8",
|
||||
"mode": cx_Oracle.SYSDBA,
|
||||
"events": True,
|
||||
},
|
||||
"events": True
|
||||
}
|
||||
)
|
||||
|
||||
Note that the default driver value for ``encoding`` and ``nencoding`` was
|
||||
changed to "UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when
|
||||
using that version, or later.
|
||||
Note that the default value for ``encoding`` and ``nencoding`` was changed to
|
||||
"UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when using that
|
||||
version, or later.
|
||||
|
||||
Options consumed by the SQLAlchemy cx_Oracle dialect outside of the driver
|
||||
--------------------------------------------------------------------------
|
||||
@@ -151,19 +121,14 @@ itself. These options are always passed directly to :func:`_sa.create_engine`
|
||||
, such as::
|
||||
|
||||
e = create_engine(
|
||||
"oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False
|
||||
)
|
||||
"oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False)
|
||||
|
||||
The parameters accepted by the cx_oracle dialect are as follows:
|
||||
|
||||
* ``arraysize`` - set the cx_oracle.arraysize value on cursors; defaults
|
||||
to ``None``, indicating that the driver default should be used (typically
|
||||
the value is 100). This setting controls how many rows are buffered when
|
||||
fetching rows, and can have a significant effect on performance when
|
||||
modified.
|
||||
|
||||
.. versionchanged:: 2.0.26 - changed the default value from 50 to None,
|
||||
to use the default value of the driver itself.
|
||||
* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted
|
||||
to 50. This setting is significant with cx_Oracle as the contents of LOB
|
||||
objects are only readable within a "live" row (e.g. within a batch of
|
||||
50 rows).
|
||||
|
||||
* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`.
|
||||
|
||||
@@ -176,16 +141,10 @@ The parameters accepted by the cx_oracle dialect are as follows:
|
||||
Using cx_Oracle SessionPool
|
||||
---------------------------
|
||||
|
||||
The cx_Oracle driver provides its own connection pool implementation that may
|
||||
be used in place of SQLAlchemy's pooling functionality. The driver pool
|
||||
supports Oracle Database features such dead connection detection, connection
|
||||
draining for planned database downtime, support for Oracle Application
|
||||
Continuity and Transparent Application Continuity, and gives support for
|
||||
Database Resident Connection Pooling (DRCP).
|
||||
|
||||
Using the driver pool can be achieved by using the
|
||||
:paramref:`_sa.create_engine.creator` parameter to provide a function that
|
||||
returns a new connection, along with setting
|
||||
The cx_Oracle library provides its own connection pool implementation that may
|
||||
be used in place of SQLAlchemy's pooling functionality. This can be achieved
|
||||
by using the :paramref:`_sa.create_engine.creator` parameter to provide a
|
||||
function that returns a new connection, along with setting
|
||||
:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable
|
||||
SQLAlchemy's pooling::
|
||||
|
||||
@@ -194,41 +153,32 @@ SQLAlchemy's pooling::
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
pool = cx_Oracle.SessionPool(
|
||||
user="scott",
|
||||
password="tiger",
|
||||
dsn="orclpdb",
|
||||
min=1,
|
||||
max=4,
|
||||
increment=1,
|
||||
threaded=True,
|
||||
encoding="UTF-8",
|
||||
nencoding="UTF-8",
|
||||
user="scott", password="tiger", dsn="orclpdb",
|
||||
min=2, max=5, increment=1, threaded=True,
|
||||
encoding="UTF-8", nencoding="UTF-8"
|
||||
)
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool
|
||||
)
|
||||
engine = create_engine("oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool)
|
||||
|
||||
The above engine may then be used normally where cx_Oracle's pool handles
|
||||
connection pooling::
|
||||
|
||||
with engine.connect() as conn:
|
||||
print(conn.scalar("select 1 from dual"))
|
||||
print(conn.scalar("select 1 FROM dual"))
|
||||
|
||||
|
||||
As well as providing a scalable solution for multi-user applications, the
|
||||
cx_Oracle session pool supports some Oracle features such as DRCP and
|
||||
`Application Continuity
|
||||
<https://cx-oracle.readthedocs.io/en/latest/user_guide/ha.html#application-continuity-ac>`_.
|
||||
|
||||
Note that the pool creation parameters ``threaded``, ``encoding`` and
|
||||
``nencoding`` were deprecated in later cx_Oracle releases.
|
||||
|
||||
Using Oracle Database Resident Connection Pooling (DRCP)
|
||||
--------------------------------------------------------
|
||||
|
||||
When using Oracle Database's DRCP, the best practice is to pass a connection
|
||||
class and "purity" when acquiring a connection from the SessionPool. Refer to
|
||||
the `cx_Oracle DRCP documentation
|
||||
When using Oracle's `DRCP
|
||||
<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-015CA8C1-2386-4626-855D-CC546DDC1086>`_,
|
||||
the best practice is to pass a connection class and "purity" when acquiring a
|
||||
connection from the SessionPool. Refer to the `cx_Oracle DRCP documentation
|
||||
<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#database-resident-connection-pooling-drcp>`_.
|
||||
|
||||
This can be achieved by wrapping ``pool.acquire()``::
|
||||
@@ -238,33 +188,21 @@ This can be achieved by wrapping ``pool.acquire()``::
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
pool = cx_Oracle.SessionPool(
|
||||
user="scott",
|
||||
password="tiger",
|
||||
dsn="orclpdb",
|
||||
min=2,
|
||||
max=5,
|
||||
increment=1,
|
||||
threaded=True,
|
||||
encoding="UTF-8",
|
||||
nencoding="UTF-8",
|
||||
user="scott", password="tiger", dsn="orclpdb",
|
||||
min=2, max=5, increment=1, threaded=True,
|
||||
encoding="UTF-8", nencoding="UTF-8"
|
||||
)
|
||||
|
||||
|
||||
def creator():
|
||||
return pool.acquire(
|
||||
cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF
|
||||
)
|
||||
return pool.acquire(cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF)
|
||||
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+cx_oracle://", creator=creator, poolclass=NullPool
|
||||
)
|
||||
engine = create_engine("oracle+cx_oracle://", creator=creator, poolclass=NullPool)
|
||||
|
||||
The above engine may then be used normally where cx_Oracle handles session
|
||||
pooling and Oracle Database additionally uses DRCP::
|
||||
|
||||
with engine.connect() as conn:
|
||||
print(conn.scalar("select 1 from dual"))
|
||||
print(conn.scalar("select 1 FROM dual"))
|
||||
|
||||
.. _cx_oracle_unicode:
|
||||
|
||||
@@ -272,28 +210,24 @@ Unicode
|
||||
-------
|
||||
|
||||
As is the case for all DBAPIs under Python 3, all strings are inherently
|
||||
Unicode strings. In all cases however, the driver requires an explicit
|
||||
Unicode strings. In all cases however, the driver requires an explicit
|
||||
encoding configuration.
|
||||
|
||||
Ensuring the Correct Client Encoding
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The long accepted standard for establishing client encoding for nearly all
|
||||
Oracle Database related software is via the `NLS_LANG
|
||||
<https://www.oracle.com/database/technologies/faq-nls-lang.html>`_ environment
|
||||
variable. Older versions of cx_Oracle use this environment variable as the
|
||||
source of its encoding configuration. The format of this variable is
|
||||
Territory_Country.CharacterSet; a typical value would be
|
||||
``AMERICAN_AMERICA.AL32UTF8``. cx_Oracle version 8 and later use the character
|
||||
set "UTF-8" by default, and ignore the character set component of NLS_LANG.
|
||||
Oracle related software is via the `NLS_LANG <https://www.oracle.com/database/technologies/faq-nls-lang.html>`_
|
||||
environment variable. cx_Oracle like most other Oracle drivers will use
|
||||
this environment variable as the source of its encoding configuration. The
|
||||
format of this variable is idiosyncratic; a typical value would be
|
||||
``AMERICAN_AMERICA.AL32UTF8``.
|
||||
|
||||
The cx_Oracle driver also supported a programmatic alternative which is to pass
|
||||
the ``encoding`` and ``nencoding`` parameters directly to its ``.connect()``
|
||||
function. These can be present in the URL as follows::
|
||||
The cx_Oracle driver also supports a programmatic alternative which is to
|
||||
pass the ``encoding`` and ``nencoding`` parameters directly to its
|
||||
``.connect()`` function. These can be present in the URL as follows::
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+cx_oracle://scott:tiger@tnsalias?encoding=UTF-8&nencoding=UTF-8"
|
||||
)
|
||||
engine = create_engine("oracle+cx_oracle://scott:tiger@orclpdb/?encoding=UTF-8&nencoding=UTF-8")
|
||||
|
||||
For the meaning of the ``encoding`` and ``nencoding`` parameters, please
|
||||
consult
|
||||
@@ -308,24 +242,25 @@ consult
|
||||
Unicode-specific Column datatypes
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The Core expression language handles unicode data by use of the
|
||||
:class:`.Unicode` and :class:`.UnicodeText` datatypes. These types correspond
|
||||
to the VARCHAR2 and CLOB Oracle Database datatypes by default. When using
|
||||
these datatypes with Unicode data, it is expected that the database is
|
||||
configured with a Unicode-aware character set, as well as that the ``NLS_LANG``
|
||||
environment variable is set appropriately (this applies to older versions of
|
||||
cx_Oracle), so that the VARCHAR2 and CLOB datatypes can accommodate the data.
|
||||
The Core expression language handles unicode data by use of the :class:`.Unicode`
|
||||
and :class:`.UnicodeText`
|
||||
datatypes. These types correspond to the VARCHAR2 and CLOB Oracle datatypes by
|
||||
default. When using these datatypes with Unicode data, it is expected that
|
||||
the Oracle database is configured with a Unicode-aware character set, as well
|
||||
as that the ``NLS_LANG`` environment variable is set appropriately, so that
|
||||
the VARCHAR2 and CLOB datatypes can accommodate the data.
|
||||
|
||||
In the case that Oracle Database is not configured with a Unicode character
|
||||
In the case that the Oracle database is not configured with a Unicode character
|
||||
set, the two options are to use the :class:`_types.NCHAR` and
|
||||
:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag
|
||||
``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, which will cause
|
||||
the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
|
||||
``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`,
|
||||
which will cause the
|
||||
SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
|
||||
:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB.
|
||||
|
||||
.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText`
|
||||
datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle Database
|
||||
datatypes unless the ``use_nchar_for_unicode=True`` is passed to the dialect
|
||||
.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText`
|
||||
datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle datatypes
|
||||
unless the ``use_nchar_for_unicode=True`` is passed to the dialect
|
||||
when :func:`_sa.create_engine` is called.
|
||||
|
||||
|
||||
@@ -334,7 +269,7 @@ the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
|
||||
Encoding Errors
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
For the unusual case that data in Oracle Database is present with a broken
|
||||
For the unusual case that data in the Oracle database is present with a broken
|
||||
encoding, the dialect accepts a parameter ``encoding_errors`` which will be
|
||||
passed to Unicode decoding functions in order to affect how decoding errors are
|
||||
handled. The value is ultimately consumed by the Python `decode
|
||||
@@ -352,13 +287,13 @@ Fine grained control over cx_Oracle data binding performance with setinputsizes
|
||||
-------------------------------------------------------------------------------
|
||||
|
||||
The cx_Oracle DBAPI has a deep and fundamental reliance upon the usage of the
|
||||
DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the
|
||||
DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the
|
||||
datatypes that are bound to a SQL statement for Python values being passed as
|
||||
parameters. While virtually no other DBAPI assigns any use to the
|
||||
``setinputsizes()`` call, the cx_Oracle DBAPI relies upon it heavily in its
|
||||
interactions with the Oracle Database client interface, and in some scenarios
|
||||
it is not possible for SQLAlchemy to know exactly how data should be bound, as
|
||||
some settings can cause profoundly different performance characteristics, while
|
||||
interactions with the Oracle client interface, and in some scenarios it is not
|
||||
possible for SQLAlchemy to know exactly how data should be bound, as some
|
||||
settings can cause profoundly different performance characteristics, while
|
||||
altering the type coercion behavior at the same time.
|
||||
|
||||
Users of the cx_Oracle dialect are **strongly encouraged** to read through
|
||||
@@ -387,16 +322,13 @@ objects which have a ``.key`` and a ``.type`` attribute::
|
||||
|
||||
engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe")
|
||||
|
||||
|
||||
@event.listens_for(engine, "do_setinputsizes")
|
||||
def _log_setinputsizes(inputsizes, cursor, statement, parameters, context):
|
||||
for bindparam, dbapitype in inputsizes.items():
|
||||
log.info(
|
||||
"Bound parameter name: %s SQLAlchemy type: %r DBAPI object: %s",
|
||||
bindparam.key,
|
||||
bindparam.type,
|
||||
dbapitype,
|
||||
)
|
||||
log.info(
|
||||
"Bound parameter name: %s SQLAlchemy type: %r "
|
||||
"DBAPI object: %s",
|
||||
bindparam.key, bindparam.type, dbapitype)
|
||||
|
||||
Example 2 - remove all bindings to CLOB
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -410,28 +342,12 @@ series. This setting can be modified as follows::
|
||||
|
||||
engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe")
|
||||
|
||||
|
||||
@event.listens_for(engine, "do_setinputsizes")
|
||||
def _remove_clob(inputsizes, cursor, statement, parameters, context):
|
||||
for bindparam, dbapitype in list(inputsizes.items()):
|
||||
if dbapitype is CLOB:
|
||||
del inputsizes[bindparam]
|
||||
|
||||
.. _cx_oracle_lob:
|
||||
|
||||
LOB Datatypes
|
||||
--------------
|
||||
|
||||
LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and
|
||||
BLOB. Modern versions of cx_Oracle is optimized for these datatypes to be
|
||||
delivered as a single buffer. As such, SQLAlchemy makes use of these newer type
|
||||
handlers by default.
|
||||
|
||||
To disable the use of newer type handlers and deliver LOB objects as classic
|
||||
buffered objects with a ``read()`` method, the parameter
|
||||
``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`,
|
||||
which takes place only engine-wide.
|
||||
|
||||
.. _cx_oracle_returning:
|
||||
|
||||
RETURNING Support
|
||||
@@ -440,12 +356,29 @@ RETURNING Support
|
||||
The cx_Oracle dialect implements RETURNING using OUT parameters.
|
||||
The dialect supports RETURNING fully.
|
||||
|
||||
Two Phase Transactions Not Supported
|
||||
------------------------------------
|
||||
.. _cx_oracle_lob:
|
||||
|
||||
Two phase transactions are **not supported** under cx_Oracle due to poor driver
|
||||
support. The newer :ref:`oracledb` dialect however **does** support two phase
|
||||
transactions.
|
||||
LOB Datatypes
|
||||
--------------
|
||||
|
||||
LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and
|
||||
BLOB. Modern versions of cx_Oracle and oracledb are optimized for these
|
||||
datatypes to be delivered as a single buffer. As such, SQLAlchemy makes use of
|
||||
these newer type handlers by default.
|
||||
|
||||
To disable the use of newer type handlers and deliver LOB objects as classic
|
||||
buffered objects with a ``read()`` method, the parameter
|
||||
``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`,
|
||||
which takes place only engine-wide.
|
||||
|
||||
Two Phase Transactions Not Supported
|
||||
-------------------------------------
|
||||
|
||||
Two phase transactions are **not supported** under cx_Oracle due to poor
|
||||
driver support. As of cx_Oracle 6.0b1, the interface for
|
||||
two phase transactions has been changed to be more of a direct pass-through
|
||||
to the underlying OCI layer with less automation. The additional logic
|
||||
to support this system is not implemented in SQLAlchemy.
|
||||
|
||||
.. _cx_oracle_numeric:
|
||||
|
||||
@@ -456,21 +389,20 @@ SQLAlchemy's numeric types can handle receiving and returning values as Python
|
||||
``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a
|
||||
subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in
|
||||
use, the :paramref:`.Numeric.asdecimal` flag determines if values should be
|
||||
coerced to ``Decimal`` upon return, or returned as float objects. To make
|
||||
matters more complicated under Oracle Database, the ``NUMBER`` type can also
|
||||
represent integer values if the "scale" is zero, so the Oracle
|
||||
Database-specific :class:`_oracle.NUMBER` type takes this into account as well.
|
||||
coerced to ``Decimal`` upon return, or returned as float objects. To make
|
||||
matters more complicated under Oracle, Oracle's ``NUMBER`` type can also
|
||||
represent integer values if the "scale" is zero, so the Oracle-specific
|
||||
:class:`_oracle.NUMBER` type takes this into account as well.
|
||||
|
||||
The cx_Oracle dialect makes extensive use of connection- and cursor-level
|
||||
"outputtypehandler" callables in order to coerce numeric values as requested.
|
||||
These callables are specific to the specific flavor of :class:`.Numeric` in
|
||||
use, as well as if no SQLAlchemy typing objects are present. There are
|
||||
observed scenarios where Oracle Database may send incomplete or ambiguous
|
||||
information about the numeric types being returned, such as a query where the
|
||||
numeric types are buried under multiple levels of subquery. The type handlers
|
||||
do their best to make the right decision in all cases, deferring to the
|
||||
underlying cx_Oracle DBAPI for all those cases where the driver can make the
|
||||
best decision.
|
||||
use, as well as if no SQLAlchemy typing objects are present. There are
|
||||
observed scenarios where Oracle may sends incomplete or ambiguous information
|
||||
about the numeric types being returned, such as a query where the numeric types
|
||||
are buried under multiple levels of subquery. The type handlers do their best
|
||||
to make the right decision in all cases, deferring to the underlying cx_Oracle
|
||||
DBAPI for all those cases where the driver can make the best decision.
|
||||
|
||||
When no typing objects are present, as when executing plain SQL strings, a
|
||||
default "outputtypehandler" is present which will generally return numeric
|
||||
@@ -882,8 +814,6 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
|
||||
|
||||
out_parameters[name] = self.cursor.var(
|
||||
dbtype,
|
||||
# this is fine also in oracledb_async since
|
||||
# the driver will await the read coroutine
|
||||
outconverter=lambda value: value.read(),
|
||||
arraysize=len_params,
|
||||
)
|
||||
@@ -902,9 +832,9 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
|
||||
)
|
||||
|
||||
for param in self.parameters:
|
||||
param[quoted_bind_names.get(name, name)] = (
|
||||
out_parameters[name]
|
||||
)
|
||||
param[
|
||||
quoted_bind_names.get(name, name)
|
||||
] = out_parameters[name]
|
||||
|
||||
def _generate_cursor_outputtype_handler(self):
|
||||
output_handlers = {}
|
||||
@@ -1100,7 +1030,7 @@ class OracleDialect_cx_oracle(OracleDialect):
|
||||
self,
|
||||
auto_convert_lobs=True,
|
||||
coerce_to_decimal=True,
|
||||
arraysize=None,
|
||||
arraysize=50,
|
||||
encoding_errors=None,
|
||||
threaded=None,
|
||||
**kwargs,
|
||||
@@ -1234,9 +1164,6 @@ class OracleDialect_cx_oracle(OracleDialect):
|
||||
with dbapi_connection.cursor() as cursor:
|
||||
cursor.execute(f"ALTER SESSION SET ISOLATION_LEVEL={level}")
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def _detect_decimal_char(self, connection):
|
||||
# we have the option to change this setting upon connect,
|
||||
# or just look at what it is upon connect and convert.
|
||||
@@ -1356,13 +1283,8 @@ class OracleDialect_cx_oracle(OracleDialect):
|
||||
cx_Oracle.CLOB,
|
||||
cx_Oracle.NCLOB,
|
||||
):
|
||||
typ = (
|
||||
cx_Oracle.DB_TYPE_VARCHAR
|
||||
if default_type is cx_Oracle.CLOB
|
||||
else cx_Oracle.DB_TYPE_NVARCHAR
|
||||
)
|
||||
return cursor.var(
|
||||
typ,
|
||||
cx_Oracle.DB_TYPE_NVARCHAR,
|
||||
_CX_ORACLE_MAGIC_LOB_SIZE,
|
||||
cursor.arraysize,
|
||||
**dialect._cursor_var_unicode_kwargs,
|
||||
@@ -1493,6 +1415,13 @@ class OracleDialect_cx_oracle(OracleDialect):
|
||||
return False
|
||||
|
||||
def create_xid(self):
|
||||
"""create a two-phase transaction ID.
|
||||
|
||||
this id will be passed to do_begin_twophase(), do_rollback_twophase(),
|
||||
do_commit_twophase(). its format is unspecified.
|
||||
|
||||
"""
|
||||
|
||||
id_ = random.randint(0, 2**128)
|
||||
return (0x1234, "%032x" % id_, "%032x" % 9)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# dialects/oracle/dictionary.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
||||
@@ -1,639 +1,68 @@
|
||||
# dialects/oracle/oracledb.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r""".. dialect:: oracle+oracledb
|
||||
r"""
|
||||
.. dialect:: oracle+oracledb
|
||||
:name: python-oracledb
|
||||
:dbapi: oracledb
|
||||
:connectstring: oracle+oracledb://user:pass@hostname:port[/dbname][?service_name=<service>[&key=value&key=value...]]
|
||||
:url: https://oracle.github.io/python-oracledb/
|
||||
|
||||
Description
|
||||
-----------
|
||||
python-oracledb is released by Oracle to supersede the cx_Oracle driver.
|
||||
It is fully compatible with cx_Oracle and features both a "thin" client
|
||||
mode that requires no dependencies, as well as a "thick" mode that uses
|
||||
the Oracle Client Interface in the same way as cx_Oracle.
|
||||
|
||||
Python-oracledb is the Oracle Database driver for Python. It features a default
|
||||
"thin" client mode that requires no dependencies, and an optional "thick" mode
|
||||
that uses Oracle Client libraries. It supports SQLAlchemy features including
|
||||
two phase transactions and Asyncio.
|
||||
.. seealso::
|
||||
|
||||
Python-oracle is the renamed, updated cx_Oracle driver. Oracle is no longer
|
||||
doing any releases in the cx_Oracle namespace.
|
||||
|
||||
The SQLAlchemy ``oracledb`` dialect provides both a sync and an async
|
||||
implementation under the same dialect name. The proper version is
|
||||
selected depending on how the engine is created:
|
||||
|
||||
* calling :func:`_sa.create_engine` with ``oracle+oracledb://...`` will
|
||||
automatically select the sync version::
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
sync_engine = create_engine(
|
||||
"oracle+oracledb://scott:tiger@localhost?service_name=FREEPDB1"
|
||||
)
|
||||
|
||||
* calling :func:`_asyncio.create_async_engine` with ``oracle+oracledb://...``
|
||||
will automatically select the async version::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
asyncio_engine = create_async_engine(
|
||||
"oracle+oracledb://scott:tiger@localhost?service_name=FREEPDB1"
|
||||
)
|
||||
|
||||
The asyncio version of the dialect may also be specified explicitly using the
|
||||
``oracledb_async`` suffix::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
asyncio_engine = create_async_engine(
|
||||
"oracle+oracledb_async://scott:tiger@localhost?service_name=FREEPDB1"
|
||||
)
|
||||
|
||||
.. versionadded:: 2.0.25 added support for the async version of oracledb.
|
||||
:ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver
|
||||
as well.
|
||||
|
||||
Thick mode support
|
||||
------------------
|
||||
|
||||
By default, the python-oracledb driver runs in a "thin" mode that does not
|
||||
require Oracle Client libraries to be installed. The driver also supports a
|
||||
"thick" mode that uses Oracle Client libraries to get functionality such as
|
||||
Oracle Application Continuity.
|
||||
By default the ``python-oracledb`` is started in thin mode, that does not
|
||||
require oracle client libraries to be installed in the system. The
|
||||
``python-oracledb`` driver also support a "thick" mode, that behaves
|
||||
similarly to ``cx_oracle`` and requires that Oracle Client Interface (OCI)
|
||||
is installed.
|
||||
|
||||
To enable thick mode, call `oracledb.init_oracle_client()
|
||||
<https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client>`_
|
||||
explicitly, or pass the parameter ``thick_mode=True`` to
|
||||
:func:`_sa.create_engine`. To pass custom arguments to
|
||||
``init_oracle_client()``, like the ``lib_dir`` path, a dict may be passed, for
|
||||
example::
|
||||
To enable this mode, the user may call ``oracledb.init_oracle_client``
|
||||
manually, or by passing the parameter ``thick_mode=True`` to
|
||||
:func:`_sa.create_engine`. To pass custom arguments to ``init_oracle_client``,
|
||||
like the ``lib_dir`` path, a dict may be passed to this parameter, as in::
|
||||
|
||||
engine = sa.create_engine(
|
||||
"oracle+oracledb://...",
|
||||
thick_mode={
|
||||
"lib_dir": "/path/to/oracle/client/lib",
|
||||
"config_dir": "/path/to/network_config_file_directory",
|
||||
"driver_name": "my-app : 1.0.0",
|
||||
},
|
||||
)
|
||||
|
||||
Note that passing a ``lib_dir`` path should only be done on macOS or
|
||||
Windows. On Linux it does not behave as you might expect.
|
||||
engine = sa.create_engine("oracle+oracledb://...", thick_mode={
|
||||
"lib_dir": "/path/to/oracle/client/lib", "driver_name": "my-app"
|
||||
})
|
||||
|
||||
.. seealso::
|
||||
|
||||
python-oracledb documentation `Enabling python-oracledb Thick mode
|
||||
<https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html#enabling-python-oracledb-thick-mode>`_
|
||||
https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client
|
||||
|
||||
Connecting to Oracle Database
|
||||
-----------------------------
|
||||
|
||||
python-oracledb provides several methods of indicating the target database.
|
||||
The dialect translates from a series of different URL forms.
|
||||
|
||||
Given the hostname, port and service name of the target database, you can
|
||||
connect in SQLAlchemy using the ``service_name`` query string parameter::
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+oracledb://scott:tiger@hostname:port?service_name=myservice"
|
||||
)
|
||||
|
||||
Connecting with Easy Connect strings
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
You can pass any valid python-oracledb connection string as the ``dsn`` key
|
||||
value in a :paramref:`_sa.create_engine.connect_args` dictionary. See
|
||||
python-oracledb documentation `Oracle Net Services Connection Strings
|
||||
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#oracle-net-services-connection-strings>`_.
|
||||
|
||||
For example to use an `Easy Connect string
|
||||
<https://download.oracle.com/ocomdocs/global/Oracle-Net-Easy-Connect-Plus.pdf>`_
|
||||
with a timeout to prevent connection establishment from hanging if the network
|
||||
transport to the database cannot be establishd in 30 seconds, and also setting
|
||||
a keep-alive time of 60 seconds to stop idle network connections from being
|
||||
terminated by a firewall::
|
||||
|
||||
e = create_engine(
|
||||
"oracle+oracledb://@",
|
||||
connect_args={
|
||||
"user": "scott",
|
||||
"password": "tiger",
|
||||
"dsn": "hostname:port/myservice?transport_connect_timeout=30&expire_time=60",
|
||||
},
|
||||
)
|
||||
|
||||
The Easy Connect syntax has been enhanced during the life of Oracle Database.
|
||||
Review the documentation for your database version. The current documentation
|
||||
is at `Understanding the Easy Connect Naming Method
|
||||
<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-B0437826-43C1-49EC-A94D-B650B6A4A6EE>`_.
|
||||
|
||||
The general syntax is similar to:
|
||||
|
||||
.. sourcecode:: text
|
||||
|
||||
[[protocol:]//]host[:port][/[service_name]][?parameter_name=value{¶meter_name=value}]
|
||||
|
||||
Note that although the SQLAlchemy URL syntax ``hostname:port/dbname`` looks
|
||||
like Oracle's Easy Connect syntax, it is different. SQLAlchemy's URL requires a
|
||||
system identifier (SID) for the ``dbname`` component::
|
||||
|
||||
engine = create_engine("oracle+oracledb://scott:tiger@hostname:port/sid")
|
||||
|
||||
Easy Connect syntax does not support SIDs. It uses services names, which are
|
||||
the preferred choice for connecting to Oracle Database.
|
||||
|
||||
Passing python-oracledb connect arguments
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Other python-oracledb driver `connection options
|
||||
<https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.connect>`_
|
||||
can be passed in ``connect_args``. For example::
|
||||
|
||||
e = create_engine(
|
||||
"oracle+oracledb://@",
|
||||
connect_args={
|
||||
"user": "scott",
|
||||
"password": "tiger",
|
||||
"dsn": "hostname:port/myservice",
|
||||
"events": True,
|
||||
"mode": oracledb.AUTH_MODE_SYSDBA,
|
||||
},
|
||||
)
|
||||
|
||||
Connecting with tnsnames.ora TNS aliases
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
If no port, database name, or service name is provided, the dialect will use an
|
||||
Oracle Database DSN "connection string". This takes the "hostname" portion of
|
||||
the URL as the data source name. For example, if the ``tnsnames.ora`` file
|
||||
contains a `TNS Alias
|
||||
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#tns-aliases-for-connection-strings>`_
|
||||
of ``myalias`` as below:
|
||||
|
||||
.. sourcecode:: text
|
||||
|
||||
myalias =
|
||||
(DESCRIPTION =
|
||||
(ADDRESS = (PROTOCOL = TCP)(HOST = mymachine.example.com)(PORT = 1521))
|
||||
(CONNECT_DATA =
|
||||
(SERVER = DEDICATED)
|
||||
(SERVICE_NAME = orclpdb1)
|
||||
)
|
||||
)
|
||||
|
||||
The python-oracledb dialect connects to this database service when ``myalias`` is the
|
||||
hostname portion of the URL, without specifying a port, database name or
|
||||
``service_name``::
|
||||
|
||||
engine = create_engine("oracle+oracledb://scott:tiger@myalias")
|
||||
|
||||
Connecting to Oracle Autonomous Database
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Users of Oracle Autonomous Database should use either use the TNS Alias URL
|
||||
shown above, or pass the TNS Alias as the ``dsn`` key value in a
|
||||
:paramref:`_sa.create_engine.connect_args` dictionary.
|
||||
|
||||
If Oracle Autonomous Database is configured for mutual TLS ("mTLS")
|
||||
connections, then additional configuration is required as shown in `Connecting
|
||||
to Oracle Cloud Autonomous Databases
|
||||
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#connecting-to-oracle-cloud-autonomous-databases>`_. In
|
||||
summary, Thick mode users should configure file locations and set the wallet
|
||||
path in ``sqlnet.ora`` appropriately::
|
||||
|
||||
e = create_engine(
|
||||
"oracle+oracledb://@",
|
||||
thick_mode={
|
||||
# directory containing tnsnames.ora and cwallet.so
|
||||
"config_dir": "/opt/oracle/wallet_dir",
|
||||
},
|
||||
connect_args={
|
||||
"user": "scott",
|
||||
"password": "tiger",
|
||||
"dsn": "mydb_high",
|
||||
},
|
||||
)
|
||||
|
||||
Thin mode users of mTLS should pass the appropriate directories and PEM wallet
|
||||
password when creating the engine, similar to::
|
||||
|
||||
e = create_engine(
|
||||
"oracle+oracledb://@",
|
||||
connect_args={
|
||||
"user": "scott",
|
||||
"password": "tiger",
|
||||
"dsn": "mydb_high",
|
||||
"config_dir": "/opt/oracle/wallet_dir", # directory containing tnsnames.ora
|
||||
"wallet_location": "/opt/oracle/wallet_dir", # directory containing ewallet.pem
|
||||
"wallet_password": "top secret", # password for the PEM file
|
||||
},
|
||||
)
|
||||
|
||||
Typically ``config_dir`` and ``wallet_location`` are the same directory, which
|
||||
is where the Oracle Autonomous Database wallet zip file was extracted. Note
|
||||
this directory should be protected.
|
||||
|
||||
Connection Pooling
|
||||
------------------
|
||||
|
||||
Applications with multiple concurrent users should use connection pooling. A
|
||||
minimal sized connection pool is also beneficial for long-running, single-user
|
||||
applications that do not frequently use a connection.
|
||||
|
||||
The python-oracledb driver provides its own connection pool implementation that
|
||||
may be used in place of SQLAlchemy's pooling functionality. The driver pool
|
||||
gives support for high availability features such as dead connection detection,
|
||||
connection draining for planned database downtime, support for Oracle
|
||||
Application Continuity and Transparent Application Continuity, and gives
|
||||
support for `Database Resident Connection Pooling (DRCP)
|
||||
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#database-resident-connection-pooling-drcp>`_.
|
||||
|
||||
To take advantage of python-oracledb's pool, use the
|
||||
:paramref:`_sa.create_engine.creator` parameter to provide a function that
|
||||
returns a new connection, along with setting
|
||||
:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable
|
||||
SQLAlchemy's pooling::
|
||||
|
||||
import oracledb
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
# Uncomment to use the optional python-oracledb Thick mode.
|
||||
# Review the python-oracledb doc for the appropriate parameters
|
||||
# oracledb.init_oracle_client(<your parameters>)
|
||||
|
||||
pool = oracledb.create_pool(
|
||||
user="scott",
|
||||
password="tiger",
|
||||
dsn="localhost:1521/freepdb1",
|
||||
min=1,
|
||||
max=4,
|
||||
increment=1,
|
||||
)
|
||||
engine = create_engine(
|
||||
"oracle+oracledb://", creator=pool.acquire, poolclass=NullPool
|
||||
)
|
||||
|
||||
The above engine may then be used normally. Internally, python-oracledb handles
|
||||
connection pooling::
|
||||
|
||||
with engine.connect() as conn:
|
||||
print(conn.scalar(text("select 1 from dual")))
|
||||
|
||||
Refer to the python-oracledb documentation for `oracledb.create_pool()
|
||||
<https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.create_pool>`_
|
||||
for the arguments that can be used when creating a connection pool.
|
||||
|
||||
.. _drcp:
|
||||
|
||||
Using Oracle Database Resident Connection Pooling (DRCP)
|
||||
--------------------------------------------------------
|
||||
|
||||
When using Oracle Database's Database Resident Connection Pooling (DRCP), the
|
||||
best practice is to specify a connection class and "purity". Refer to the
|
||||
`python-oracledb documentation on DRCP
|
||||
<https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html#database-resident-connection-pooling-drcp>`_.
|
||||
For example::
|
||||
|
||||
import oracledb
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
# Uncomment to use the optional python-oracledb Thick mode.
|
||||
# Review the python-oracledb doc for the appropriate parameters
|
||||
# oracledb.init_oracle_client(<your parameters>)
|
||||
|
||||
pool = oracledb.create_pool(
|
||||
user="scott",
|
||||
password="tiger",
|
||||
dsn="localhost:1521/freepdb1",
|
||||
min=1,
|
||||
max=4,
|
||||
increment=1,
|
||||
cclass="MYCLASS",
|
||||
purity=oracledb.PURITY_SELF,
|
||||
)
|
||||
engine = create_engine(
|
||||
"oracle+oracledb://", creator=pool.acquire, poolclass=NullPool
|
||||
)
|
||||
|
||||
The above engine may then be used normally where python-oracledb handles
|
||||
application connection pooling and Oracle Database additionally uses DRCP::
|
||||
|
||||
with engine.connect() as conn:
|
||||
print(conn.scalar(text("select 1 from dual")))
|
||||
|
||||
If you wish to use different connection classes or purities for different
|
||||
connections, then wrap ``pool.acquire()``::
|
||||
|
||||
import oracledb
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
# Uncomment to use python-oracledb Thick mode.
|
||||
# Review the python-oracledb doc for the appropriate parameters
|
||||
# oracledb.init_oracle_client(<your parameters>)
|
||||
|
||||
pool = oracledb.create_pool(
|
||||
user="scott",
|
||||
password="tiger",
|
||||
dsn="localhost:1521/freepdb1",
|
||||
min=1,
|
||||
max=4,
|
||||
increment=1,
|
||||
cclass="MYCLASS",
|
||||
purity=oracledb.PURITY_SELF,
|
||||
)
|
||||
|
||||
|
||||
def creator():
|
||||
return pool.acquire(cclass="MYOTHERCLASS", purity=oracledb.PURITY_NEW)
|
||||
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+oracledb://", creator=creator, poolclass=NullPool
|
||||
)
|
||||
|
||||
Engine Options consumed by the SQLAlchemy oracledb dialect outside of the driver
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
There are also options that are consumed by the SQLAlchemy oracledb dialect
|
||||
itself. These options are always passed directly to :func:`_sa.create_engine`,
|
||||
such as::
|
||||
|
||||
e = create_engine("oracle+oracledb://user:pass@tnsalias", arraysize=500)
|
||||
|
||||
The parameters accepted by the oracledb dialect are as follows:
|
||||
|
||||
* ``arraysize`` - set the driver cursor.arraysize value. It defaults to
|
||||
``None``, indicating that the driver default value of 100 should be used.
|
||||
This setting controls how many rows are buffered when fetching rows, and can
|
||||
have a significant effect on performance if increased for queries that return
|
||||
large numbers of rows.
|
||||
|
||||
.. versionchanged:: 2.0.26 - changed the default value from 50 to None,
|
||||
to use the default value of the driver itself.
|
||||
|
||||
* ``auto_convert_lobs`` - defaults to True; See :ref:`oracledb_lob`.
|
||||
|
||||
* ``coerce_to_decimal`` - see :ref:`oracledb_numeric` for detail.
|
||||
|
||||
* ``encoding_errors`` - see :ref:`oracledb_unicode_encoding_errors` for detail.
|
||||
|
||||
.. _oracledb_unicode:
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
As is the case for all DBAPIs under Python 3, all strings are inherently
|
||||
Unicode strings.
|
||||
|
||||
Ensuring the Correct Client Encoding
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In python-oracledb, the encoding used for all character data is "UTF-8".
|
||||
|
||||
Unicode-specific Column datatypes
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The Core expression language handles unicode data by use of the
|
||||
:class:`.Unicode` and :class:`.UnicodeText` datatypes. These types correspond
|
||||
to the VARCHAR2 and CLOB Oracle Database datatypes by default. When using
|
||||
these datatypes with Unicode data, it is expected that the database is
|
||||
configured with a Unicode-aware character set so that the VARCHAR2 and CLOB
|
||||
datatypes can accommodate the data.
|
||||
|
||||
In the case that Oracle Database is not configured with a Unicode character
|
||||
set, the two options are to use the :class:`_types.NCHAR` and
|
||||
:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag
|
||||
``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, which will cause
|
||||
the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
|
||||
:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB.
|
||||
|
||||
.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText`
|
||||
datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle Database
|
||||
datatypes unless the ``use_nchar_for_unicode=True`` is passed to the dialect
|
||||
when :func:`_sa.create_engine` is called.
|
||||
|
||||
|
||||
.. _oracledb_unicode_encoding_errors:
|
||||
|
||||
Encoding Errors
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
For the unusual case that data in Oracle Database is present with a broken
|
||||
encoding, the dialect accepts a parameter ``encoding_errors`` which will be
|
||||
passed to Unicode decoding functions in order to affect how decoding errors are
|
||||
handled. The value is ultimately consumed by the Python `decode
|
||||
<https://docs.python.org/3/library/stdtypes.html#bytes.decode>`_ function, and
|
||||
is passed both via python-oracledb's ``encodingErrors`` parameter consumed by
|
||||
``Cursor.var()``, as well as SQLAlchemy's own decoding function, as the
|
||||
python-oracledb dialect makes use of both under different circumstances.
|
||||
|
||||
.. versionadded:: 1.3.11
|
||||
|
||||
|
||||
.. _oracledb_setinputsizes:
|
||||
|
||||
Fine grained control over python-oracledb data binding with setinputsizes
|
||||
-------------------------------------------------------------------------
|
||||
|
||||
The python-oracle DBAPI has a deep and fundamental reliance upon the usage of
|
||||
the DBAPI ``setinputsizes()`` call. The purpose of this call is to establish
|
||||
the datatypes that are bound to a SQL statement for Python values being passed
|
||||
as parameters. While virtually no other DBAPI assigns any use to the
|
||||
``setinputsizes()`` call, the python-oracledb DBAPI relies upon it heavily in
|
||||
its interactions with the Oracle Database, and in some scenarios it is not
|
||||
possible for SQLAlchemy to know exactly how data should be bound, as some
|
||||
settings can cause profoundly different performance characteristics, while
|
||||
altering the type coercion behavior at the same time.
|
||||
|
||||
Users of the oracledb dialect are **strongly encouraged** to read through
|
||||
python-oracledb's list of built-in datatype symbols at `Database Types
|
||||
<https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#database-types>`_
|
||||
Note that in some cases, significant performance degradation can occur when
|
||||
using these types vs. not.
|
||||
|
||||
On the SQLAlchemy side, the :meth:`.DialectEvents.do_setinputsizes` event can
|
||||
be used both for runtime visibility (e.g. logging) of the setinputsizes step as
|
||||
well as to fully control how ``setinputsizes()`` is used on a per-statement
|
||||
basis.
|
||||
|
||||
.. versionadded:: 1.2.9 Added :meth:`.DialectEvents.setinputsizes`
|
||||
|
||||
|
||||
Example 1 - logging all setinputsizes calls
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The following example illustrates how to log the intermediary values from a
|
||||
SQLAlchemy perspective before they are converted to the raw ``setinputsizes()``
|
||||
parameter dictionary. The keys of the dictionary are :class:`.BindParameter`
|
||||
objects which have a ``.key`` and a ``.type`` attribute::
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1"
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(engine, "do_setinputsizes")
|
||||
def _log_setinputsizes(inputsizes, cursor, statement, parameters, context):
|
||||
for bindparam, dbapitype in inputsizes.items():
|
||||
log.info(
|
||||
"Bound parameter name: %s SQLAlchemy type: %r DBAPI object: %s",
|
||||
bindparam.key,
|
||||
bindparam.type,
|
||||
dbapitype,
|
||||
)
|
||||
|
||||
Example 2 - remove all bindings to CLOB
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
For performance, fetching LOB datatypes from Oracle Database is set by default
|
||||
for the ``Text`` type within SQLAlchemy. This setting can be modified as
|
||||
follows::
|
||||
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from oracledb import CLOB
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1"
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(engine, "do_setinputsizes")
|
||||
def _remove_clob(inputsizes, cursor, statement, parameters, context):
|
||||
for bindparam, dbapitype in list(inputsizes.items()):
|
||||
if dbapitype is CLOB:
|
||||
del inputsizes[bindparam]
|
||||
|
||||
.. _oracledb_lob:
|
||||
|
||||
LOB Datatypes
|
||||
--------------
|
||||
|
||||
LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and
|
||||
BLOB. Oracle Database can efficiently return these datatypes as a single
|
||||
buffer. SQLAlchemy makes use of type handlers to do this by default.
|
||||
|
||||
To disable the use of the type handlers and deliver LOB objects as classic
|
||||
buffered objects with a ``read()`` method, the parameter
|
||||
``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`.
|
||||
|
||||
.. _oracledb_returning:
|
||||
|
||||
RETURNING Support
|
||||
-----------------
|
||||
|
||||
The oracledb dialect implements RETURNING using OUT parameters. The dialect
|
||||
supports RETURNING fully.
|
||||
|
||||
Two Phase Transaction Support
|
||||
-----------------------------
|
||||
|
||||
Two phase transactions are fully supported with python-oracledb. (Thin mode
|
||||
requires python-oracledb 2.3). APIs for two phase transactions are provided at
|
||||
the Core level via :meth:`_engine.Connection.begin_twophase` and
|
||||
:paramref:`_orm.Session.twophase` for transparent ORM use.
|
||||
|
||||
.. versionchanged:: 2.0.32 added support for two phase transactions
|
||||
|
||||
.. _oracledb_numeric:
|
||||
|
||||
Precision Numerics
|
||||
------------------
|
||||
|
||||
SQLAlchemy's numeric types can handle receiving and returning values as Python
|
||||
``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a
|
||||
subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in
|
||||
use, the :paramref:`.Numeric.asdecimal` flag determines if values should be
|
||||
coerced to ``Decimal`` upon return, or returned as float objects. To make
|
||||
matters more complicated under Oracle Database, the ``NUMBER`` type can also
|
||||
represent integer values if the "scale" is zero, so the Oracle
|
||||
Database-specific :class:`_oracle.NUMBER` type takes this into account as well.
|
||||
|
||||
The oracledb dialect makes extensive use of connection- and cursor-level
|
||||
"outputtypehandler" callables in order to coerce numeric values as requested.
|
||||
These callables are specific to the specific flavor of :class:`.Numeric` in
|
||||
use, as well as if no SQLAlchemy typing objects are present. There are
|
||||
observed scenarios where Oracle Database may send incomplete or ambiguous
|
||||
information about the numeric types being returned, such as a query where the
|
||||
numeric types are buried under multiple levels of subquery. The type handlers
|
||||
do their best to make the right decision in all cases, deferring to the
|
||||
underlying python-oracledb DBAPI for all those cases where the driver can make
|
||||
the best decision.
|
||||
|
||||
When no typing objects are present, as when executing plain SQL strings, a
|
||||
default "outputtypehandler" is present which will generally return numeric
|
||||
values which specify precision and scale as Python ``Decimal`` objects. To
|
||||
disable this coercion to decimal for performance reasons, pass the flag
|
||||
``coerce_to_decimal=False`` to :func:`_sa.create_engine`::
|
||||
|
||||
engine = create_engine(
|
||||
"oracle+oracledb://scott:tiger@tnsalias", coerce_to_decimal=False
|
||||
)
|
||||
|
||||
The ``coerce_to_decimal`` flag only impacts the results of plain string
|
||||
SQL statements that are not otherwise associated with a :class:`.Numeric`
|
||||
SQLAlchemy type (or a subclass of such).
|
||||
|
||||
.. versionchanged:: 1.2 The numeric handling system for the oracle dialects has
|
||||
been reworked to take advantage of newer driver features as well as better
|
||||
integration of outputtypehandlers.
|
||||
|
||||
.. versionadded:: 2.0.0 added support for the python-oracledb driver.
|
||||
.. versionadded:: 2.0.0 added support for oracledb driver.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import cx_oracle as _cx_oracle
|
||||
from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle
|
||||
from ... import exc
|
||||
from ... import pool
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
|
||||
from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
|
||||
from ...engine import default
|
||||
from ...util import asbool
|
||||
from ...util import await_fallback
|
||||
from ...util import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from oracledb import AsyncConnection
|
||||
from oracledb import AsyncCursor
|
||||
|
||||
|
||||
class OracleExecutionContext_oracledb(
|
||||
_cx_oracle.OracleExecutionContext_cx_oracle
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
|
||||
class OracleDialect_oracledb(_OracleDialect_cx_oracle):
|
||||
supports_statement_cache = True
|
||||
execution_ctx_cls = OracleExecutionContext_oracledb
|
||||
|
||||
driver = "oracledb"
|
||||
_min_version = (1,)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_convert_lobs=True,
|
||||
coerce_to_decimal=True,
|
||||
arraysize=None,
|
||||
arraysize=50,
|
||||
encoding_errors=None,
|
||||
thick_mode=None,
|
||||
**kwargs,
|
||||
@@ -662,10 +91,6 @@ class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
|
||||
def is_thin_mode(cls, connection):
|
||||
return connection.connection.dbapi_connection.thin
|
||||
|
||||
@classmethod
|
||||
def get_async_dialect_cls(cls, url):
|
||||
return OracleDialectAsync_oracledb
|
||||
|
||||
def _load_version(self, dbapi_module):
|
||||
version = (0, 0, 0)
|
||||
if dbapi_module is not None:
|
||||
@@ -675,273 +100,10 @@ class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
|
||||
int(x) for x in m.group(1, 2, 3) if x is not None
|
||||
)
|
||||
self.oracledb_ver = version
|
||||
if (
|
||||
self.oracledb_ver > (0, 0, 0)
|
||||
and self.oracledb_ver < self._min_version
|
||||
):
|
||||
if self.oracledb_ver < (1,) and self.oracledb_ver > (0, 0, 0):
|
||||
raise exc.InvalidRequestError(
|
||||
f"oracledb version {self._min_version} and above are supported"
|
||||
"oracledb version 1 and above are supported"
|
||||
)
|
||||
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
conn_xis = connection.connection.xid(*xid)
|
||||
connection.connection.tpc_begin(conn_xis)
|
||||
connection.connection.info["oracledb_xid"] = conn_xis
|
||||
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
should_commit = connection.connection.tpc_prepare()
|
||||
connection.info["oracledb_should_commit"] = should_commit
|
||||
|
||||
def do_rollback_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
if recover:
|
||||
conn_xid = connection.connection.xid(*xid)
|
||||
else:
|
||||
conn_xid = None
|
||||
connection.connection.tpc_rollback(conn_xid)
|
||||
|
||||
def do_commit_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
conn_xid = None
|
||||
if not is_prepared:
|
||||
should_commit = connection.connection.tpc_prepare()
|
||||
elif recover:
|
||||
conn_xid = connection.connection.xid(*xid)
|
||||
should_commit = True
|
||||
else:
|
||||
should_commit = connection.info["oracledb_should_commit"]
|
||||
if should_commit:
|
||||
connection.connection.tpc_commit(conn_xid)
|
||||
|
||||
def do_recover_twophase(self, connection):
|
||||
return [
|
||||
# oracledb seems to return bytes
|
||||
(
|
||||
fi,
|
||||
gti.decode() if isinstance(gti, bytes) else gti,
|
||||
bq.decode() if isinstance(bq, bytes) else bq,
|
||||
)
|
||||
for fi, gti, bq in connection.connection.tpc_recover()
|
||||
]
|
||||
|
||||
def _check_max_identifier_length(self, connection):
|
||||
if self.oracledb_ver >= (2, 5):
|
||||
max_len = connection.connection.max_identifier_length
|
||||
if max_len is not None:
|
||||
return max_len
|
||||
return super()._check_max_identifier_length(connection)
|
||||
|
||||
|
||||
class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
|
||||
_cursor: AsyncCursor
|
||||
__slots__ = ()
|
||||
|
||||
@property
|
||||
def outputtypehandler(self):
|
||||
return self._cursor.outputtypehandler
|
||||
|
||||
@outputtypehandler.setter
|
||||
def outputtypehandler(self, value):
|
||||
self._cursor.outputtypehandler = value
|
||||
|
||||
def var(self, *args, **kwargs):
|
||||
return self._cursor.var(*args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self._rows.clear()
|
||||
self._cursor.close()
|
||||
|
||||
def setinputsizes(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._cursor.setinputsizes(*args, **kwargs)
|
||||
|
||||
def _aenter_cursor(self, cursor: AsyncCursor) -> AsyncCursor:
|
||||
try:
|
||||
return cursor.__enter__()
|
||||
except Exception as error:
|
||||
self._adapt_connection._handle_exception(error)
|
||||
|
||||
async def _execute_async(self, operation, parameters):
|
||||
# override to not use mutex, oracledb already has a mutex
|
||||
|
||||
if parameters is None:
|
||||
result = await self._cursor.execute(operation)
|
||||
else:
|
||||
result = await self._cursor.execute(operation, parameters)
|
||||
|
||||
if self._cursor.description and not self.server_side:
|
||||
self._rows = collections.deque(await self._cursor.fetchall())
|
||||
return result
|
||||
|
||||
async def _executemany_async(
|
||||
self,
|
||||
operation,
|
||||
seq_of_parameters,
|
||||
):
|
||||
# override to not use mutex, oracledb already has a mutex
|
||||
return await self._cursor.executemany(operation, seq_of_parameters)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class AsyncAdapt_oracledb_ss_cursor(
|
||||
AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_oracledb_cursor
|
||||
):
|
||||
__slots__ = ()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._cursor is not None:
|
||||
self._cursor.close()
|
||||
self._cursor = None # type: ignore
|
||||
|
||||
|
||||
class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
|
||||
_connection: AsyncConnection
|
||||
__slots__ = ()
|
||||
|
||||
thin = True
|
||||
|
||||
_cursor_cls = AsyncAdapt_oracledb_cursor
|
||||
_ss_cursor_cls = None
|
||||
|
||||
@property
|
||||
def autocommit(self):
|
||||
return self._connection.autocommit
|
||||
|
||||
@autocommit.setter
|
||||
def autocommit(self, value):
|
||||
self._connection.autocommit = value
|
||||
|
||||
@property
|
||||
def outputtypehandler(self):
|
||||
return self._connection.outputtypehandler
|
||||
|
||||
@outputtypehandler.setter
|
||||
def outputtypehandler(self, value):
|
||||
self._connection.outputtypehandler = value
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self._connection.version
|
||||
|
||||
@property
|
||||
def stmtcachesize(self):
|
||||
return self._connection.stmtcachesize
|
||||
|
||||
@stmtcachesize.setter
|
||||
def stmtcachesize(self, value):
|
||||
self._connection.stmtcachesize = value
|
||||
|
||||
@property
|
||||
def max_identifier_length(self):
|
||||
return self._connection.max_identifier_length
|
||||
|
||||
def cursor(self):
|
||||
return AsyncAdapt_oracledb_cursor(self)
|
||||
|
||||
def ss_cursor(self):
|
||||
return AsyncAdapt_oracledb_ss_cursor(self)
|
||||
|
||||
def xid(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._connection.xid(*args, **kwargs)
|
||||
|
||||
def tpc_begin(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.await_(self._connection.tpc_begin(*args, **kwargs))
|
||||
|
||||
def tpc_commit(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.await_(self._connection.tpc_commit(*args, **kwargs))
|
||||
|
||||
def tpc_prepare(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.await_(self._connection.tpc_prepare(*args, **kwargs))
|
||||
|
||||
def tpc_recover(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.await_(self._connection.tpc_recover(*args, **kwargs))
|
||||
|
||||
def tpc_rollback(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.await_(self._connection.tpc_rollback(*args, **kwargs))
|
||||
|
||||
|
||||
class AsyncAdaptFallback_oracledb_connection(
|
||||
AsyncAdaptFallback_dbapi_connection, AsyncAdapt_oracledb_connection
|
||||
):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class OracledbAdaptDBAPI:
|
||||
def __init__(self, oracledb) -> None:
|
||||
self.oracledb = oracledb
|
||||
|
||||
for k, v in self.oracledb.__dict__.items():
|
||||
if k != "connect":
|
||||
self.__dict__[k] = v
|
||||
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async)
|
||||
|
||||
if asbool(async_fallback):
|
||||
return AsyncAdaptFallback_oracledb_connection(
|
||||
self, await_fallback(creator_fn(*arg, **kw))
|
||||
)
|
||||
|
||||
else:
|
||||
return AsyncAdapt_oracledb_connection(
|
||||
self, await_only(creator_fn(*arg, **kw))
|
||||
)
|
||||
|
||||
|
||||
class OracleExecutionContextAsync_oracledb(OracleExecutionContext_oracledb):
|
||||
# restore default create cursor
|
||||
create_cursor = default.DefaultExecutionContext.create_cursor
|
||||
|
||||
def create_default_cursor(self):
|
||||
# copy of OracleExecutionContext_cx_oracle.create_cursor
|
||||
c = self._dbapi_connection.cursor()
|
||||
if self.dialect.arraysize:
|
||||
c.arraysize = self.dialect.arraysize
|
||||
|
||||
return c
|
||||
|
||||
def create_server_side_cursor(self):
|
||||
c = self._dbapi_connection.ss_cursor()
|
||||
if self.dialect.arraysize:
|
||||
c.arraysize = self.dialect.arraysize
|
||||
|
||||
return c
|
||||
|
||||
|
||||
class OracleDialectAsync_oracledb(OracleDialect_oracledb):
|
||||
is_async = True
|
||||
supports_server_side_cursors = True
|
||||
supports_statement_cache = True
|
||||
execution_ctx_cls = OracleExecutionContextAsync_oracledb
|
||||
|
||||
_min_version = (2,)
|
||||
|
||||
# thick_mode mode is not supported by asyncio, oracledb will raise
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
import oracledb
|
||||
|
||||
return OracledbAdaptDBAPI(oracledb)
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
async_fallback = url.query.get("async_fallback", False)
|
||||
|
||||
if asbool(async_fallback):
|
||||
return pool.FallbackAsyncAdaptedQueuePool
|
||||
else:
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = OracleDialect_oracledb
|
||||
dialect_async = OracleDialectAsync_oracledb
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
# dialects/oracle/provision.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import create_engine
|
||||
@@ -89,7 +83,7 @@ def _oracle_drop_db(cfg, eng, ident):
|
||||
# cx_Oracle seems to occasionally leak open connections when a large
|
||||
# suite it run, even if we confirm we have zero references to
|
||||
# connection objects.
|
||||
# while there is a "kill session" command in Oracle Database,
|
||||
# while there is a "kill session" command in Oracle,
|
||||
# it unfortunately does not release the connection sufficiently.
|
||||
_ora_drop_ignore(conn, ident)
|
||||
_ora_drop_ignore(conn, "%s_ts1" % ident)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# dialects/oracle/types.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -64,18 +63,17 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer):
|
||||
|
||||
|
||||
class FLOAT(sqltypes.FLOAT):
|
||||
"""Oracle Database FLOAT.
|
||||
"""Oracle FLOAT.
|
||||
|
||||
This is the same as :class:`_sqltypes.FLOAT` except that
|
||||
an Oracle Database -specific :paramref:`_oracle.FLOAT.binary_precision`
|
||||
an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
|
||||
parameter is accepted, and
|
||||
the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
|
||||
|
||||
Oracle Database FLOAT types indicate precision in terms of "binary
|
||||
precision", which defaults to 126. For a REAL type, the value is 63. This
|
||||
parameter does not cleanly map to a specific number of decimal places but
|
||||
is roughly equivalent to the desired number of decimal places divided by
|
||||
0.3103.
|
||||
Oracle FLOAT types indicate precision in terms of "binary precision", which
|
||||
defaults to 126. For a REAL type, the value is 63. This parameter does not
|
||||
cleanly map to a specific number of decimal places but is roughly
|
||||
equivalent to the desired number of decimal places divided by 0.3103.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
@@ -92,11 +90,10 @@ class FLOAT(sqltypes.FLOAT):
|
||||
r"""
|
||||
Construct a FLOAT
|
||||
|
||||
:param binary_precision: Oracle Database binary precision value to be
|
||||
rendered in DDL. This may be approximated to the number of decimal
|
||||
characters using the formula "decimal precision = 0.30103 * binary
|
||||
precision". The default value used by Oracle Database for FLOAT /
|
||||
DOUBLE PRECISION is 126.
|
||||
:param binary_precision: Oracle binary precision value to be rendered
|
||||
in DDL. This may be approximated to the number of decimal characters
|
||||
using the formula "decimal precision = 0.30103 * binary precision".
|
||||
The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
|
||||
|
||||
:param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
|
||||
|
||||
@@ -111,36 +108,10 @@ class FLOAT(sqltypes.FLOAT):
|
||||
|
||||
|
||||
class BINARY_DOUBLE(sqltypes.Double):
|
||||
"""Implement the Oracle ``BINARY_DOUBLE`` datatype.
|
||||
|
||||
This datatype differs from the Oracle ``DOUBLE`` datatype in that it
|
||||
delivers a true 8-byte FP value. The datatype may be combined with a
|
||||
generic :class:`.Double` datatype using :meth:`.TypeEngine.with_variant`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`oracle_float_support`
|
||||
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "BINARY_DOUBLE"
|
||||
|
||||
|
||||
class BINARY_FLOAT(sqltypes.Float):
|
||||
"""Implement the Oracle ``BINARY_FLOAT`` datatype.
|
||||
|
||||
This datatype differs from the Oracle ``FLOAT`` datatype in that it
|
||||
delivers a true 4-byte FP value. The datatype may be combined with a
|
||||
generic :class:`.Float` datatype using :meth:`.TypeEngine.with_variant`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`oracle_float_support`
|
||||
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "BINARY_FLOAT"
|
||||
|
||||
|
||||
@@ -191,10 +162,10 @@ class _OracleDateLiteralRender:
|
||||
|
||||
|
||||
class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
|
||||
"""Provide the Oracle Database DATE type.
|
||||
"""Provide the oracle DATE type.
|
||||
|
||||
This type has no special Python behavior, except that it subclasses
|
||||
:class:`_types.DateTime`; this is to suit the fact that the Oracle Database
|
||||
:class:`_types.DateTime`; this is to suit the fact that the Oracle
|
||||
``DATE`` type supports a time value.
|
||||
|
||||
"""
|
||||
@@ -274,8 +245,8 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
|
||||
|
||||
|
||||
class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
"""Oracle Database implementation of ``TIMESTAMP``, which supports
|
||||
additional Oracle Database-specific modes
|
||||
"""Oracle implementation of ``TIMESTAMP``, which supports additional
|
||||
Oracle-specific modes
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
@@ -285,11 +256,10 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
"""Construct a new :class:`_oracle.TIMESTAMP`.
|
||||
|
||||
:param timezone: boolean. Indicates that the TIMESTAMP type should
|
||||
use Oracle Database's ``TIMESTAMP WITH TIME ZONE`` datatype.
|
||||
use Oracle's ``TIMESTAMP WITH TIME ZONE`` datatype.
|
||||
|
||||
:param local_timezone: boolean. Indicates that the TIMESTAMP type
|
||||
should use Oracle Database's ``TIMESTAMP WITH LOCAL TIME ZONE``
|
||||
datatype.
|
||||
should use Oracle's ``TIMESTAMP WITH LOCAL TIME ZONE`` datatype.
|
||||
|
||||
|
||||
"""
|
||||
@@ -302,7 +272,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
|
||||
|
||||
class ROWID(sqltypes.TypeEngine):
|
||||
"""Oracle Database ROWID type.
|
||||
"""Oracle ROWID type.
|
||||
|
||||
When used in a cast() or similar, generates ROWID.
|
||||
|
||||
|
||||
@@ -1,364 +0,0 @@
|
||||
# dialects/oracle/vector.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import array
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import sqlalchemy.types as types
|
||||
from sqlalchemy.types import Float
|
||||
|
||||
|
||||
class VectorIndexType(Enum):
|
||||
"""Enum representing different types of VECTOR index structures.
|
||||
|
||||
See :ref:`oracle_vector_datatype` for background.
|
||||
|
||||
.. versionadded:: 2.0.41
|
||||
|
||||
"""
|
||||
|
||||
HNSW = "HNSW"
|
||||
"""
|
||||
The HNSW (Hierarchical Navigable Small World) index type.
|
||||
"""
|
||||
IVF = "IVF"
|
||||
"""
|
||||
The IVF (Inverted File Index) index type
|
||||
"""
|
||||
|
||||
|
||||
class VectorDistanceType(Enum):
|
||||
"""Enum representing different types of vector distance metrics.
|
||||
|
||||
See :ref:`oracle_vector_datatype` for background.
|
||||
|
||||
.. versionadded:: 2.0.41
|
||||
|
||||
"""
|
||||
|
||||
EUCLIDEAN = "EUCLIDEAN"
|
||||
"""Euclidean distance (L2 norm).
|
||||
|
||||
Measures the straight-line distance between two vectors in space.
|
||||
"""
|
||||
DOT = "DOT"
|
||||
"""Dot product similarity.
|
||||
|
||||
Measures the algebraic similarity between two vectors.
|
||||
"""
|
||||
COSINE = "COSINE"
|
||||
"""Cosine similarity.
|
||||
|
||||
Measures the cosine of the angle between two vectors.
|
||||
"""
|
||||
MANHATTAN = "MANHATTAN"
|
||||
"""Manhattan distance (L1 norm).
|
||||
|
||||
Calculates the sum of absolute differences across dimensions.
|
||||
"""
|
||||
|
||||
|
||||
class VectorStorageFormat(Enum):
|
||||
"""Enum representing the data format used to store vector components.
|
||||
|
||||
See :ref:`oracle_vector_datatype` for background.
|
||||
|
||||
.. versionadded:: 2.0.41
|
||||
|
||||
"""
|
||||
|
||||
INT8 = "INT8"
|
||||
"""
|
||||
8-bit integer format.
|
||||
"""
|
||||
BINARY = "BINARY"
|
||||
"""
|
||||
Binary format.
|
||||
"""
|
||||
FLOAT32 = "FLOAT32"
|
||||
"""
|
||||
32-bit floating-point format.
|
||||
"""
|
||||
FLOAT64 = "FLOAT64"
|
||||
"""
|
||||
64-bit floating-point format.
|
||||
"""
|
||||
|
||||
|
||||
class VectorStorageType(Enum):
|
||||
"""Enum representing the vector type,
|
||||
|
||||
See :ref:`oracle_vector_datatype` for background.
|
||||
|
||||
.. versionadded:: 2.0.43
|
||||
|
||||
"""
|
||||
|
||||
SPARSE = "SPARSE"
|
||||
"""
|
||||
A Sparse vector is a vector which has zero value for
|
||||
most of its dimensions.
|
||||
"""
|
||||
DENSE = "DENSE"
|
||||
"""
|
||||
A Dense vector is a vector where most, if not all, elements
|
||||
hold meaningful values.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorIndexConfig:
|
||||
"""Define the configuration for Oracle VECTOR Index.
|
||||
|
||||
See :ref:`oracle_vector_datatype` for background.
|
||||
|
||||
.. versionadded:: 2.0.41
|
||||
|
||||
:param index_type: Enum value from :class:`.VectorIndexType`
|
||||
Specifies the indexing method. For HNSW, this must be
|
||||
:attr:`.VectorIndexType.HNSW`.
|
||||
|
||||
:param distance: Enum value from :class:`.VectorDistanceType`
|
||||
specifies the metric for calculating distance between VECTORS.
|
||||
|
||||
:param accuracy: interger. Should be in the range 0 to 100
|
||||
Specifies the accuracy of the nearest neighbor search during
|
||||
query execution.
|
||||
|
||||
:param parallel: integer. Specifies degree of parallelism.
|
||||
|
||||
:param hnsw_neighbors: interger. Should be in the range 0 to
|
||||
2048. Specifies the number of nearest neighbors considered
|
||||
during the search. The attribute :attr:`.VectorIndexConfig.hnsw_neighbors`
|
||||
is HNSW index specific.
|
||||
|
||||
:param hnsw_efconstruction: integer. Should be in the range 0
|
||||
to 65535. Controls the trade-off between indexing speed and
|
||||
recall quality during index construction. The attribute
|
||||
:attr:`.VectorIndexConfig.hnsw_efconstruction` is HNSW index
|
||||
specific.
|
||||
|
||||
:param ivf_neighbor_partitions: integer. Should be in the range
|
||||
0 to 10,000,000. Specifies the number of partitions used to
|
||||
divide the dataset. The attribute
|
||||
:attr:`.VectorIndexConfig.ivf_neighbor_partitions` is IVF index
|
||||
specific.
|
||||
|
||||
:param ivf_sample_per_partition: integer. Should be between 1
|
||||
and ``num_vectors / neighbor partitions``. Specifies the
|
||||
number of samples used per partition. The attribute
|
||||
:attr:`.VectorIndexConfig.ivf_sample_per_partition` is IVF index
|
||||
specific.
|
||||
|
||||
:param ivf_min_vectors_per_partition: integer. From 0 (no trimming)
|
||||
to the total number of vectors (results in 1 partition). Specifies
|
||||
the minimum number of vectors per partition. The attribute
|
||||
:attr:`.VectorIndexConfig.ivf_min_vectors_per_partition`
|
||||
is IVF index specific.
|
||||
|
||||
"""
|
||||
|
||||
index_type: VectorIndexType = VectorIndexType.HNSW
|
||||
distance: Optional[VectorDistanceType] = None
|
||||
accuracy: Optional[int] = None
|
||||
hnsw_neighbors: Optional[int] = None
|
||||
hnsw_efconstruction: Optional[int] = None
|
||||
ivf_neighbor_partitions: Optional[int] = None
|
||||
ivf_sample_per_partition: Optional[int] = None
|
||||
ivf_min_vectors_per_partition: Optional[int] = None
|
||||
parallel: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_type = VectorIndexType(self.index_type)
|
||||
for field in [
|
||||
"hnsw_neighbors",
|
||||
"hnsw_efconstruction",
|
||||
"ivf_neighbor_partitions",
|
||||
"ivf_sample_per_partition",
|
||||
"ivf_min_vectors_per_partition",
|
||||
"parallel",
|
||||
"accuracy",
|
||||
]:
|
||||
value = getattr(self, field)
|
||||
if value is not None and not isinstance(value, int):
|
||||
raise TypeError(
|
||||
f"{field} must be an integer if"
|
||||
f"provided, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
|
||||
class SparseVector:
|
||||
"""
|
||||
Lightweight SQLAlchemy-side version of SparseVector.
|
||||
This mimics oracledb.SparseVector.
|
||||
|
||||
.. versionadded:: 2.0.43
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_dimensions: int,
|
||||
indices: Union[list, array.array],
|
||||
values: Union[list, array.array],
|
||||
):
|
||||
if not isinstance(indices, array.array) or indices.typecode != "I":
|
||||
indices = array.array("I", indices)
|
||||
if not isinstance(values, array.array):
|
||||
values = array.array("d", values)
|
||||
if len(indices) != len(values):
|
||||
raise TypeError("indices and values must be of the same length!")
|
||||
|
||||
self.num_dimensions = num_dimensions
|
||||
self.indices = indices
|
||||
self.values = values
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"SparseVector(num_dimensions={self.num_dimensions}, "
|
||||
f"size={len(self.indices)}, typecode={self.values.typecode})"
|
||||
)
|
||||
|
||||
|
||||
class VECTOR(types.TypeEngine):
|
||||
"""Oracle VECTOR datatype.
|
||||
|
||||
For complete background on using this type, see
|
||||
:ref:`oracle_vector_datatype`.
|
||||
|
||||
.. versionadded:: 2.0.41
|
||||
|
||||
"""
|
||||
|
||||
cache_ok = True
|
||||
__visit_name__ = "VECTOR"
|
||||
|
||||
_typecode_map = {
|
||||
VectorStorageFormat.INT8: "b", # Signed int
|
||||
VectorStorageFormat.BINARY: "B", # Unsigned int
|
||||
VectorStorageFormat.FLOAT32: "f", # Float
|
||||
VectorStorageFormat.FLOAT64: "d", # Double
|
||||
}
|
||||
|
||||
def __init__(self, dim=None, storage_format=None, storage_type=None):
|
||||
"""Construct a VECTOR.
|
||||
|
||||
:param dim: integer. The dimension of the VECTOR datatype. This
|
||||
should be an integer value.
|
||||
|
||||
:param storage_format: VectorStorageFormat. The VECTOR storage
|
||||
type format. This should be Enum values form
|
||||
:class:`.VectorStorageFormat` INT8, BINARY, FLOAT32, or FLOAT64.
|
||||
|
||||
:param storage_type: VectorStorageType. The Vector storage type. This
|
||||
should be Enum values from :class:`.VectorStorageType` SPARSE or
|
||||
DENSE.
|
||||
|
||||
"""
|
||||
|
||||
if dim is not None and not isinstance(dim, int):
|
||||
raise TypeError("dim must be an interger")
|
||||
if storage_format is not None and not isinstance(
|
||||
storage_format, VectorStorageFormat
|
||||
):
|
||||
raise TypeError(
|
||||
"storage_format must be an enum of type VectorStorageFormat"
|
||||
)
|
||||
if storage_type is not None and not isinstance(
|
||||
storage_type, VectorStorageType
|
||||
):
|
||||
raise TypeError(
|
||||
"storage_type must be an enum of type VectorStorageType"
|
||||
)
|
||||
|
||||
self.dim = dim
|
||||
self.storage_format = storage_format
|
||||
self.storage_type = storage_type
|
||||
|
||||
def _cached_bind_processor(self, dialect):
|
||||
"""
|
||||
Converts a Python-side SparseVector instance into an
|
||||
oracledb.SparseVectormor a compatible array format before
|
||||
binding it to the database.
|
||||
"""
|
||||
|
||||
def process(value):
|
||||
if value is None or isinstance(value, array.array):
|
||||
return value
|
||||
|
||||
# Convert list to a array.array
|
||||
elif isinstance(value, list):
|
||||
typecode = self._array_typecode(self.storage_format)
|
||||
value = array.array(typecode, value)
|
||||
return value
|
||||
|
||||
# Convert SqlAlchemy SparseVector to oracledb SparseVector object
|
||||
elif isinstance(value, SparseVector):
|
||||
return dialect.dbapi.SparseVector(
|
||||
value.num_dimensions,
|
||||
value.indices,
|
||||
value.values,
|
||||
)
|
||||
|
||||
else:
|
||||
raise TypeError(
|
||||
"""
|
||||
Invalid input for VECTOR: expected a list, an array.array,
|
||||
or a SparseVector object.
|
||||
"""
|
||||
)
|
||||
|
||||
return process
|
||||
|
||||
def _cached_result_processor(self, dialect, coltype):
|
||||
"""
|
||||
Converts database-returned values into Python-native representations.
|
||||
If the value is an oracledb.SparseVector, it is converted into the
|
||||
SQLAlchemy-side SparseVector class.
|
||||
If the value is a array.array, it is converted to a plain Python list.
|
||||
|
||||
"""
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
elif isinstance(value, array.array):
|
||||
return list(value)
|
||||
|
||||
# Convert Oracledb SparseVector to SqlAlchemy SparseVector object
|
||||
elif isinstance(value, dialect.dbapi.SparseVector):
|
||||
return SparseVector(
|
||||
num_dimensions=value.num_dimensions,
|
||||
indices=value.indices,
|
||||
values=value.values,
|
||||
)
|
||||
|
||||
return process
|
||||
|
||||
def _array_typecode(self, typecode):
|
||||
"""
|
||||
Map storage format to array typecode.
|
||||
"""
|
||||
return self._typecode_map.get(typecode, "d")
|
||||
|
||||
class comparator_factory(types.TypeEngine.Comparator):
|
||||
def l2_distance(self, other):
|
||||
return self.op("<->", return_type=Float)(other)
|
||||
|
||||
def inner_product(self, other):
|
||||
return self.op("<#>", return_type=Float)(other)
|
||||
|
||||
def cosine_distance(self, other):
|
||||
return self.op("<=>", return_type=Float)(other)
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/__init__.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
from types import ModuleType
|
||||
|
||||
from . import array as arraylib # noqa # keep above base and other dialects
|
||||
from . import asyncpg # noqa
|
||||
from . import base
|
||||
from . import pg8000 # noqa
|
||||
@@ -57,14 +56,12 @@ from .named_types import ENUM
|
||||
from .named_types import NamedType
|
||||
from .ranges import AbstractMultiRange
|
||||
from .ranges import AbstractRange
|
||||
from .ranges import AbstractSingleRange
|
||||
from .ranges import DATEMULTIRANGE
|
||||
from .ranges import DATERANGE
|
||||
from .ranges import INT4MULTIRANGE
|
||||
from .ranges import INT4RANGE
|
||||
from .ranges import INT8MULTIRANGE
|
||||
from .ranges import INT8RANGE
|
||||
from .ranges import MultiRange
|
||||
from .ranges import NUMMULTIRANGE
|
||||
from .ranges import NUMRANGE
|
||||
from .ranges import Range
|
||||
@@ -89,7 +86,6 @@ from .types import TIMESTAMP
|
||||
from .types import TSQUERY
|
||||
from .types import TSVECTOR
|
||||
|
||||
|
||||
# Alias psycopg also as psycopg_async
|
||||
psycopg_async = type(
|
||||
"psycopg_async", (ModuleType,), {"dialect": psycopg.dialect_async}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# dialects/postgresql/_psycopg_common.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -170,10 +169,8 @@ class _PGDialect_common_psycopg(PGDialect):
|
||||
def _do_autocommit(self, connection, value):
|
||||
connection.autocommit = value
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_connection):
|
||||
return bool(dbapi_connection.autocommit)
|
||||
|
||||
def do_ping(self, dbapi_connection):
|
||||
cursor = None
|
||||
before_autocommit = dbapi_connection.autocommit
|
||||
|
||||
if not before_autocommit:
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
# dialects/postgresql/array.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/array.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any as typing_Any
|
||||
from typing import Iterable
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from .operators import CONTAINED_BY
|
||||
from .operators import CONTAINS
|
||||
@@ -24,55 +21,32 @@ from ... import types as sqltypes
|
||||
from ... import util
|
||||
from ...sql import expression
|
||||
from ...sql import operators
|
||||
from ...sql.visitors import InternalTraversal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql._typing import _ColumnExpressionArgument
|
||||
from ...sql._typing import _TypeEngineArgument
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.elements import Grouping
|
||||
from ...sql.expression import BindParameter
|
||||
from ...sql.operators import OperatorType
|
||||
from ...sql.selectable import _SelectIterable
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _LiteralProcessorType
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
from ...sql.type_api import TypeEngine
|
||||
from ...sql.visitors import _TraverseInternalsType
|
||||
from ...util.typing import Self
|
||||
from ...sql._typing import _TypeEngineArgument
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=typing_Any)
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
|
||||
|
||||
def Any(
|
||||
other: typing_Any,
|
||||
arrexpr: _ColumnExpressionArgument[_T],
|
||||
operator: OperatorType = operators.eq,
|
||||
) -> ColumnElement[bool]:
|
||||
def Any(other, arrexpr, operator=operators.eq):
|
||||
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
|
||||
See that method for details.
|
||||
|
||||
"""
|
||||
|
||||
return arrexpr.any(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
|
||||
return arrexpr.any(other, operator)
|
||||
|
||||
|
||||
def All(
|
||||
other: typing_Any,
|
||||
arrexpr: _ColumnExpressionArgument[_T],
|
||||
operator: OperatorType = operators.eq,
|
||||
) -> ColumnElement[bool]:
|
||||
def All(other, arrexpr, operator=operators.eq):
|
||||
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
|
||||
See that method for details.
|
||||
|
||||
"""
|
||||
|
||||
return arrexpr.all(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
|
||||
return arrexpr.all(other, operator)
|
||||
|
||||
|
||||
class array(expression.ExpressionClauseList[_T]):
|
||||
|
||||
"""A PostgreSQL ARRAY literal.
|
||||
|
||||
This is used to produce ARRAY literals in SQL expressions, e.g.::
|
||||
@@ -81,43 +55,20 @@ class array(expression.ExpressionClauseList[_T]):
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import select, func
|
||||
|
||||
stmt = select(array([1, 2]) + array([3, 4, 5]))
|
||||
stmt = select(array([1,2]) + array([3,4,5]))
|
||||
|
||||
print(stmt.compile(dialect=postgresql.dialect()))
|
||||
|
||||
Produces the SQL:
|
||||
|
||||
.. sourcecode:: sql
|
||||
Produces the SQL::
|
||||
|
||||
SELECT ARRAY[%(param_1)s, %(param_2)s] ||
|
||||
ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
|
||||
|
||||
An instance of :class:`.array` will always have the datatype
|
||||
:class:`_types.ARRAY`. The "inner" type of the array is inferred from the
|
||||
values present, unless the :paramref:`_postgresql.array.type_` keyword
|
||||
argument is passed::
|
||||
:class:`_types.ARRAY`. The "inner" type of the array is inferred from
|
||||
the values present, unless the ``type_`` keyword argument is passed::
|
||||
|
||||
array(["foo", "bar"], type_=CHAR)
|
||||
|
||||
When constructing an empty array, the :paramref:`_postgresql.array.type_`
|
||||
argument is particularly important as PostgreSQL server typically requires
|
||||
a cast to be rendered for the inner type in order to render an empty array.
|
||||
SQLAlchemy's compilation for the empty array will produce this cast so
|
||||
that::
|
||||
|
||||
stmt = array([], type_=Integer)
|
||||
print(stmt.compile(dialect=postgresql.dialect()))
|
||||
|
||||
Produces:
|
||||
|
||||
.. sourcecode:: sql
|
||||
|
||||
ARRAY[]::INTEGER[]
|
||||
|
||||
As required by PostgreSQL for empty arrays.
|
||||
|
||||
.. versionadded:: 2.0.40 added support to render empty PostgreSQL array
|
||||
literals with a required cast.
|
||||
array(['foo', 'bar'], type_=CHAR)
|
||||
|
||||
Multidimensional arrays are produced by nesting :class:`.array` constructs.
|
||||
The dimensionality of the final :class:`_types.ARRAY`
|
||||
@@ -126,21 +77,16 @@ class array(expression.ExpressionClauseList[_T]):
|
||||
type::
|
||||
|
||||
stmt = select(
|
||||
array(
|
||||
[array([1, 2]), array([3, 4]), array([column("q"), column("x")])]
|
||||
)
|
||||
array([
|
||||
array([1, 2]), array([3, 4]), array([column('q'), column('x')])
|
||||
])
|
||||
)
|
||||
print(stmt.compile(dialect=postgresql.dialect()))
|
||||
|
||||
Produces:
|
||||
Produces::
|
||||
|
||||
.. sourcecode:: sql
|
||||
|
||||
SELECT ARRAY[
|
||||
ARRAY[%(param_1)s, %(param_2)s],
|
||||
ARRAY[%(param_3)s, %(param_4)s],
|
||||
ARRAY[q, x]
|
||||
] AS anon_1
|
||||
SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s],
|
||||
ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1
|
||||
|
||||
.. versionadded:: 1.3.6 added support for multidimensional array literals
|
||||
|
||||
@@ -148,63 +94,42 @@ class array(expression.ExpressionClauseList[_T]):
|
||||
|
||||
:class:`_postgresql.ARRAY`
|
||||
|
||||
""" # noqa: E501
|
||||
"""
|
||||
|
||||
__visit_name__ = "array"
|
||||
|
||||
stringify_dialect = "postgresql"
|
||||
inherit_cache = True
|
||||
|
||||
_traverse_internals: _TraverseInternalsType = [
|
||||
("clauses", InternalTraversal.dp_clauseelement_tuple),
|
||||
("type", InternalTraversal.dp_type),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clauses: Iterable[_T],
|
||||
*,
|
||||
type_: Optional[_TypeEngineArgument[_T]] = None,
|
||||
**kw: typing_Any,
|
||||
):
|
||||
r"""Construct an ARRAY literal.
|
||||
|
||||
:param clauses: iterable, such as a list, containing elements to be
|
||||
rendered in the array
|
||||
:param type\_: optional type. If omitted, the type is inferred
|
||||
from the contents of the array.
|
||||
|
||||
"""
|
||||
def __init__(self, clauses, **kw):
|
||||
type_arg = kw.pop("type_", None)
|
||||
super().__init__(operators.comma_op, *clauses, **kw)
|
||||
|
||||
self._type_tuple = [arg.type for arg in self.clauses]
|
||||
|
||||
main_type = (
|
||||
type_
|
||||
if type_ is not None
|
||||
else self.clauses[0].type if self.clauses else sqltypes.NULLTYPE
|
||||
type_arg
|
||||
if type_arg is not None
|
||||
else self._type_tuple[0]
|
||||
if self._type_tuple
|
||||
else sqltypes.NULLTYPE
|
||||
)
|
||||
|
||||
if isinstance(main_type, ARRAY):
|
||||
self.type = ARRAY(
|
||||
main_type.item_type,
|
||||
dimensions=(
|
||||
main_type.dimensions + 1
|
||||
if main_type.dimensions is not None
|
||||
else 2
|
||||
),
|
||||
) # type: ignore[assignment]
|
||||
dimensions=main_type.dimensions + 1
|
||||
if main_type.dimensions is not None
|
||||
else 2,
|
||||
)
|
||||
else:
|
||||
self.type = ARRAY(main_type) # type: ignore[assignment]
|
||||
self.type = ARRAY(main_type)
|
||||
|
||||
@property
|
||||
def _select_iterable(self) -> _SelectIterable:
|
||||
def _select_iterable(self):
|
||||
return (self,)
|
||||
|
||||
def _bind_param(
|
||||
self,
|
||||
operator: OperatorType,
|
||||
obj: typing_Any,
|
||||
type_: Optional[TypeEngine[_T]] = None,
|
||||
_assume_scalar: bool = False,
|
||||
) -> BindParameter[_T]:
|
||||
def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
|
||||
if _assume_scalar or operator is operators.getitem:
|
||||
return expression.BindParameter(
|
||||
None,
|
||||
@@ -223,18 +148,16 @@ class array(expression.ExpressionClauseList[_T]):
|
||||
)
|
||||
for o in obj
|
||||
]
|
||||
) # type: ignore[return-value]
|
||||
)
|
||||
|
||||
def self_group(
|
||||
self, against: Optional[OperatorType] = None
|
||||
) -> Union[Self, Grouping[_T]]:
|
||||
def self_group(self, against=None):
|
||||
if against in (operators.any_op, operators.all_op, operators.getitem):
|
||||
return expression.Grouping(self)
|
||||
else:
|
||||
return self
|
||||
|
||||
|
||||
class ARRAY(sqltypes.ARRAY[_T]):
|
||||
class ARRAY(sqltypes.ARRAY):
|
||||
"""PostgreSQL ARRAY type.
|
||||
|
||||
The :class:`_postgresql.ARRAY` type is constructed in the same way
|
||||
@@ -244,11 +167,9 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
mytable = Table(
|
||||
"mytable",
|
||||
metadata,
|
||||
Column("data", postgresql.ARRAY(Integer, dimensions=2)),
|
||||
)
|
||||
mytable = Table("mytable", metadata,
|
||||
Column("data", postgresql.ARRAY(Integer, dimensions=2))
|
||||
)
|
||||
|
||||
The :class:`_postgresql.ARRAY` type provides all operations defined on the
|
||||
core :class:`_types.ARRAY` type, including support for "dimensions",
|
||||
@@ -263,9 +184,8 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
|
||||
mytable.c.data.contains([1, 2])
|
||||
|
||||
Indexed access is one-based by default, to match that of PostgreSQL;
|
||||
for zero-based indexed access, set
|
||||
:paramref:`_postgresql.ARRAY.zero_indexes`.
|
||||
The :class:`_postgresql.ARRAY` type may not be supported on all
|
||||
PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
|
||||
|
||||
Additionally, the :class:`_postgresql.ARRAY`
|
||||
type does not work directly in
|
||||
@@ -284,7 +204,6 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from sqlalchemy.ext.mutable import MutableList
|
||||
|
||||
|
||||
class SomeOrmClass(Base):
|
||||
# ...
|
||||
|
||||
@@ -306,9 +225,45 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
|
||||
"""
|
||||
|
||||
class Comparator(sqltypes.ARRAY.Comparator):
|
||||
|
||||
"""Define comparison operations for :class:`_types.ARRAY`.
|
||||
|
||||
Note that these operations are in addition to those provided
|
||||
by the base :class:`.types.ARRAY.Comparator` class, including
|
||||
:meth:`.types.ARRAY.Comparator.any` and
|
||||
:meth:`.types.ARRAY.Comparator.all`.
|
||||
|
||||
"""
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
"""Boolean expression. Test if elements are a superset of the
|
||||
elements of the argument array expression.
|
||||
|
||||
kwargs may be ignored by this operator but are required for API
|
||||
conformance.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Test if elements are a proper subset of the
|
||||
elements of the argument array expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def overlap(self, other):
|
||||
"""Boolean expression. Test if array has elements in common with
|
||||
an argument array expression.
|
||||
"""
|
||||
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item_type: _TypeEngineArgument[_T],
|
||||
item_type: _TypeEngineArgument[Any],
|
||||
as_tuple: bool = False,
|
||||
dimensions: Optional[int] = None,
|
||||
zero_indexes: bool = False,
|
||||
@@ -317,7 +272,7 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
|
||||
E.g.::
|
||||
|
||||
Column("myarray", ARRAY(Integer))
|
||||
Column('myarray', ARRAY(Integer))
|
||||
|
||||
Arguments are:
|
||||
|
||||
@@ -357,63 +312,35 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
self.dimensions = dimensions
|
||||
self.zero_indexes = zero_indexes
|
||||
|
||||
class Comparator(sqltypes.ARRAY.Comparator[_T]):
|
||||
"""Define comparison operations for :class:`_types.ARRAY`.
|
||||
@property
|
||||
def hashable(self):
|
||||
return self.as_tuple
|
||||
|
||||
Note that these operations are in addition to those provided
|
||||
by the base :class:`.types.ARRAY.Comparator` class, including
|
||||
:meth:`.types.ARRAY.Comparator.any` and
|
||||
:meth:`.types.ARRAY.Comparator.all`.
|
||||
@property
|
||||
def python_type(self):
|
||||
return list
|
||||
|
||||
"""
|
||||
|
||||
def contains(
|
||||
self, other: typing_Any, **kwargs: typing_Any
|
||||
) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test if elements are a superset of the
|
||||
elements of the argument array expression.
|
||||
|
||||
kwargs may be ignored by this operator but are required for API
|
||||
conformance.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other: typing_Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test if elements are a proper subset of the
|
||||
elements of the argument array expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def overlap(self, other: typing_Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test if array has elements in common with
|
||||
an argument array expression.
|
||||
"""
|
||||
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
|
||||
|
||||
comparator_factory = Comparator
|
||||
def compare_values(self, x, y):
|
||||
return x == y
|
||||
|
||||
@util.memoized_property
|
||||
def _against_native_enum(self) -> bool:
|
||||
def _against_native_enum(self):
|
||||
return (
|
||||
isinstance(self.item_type, sqltypes.Enum)
|
||||
and self.item_type.native_enum
|
||||
)
|
||||
|
||||
def literal_processor(
|
||||
self, dialect: Dialect
|
||||
) -> Optional[_LiteralProcessorType[_T]]:
|
||||
def literal_processor(self, dialect):
|
||||
item_proc = self.item_type.dialect_impl(dialect).literal_processor(
|
||||
dialect
|
||||
)
|
||||
if item_proc is None:
|
||||
return None
|
||||
|
||||
def to_str(elements: Iterable[typing_Any]) -> str:
|
||||
def to_str(elements):
|
||||
return f"ARRAY[{', '.join(elements)}]"
|
||||
|
||||
def process(value: Sequence[typing_Any]) -> str:
|
||||
def process(value):
|
||||
inner = self._apply_item_processor(
|
||||
value, item_proc, self.dimensions, to_str
|
||||
)
|
||||
@@ -421,16 +348,12 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
|
||||
return process
|
||||
|
||||
def bind_processor(
|
||||
self, dialect: Dialect
|
||||
) -> Optional[_BindProcessorType[Sequence[typing_Any]]]:
|
||||
def bind_processor(self, dialect):
|
||||
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
|
||||
dialect
|
||||
)
|
||||
|
||||
def process(
|
||||
value: Optional[Sequence[typing_Any]],
|
||||
) -> Optional[list[typing_Any]]:
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
@@ -440,16 +363,12 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
|
||||
return process
|
||||
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> _ResultProcessorType[Sequence[typing_Any]]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
item_proc = self.item_type.dialect_impl(dialect).result_processor(
|
||||
dialect, coltype
|
||||
)
|
||||
|
||||
def process(
|
||||
value: Sequence[typing_Any],
|
||||
) -> Optional[Sequence[typing_Any]]:
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
@@ -464,13 +383,11 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
super_rp = process
|
||||
pattern = re.compile(r"^{(.*)}$")
|
||||
|
||||
def handle_raw_string(value: str) -> list[str]:
|
||||
inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501
|
||||
def handle_raw_string(value):
|
||||
inner = pattern.match(value).group(1)
|
||||
return _split_enum_values(inner)
|
||||
|
||||
def process(
|
||||
value: Sequence[typing_Any],
|
||||
) -> Optional[Sequence[typing_Any]]:
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
# isinstance(value, str) is required to handle
|
||||
@@ -485,7 +402,7 @@ class ARRAY(sqltypes.ARRAY[_T]):
|
||||
return process
|
||||
|
||||
|
||||
def _split_enum_values(array_string: str) -> list[str]:
|
||||
def _split_enum_values(array_string):
|
||||
if '"' not in array_string:
|
||||
# no escape char is present so it can just split on the comma
|
||||
return array_string.split(",") if array_string else []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/asyncpg.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# postgresql/asyncpg.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -23,10 +23,18 @@ This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
|
||||
|
||||
The dialect can also be run as a "synchronous" dialect within the
|
||||
:func:`_sa.create_engine` function, which will pass "await" calls into
|
||||
an ad-hoc event loop. This mode of operation is of **limited use**
|
||||
and is for special testing scenarios only. The mode can be enabled by
|
||||
adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
|
||||
in conjunction with :func:`_sa.create_engine`::
|
||||
|
||||
# for testing purposes only; do not use in production!
|
||||
engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname"
|
||||
)
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
@@ -81,15 +89,11 @@ asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
|
||||
argument)::
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500"
|
||||
)
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500")
|
||||
|
||||
To disable the prepared statement cache, use a value of zero::
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0"
|
||||
)
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0")
|
||||
|
||||
.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
|
||||
|
||||
@@ -119,8 +123,8 @@ To disable the prepared statement cache, use a value of zero::
|
||||
|
||||
.. _asyncpg_prepared_statement_name:
|
||||
|
||||
Prepared Statement Name with PGBouncer
|
||||
--------------------------------------
|
||||
Prepared Statement Name
|
||||
-----------------------
|
||||
|
||||
By default, asyncpg enumerates prepared statements in numeric order, which
|
||||
can lead to errors if a name has already been taken for another prepared
|
||||
@@ -135,10 +139,10 @@ a prepared statement is prepared::
|
||||
from uuid import uuid4
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@somepgbouncer/dbname",
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname",
|
||||
poolclass=NullPool,
|
||||
connect_args={
|
||||
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__",
|
||||
'prepared_statement_name_func': lambda: f'__asyncpg_{uuid4()}__',
|
||||
},
|
||||
)
|
||||
|
||||
@@ -148,7 +152,7 @@ a prepared statement is prepared::
|
||||
|
||||
https://github.com/sqlalchemy/sqlalchemy/issues/6467
|
||||
|
||||
.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in
|
||||
.. warning:: To prevent a buildup of useless prepared statements in
|
||||
your application, it's important to use the :class:`.NullPool` pool
|
||||
class, and to configure PgBouncer to use `DISCARD <https://www.postgresql.org/docs/current/sql-discard.html>`_
|
||||
when returning connections. The DISCARD command is used to release resources held by the db connection,
|
||||
@@ -178,11 +182,13 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
import collections
|
||||
import decimal
|
||||
import json as _py_json
|
||||
import re
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import json
|
||||
from . import ranges
|
||||
@@ -212,6 +218,9 @@ from ...util.concurrency import asyncio
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class AsyncpgARRAY(PGARRAY):
|
||||
render_bind_cast = True
|
||||
@@ -265,20 +274,20 @@ class AsyncpgInteger(sqltypes.Integer):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class AsyncpgSmallInteger(sqltypes.SmallInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class AsyncpgBigInteger(sqltypes.BigInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class AsyncpgJSON(json.JSON):
|
||||
render_bind_cast = True
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
|
||||
class AsyncpgJSONB(json.JSONB):
|
||||
render_bind_cast = True
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
@@ -363,7 +372,7 @@ class AsyncpgCHAR(sqltypes.CHAR):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _AsyncpgRange(ranges.AbstractSingleRangeImpl):
|
||||
class _AsyncpgRange(ranges.AbstractRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
asyncpg_Range = dialect.dbapi.asyncpg.Range
|
||||
|
||||
@@ -417,7 +426,10 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
)
|
||||
return value
|
||||
|
||||
return [to_range(element) for element in value]
|
||||
return [
|
||||
to_range(element)
|
||||
for element in cast("Iterable[ranges.Range]", value)
|
||||
]
|
||||
|
||||
return to_range
|
||||
|
||||
@@ -436,7 +448,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
return rvalue
|
||||
|
||||
if value is not None:
|
||||
value = ranges.MultiRange(to_range(elem) for elem in value)
|
||||
value = [to_range(elem) for elem in value]
|
||||
|
||||
return value
|
||||
|
||||
@@ -494,7 +506,7 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self._rows = deque()
|
||||
self._rows = []
|
||||
self._cursor = None
|
||||
self.description = None
|
||||
self.arraysize = 1
|
||||
@@ -502,7 +514,7 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
self._invalidate_schema_cache_asof = 0
|
||||
|
||||
def close(self):
|
||||
self._rows.clear()
|
||||
self._rows[:] = []
|
||||
|
||||
def _handle_exception(self, error):
|
||||
self._adapt_connection._handle_exception(error)
|
||||
@@ -542,12 +554,11 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
self._cursor = await prepared_stmt.cursor(*parameters)
|
||||
self.rowcount = -1
|
||||
else:
|
||||
self._rows = deque(await prepared_stmt.fetch(*parameters))
|
||||
self._rows = await prepared_stmt.fetch(*parameters)
|
||||
status = prepared_stmt.get_statusmsg()
|
||||
|
||||
reg = re.match(
|
||||
r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)",
|
||||
status or "",
|
||||
r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status
|
||||
)
|
||||
if reg:
|
||||
self.rowcount = int(reg.group(1))
|
||||
@@ -591,11 +602,11 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.popleft()
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.popleft()
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -603,12 +614,13 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
rr = self._rows
|
||||
return [rr.popleft() for _ in range(min(size, len(rr)))]
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = list(self._rows)
|
||||
self._rows.clear()
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
@@ -618,21 +630,23 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
super().__init__(adapt_connection)
|
||||
self._rowbuffer = deque()
|
||||
self._rowbuffer = None
|
||||
|
||||
def close(self):
|
||||
self._cursor = None
|
||||
self._rowbuffer.clear()
|
||||
self._rowbuffer = None
|
||||
|
||||
def _buffer_rows(self):
|
||||
assert self._cursor is not None
|
||||
new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
|
||||
self._rowbuffer.extend(new_rows)
|
||||
self._rowbuffer = collections.deque(new_rows)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._rowbuffer:
|
||||
self._buffer_rows()
|
||||
|
||||
while True:
|
||||
while self._rowbuffer:
|
||||
yield self._rowbuffer.popleft()
|
||||
@@ -655,19 +669,21 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
|
||||
if not self._rowbuffer:
|
||||
self._buffer_rows()
|
||||
|
||||
assert self._cursor is not None
|
||||
rb = self._rowbuffer
|
||||
lb = len(rb)
|
||||
buf = list(self._rowbuffer)
|
||||
lb = len(buf)
|
||||
if size > lb:
|
||||
rb.extend(
|
||||
buf.extend(
|
||||
self._adapt_connection.await_(self._cursor.fetch(size - lb))
|
||||
)
|
||||
|
||||
return [rb.popleft() for _ in range(min(size, len(rb)))]
|
||||
result = buf[0:size]
|
||||
self._rowbuffer = collections.deque(buf[size:])
|
||||
return result
|
||||
|
||||
def fetchall(self):
|
||||
ret = list(self._rowbuffer)
|
||||
ret.extend(self._adapt_connection.await_(self._all()))
|
||||
ret = list(self._rowbuffer) + list(
|
||||
self._adapt_connection.await_(self._all())
|
||||
)
|
||||
self._rowbuffer.clear()
|
||||
return ret
|
||||
|
||||
@@ -717,7 +733,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
self.isolation_level = self._isolation_setting = None
|
||||
self.isolation_level = self._isolation_setting = "read_committed"
|
||||
self.readonly = False
|
||||
self.deferrable = False
|
||||
self._transaction = None
|
||||
@@ -786,9 +802,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
translated_error = exception_mapping[super_](
|
||||
"%s: %s" % (type(error), error)
|
||||
)
|
||||
translated_error.pgcode = translated_error.sqlstate = (
|
||||
getattr(error, "sqlstate", None)
|
||||
)
|
||||
translated_error.pgcode = (
|
||||
translated_error.sqlstate
|
||||
) = getattr(error, "sqlstate", None)
|
||||
raise translated_error from error
|
||||
else:
|
||||
raise error
|
||||
@@ -852,45 +868,25 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
else:
|
||||
return AsyncAdapt_asyncpg_cursor(self)
|
||||
|
||||
async def _rollback_and_discard(self):
|
||||
try:
|
||||
await self._transaction.rollback()
|
||||
finally:
|
||||
# if asyncpg .rollback() was actually called, then whether or
|
||||
# not it raised or succeeded, the transation is done, discard it
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
|
||||
async def _commit_and_discard(self):
|
||||
try:
|
||||
await self._transaction.commit()
|
||||
finally:
|
||||
# if asyncpg .commit() was actually called, then whether or
|
||||
# not it raised or succeeded, the transation is done, discard it
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
|
||||
def rollback(self):
|
||||
if self._started:
|
||||
try:
|
||||
self.await_(self._rollback_and_discard())
|
||||
self.await_(self._transaction.rollback())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
finally:
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
except Exception as error:
|
||||
# don't dereference asyncpg transaction if we didn't
|
||||
# actually try to call rollback() on it
|
||||
self._handle_exception(error)
|
||||
|
||||
def commit(self):
|
||||
if self._started:
|
||||
try:
|
||||
self.await_(self._commit_and_discard())
|
||||
self.await_(self._transaction.commit())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
finally:
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
except Exception as error:
|
||||
# don't dereference asyncpg transaction if we didn't
|
||||
# actually try to call commit() on it
|
||||
self._handle_exception(error)
|
||||
|
||||
def close(self):
|
||||
self.rollback()
|
||||
@@ -898,31 +894,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
self.await_(self._connection.close())
|
||||
|
||||
def terminate(self):
|
||||
if util.concurrency.in_greenlet():
|
||||
# in a greenlet; this is the connection was invalidated
|
||||
# case.
|
||||
try:
|
||||
# try to gracefully close; see #10717
|
||||
# timeout added in asyncpg 0.14.0 December 2017
|
||||
self.await_(asyncio.shield(self._connection.close(timeout=2)))
|
||||
except (
|
||||
asyncio.TimeoutError,
|
||||
asyncio.CancelledError,
|
||||
OSError,
|
||||
self.dbapi.asyncpg.PostgresError,
|
||||
) as e:
|
||||
# in the case where we are recycling an old connection
|
||||
# that may have already been disconnected, close() will
|
||||
# fail with the above timeout. in this case, terminate
|
||||
# the connection without any further waiting.
|
||||
# see issue #8419
|
||||
self._connection.terminate()
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
# re-raise CancelledError if we were cancelled
|
||||
raise
|
||||
else:
|
||||
# not in a greenlet; this is the gc cleanup case
|
||||
self._connection.terminate()
|
||||
self._connection.terminate()
|
||||
self._started = False
|
||||
|
||||
@staticmethod
|
||||
@@ -1059,7 +1031,6 @@ class PGDialect_asyncpg(PGDialect):
|
||||
INTERVAL: AsyncPgInterval,
|
||||
sqltypes.Boolean: AsyncpgBoolean,
|
||||
sqltypes.Integer: AsyncpgInteger,
|
||||
sqltypes.SmallInteger: AsyncpgSmallInteger,
|
||||
sqltypes.BigInteger: AsyncpgBigInteger,
|
||||
sqltypes.Numeric: AsyncpgNumeric,
|
||||
sqltypes.Float: AsyncpgFloat,
|
||||
@@ -1074,7 +1045,7 @@ class PGDialect_asyncpg(PGDialect):
|
||||
OID: AsyncpgOID,
|
||||
REGCLASS: AsyncpgREGCLASS,
|
||||
sqltypes.CHAR: AsyncpgCHAR,
|
||||
ranges.AbstractSingleRange: _AsyncpgRange,
|
||||
ranges.AbstractRange: _AsyncpgRange,
|
||||
ranges.AbstractMultiRange: _AsyncpgMultiRange,
|
||||
},
|
||||
)
|
||||
@@ -1117,9 +1088,6 @@ class PGDialect_asyncpg(PGDialect):
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
dbapi_connection.set_isolation_level(self._isolation_lookup[level])
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def set_readonly(self, connection, value):
|
||||
connection.readonly = value
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/dml.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/dml.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -7,10 +7,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from . import ext
|
||||
from .._typing import _OnConflictConstraintT
|
||||
@@ -29,9 +26,7 @@ from ...sql.base import ColumnCollection
|
||||
from ...sql.base import ReadOnlyColumnCollection
|
||||
from ...sql.dml import Insert as StandardInsert
|
||||
from ...sql.elements import ClauseElement
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.elements import KeyedColumnElement
|
||||
from ...sql.elements import TextClause
|
||||
from ...sql.expression import alias
|
||||
from ...util.typing import Self
|
||||
|
||||
@@ -158,10 +153,11 @@ class Insert(StandardInsert):
|
||||
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
|
||||
|
||||
:param where:
|
||||
Optional argument. An expression object representing a ``WHERE``
|
||||
clause that restricts the rows affected by ``DO UPDATE SET``. Rows not
|
||||
meeting the ``WHERE`` condition will not be updated (effectively a
|
||||
``DO NOTHING`` for those rows).
|
||||
Optional argument. If present, can be a literal SQL
|
||||
string or an acceptable expression for a ``WHERE`` clause
|
||||
that restricts the rows affected by ``DO UPDATE SET``. Rows
|
||||
not meeting the ``WHERE`` condition will not be updated
|
||||
(effectively a ``DO NOTHING`` for those rows).
|
||||
|
||||
|
||||
.. seealso::
|
||||
@@ -216,10 +212,8 @@ class OnConflictClause(ClauseElement):
|
||||
stringify_dialect = "postgresql"
|
||||
|
||||
constraint_target: Optional[str]
|
||||
inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]]
|
||||
inferred_target_whereclause: Optional[
|
||||
Union[ColumnElement[Any], TextClause]
|
||||
]
|
||||
inferred_target_elements: _OnConflictIndexElementsT
|
||||
inferred_target_whereclause: _OnConflictIndexWhereT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -260,28 +254,12 @@ class OnConflictClause(ClauseElement):
|
||||
|
||||
if index_elements is not None:
|
||||
self.constraint_target = None
|
||||
self.inferred_target_elements = [
|
||||
coercions.expect(roles.DDLConstraintColumnRole, column)
|
||||
for column in index_elements
|
||||
]
|
||||
|
||||
self.inferred_target_whereclause = (
|
||||
coercions.expect(
|
||||
(
|
||||
roles.StatementOptionRole
|
||||
if isinstance(constraint, ext.ExcludeConstraint)
|
||||
else roles.WhereHavingRole
|
||||
),
|
||||
index_where,
|
||||
)
|
||||
if index_where is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.inferred_target_elements = index_elements
|
||||
self.inferred_target_whereclause = index_where
|
||||
elif constraint is None:
|
||||
self.constraint_target = self.inferred_target_elements = (
|
||||
self.inferred_target_whereclause
|
||||
) = None
|
||||
self.constraint_target = (
|
||||
self.inferred_target_elements
|
||||
) = self.inferred_target_whereclause = None
|
||||
|
||||
|
||||
class OnConflictDoNothing(OnConflictClause):
|
||||
@@ -291,9 +269,6 @@ class OnConflictDoNothing(OnConflictClause):
|
||||
class OnConflictDoUpdate(OnConflictClause):
|
||||
__visit_name__ = "on_conflict_do_update"
|
||||
|
||||
update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]]
|
||||
update_whereclause: Optional[ColumnElement[Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constraint: _OnConflictConstraintT = None,
|
||||
@@ -332,8 +307,4 @@ class OnConflictDoUpdate(OnConflictClause):
|
||||
(coercions.expect(roles.DMLColumnRole, key), value)
|
||||
for key, value in set_.items()
|
||||
]
|
||||
self.update_whereclause = (
|
||||
coercions.expect(roles.WhereHavingRole, where)
|
||||
if where is not None
|
||||
else None
|
||||
)
|
||||
self.update_whereclause = where
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/ext.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/ext.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -8,10 +8,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -27,44 +23,34 @@ from ...sql.schema import ColumnCollectionConstraint
|
||||
from ...sql.sqltypes import TEXT
|
||||
from ...sql.visitors import InternalTraversal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...sql._typing import _ColumnExpressionArgument
|
||||
from ...sql.elements import ClauseElement
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.operators import OperatorType
|
||||
from ...sql.selectable import FromClause
|
||||
from ...sql.visitors import _CloneCallableType
|
||||
from ...sql.visitors import _TraverseInternalsType
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...sql.visitors import _TraverseInternalsType
|
||||
|
||||
class aggregate_order_by(expression.ColumnElement[_T]):
|
||||
|
||||
class aggregate_order_by(expression.ColumnElement):
|
||||
"""Represent a PostgreSQL aggregate order by expression.
|
||||
|
||||
E.g.::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import aggregate_order_by
|
||||
|
||||
expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
|
||||
stmt = select(expr)
|
||||
|
||||
would represent the expression:
|
||||
|
||||
.. sourcecode:: sql
|
||||
would represent the expression::
|
||||
|
||||
SELECT array_agg(a ORDER BY b DESC) FROM table;
|
||||
|
||||
Similarly::
|
||||
|
||||
expr = func.string_agg(
|
||||
table.c.a, aggregate_order_by(literal_column("','"), table.c.a)
|
||||
table.c.a,
|
||||
aggregate_order_by(literal_column("','"), table.c.a)
|
||||
)
|
||||
stmt = select(expr)
|
||||
|
||||
Would represent:
|
||||
|
||||
.. sourcecode:: sql
|
||||
Would represent::
|
||||
|
||||
SELECT string_agg(a, ',' ORDER BY a) FROM table;
|
||||
|
||||
@@ -85,32 +71,11 @@ class aggregate_order_by(expression.ColumnElement[_T]):
|
||||
("order_by", InternalTraversal.dp_clauseelement),
|
||||
]
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: ColumnElement[_T],
|
||||
*order_by: _ColumnExpressionArgument[Any],
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: _ColumnExpressionArgument[_T],
|
||||
*order_by: _ColumnExpressionArgument[Any],
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: _ColumnExpressionArgument[_T],
|
||||
*order_by: _ColumnExpressionArgument[Any],
|
||||
):
|
||||
self.target: ClauseElement = coercions.expect(
|
||||
roles.ExpressionElementRole, target
|
||||
)
|
||||
def __init__(self, target, *order_by):
|
||||
self.target = coercions.expect(roles.ExpressionElementRole, target)
|
||||
self.type = self.target.type
|
||||
|
||||
_lob = len(order_by)
|
||||
self.order_by: ClauseElement
|
||||
if _lob == 0:
|
||||
raise TypeError("at least one ORDER BY element is required")
|
||||
elif _lob == 1:
|
||||
@@ -122,22 +87,18 @@ class aggregate_order_by(expression.ColumnElement[_T]):
|
||||
*order_by, _literal_as_text_role=roles.ExpressionElementRole
|
||||
)
|
||||
|
||||
def self_group(
|
||||
self, against: Optional[OperatorType] = None
|
||||
) -> ClauseElement:
|
||||
def self_group(self, against=None):
|
||||
return self
|
||||
|
||||
def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]:
|
||||
def get_children(self, **kwargs):
|
||||
return self.target, self.order_by
|
||||
|
||||
def _copy_internals(
|
||||
self, clone: _CloneCallableType = elements._clone, **kw: Any
|
||||
) -> None:
|
||||
def _copy_internals(self, clone=elements._clone, **kw):
|
||||
self.target = clone(self.target, **kw)
|
||||
self.order_by = clone(self.order_by, **kw)
|
||||
|
||||
@property
|
||||
def _from_objects(self) -> List[FromClause]:
|
||||
def _from_objects(self):
|
||||
return self.target._from_objects + self.order_by._from_objects
|
||||
|
||||
|
||||
@@ -170,10 +131,10 @@ class ExcludeConstraint(ColumnCollectionConstraint):
|
||||
E.g.::
|
||||
|
||||
const = ExcludeConstraint(
|
||||
(Column("period"), "&&"),
|
||||
(Column("group"), "="),
|
||||
where=(Column("group") != "some group"),
|
||||
ops={"group": "my_operator_class"},
|
||||
(Column('period'), '&&'),
|
||||
(Column('group'), '='),
|
||||
where=(Column('group') != 'some group'),
|
||||
ops={'group': 'my_operator_class'}
|
||||
)
|
||||
|
||||
The constraint is normally embedded into the :class:`_schema.Table`
|
||||
@@ -181,20 +142,19 @@ class ExcludeConstraint(ColumnCollectionConstraint):
|
||||
directly, or added later using :meth:`.append_constraint`::
|
||||
|
||||
some_table = Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("period", TSRANGE()),
|
||||
Column("group", String),
|
||||
'some_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('period', TSRANGE()),
|
||||
Column('group', String)
|
||||
)
|
||||
|
||||
some_table.append_constraint(
|
||||
ExcludeConstraint(
|
||||
(some_table.c.period, "&&"),
|
||||
(some_table.c.group, "="),
|
||||
where=some_table.c.group != "some group",
|
||||
name="some_table_excl_const",
|
||||
ops={"group": "my_operator_class"},
|
||||
(some_table.c.period, '&&'),
|
||||
(some_table.c.group, '='),
|
||||
where=some_table.c.group != 'some group',
|
||||
name='some_table_excl_const',
|
||||
ops={'group': 'my_operator_class'}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/hstore.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/hstore.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -28,29 +28,28 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
|
||||
The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
|
||||
|
||||
data_table = Table(
|
||||
"data_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", HSTORE),
|
||||
data_table = Table('data_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', HSTORE)
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
data_table.insert(), data={"key1": "value1", "key2": "value2"}
|
||||
data_table.insert(),
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
)
|
||||
|
||||
:class:`.HSTORE` provides for a wide range of operations, including:
|
||||
|
||||
* Index operations::
|
||||
|
||||
data_table.c.data["some key"] == "some value"
|
||||
data_table.c.data['some key'] == 'some value'
|
||||
|
||||
* Containment operations::
|
||||
|
||||
data_table.c.data.has_key("some key")
|
||||
data_table.c.data.has_key('some key')
|
||||
|
||||
data_table.c.data.has_all(["one", "two", "three"])
|
||||
data_table.c.data.has_all(['one', 'two', 'three'])
|
||||
|
||||
* Concatenation::
|
||||
|
||||
@@ -73,19 +72,17 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = "data_table"
|
||||
__tablename__ = 'data_table'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(MutableDict.as_mutable(HSTORE))
|
||||
|
||||
|
||||
my_object = session.query(MyClass).one()
|
||||
|
||||
# in-place mutation, requires Mutable extension
|
||||
# in order for the ORM to detect
|
||||
my_object.data["some_key"] = "some value"
|
||||
my_object.data['some_key'] = 'some value'
|
||||
|
||||
session.commit()
|
||||
|
||||
@@ -99,7 +96,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
:class:`.hstore` - render the PostgreSQL ``hstore()`` function.
|
||||
|
||||
|
||||
""" # noqa: E501
|
||||
"""
|
||||
|
||||
__visit_name__ = "HSTORE"
|
||||
hashable = False
|
||||
@@ -195,9 +192,6 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
comparator_factory = Comparator
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
# note that dialect-specific types like that of psycopg and
|
||||
# psycopg2 will override this method to allow driver-level conversion
|
||||
# instead, see _PsycopgHStore
|
||||
def process(value):
|
||||
if isinstance(value, dict):
|
||||
return _serialize_hstore(value)
|
||||
@@ -207,9 +201,6 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
# note that dialect-specific types like that of psycopg and
|
||||
# psycopg2 will override this method to allow driver-level conversion
|
||||
# instead, see _PsycopgHStore
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return _parse_hstore(value)
|
||||
@@ -230,12 +221,12 @@ class hstore(sqlfunc.GenericFunction):
|
||||
|
||||
from sqlalchemy.dialects.postgresql import array, hstore
|
||||
|
||||
select(hstore("key1", "value1"))
|
||||
select(hstore('key1', 'value1'))
|
||||
|
||||
select(
|
||||
hstore(
|
||||
array(["key1", "key2", "key3"]),
|
||||
array(["value1", "value2", "value3"]),
|
||||
array(['key1', 'key2', 'key3']),
|
||||
array(['value1', 'value2', 'value3'])
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,18 +1,11 @@
|
||||
# dialects/postgresql/json.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/json.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .array import ARRAY
|
||||
from .array import array as _pg_array
|
||||
@@ -28,23 +21,13 @@ from .operators import PATH_EXISTS
|
||||
from .operators import PATH_MATCH
|
||||
from ... import types as sqltypes
|
||||
from ...sql import cast
|
||||
from ...sql._typing import _T
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.type_api import _BindProcessorType
|
||||
from ...sql.type_api import _LiteralProcessorType
|
||||
from ...sql.type_api import TypeEngine
|
||||
|
||||
__all__ = ("JSON", "JSONB")
|
||||
|
||||
|
||||
class JSONPathType(sqltypes.JSON.JSONPathType):
|
||||
def _processor(
|
||||
self, dialect: Dialect, super_proc: Optional[Callable[[Any], Any]]
|
||||
) -> Callable[[Any], Any]:
|
||||
def process(value: Any) -> Any:
|
||||
def _processor(self, dialect, super_proc):
|
||||
def process(value):
|
||||
if isinstance(value, str):
|
||||
# If it's already a string assume that it's in json path
|
||||
# format. This allows using cast with json paths literals
|
||||
@@ -61,13 +44,11 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
|
||||
|
||||
return process
|
||||
|
||||
def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
|
||||
return self._processor(dialect, self.string_bind_processor(dialect)) # type: ignore[return-value] # noqa: E501
|
||||
def bind_processor(self, dialect):
|
||||
return self._processor(dialect, self.string_bind_processor(dialect))
|
||||
|
||||
def literal_processor(
|
||||
self, dialect: Dialect
|
||||
) -> _LiteralProcessorType[Any]:
|
||||
return self._processor(dialect, self.string_literal_processor(dialect)) # type: ignore[return-value] # noqa: E501
|
||||
def literal_processor(self, dialect):
|
||||
return self._processor(dialect, self.string_literal_processor(dialect))
|
||||
|
||||
|
||||
class JSONPATH(JSONPathType):
|
||||
@@ -109,14 +90,14 @@ class JSON(sqltypes.JSON):
|
||||
|
||||
* Index operations (the ``->`` operator)::
|
||||
|
||||
data_table.c.data["some key"]
|
||||
data_table.c.data['some key']
|
||||
|
||||
data_table.c.data[5]
|
||||
|
||||
* Index operations returning text
|
||||
(the ``->>`` operator)::
|
||||
|
||||
data_table.c.data["some key"].astext == "some value"
|
||||
* Index operations returning text (the ``->>`` operator)::
|
||||
|
||||
data_table.c.data['some key'].astext == 'some value'
|
||||
|
||||
Note that equivalent functionality is available via the
|
||||
:attr:`.JSON.Comparator.as_string` accessor.
|
||||
@@ -124,20 +105,18 @@ class JSON(sqltypes.JSON):
|
||||
* Index operations with CAST
|
||||
(equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
|
||||
|
||||
data_table.c.data["some key"].astext.cast(Integer) == 5
|
||||
data_table.c.data['some key'].astext.cast(Integer) == 5
|
||||
|
||||
Note that equivalent functionality is available via the
|
||||
:attr:`.JSON.Comparator.as_integer` and similar accessors.
|
||||
|
||||
* Path index operations (the ``#>`` operator)::
|
||||
|
||||
data_table.c.data[("key_1", "key_2", 5, ..., "key_n")]
|
||||
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
|
||||
|
||||
* Path index operations returning text (the ``#>>`` operator)::
|
||||
|
||||
data_table.c.data[
|
||||
("key_1", "key_2", 5, ..., "key_n")
|
||||
].astext == "some value"
|
||||
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value'
|
||||
|
||||
Index operations return an expression object whose type defaults to
|
||||
:class:`_types.JSON` by default,
|
||||
@@ -149,11 +128,10 @@ class JSON(sqltypes.JSON):
|
||||
using psycopg2, the DBAPI only allows serializers at the per-cursor
|
||||
or per-connection level. E.g.::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
json_serializer=my_serialize_fn,
|
||||
json_deserializer=my_deserialize_fn,
|
||||
)
|
||||
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
json_serializer=my_serialize_fn,
|
||||
json_deserializer=my_deserialize_fn
|
||||
)
|
||||
|
||||
When using the psycopg2 dialect, the json_deserializer is registered
|
||||
against the database using ``psycopg2.extras.register_default_json``.
|
||||
@@ -166,14 +144,9 @@ class JSON(sqltypes.JSON):
|
||||
|
||||
""" # noqa
|
||||
|
||||
render_bind_cast = True
|
||||
astext_type: TypeEngine[str] = sqltypes.Text()
|
||||
astext_type = sqltypes.Text()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
none_as_null: bool = False,
|
||||
astext_type: Optional[TypeEngine[str]] = None,
|
||||
):
|
||||
def __init__(self, none_as_null=False, astext_type=None):
|
||||
"""Construct a :class:`_types.JSON` type.
|
||||
|
||||
:param none_as_null: if True, persist the value ``None`` as a
|
||||
@@ -182,8 +155,7 @@ class JSON(sqltypes.JSON):
|
||||
be used to persist a NULL value::
|
||||
|
||||
from sqlalchemy import null
|
||||
|
||||
conn.execute(table.insert(), {"data": null()})
|
||||
conn.execute(table.insert(), data=null())
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -198,19 +170,17 @@ class JSON(sqltypes.JSON):
|
||||
if astext_type is not None:
|
||||
self.astext_type = astext_type
|
||||
|
||||
class Comparator(sqltypes.JSON.Comparator[_T]):
|
||||
class Comparator(sqltypes.JSON.Comparator):
|
||||
"""Define comparison operations for :class:`_types.JSON`."""
|
||||
|
||||
type: JSON
|
||||
|
||||
@property
|
||||
def astext(self) -> ColumnElement[str]:
|
||||
def astext(self):
|
||||
"""On an indexed expression, use the "astext" (e.g. "->>")
|
||||
conversion when rendered in SQL.
|
||||
|
||||
E.g.::
|
||||
|
||||
select(data_table.c.data["some key"].astext)
|
||||
select(data_table.c.data['some key'].astext)
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -218,13 +188,13 @@ class JSON(sqltypes.JSON):
|
||||
|
||||
"""
|
||||
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
|
||||
return self.expr.left.operate( # type: ignore[no-any-return]
|
||||
return self.expr.left.operate(
|
||||
JSONPATH_ASTEXT,
|
||||
self.expr.right,
|
||||
result_type=self.type.astext_type,
|
||||
)
|
||||
else:
|
||||
return self.expr.left.operate( # type: ignore[no-any-return]
|
||||
return self.expr.left.operate(
|
||||
ASTEXT, self.expr.right, result_type=self.type.astext_type
|
||||
)
|
||||
|
||||
@@ -237,16 +207,15 @@ class JSONB(JSON):
|
||||
The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
|
||||
e.g.::
|
||||
|
||||
data_table = Table(
|
||||
"data_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", JSONB),
|
||||
data_table = Table('data_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', JSONB)
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
data_table.insert(), data={"key1": "value1", "key2": "value2"}
|
||||
data_table.insert(),
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
)
|
||||
|
||||
The :class:`_postgresql.JSONB` type includes all operations provided by
|
||||
@@ -283,53 +252,43 @@ class JSONB(JSON):
|
||||
|
||||
__visit_name__ = "JSONB"
|
||||
|
||||
class Comparator(JSON.Comparator[_T]):
|
||||
class Comparator(JSON.Comparator):
|
||||
"""Define comparison operations for :class:`_types.JSON`."""
|
||||
|
||||
type: JSONB
|
||||
|
||||
def has_key(self, other: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test for presence of a key (equivalent of
|
||||
the ``?`` operator). Note that the key may be a SQLA expression.
|
||||
def has_key(self, other):
|
||||
"""Boolean expression. Test for presence of a key. Note that the
|
||||
key may be a SQLA expression.
|
||||
"""
|
||||
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_all(self, other: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test for presence of all keys in jsonb
|
||||
(equivalent of the ``?&`` operator)
|
||||
"""
|
||||
def has_all(self, other):
|
||||
"""Boolean expression. Test for presence of all keys in jsonb"""
|
||||
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_any(self, other: Any) -> ColumnElement[bool]:
|
||||
"""Boolean expression. Test for presence of any key in jsonb
|
||||
(equivalent of the ``?|`` operator)
|
||||
"""
|
||||
def has_any(self, other):
|
||||
"""Boolean expression. Test for presence of any key in jsonb"""
|
||||
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contains(self, other: Any, **kwargs: Any) -> ColumnElement[bool]:
|
||||
def contains(self, other, **kwargs):
|
||||
"""Boolean expression. Test if keys (or array) are a superset
|
||||
of/contained the keys of the argument jsonb expression
|
||||
(equivalent of the ``@>`` operator).
|
||||
of/contained the keys of the argument jsonb expression.
|
||||
|
||||
kwargs may be ignored by this operator but are required for API
|
||||
conformance.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other: Any) -> ColumnElement[bool]:
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Test if keys are a proper subset of the
|
||||
keys of the argument jsonb expression
|
||||
(equivalent of the ``<@`` operator).
|
||||
keys of the argument jsonb expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def delete_path(
|
||||
self, array: Union[List[str], _pg_array[str]]
|
||||
) -> ColumnElement[JSONB]:
|
||||
def delete_path(self, array):
|
||||
"""JSONB expression. Deletes field or array element specified in
|
||||
the argument array (equivalent of the ``#-`` operator).
|
||||
the argument array.
|
||||
|
||||
The input may be a list of strings that will be coerced to an
|
||||
``ARRAY`` or an instance of :meth:`_postgres.array`.
|
||||
@@ -341,9 +300,9 @@ class JSONB(JSON):
|
||||
right_side = cast(array, ARRAY(sqltypes.TEXT))
|
||||
return self.operate(DELETE_PATH, right_side, result_type=JSONB)
|
||||
|
||||
def path_exists(self, other: Any) -> ColumnElement[bool]:
|
||||
def path_exists(self, other):
|
||||
"""Boolean expression. Test for presence of item given by the
|
||||
argument JSONPath expression (equivalent of the ``@?`` operator).
|
||||
argument JSONPath expression.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
@@ -351,10 +310,9 @@ class JSONB(JSON):
|
||||
PATH_EXISTS, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def path_match(self, other: Any) -> ColumnElement[bool]:
|
||||
def path_match(self, other):
|
||||
"""Boolean expression. Test if JSONPath predicate given by the
|
||||
argument JSONPath expression matches
|
||||
(equivalent of the ``@@`` operator).
|
||||
argument JSONPath expression matches.
|
||||
|
||||
Only the first item of the result is taken into account.
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/named_types.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/named_types.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -7,9 +7,7 @@
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -27,11 +25,10 @@ from ...sql.ddl import InvokeCreateDDLBase
|
||||
from ...sql.ddl import InvokeDropDDLBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...sql._typing import _CreateDropBind
|
||||
from ...sql._typing import _TypeEngineArgument
|
||||
|
||||
|
||||
class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
|
||||
class NamedType(sqltypes.TypeEngine):
|
||||
"""Base for named types."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -39,9 +36,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
|
||||
DDLDropper: Type[NamedTypeDropper]
|
||||
create_type: bool
|
||||
|
||||
def create(
|
||||
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
|
||||
) -> None:
|
||||
def create(self, bind, checkfirst=True, **kw):
|
||||
"""Emit ``CREATE`` DDL for this type.
|
||||
|
||||
:param bind: a connectable :class:`_engine.Engine`,
|
||||
@@ -55,9 +50,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
|
||||
"""
|
||||
bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
|
||||
|
||||
def drop(
|
||||
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
|
||||
) -> None:
|
||||
def drop(self, bind, checkfirst=True, **kw):
|
||||
"""Emit ``DROP`` DDL for this type.
|
||||
|
||||
:param bind: a connectable :class:`_engine.Engine`,
|
||||
@@ -70,9 +63,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
|
||||
"""
|
||||
bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
|
||||
|
||||
def _check_for_name_in_memos(
|
||||
self, checkfirst: bool, kw: Dict[str, Any]
|
||||
) -> bool:
|
||||
def _check_for_name_in_memos(self, checkfirst, kw):
|
||||
"""Look in the 'ddl runner' for 'memos', then
|
||||
note our name in that collection.
|
||||
|
||||
@@ -96,13 +87,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
|
||||
else:
|
||||
return False
|
||||
|
||||
def _on_table_create(
|
||||
self,
|
||||
target: Any,
|
||||
bind: _CreateDropBind,
|
||||
checkfirst: bool = False,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
def _on_table_create(self, target, bind, checkfirst=False, **kw):
|
||||
if (
|
||||
checkfirst
|
||||
or (
|
||||
@@ -112,13 +97,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
|
||||
) and not self._check_for_name_in_memos(checkfirst, kw):
|
||||
self.create(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
def _on_table_drop(
|
||||
self,
|
||||
target: Any,
|
||||
bind: _CreateDropBind,
|
||||
checkfirst: bool = False,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
|
||||
if (
|
||||
not self.metadata
|
||||
and not kw.get("_is_metadata_operation", False)
|
||||
@@ -126,23 +105,11 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
|
||||
):
|
||||
self.drop(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
def _on_metadata_create(
|
||||
self,
|
||||
target: Any,
|
||||
bind: _CreateDropBind,
|
||||
checkfirst: bool = False,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
|
||||
if not self._check_for_name_in_memos(checkfirst, kw):
|
||||
self.create(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
def _on_metadata_drop(
|
||||
self,
|
||||
target: Any,
|
||||
bind: _CreateDropBind,
|
||||
checkfirst: bool = False,
|
||||
**kw: Any,
|
||||
) -> None:
|
||||
def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
|
||||
if not self._check_for_name_in_memos(checkfirst, kw):
|
||||
self.drop(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
@@ -196,6 +163,7 @@ class EnumDropper(NamedTypeDropper):
|
||||
|
||||
|
||||
class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
|
||||
"""PostgreSQL ENUM type.
|
||||
|
||||
This is a subclass of :class:`_types.Enum` which includes
|
||||
@@ -218,10 +186,8 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
:meth:`_schema.Table.drop`
|
||||
methods are called::
|
||||
|
||||
table = Table(
|
||||
"sometable",
|
||||
metadata,
|
||||
Column("some_enum", ENUM("a", "b", "c", name="myenum")),
|
||||
table = Table('sometable', metadata,
|
||||
Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
|
||||
)
|
||||
|
||||
table.create(engine) # will emit CREATE ENUM and CREATE TABLE
|
||||
@@ -232,17 +198,21 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
:class:`_postgresql.ENUM` independently, and associate it with the
|
||||
:class:`_schema.MetaData` object itself::
|
||||
|
||||
my_enum = ENUM("a", "b", "c", name="myenum", metadata=metadata)
|
||||
my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
|
||||
|
||||
t1 = Table("sometable_one", metadata, Column("some_enum", myenum))
|
||||
t1 = Table('sometable_one', metadata,
|
||||
Column('some_enum', myenum)
|
||||
)
|
||||
|
||||
t2 = Table("sometable_two", metadata, Column("some_enum", myenum))
|
||||
t2 = Table('sometable_two', metadata,
|
||||
Column('some_enum', myenum)
|
||||
)
|
||||
|
||||
When this pattern is used, care must still be taken at the level
|
||||
of individual table creates. Emitting CREATE TABLE without also
|
||||
specifying ``checkfirst=True`` will still cause issues::
|
||||
|
||||
t1.create(engine) # will fail: no such type 'myenum'
|
||||
t1.create(engine) # will fail: no such type 'myenum'
|
||||
|
||||
If we specify ``checkfirst=True``, the individual table-level create
|
||||
operation will check for the ``ENUM`` and create if not exists::
|
||||
@@ -347,7 +317,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
|
||||
return cls(**kw)
|
||||
|
||||
def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
|
||||
def create(self, bind=None, checkfirst=True):
|
||||
"""Emit ``CREATE TYPE`` for this
|
||||
:class:`_postgresql.ENUM`.
|
||||
|
||||
@@ -368,7 +338,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
|
||||
super().create(bind, checkfirst=checkfirst)
|
||||
|
||||
def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
|
||||
def drop(self, bind=None, checkfirst=True):
|
||||
"""Emit ``DROP TYPE`` for this
|
||||
:class:`_postgresql.ENUM`.
|
||||
|
||||
@@ -388,7 +358,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
|
||||
super().drop(bind, checkfirst=checkfirst)
|
||||
|
||||
def get_dbapi_type(self, dbapi: ModuleType) -> None:
|
||||
def get_dbapi_type(self, dbapi):
|
||||
"""dont return dbapi.STRING for ENUM in PostgreSQL, since that's
|
||||
a different type"""
|
||||
|
||||
@@ -418,12 +388,14 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
|
||||
A domain is essentially a data type with optional constraints
|
||||
that restrict the allowed set of values. E.g.::
|
||||
|
||||
PositiveInt = DOMAIN("pos_int", Integer, check="VALUE > 0", not_null=True)
|
||||
PositiveInt = DOMAIN(
|
||||
"pos_int", Integer, check="VALUE > 0", not_null=True
|
||||
)
|
||||
|
||||
UsPostalCode = DOMAIN(
|
||||
"us_postal_code",
|
||||
Text,
|
||||
check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'",
|
||||
check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'"
|
||||
)
|
||||
|
||||
See the `PostgreSQL documentation`__ for additional details
|
||||
@@ -432,7 +404,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
""" # noqa: E501
|
||||
"""
|
||||
|
||||
DDLGenerator = DomainGenerator
|
||||
DDLDropper = DomainDropper
|
||||
@@ -445,10 +417,10 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
|
||||
data_type: _TypeEngineArgument[Any],
|
||||
*,
|
||||
collation: Optional[str] = None,
|
||||
default: Union[elements.TextClause, str, None] = None,
|
||||
default: Optional[Union[str, elements.TextClause]] = None,
|
||||
constraint_name: Optional[str] = None,
|
||||
not_null: Optional[bool] = None,
|
||||
check: Union[elements.TextClause, str, None] = None,
|
||||
check: Optional[str] = None,
|
||||
create_type: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
@@ -492,7 +464,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
|
||||
self.default = default
|
||||
self.collation = collation
|
||||
self.constraint_name = constraint_name
|
||||
self.not_null = bool(not_null)
|
||||
self.not_null = not_null
|
||||
if check is not None:
|
||||
check = coercions.expect(roles.DDLExpressionRole, check)
|
||||
self.check = check
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/operators.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/operators.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/pg8000.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# postgresql/pg8000.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -27,21 +27,19 @@ PostgreSQL ``client_encoding`` parameter; by default this is the value in
|
||||
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
|
||||
Typically, this can be changed to ``utf-8``, as a more useful default::
|
||||
|
||||
# client_encoding = sql_ascii # actually, defaults to database encoding
|
||||
#client_encoding = sql_ascii # actually, defaults to database
|
||||
# encoding
|
||||
client_encoding = utf8
|
||||
|
||||
The ``client_encoding`` can be overridden for a session by executing the SQL:
|
||||
|
||||
.. sourcecode:: sql
|
||||
|
||||
SET CLIENT_ENCODING TO 'utf8';
|
||||
SET CLIENT_ENCODING TO 'utf8';
|
||||
|
||||
SQLAlchemy will execute this SQL on all new connections based on the value
|
||||
passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+pg8000://user:pass@host/dbname", client_encoding="utf8"
|
||||
)
|
||||
"postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
|
||||
|
||||
.. _pg8000_ssl:
|
||||
|
||||
@@ -52,7 +50,6 @@ pg8000 accepts a Python ``SSLContext`` object which may be specified using the
|
||||
:paramref:`_sa.create_engine.connect_args` dictionary::
|
||||
|
||||
import ssl
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
engine = sa.create_engine(
|
||||
"postgresql+pg8000://scott:tiger@192.168.0.199/test",
|
||||
@@ -64,7 +61,6 @@ or does not match the host name (as seen from the client), it may also be
|
||||
necessary to disable hostname checking::
|
||||
|
||||
import ssl
|
||||
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
@@ -257,7 +253,7 @@ class _PGOIDVECTOR(_SpaceVector, OIDVECTOR):
|
||||
pass
|
||||
|
||||
|
||||
class _Pg8000Range(ranges.AbstractSingleRangeImpl):
|
||||
class _Pg8000Range(ranges.AbstractRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
pg8000_Range = dialect.dbapi.Range
|
||||
|
||||
@@ -308,13 +304,15 @@ class _Pg8000MultiRange(ranges.AbstractMultiRangeImpl):
|
||||
def to_multirange(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return ranges.MultiRange(
|
||||
|
||||
mr = []
|
||||
for v in value:
|
||||
mr.append(
|
||||
ranges.Range(
|
||||
v.lower, v.upper, bounds=v.bounds, empty=v.is_empty
|
||||
)
|
||||
for v in value
|
||||
)
|
||||
return mr
|
||||
|
||||
return to_multirange
|
||||
|
||||
@@ -540,9 +538,6 @@ class PGDialect_pg8000(PGDialect):
|
||||
cursor.execute("COMMIT")
|
||||
cursor.close()
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def set_readonly(self, connection, value):
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
@@ -589,8 +584,8 @@ class PGDialect_pg8000(PGDialect):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute(
|
||||
f"""SET CLIENT_ENCODING TO '{
|
||||
client_encoding.replace("'", "''")
|
||||
}'"""
|
||||
client_encoding.replace("'", "''")
|
||||
}'"""
|
||||
)
|
||||
cursor.execute("COMMIT")
|
||||
cursor.close()
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
# dialects/postgresql/pg_catalog.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/pg_catalog.py
|
||||
# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .array import ARRAY
|
||||
from .types import OID
|
||||
@@ -29,37 +23,31 @@ from ...types import String
|
||||
from ...types import Text
|
||||
from ...types import TypeDecorator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.type_api import _ResultProcessorType
|
||||
|
||||
|
||||
# types
|
||||
class NAME(TypeDecorator[str]):
|
||||
class NAME(TypeDecorator):
|
||||
impl = String(64, collation="C")
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class PG_NODE_TREE(TypeDecorator[str]):
|
||||
class PG_NODE_TREE(TypeDecorator):
|
||||
impl = Text(collation="C")
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class INT2VECTOR(TypeDecorator[Sequence[int]]):
|
||||
class INT2VECTOR(TypeDecorator):
|
||||
impl = ARRAY(SmallInteger)
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class OIDVECTOR(TypeDecorator[Sequence[int]]):
|
||||
class OIDVECTOR(TypeDecorator):
|
||||
impl = ARRAY(OID)
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class _SpaceVector:
|
||||
def result_processor(
|
||||
self, dialect: Dialect, coltype: object
|
||||
) -> _ResultProcessorType[list[int]]:
|
||||
def process(value: Any) -> Optional[list[int]]:
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
return [int(p) for p in value.split(" ")]
|
||||
@@ -89,7 +77,7 @@ RELKINDS_MAT_VIEW = ("m",)
|
||||
RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
|
||||
|
||||
# tables
|
||||
pg_catalog_meta = MetaData(schema="pg_catalog")
|
||||
pg_catalog_meta = MetaData()
|
||||
|
||||
pg_namespace = Table(
|
||||
"pg_namespace",
|
||||
@@ -97,6 +85,7 @@ pg_namespace = Table(
|
||||
Column("oid", OID),
|
||||
Column("nspname", NAME),
|
||||
Column("nspowner", OID),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_class = Table(
|
||||
@@ -131,6 +120,7 @@ pg_class = Table(
|
||||
Column("relispartition", Boolean, info={"server_version": (10,)}),
|
||||
Column("relrewrite", OID, info={"server_version": (11,)}),
|
||||
Column("reloptions", ARRAY(Text)),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_type = Table(
|
||||
@@ -165,6 +155,7 @@ pg_type = Table(
|
||||
Column("typndims", Integer),
|
||||
Column("typcollation", OID, info={"server_version": (9, 1)}),
|
||||
Column("typdefault", Text),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_index = Table(
|
||||
@@ -191,6 +182,7 @@ pg_index = Table(
|
||||
Column("indoption", INT2VECTOR),
|
||||
Column("indexprs", PG_NODE_TREE),
|
||||
Column("indpred", PG_NODE_TREE),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_attribute = Table(
|
||||
@@ -217,6 +209,7 @@ pg_attribute = Table(
|
||||
Column("attislocal", Boolean),
|
||||
Column("attinhcount", Integer),
|
||||
Column("attcollation", OID, info={"server_version": (9, 1)}),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_constraint = Table(
|
||||
@@ -242,6 +235,7 @@ pg_constraint = Table(
|
||||
Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
|
||||
Column("conkey", ARRAY(SmallInteger)),
|
||||
Column("confkey", ARRAY(SmallInteger)),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_sequence = Table(
|
||||
@@ -255,6 +249,7 @@ pg_sequence = Table(
|
||||
Column("seqmin", BigInteger),
|
||||
Column("seqcache", BigInteger),
|
||||
Column("seqcycle", Boolean),
|
||||
schema="pg_catalog",
|
||||
info={"server_version": (10,)},
|
||||
)
|
||||
|
||||
@@ -265,6 +260,7 @@ pg_attrdef = Table(
|
||||
Column("adrelid", OID),
|
||||
Column("adnum", SmallInteger),
|
||||
Column("adbin", PG_NODE_TREE),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_description = Table(
|
||||
@@ -274,6 +270,7 @@ pg_description = Table(
|
||||
Column("classoid", OID),
|
||||
Column("objsubid", Integer),
|
||||
Column("description", Text(collation="C")),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_enum = Table(
|
||||
@@ -283,6 +280,7 @@ pg_enum = Table(
|
||||
Column("enumtypid", OID),
|
||||
Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
|
||||
Column("enumlabel", NAME),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_am = Table(
|
||||
@@ -292,35 +290,5 @@ pg_am = Table(
|
||||
Column("amname", NAME),
|
||||
Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
|
||||
Column("amtype", CHAR, info={"server_version": (9, 6)}),
|
||||
)
|
||||
|
||||
pg_collation = Table(
|
||||
"pg_collation",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID, info={"server_version": (9, 3)}),
|
||||
Column("collname", NAME),
|
||||
Column("collnamespace", OID),
|
||||
Column("collowner", OID),
|
||||
Column("collprovider", CHAR, info={"server_version": (10,)}),
|
||||
Column("collisdeterministic", Boolean, info={"server_version": (12,)}),
|
||||
Column("collencoding", Integer),
|
||||
Column("collcollate", Text),
|
||||
Column("collctype", Text),
|
||||
Column("colliculocale", Text),
|
||||
Column("collicurules", Text, info={"server_version": (16,)}),
|
||||
Column("collversion", Text, info={"server_version": (10,)}),
|
||||
)
|
||||
|
||||
pg_opclass = Table(
|
||||
"pg_opclass",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID, info={"server_version": (9, 3)}),
|
||||
Column("opcmethod", NAME),
|
||||
Column("opcname", NAME),
|
||||
Column("opsnamespace", OID),
|
||||
Column("opsowner", OID),
|
||||
Column("opcfamily", OID),
|
||||
Column("opcintype", OID),
|
||||
Column("opcdefault", Boolean),
|
||||
Column("opckeytype", OID),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
# dialects/postgresql/provision.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import time
|
||||
@@ -97,7 +91,7 @@ def drop_all_schema_objects_pre_tables(cfg, eng):
|
||||
for xid in conn.exec_driver_sql(
|
||||
"select gid from pg_prepared_xacts"
|
||||
).scalars():
|
||||
conn.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
|
||||
conn.execute("ROLLBACK PREPARED '%s'" % xid)
|
||||
|
||||
|
||||
@drop_all_schema_objects_post_tables.for_db("postgresql")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/psycopg.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/psycopg2.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -29,29 +29,20 @@ selected depending on how the engine is created:
|
||||
automatically select the sync version, e.g.::
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
sync_engine = create_engine(
|
||||
"postgresql+psycopg://scott:tiger@localhost/test"
|
||||
)
|
||||
sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test")
|
||||
|
||||
* calling :func:`_asyncio.create_async_engine` with
|
||||
``postgresql+psycopg://...`` will automatically select the async version,
|
||||
e.g.::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
asyncio_engine = create_async_engine(
|
||||
"postgresql+psycopg://scott:tiger@localhost/test"
|
||||
)
|
||||
asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test")
|
||||
|
||||
The asyncio version of the dialect may also be specified explicitly using the
|
||||
``psycopg_async`` suffix, as::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
asyncio_engine = create_async_engine(
|
||||
"postgresql+psycopg_async://scott:tiger@localhost/test"
|
||||
)
|
||||
asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test")
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -59,42 +50,9 @@ The asyncio version of the dialect may also be specified explicitly using the
|
||||
dialect shares most of its behavior with the ``psycopg2`` dialect.
|
||||
Further documentation is available there.
|
||||
|
||||
Using a different Cursor class
|
||||
------------------------------
|
||||
|
||||
One of the differences between ``psycopg`` and the older ``psycopg2``
|
||||
is how bound parameters are handled: ``psycopg2`` would bind them
|
||||
client side, while ``psycopg`` by default will bind them server side.
|
||||
|
||||
It's possible to configure ``psycopg`` to do client side binding by
|
||||
specifying the ``cursor_factory`` to be ``ClientCursor`` when creating
|
||||
the engine::
|
||||
|
||||
from psycopg import ClientCursor
|
||||
|
||||
client_side_engine = create_engine(
|
||||
"postgresql+psycopg://...",
|
||||
connect_args={"cursor_factory": ClientCursor},
|
||||
)
|
||||
|
||||
Similarly when using an async engine the ``AsyncClientCursor`` can be
|
||||
specified::
|
||||
|
||||
from psycopg import AsyncClientCursor
|
||||
|
||||
client_side_engine = create_async_engine(
|
||||
"postgresql+psycopg://...",
|
||||
connect_args={"cursor_factory": AsyncClientCursor},
|
||||
)
|
||||
|
||||
.. seealso::
|
||||
|
||||
`Client-side-binding cursors <https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#client-side-binding-cursors>`_
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
import logging
|
||||
import re
|
||||
from typing import cast
|
||||
@@ -121,8 +79,6 @@ from ...util.concurrency import await_only
|
||||
if TYPE_CHECKING:
|
||||
from typing import Iterable
|
||||
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
|
||||
|
||||
|
||||
@@ -135,6 +91,8 @@ class _PGREGCONFIG(REGCONFIG):
|
||||
|
||||
|
||||
class _PGJSON(JSON):
|
||||
render_bind_cast = True
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return self._make_bind_processor(None, dialect._psycopg_Json)
|
||||
|
||||
@@ -143,6 +101,8 @@ class _PGJSON(JSON):
|
||||
|
||||
|
||||
class _PGJSONB(JSONB):
|
||||
render_bind_cast = True
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return self._make_bind_processor(None, dialect._psycopg_Jsonb)
|
||||
|
||||
@@ -202,7 +162,7 @@ class _PGBoolean(sqltypes.Boolean):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PsycopgRange(ranges.AbstractSingleRangeImpl):
|
||||
class _PsycopgRange(ranges.AbstractRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
|
||||
|
||||
@@ -258,10 +218,8 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def to_range(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return ranges.MultiRange(
|
||||
if value is not None:
|
||||
value = [
|
||||
ranges.Range(
|
||||
elem._lower,
|
||||
elem._upper,
|
||||
@@ -269,7 +227,9 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
empty=not elem._bounds,
|
||||
)
|
||||
for elem in value
|
||||
)
|
||||
]
|
||||
|
||||
return value
|
||||
|
||||
return to_range
|
||||
|
||||
@@ -326,7 +286,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
|
||||
sqltypes.Integer: _PGInteger,
|
||||
sqltypes.SmallInteger: _PGSmallInteger,
|
||||
sqltypes.BigInteger: _PGBigInteger,
|
||||
ranges.AbstractSingleRange: _PsycopgRange,
|
||||
ranges.AbstractRange: _PsycopgRange,
|
||||
ranges.AbstractMultiRange: _PsycopgMultiRange,
|
||||
},
|
||||
)
|
||||
@@ -406,12 +366,10 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
|
||||
|
||||
# register the adapter for connections made subsequent to
|
||||
# this one
|
||||
assert self._psycopg_adapters_map
|
||||
register_hstore(info, self._psycopg_adapters_map)
|
||||
|
||||
# register the adapter for this connection
|
||||
assert connection.connection
|
||||
register_hstore(info, connection.connection.driver_connection)
|
||||
register_hstore(info, connection.connection)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
@@ -572,7 +530,7 @@ class AsyncAdapt_psycopg_cursor:
|
||||
def __init__(self, cursor, await_) -> None:
|
||||
self._cursor = cursor
|
||||
self.await_ = await_
|
||||
self._rows = deque()
|
||||
self._rows = []
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._cursor, name)
|
||||
@@ -599,19 +557,24 @@ class AsyncAdapt_psycopg_cursor:
|
||||
# eq/ne
|
||||
if res and res.status == self._psycopg_ExecStatus.TUPLES_OK:
|
||||
rows = self.await_(self._cursor.fetchall())
|
||||
self._rows = deque(rows)
|
||||
if not isinstance(rows, list):
|
||||
self._rows = list(rows)
|
||||
else:
|
||||
self._rows = rows
|
||||
return result
|
||||
|
||||
def executemany(self, query, params_seq):
|
||||
return self.await_(self._cursor.executemany(query, params_seq))
|
||||
|
||||
def __iter__(self):
|
||||
# TODO: try to avoid pop(0) on a list
|
||||
while self._rows:
|
||||
yield self._rows.popleft()
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.popleft()
|
||||
# TODO: try to avoid pop(0) on a list
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -619,12 +582,13 @@ class AsyncAdapt_psycopg_cursor:
|
||||
if size is None:
|
||||
size = self._cursor.arraysize
|
||||
|
||||
rr = self._rows
|
||||
return [rr.popleft() for _ in range(min(size, len(rr)))]
|
||||
retval = self._rows[0:size]
|
||||
self._rows = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = list(self._rows)
|
||||
self._rows.clear()
|
||||
retval = self._rows
|
||||
self._rows = []
|
||||
return retval
|
||||
|
||||
|
||||
@@ -655,7 +619,6 @@ class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor):
|
||||
|
||||
|
||||
class AsyncAdapt_psycopg_connection(AdaptedConnection):
|
||||
_connection: AsyncConnection
|
||||
__slots__ = ()
|
||||
await_ = staticmethod(await_only)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/psycopg2.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# postgresql/psycopg2.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -88,6 +88,7 @@ connection URI::
|
||||
"postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require"
|
||||
)
|
||||
|
||||
|
||||
Unix Domain Connections
|
||||
------------------------
|
||||
|
||||
@@ -102,17 +103,13 @@ in ``/tmp``, or whatever socket directory was specified when PostgreSQL
|
||||
was built. This value can be overridden by passing a pathname to psycopg2,
|
||||
using ``host`` as an additional keyword argument::
|
||||
|
||||
create_engine(
|
||||
"postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql"
|
||||
)
|
||||
create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql")
|
||||
|
||||
.. warning:: The format accepted here allows for a hostname in the main URL
|
||||
in addition to the "host" query string argument. **When using this URL
|
||||
format, the initial host is silently ignored**. That is, this URL::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2"
|
||||
)
|
||||
engine = create_engine("postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2")
|
||||
|
||||
Above, the hostname ``myhost1`` is **silently ignored and discarded.** The
|
||||
host which is connected is the ``myhost2`` host.
|
||||
@@ -193,7 +190,7 @@ any or all elements of the connection string.
|
||||
For this form, the URL can be passed without any elements other than the
|
||||
initial scheme::
|
||||
|
||||
engine = create_engine("postgresql+psycopg2://")
|
||||
engine = create_engine('postgresql+psycopg2://')
|
||||
|
||||
In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()``
|
||||
function which in turn represents an empty DSN passed to libpq.
|
||||
@@ -245,7 +242,7 @@ Psycopg2 Fast Execution Helpers
|
||||
|
||||
Modern versions of psycopg2 include a feature known as
|
||||
`Fast Execution Helpers \
|
||||
<https://www.psycopg.org/docs/extras.html#fast-execution-helpers>`_, which
|
||||
<https://initd.org/psycopg/docs/extras.html#fast-execution-helpers>`_, which
|
||||
have been shown in benchmarking to improve psycopg2's executemany()
|
||||
performance, primarily with INSERT statements, by at least
|
||||
an order of magnitude.
|
||||
@@ -267,8 +264,8 @@ used feature. The use of this extension may be enabled using the
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@host/dbname",
|
||||
executemany_mode="values_plus_batch",
|
||||
)
|
||||
executemany_mode='values_plus_batch')
|
||||
|
||||
|
||||
Possible options for ``executemany_mode`` include:
|
||||
|
||||
@@ -314,10 +311,8 @@ is below::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@host/dbname",
|
||||
executemany_mode="values_plus_batch",
|
||||
insertmanyvalues_page_size=5000,
|
||||
executemany_batch_page_size=500,
|
||||
)
|
||||
executemany_mode='values_plus_batch',
|
||||
insertmanyvalues_page_size=5000, executemany_batch_page_size=500)
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -343,9 +338,7 @@ in the following ways:
|
||||
passed in the database URL; this parameter is consumed by the underlying
|
||||
``libpq`` PostgreSQL client library::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8"
|
||||
)
|
||||
engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8")
|
||||
|
||||
Alternatively, the above ``client_encoding`` value may be passed using
|
||||
:paramref:`_sa.create_engine.connect_args` for programmatic establishment with
|
||||
@@ -353,7 +346,7 @@ in the following ways:
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:pass@host/dbname",
|
||||
connect_args={"client_encoding": "utf8"},
|
||||
connect_args={'client_encoding': 'utf8'}
|
||||
)
|
||||
|
||||
* For all PostgreSQL versions, psycopg2 supports a client-side encoding
|
||||
@@ -362,7 +355,8 @@ in the following ways:
|
||||
``client_encoding`` parameter passed to :func:`_sa.create_engine`::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:pass@host/dbname", client_encoding="utf8"
|
||||
"postgresql+psycopg2://user:pass@host/dbname",
|
||||
client_encoding="utf8"
|
||||
)
|
||||
|
||||
.. tip:: The above ``client_encoding`` parameter admittedly is very similar
|
||||
@@ -381,9 +375,11 @@ in the following ways:
|
||||
# postgresql.conf file
|
||||
|
||||
# client_encoding = sql_ascii # actually, defaults to database
|
||||
# encoding
|
||||
# encoding
|
||||
client_encoding = utf8
|
||||
|
||||
|
||||
|
||||
Transactions
|
||||
------------
|
||||
|
||||
@@ -430,15 +426,15 @@ is set to the ``logging.INFO`` level, notice messages will be logged::
|
||||
|
||||
import logging
|
||||
|
||||
logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO)
|
||||
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
|
||||
|
||||
Above, it is assumed that logging is configured externally. If this is not
|
||||
the case, configuration such as ``logging.basicConfig()`` must be utilized::
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig() # log messages to stdout
|
||||
logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO)
|
||||
logging.basicConfig() # log messages to stdout
|
||||
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
|
||||
|
||||
.. seealso::
|
||||
|
||||
@@ -475,10 +471,8 @@ textual HSTORE expression. If this behavior is not desired, disable the
|
||||
use of the hstore extension by setting ``use_native_hstore`` to ``False`` as
|
||||
follows::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
use_native_hstore=False,
|
||||
)
|
||||
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
use_native_hstore=False)
|
||||
|
||||
The ``HSTORE`` type is **still supported** when the
|
||||
``psycopg2.extensions.register_hstore()`` extension is not used. It merely
|
||||
@@ -519,7 +513,7 @@ class _PGJSONB(JSONB):
|
||||
return None
|
||||
|
||||
|
||||
class _Psycopg2Range(ranges.AbstractSingleRangeImpl):
|
||||
class _Psycopg2Range(ranges.AbstractRangeImpl):
|
||||
_psycopg2_range_cls = "none"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
@@ -850,43 +844,33 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg):
|
||||
# checks based on strings. in the case that .closed
|
||||
# didn't cut it, fall back onto these.
|
||||
str_e = str(e).partition("\n")[0]
|
||||
for msg in self._is_disconnect_messages:
|
||||
for msg in [
|
||||
# these error messages from libpq: interfaces/libpq/fe-misc.c
|
||||
# and interfaces/libpq/fe-secure.c.
|
||||
"terminating connection",
|
||||
"closed the connection",
|
||||
"connection not open",
|
||||
"could not receive data from server",
|
||||
"could not send data to server",
|
||||
# psycopg2 client errors, psycopg2/connection.h,
|
||||
# psycopg2/cursor.h
|
||||
"connection already closed",
|
||||
"cursor already closed",
|
||||
# not sure where this path is originally from, it may
|
||||
# be obsolete. It really says "losed", not "closed".
|
||||
"losed the connection unexpectedly",
|
||||
# these can occur in newer SSL
|
||||
"connection has been closed unexpectedly",
|
||||
"SSL error: decryption failed or bad record mac",
|
||||
"SSL SYSCALL error: Bad file descriptor",
|
||||
"SSL SYSCALL error: EOF detected",
|
||||
"SSL SYSCALL error: Operation timed out",
|
||||
"SSL SYSCALL error: Bad address",
|
||||
]:
|
||||
idx = str_e.find(msg)
|
||||
if idx >= 0 and '"' not in str_e[:idx]:
|
||||
return True
|
||||
return False
|
||||
|
||||
@util.memoized_property
|
||||
def _is_disconnect_messages(self):
|
||||
return (
|
||||
# these error messages from libpq: interfaces/libpq/fe-misc.c
|
||||
# and interfaces/libpq/fe-secure.c.
|
||||
"terminating connection",
|
||||
"closed the connection",
|
||||
"connection not open",
|
||||
"could not receive data from server",
|
||||
"could not send data to server",
|
||||
# psycopg2 client errors, psycopg2/connection.h,
|
||||
# psycopg2/cursor.h
|
||||
"connection already closed",
|
||||
"cursor already closed",
|
||||
# not sure where this path is originally from, it may
|
||||
# be obsolete. It really says "losed", not "closed".
|
||||
"losed the connection unexpectedly",
|
||||
# these can occur in newer SSL
|
||||
"connection has been closed unexpectedly",
|
||||
"SSL error: decryption failed or bad record mac",
|
||||
"SSL SYSCALL error: Bad file descriptor",
|
||||
"SSL SYSCALL error: EOF detected",
|
||||
"SSL SYSCALL error: Operation timed out",
|
||||
"SSL SYSCALL error: Bad address",
|
||||
# This can occur in OpenSSL 1 when an unexpected EOF occurs.
|
||||
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html#BUGS
|
||||
# It may also occur in newer OpenSSL for a non-recoverable I/O
|
||||
# error as a result of a system call that does not set 'errno'
|
||||
# in libc.
|
||||
"SSL SYSCALL error: Success",
|
||||
)
|
||||
|
||||
|
||||
dialect = PGDialect_psycopg2
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/psycopg2cffi.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# testing/engines.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# dialects/postgresql/ranges.py
|
||||
# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -15,10 +14,8 @@ from decimal import Decimal
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Generic
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -154,8 +151,8 @@ class Range(Generic[_T]):
|
||||
return not self.empty and self.upper is None
|
||||
|
||||
@property
|
||||
def __sa_type_engine__(self) -> AbstractSingleRange[_T]:
|
||||
return AbstractSingleRange()
|
||||
def __sa_type_engine__(self) -> AbstractRange[Range[_T]]:
|
||||
return AbstractRange()
|
||||
|
||||
def _contains_value(self, value: _T) -> bool:
|
||||
"""Return True if this range contains the given value."""
|
||||
@@ -271,9 +268,9 @@ class Range(Generic[_T]):
|
||||
value2 += step
|
||||
value2_inc = False
|
||||
|
||||
if value1 < value2:
|
||||
if value1 < value2: # type: ignore
|
||||
return -1
|
||||
elif value1 > value2:
|
||||
elif value1 > value2: # type: ignore
|
||||
return 1
|
||||
elif only_values:
|
||||
return 0
|
||||
@@ -360,8 +357,6 @@ class Range(Generic[_T]):
|
||||
else:
|
||||
return self._contains_value(value)
|
||||
|
||||
__contains__ = contains
|
||||
|
||||
def overlaps(self, other: Range[_T]) -> bool:
|
||||
"Determine whether this range overlaps with `other`."
|
||||
|
||||
@@ -712,46 +707,27 @@ class Range(Generic[_T]):
|
||||
return f"{b0}{l},{r}{b1}"
|
||||
|
||||
|
||||
class MultiRange(List[Range[_T]]):
|
||||
"""Represents a multirange sequence.
|
||||
|
||||
This list subclass is an utility to allow automatic type inference of
|
||||
the proper multi-range SQL type depending on the single range values.
|
||||
This is useful when operating on literal multi-ranges::
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import MultiRange, Range
|
||||
|
||||
value = literal(MultiRange([Range(2, 4)]))
|
||||
|
||||
select(tbl).where(tbl.c.value.op("@")(MultiRange([Range(-3, 7)])))
|
||||
|
||||
.. versionadded:: 2.0.26
|
||||
class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
|
||||
"""
|
||||
Base for PostgreSQL RANGE types.
|
||||
|
||||
.. seealso::
|
||||
|
||||
- :ref:`postgresql_multirange_list_use`.
|
||||
"""
|
||||
`PostgreSQL range functions <https://www.postgresql.org/docs/current/static/functions-range.html>`_
|
||||
|
||||
@property
|
||||
def __sa_type_engine__(self) -> AbstractMultiRange[_T]:
|
||||
return AbstractMultiRange()
|
||||
|
||||
|
||||
class AbstractRange(sqltypes.TypeEngine[_T]):
|
||||
"""Base class for single and multi Range SQL types."""
|
||||
""" # noqa: E501
|
||||
|
||||
render_bind_cast = True
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
@overload
|
||||
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ...
|
||||
def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
|
||||
...
|
||||
|
||||
@overload
|
||||
def adapt(
|
||||
self, cls: Type[TypeEngineMixin], **kw: Any
|
||||
) -> TypeEngine[Any]: ...
|
||||
def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
|
||||
...
|
||||
|
||||
def adapt(
|
||||
self,
|
||||
@@ -765,10 +741,7 @@ class AbstractRange(sqltypes.TypeEngine[_T]):
|
||||
and also render as ``INT4RANGE`` in SQL and DDL.
|
||||
|
||||
"""
|
||||
if (
|
||||
issubclass(cls, (AbstractSingleRangeImpl, AbstractMultiRangeImpl))
|
||||
and cls is not self.__class__
|
||||
):
|
||||
if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__:
|
||||
# two ways to do this are: 1. create a new type on the fly
|
||||
# or 2. have AbstractRangeImpl(visit_name) constructor and a
|
||||
# visit_abstract_range_impl() method in the PG compiler.
|
||||
@@ -787,6 +760,21 @@ class AbstractRange(sqltypes.TypeEngine[_T]):
|
||||
else:
|
||||
return super().adapt(cls)
|
||||
|
||||
def _resolve_for_literal(self, value: Any) -> Any:
|
||||
spec = value.lower if value.lower is not None else value.upper
|
||||
|
||||
if isinstance(spec, int):
|
||||
return INT8RANGE()
|
||||
elif isinstance(spec, (Decimal, float)):
|
||||
return NUMRANGE()
|
||||
elif isinstance(spec, datetime):
|
||||
return TSRANGE() if not spec.tzinfo else TSTZRANGE()
|
||||
elif isinstance(spec, date):
|
||||
return DATERANGE()
|
||||
else:
|
||||
# empty Range, SQL datatype can't be determined here
|
||||
return sqltypes.NULLTYPE
|
||||
|
||||
class comparator_factory(TypeEngine.Comparator[Range[Any]]):
|
||||
"""Define comparison operations for range types."""
|
||||
|
||||
@@ -868,164 +856,91 @@ class AbstractRange(sqltypes.TypeEngine[_T]):
|
||||
return self.expr.operate(operators.mul, other)
|
||||
|
||||
|
||||
class AbstractSingleRange(AbstractRange[Range[_T]]):
|
||||
"""Base for PostgreSQL RANGE types.
|
||||
|
||||
These are types that return a single :class:`_postgresql.Range` object.
|
||||
|
||||
.. seealso::
|
||||
|
||||
`PostgreSQL range functions <https://www.postgresql.org/docs/current/static/functions-range.html>`_
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
def _resolve_for_literal(self, value: Range[Any]) -> Any:
|
||||
spec = value.lower if value.lower is not None else value.upper
|
||||
|
||||
if isinstance(spec, int):
|
||||
# pg is unreasonably picky here: the query
|
||||
# "select 1::INTEGER <@ '[1, 4)'::INT8RANGE" raises
|
||||
# "operator does not exist: integer <@ int8range" as of pg 16
|
||||
if _is_int32(value):
|
||||
return INT4RANGE()
|
||||
else:
|
||||
return INT8RANGE()
|
||||
elif isinstance(spec, (Decimal, float)):
|
||||
return NUMRANGE()
|
||||
elif isinstance(spec, datetime):
|
||||
return TSRANGE() if not spec.tzinfo else TSTZRANGE()
|
||||
elif isinstance(spec, date):
|
||||
return DATERANGE()
|
||||
else:
|
||||
# empty Range, SQL datatype can't be determined here
|
||||
return sqltypes.NULLTYPE
|
||||
|
||||
|
||||
class AbstractSingleRangeImpl(AbstractSingleRange[_T]):
|
||||
"""Marker for AbstractSingleRange that will apply a subclass-specific
|
||||
class AbstractRangeImpl(AbstractRange[Range[_T]]):
|
||||
"""Marker for AbstractRange that will apply a subclass-specific
|
||||
adaptation"""
|
||||
|
||||
|
||||
class AbstractMultiRange(AbstractRange[Sequence[Range[_T]]]):
|
||||
"""Base for PostgreSQL MULTIRANGE types.
|
||||
|
||||
these are types that return a sequence of :class:`_postgresql.Range`
|
||||
objects.
|
||||
|
||||
"""
|
||||
class AbstractMultiRange(AbstractRange[Range[_T]]):
|
||||
"""base for PostgreSQL MULTIRANGE types"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
def _resolve_for_literal(self, value: Sequence[Range[Any]]) -> Any:
|
||||
if not value:
|
||||
# empty MultiRange, SQL datatype can't be determined here
|
||||
return sqltypes.NULLTYPE
|
||||
first = value[0]
|
||||
spec = first.lower if first.lower is not None else first.upper
|
||||
|
||||
if isinstance(spec, int):
|
||||
# pg is unreasonably picky here: the query
|
||||
# "select 1::INTEGER <@ '{[1, 4),[6,19)}'::INT8MULTIRANGE" raises
|
||||
# "operator does not exist: integer <@ int8multirange" as of pg 16
|
||||
if all(_is_int32(r) for r in value):
|
||||
return INT4MULTIRANGE()
|
||||
else:
|
||||
return INT8MULTIRANGE()
|
||||
elif isinstance(spec, (Decimal, float)):
|
||||
return NUMMULTIRANGE()
|
||||
elif isinstance(spec, datetime):
|
||||
return TSMULTIRANGE() if not spec.tzinfo else TSTZMULTIRANGE()
|
||||
elif isinstance(spec, date):
|
||||
return DATEMULTIRANGE()
|
||||
else:
|
||||
# empty Range, SQL datatype can't be determined here
|
||||
return sqltypes.NULLTYPE
|
||||
|
||||
|
||||
class AbstractMultiRangeImpl(AbstractMultiRange[_T]):
|
||||
"""Marker for AbstractMultiRange that will apply a subclass-specific
|
||||
class AbstractMultiRangeImpl(
|
||||
AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]]
|
||||
):
|
||||
"""Marker for AbstractRange that will apply a subclass-specific
|
||||
adaptation"""
|
||||
|
||||
|
||||
class INT4RANGE(AbstractSingleRange[int]):
|
||||
class INT4RANGE(AbstractRange[Range[int]]):
|
||||
"""Represent the PostgreSQL INT4RANGE type."""
|
||||
|
||||
__visit_name__ = "INT4RANGE"
|
||||
|
||||
|
||||
class INT8RANGE(AbstractSingleRange[int]):
|
||||
class INT8RANGE(AbstractRange[Range[int]]):
|
||||
"""Represent the PostgreSQL INT8RANGE type."""
|
||||
|
||||
__visit_name__ = "INT8RANGE"
|
||||
|
||||
|
||||
class NUMRANGE(AbstractSingleRange[Decimal]):
|
||||
class NUMRANGE(AbstractRange[Range[Decimal]]):
|
||||
"""Represent the PostgreSQL NUMRANGE type."""
|
||||
|
||||
__visit_name__ = "NUMRANGE"
|
||||
|
||||
|
||||
class DATERANGE(AbstractSingleRange[date]):
|
||||
class DATERANGE(AbstractRange[Range[date]]):
|
||||
"""Represent the PostgreSQL DATERANGE type."""
|
||||
|
||||
__visit_name__ = "DATERANGE"
|
||||
|
||||
|
||||
class TSRANGE(AbstractSingleRange[datetime]):
|
||||
class TSRANGE(AbstractRange[Range[datetime]]):
|
||||
"""Represent the PostgreSQL TSRANGE type."""
|
||||
|
||||
__visit_name__ = "TSRANGE"
|
||||
|
||||
|
||||
class TSTZRANGE(AbstractSingleRange[datetime]):
|
||||
class TSTZRANGE(AbstractRange[Range[datetime]]):
|
||||
"""Represent the PostgreSQL TSTZRANGE type."""
|
||||
|
||||
__visit_name__ = "TSTZRANGE"
|
||||
|
||||
|
||||
class INT4MULTIRANGE(AbstractMultiRange[int]):
|
||||
class INT4MULTIRANGE(AbstractMultiRange[Range[int]]):
|
||||
"""Represent the PostgreSQL INT4MULTIRANGE type."""
|
||||
|
||||
__visit_name__ = "INT4MULTIRANGE"
|
||||
|
||||
|
||||
class INT8MULTIRANGE(AbstractMultiRange[int]):
|
||||
class INT8MULTIRANGE(AbstractMultiRange[Range[int]]):
|
||||
"""Represent the PostgreSQL INT8MULTIRANGE type."""
|
||||
|
||||
__visit_name__ = "INT8MULTIRANGE"
|
||||
|
||||
|
||||
class NUMMULTIRANGE(AbstractMultiRange[Decimal]):
|
||||
class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]):
|
||||
"""Represent the PostgreSQL NUMMULTIRANGE type."""
|
||||
|
||||
__visit_name__ = "NUMMULTIRANGE"
|
||||
|
||||
|
||||
class DATEMULTIRANGE(AbstractMultiRange[date]):
|
||||
class DATEMULTIRANGE(AbstractMultiRange[Range[date]]):
|
||||
"""Represent the PostgreSQL DATEMULTIRANGE type."""
|
||||
|
||||
__visit_name__ = "DATEMULTIRANGE"
|
||||
|
||||
|
||||
class TSMULTIRANGE(AbstractMultiRange[datetime]):
|
||||
class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]):
|
||||
"""Represent the PostgreSQL TSRANGE type."""
|
||||
|
||||
__visit_name__ = "TSMULTIRANGE"
|
||||
|
||||
|
||||
class TSTZMULTIRANGE(AbstractMultiRange[datetime]):
|
||||
class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]):
|
||||
"""Represent the PostgreSQL TSTZRANGE type."""
|
||||
|
||||
__visit_name__ = "TSTZMULTIRANGE"
|
||||
|
||||
|
||||
_max_int_32 = 2**31 - 1
|
||||
_min_int_32 = -(2**31)
|
||||
|
||||
|
||||
def _is_int32(r: Range[int]) -> bool:
|
||||
return (r.lower is None or _min_int_32 <= r.lower <= _max_int_32) and (
|
||||
r.upper is None or _min_int_32 <= r.upper <= _max_int_32
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# dialects/postgresql/types.py
|
||||
# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors
|
||||
# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -38,52 +37,43 @@ class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
@overload
|
||||
def __init__(
|
||||
self: PGUuid[_python_UUID], as_uuid: Literal[True] = ...
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: PGUuid[str], as_uuid: Literal[False] = ...
|
||||
) -> None: ...
|
||||
def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None:
|
||||
...
|
||||
|
||||
def __init__(self, as_uuid: bool = True) -> None: ...
|
||||
def __init__(self, as_uuid: bool = True) -> None:
|
||||
...
|
||||
|
||||
|
||||
class BYTEA(sqltypes.LargeBinary):
|
||||
__visit_name__ = "BYTEA"
|
||||
|
||||
|
||||
class _NetworkAddressTypeMixin:
|
||||
|
||||
def coerce_compared_value(
|
||||
self, op: Optional[OperatorType], value: Any
|
||||
) -> TypeEngine[Any]:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(self, TypeEngine)
|
||||
return self
|
||||
|
||||
|
||||
class INET(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
|
||||
class INET(sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "INET"
|
||||
|
||||
|
||||
PGInet = INET
|
||||
|
||||
|
||||
class CIDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
|
||||
class CIDR(sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "CIDR"
|
||||
|
||||
|
||||
PGCidr = CIDR
|
||||
|
||||
|
||||
class MACADDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
|
||||
class MACADDR(sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "MACADDR"
|
||||
|
||||
|
||||
PGMacAddr = MACADDR
|
||||
|
||||
|
||||
class MACADDR8(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
|
||||
class MACADDR8(sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "MACADDR8"
|
||||
|
||||
|
||||
@@ -104,11 +94,12 @@ class MONEY(sqltypes.TypeEngine[str]):
|
||||
from sqlalchemy import Dialect
|
||||
from sqlalchemy import TypeDecorator
|
||||
|
||||
|
||||
class NumericMoney(TypeDecorator):
|
||||
impl = MONEY
|
||||
|
||||
def process_result_value(self, value: Any, dialect: Dialect) -> None:
|
||||
def process_result_value(
|
||||
self, value: Any, dialect: Dialect
|
||||
) -> None:
|
||||
if value is not None:
|
||||
# adjust this for the currency and numeric
|
||||
m = re.match(r"\$([\d.]+)", value)
|
||||
@@ -123,7 +114,6 @@ class MONEY(sqltypes.TypeEngine[str]):
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import TypeDecorator
|
||||
|
||||
|
||||
class NumericMoney(TypeDecorator):
|
||||
impl = MONEY
|
||||
|
||||
@@ -132,18 +122,20 @@ class MONEY(sqltypes.TypeEngine[str]):
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
""" # noqa: E501
|
||||
"""
|
||||
|
||||
__visit_name__ = "MONEY"
|
||||
|
||||
|
||||
class OID(sqltypes.TypeEngine[int]):
|
||||
|
||||
"""Provide the PostgreSQL OID type."""
|
||||
|
||||
__visit_name__ = "OID"
|
||||
|
||||
|
||||
class REGCONFIG(sqltypes.TypeEngine[str]):
|
||||
|
||||
"""Provide the PostgreSQL REGCONFIG type.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
@@ -154,6 +146,7 @@ class REGCONFIG(sqltypes.TypeEngine[str]):
|
||||
|
||||
|
||||
class TSQUERY(sqltypes.TypeEngine[str]):
|
||||
|
||||
"""Provide the PostgreSQL TSQUERY type.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
@@ -164,6 +157,7 @@ class TSQUERY(sqltypes.TypeEngine[str]):
|
||||
|
||||
|
||||
class REGCLASS(sqltypes.TypeEngine[str]):
|
||||
|
||||
"""Provide the PostgreSQL REGCLASS type.
|
||||
|
||||
.. versionadded:: 1.2.7
|
||||
@@ -174,6 +168,7 @@ class REGCLASS(sqltypes.TypeEngine[str]):
|
||||
|
||||
|
||||
class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
|
||||
"""Provide the PostgreSQL TIMESTAMP type."""
|
||||
|
||||
__visit_name__ = "TIMESTAMP"
|
||||
@@ -194,6 +189,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
|
||||
|
||||
class TIME(sqltypes.TIME):
|
||||
|
||||
"""PostgreSQL TIME type."""
|
||||
|
||||
__visit_name__ = "TIME"
|
||||
@@ -214,6 +210,7 @@ class TIME(sqltypes.TIME):
|
||||
|
||||
|
||||
class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
|
||||
|
||||
"""PostgreSQL INTERVAL type."""
|
||||
|
||||
__visit_name__ = "INTERVAL"
|
||||
@@ -283,6 +280,7 @@ PGBit = BIT
|
||||
|
||||
|
||||
class TSVECTOR(sqltypes.TypeEngine[str]):
|
||||
|
||||
"""The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
|
||||
text search type TSVECTOR.
|
||||
|
||||
@@ -299,6 +297,7 @@ class TSVECTOR(sqltypes.TypeEngine[str]):
|
||||
|
||||
|
||||
class CITEXT(sqltypes.TEXT):
|
||||
|
||||
"""Provide the PostgreSQL CITEXT type.
|
||||
|
||||
.. versionadded:: 2.0.7
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/sqlite/__init__.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# sqlite/__init__.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# dialects/sqlite/aiosqlite.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# sqlite/aiosqlite.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
@@ -30,7 +31,6 @@ This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///filename")
|
||||
|
||||
The URL passes through all arguments to the ``pysqlite`` driver, so all
|
||||
@@ -49,71 +49,45 @@ in Python and use them directly in SQLite queries as described here: :ref:`pysql
|
||||
Serializable isolation / Savepoints / Transactional DDL (asyncio version)
|
||||
-------------------------------------------------------------------------
|
||||
|
||||
A newly revised version of this important section is now available
|
||||
at the top level of the SQLAlchemy SQLite documentation, in the section
|
||||
:ref:`sqlite_transactions`.
|
||||
Similarly to pysqlite, aiosqlite does not support SAVEPOINT feature.
|
||||
|
||||
The solution is similar to :ref:`pysqlite_serializable`. This is achieved by the event listeners in async::
|
||||
|
||||
.. _aiosqlite_pooling:
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
Pooling Behavior
|
||||
----------------
|
||||
engine = create_async_engine("sqlite+aiosqlite:///myfile.db")
|
||||
|
||||
The SQLAlchemy ``aiosqlite`` DBAPI establishes the connection pool differently
|
||||
based on the kind of SQLite database that's requested:
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def do_connect(dbapi_connection, connection_record):
|
||||
# disable aiosqlite's emitting of the BEGIN statement entirely.
|
||||
# also stops it from emitting COMMIT before any DDL.
|
||||
dbapi_connection.isolation_level = None
|
||||
|
||||
* When a ``:memory:`` SQLite database is specified, the dialect by default
|
||||
will use :class:`.StaticPool`. This pool maintains a single
|
||||
connection, so that all access to the engine
|
||||
use the same ``:memory:`` database.
|
||||
* When a file-based database is specified, the dialect will use
|
||||
:class:`.AsyncAdaptedQueuePool` as the source of connections.
|
||||
@event.listens_for(engine.sync_engine, "begin")
|
||||
def do_begin(conn):
|
||||
# emit our own BEGIN
|
||||
conn.exec_driver_sql("BEGIN")
|
||||
|
||||
.. versionchanged:: 2.0.38
|
||||
|
||||
SQLite file database engines now use :class:`.AsyncAdaptedQueuePool` by default.
|
||||
Previously, :class:`.NullPool` were used. The :class:`.NullPool` class
|
||||
may be used by specifying it via the
|
||||
:paramref:`_sa.create_engine.poolclass` parameter.
|
||||
.. warning:: When using the above recipe, it is advised to not use the
|
||||
:paramref:`.Connection.execution_options.isolation_level` setting on
|
||||
:class:`_engine.Connection` and :func:`_sa.create_engine`
|
||||
with the SQLite driver,
|
||||
as this function necessarily will also alter the ".isolation_level" setting.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Deque
|
||||
from typing import Iterator
|
||||
from typing import NoReturn
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .base import SQLiteExecutionContext
|
||||
from .pysqlite import SQLiteDialect_pysqlite
|
||||
from ... import pool
|
||||
from ... import util
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_module
|
||||
from ...engine import AdaptedConnection
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...connectors.asyncio import AsyncIODBAPIConnection
|
||||
from ...connectors.asyncio import AsyncIODBAPICursor
|
||||
from ...engine.interfaces import _DBAPICursorDescription
|
||||
from ...engine.interfaces import _DBAPIMultiExecuteParams
|
||||
from ...engine.interfaces import _DBAPISingleExecuteParams
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.url import URL
|
||||
from ...pool.base import PoolProxiedConnection
|
||||
|
||||
|
||||
class AsyncAdapt_aiosqlite_cursor:
|
||||
# TODO: base on connectors/asyncio.py
|
||||
@@ -132,26 +106,21 @@ class AsyncAdapt_aiosqlite_cursor:
|
||||
|
||||
server_side = False
|
||||
|
||||
def __init__(self, adapt_connection: AsyncAdapt_aiosqlite_connection):
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
self.arraysize = 1
|
||||
self.rowcount = -1
|
||||
self.description: Optional[_DBAPICursorDescription] = None
|
||||
self._rows: Deque[Any] = deque()
|
||||
self.description = None
|
||||
self._rows = []
|
||||
|
||||
def close(self) -> None:
|
||||
self._rows.clear()
|
||||
|
||||
def execute(
|
||||
self,
|
||||
operation: Any,
|
||||
parameters: Optional[_DBAPISingleExecuteParams] = None,
|
||||
) -> Any:
|
||||
def close(self):
|
||||
self._rows[:] = []
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
try:
|
||||
_cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor()) # type: ignore[arg-type] # noqa: E501
|
||||
_cursor = self.await_(self._connection.cursor())
|
||||
|
||||
if parameters is None:
|
||||
self.await_(_cursor.execute(operation))
|
||||
@@ -163,7 +132,7 @@ class AsyncAdapt_aiosqlite_cursor:
|
||||
self.lastrowid = self.rowcount = -1
|
||||
|
||||
if not self.server_side:
|
||||
self._rows = deque(self.await_(_cursor.fetchall()))
|
||||
self._rows = self.await_(_cursor.fetchall())
|
||||
else:
|
||||
self.description = None
|
||||
self.lastrowid = _cursor.lastrowid
|
||||
@@ -172,17 +141,13 @@ class AsyncAdapt_aiosqlite_cursor:
|
||||
if not self.server_side:
|
||||
self.await_(_cursor.close())
|
||||
else:
|
||||
self._cursor = _cursor # type: ignore[misc]
|
||||
self._cursor = _cursor
|
||||
except Exception as error:
|
||||
self._adapt_connection._handle_exception(error)
|
||||
|
||||
def executemany(
|
||||
self,
|
||||
operation: Any,
|
||||
seq_of_parameters: _DBAPIMultiExecuteParams,
|
||||
) -> Any:
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
try:
|
||||
_cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor()) # type: ignore[arg-type] # noqa: E501
|
||||
_cursor = self.await_(self._connection.cursor())
|
||||
self.await_(_cursor.executemany(operation, seq_of_parameters))
|
||||
self.description = None
|
||||
self.lastrowid = _cursor.lastrowid
|
||||
@@ -191,29 +156,30 @@ class AsyncAdapt_aiosqlite_cursor:
|
||||
except Exception as error:
|
||||
self._adapt_connection._handle_exception(error)
|
||||
|
||||
def setinputsizes(self, *inputsizes: Any) -> None:
|
||||
def setinputsizes(self, *inputsizes):
|
||||
pass
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.popleft()
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self) -> Optional[Any]:
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.popleft()
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
rr = self._rows
|
||||
return [rr.popleft() for _ in range(min(size, len(rr)))]
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self) -> Sequence[Any]:
|
||||
retval = list(self._rows)
|
||||
self._rows.clear()
|
||||
def fetchall(self):
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
@@ -224,27 +190,24 @@ class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
|
||||
|
||||
server_side = True
|
||||
|
||||
def __init__(self, *arg: Any, **kw: Any) -> None:
|
||||
def __init__(self, *arg, **kw):
|
||||
super().__init__(*arg, **kw)
|
||||
self._cursor: Optional[AsyncIODBAPICursor] = None
|
||||
self._cursor = None
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self):
|
||||
if self._cursor is not None:
|
||||
self.await_(self._cursor.close())
|
||||
self._cursor = None
|
||||
|
||||
def fetchone(self) -> Optional[Any]:
|
||||
assert self._cursor is not None
|
||||
def fetchone(self):
|
||||
return self.await_(self._cursor.fetchone())
|
||||
|
||||
def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]:
|
||||
assert self._cursor is not None
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
return self.await_(self._cursor.fetchmany(size=size))
|
||||
|
||||
def fetchall(self) -> Sequence[Any]:
|
||||
assert self._cursor is not None
|
||||
def fetchall(self):
|
||||
return self.await_(self._cursor.fetchall())
|
||||
|
||||
|
||||
@@ -252,24 +215,22 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
|
||||
await_ = staticmethod(await_only)
|
||||
__slots__ = ("dbapi",)
|
||||
|
||||
def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection) -> None:
|
||||
def __init__(self, dbapi, connection):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
|
||||
@property
|
||||
def isolation_level(self) -> Optional[str]:
|
||||
return cast(str, self._connection.isolation_level)
|
||||
def isolation_level(self):
|
||||
return self._connection.isolation_level
|
||||
|
||||
@isolation_level.setter
|
||||
def isolation_level(self, value: Optional[str]) -> None:
|
||||
def isolation_level(self, value):
|
||||
# aiosqlite's isolation_level setter works outside the Thread
|
||||
# that it's supposed to, necessitating setting check_same_thread=False.
|
||||
# for improved stability, we instead invent our own awaitable version
|
||||
# using aiosqlite's async queue directly.
|
||||
|
||||
def set_iso(
|
||||
connection: AsyncAdapt_aiosqlite_connection, value: Optional[str]
|
||||
) -> None:
|
||||
def set_iso(connection, value):
|
||||
connection.isolation_level = value
|
||||
|
||||
function = partial(set_iso, self._connection._conn, value)
|
||||
@@ -278,38 +239,38 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
|
||||
self._connection._tx.put_nowait((future, function))
|
||||
|
||||
try:
|
||||
self.await_(future)
|
||||
return self.await_(future)
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def create_function(self, *args: Any, **kw: Any) -> None:
|
||||
def create_function(self, *args, **kw):
|
||||
try:
|
||||
self.await_(self._connection.create_function(*args, **kw))
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def cursor(self, server_side: bool = False) -> AsyncAdapt_aiosqlite_cursor:
|
||||
def cursor(self, server_side=False):
|
||||
if server_side:
|
||||
return AsyncAdapt_aiosqlite_ss_cursor(self)
|
||||
else:
|
||||
return AsyncAdapt_aiosqlite_cursor(self)
|
||||
|
||||
def execute(self, *args: Any, **kw: Any) -> Any:
|
||||
def execute(self, *args, **kw):
|
||||
return self.await_(self._connection.execute(*args, **kw))
|
||||
|
||||
def rollback(self) -> None:
|
||||
def rollback(self):
|
||||
try:
|
||||
self.await_(self._connection.rollback())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def commit(self) -> None:
|
||||
def commit(self):
|
||||
try:
|
||||
self.await_(self._connection.commit())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self):
|
||||
try:
|
||||
self.await_(self._connection.close())
|
||||
except ValueError:
|
||||
@@ -325,7 +286,7 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def _handle_exception(self, error: Exception) -> NoReturn:
|
||||
def _handle_exception(self, error):
|
||||
if (
|
||||
isinstance(error, ValueError)
|
||||
and error.args[0] == "no active connection"
|
||||
@@ -343,14 +304,14 @@ class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
|
||||
await_ = staticmethod(await_fallback)
|
||||
|
||||
|
||||
class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
|
||||
def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType):
|
||||
class AsyncAdapt_aiosqlite_dbapi:
|
||||
def __init__(self, aiosqlite, sqlite):
|
||||
self.aiosqlite = aiosqlite
|
||||
self.sqlite = sqlite
|
||||
self.paramstyle = "qmark"
|
||||
self._init_dbapi_attributes()
|
||||
|
||||
def _init_dbapi_attributes(self) -> None:
|
||||
def _init_dbapi_attributes(self):
|
||||
for name in (
|
||||
"DatabaseError",
|
||||
"Error",
|
||||
@@ -369,7 +330,7 @@ class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
|
||||
for name in ("Binary",):
|
||||
setattr(self, name, getattr(self.sqlite, name))
|
||||
|
||||
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiosqlite_connection:
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
|
||||
creator_fn = kw.pop("async_creator_fn", None)
|
||||
@@ -393,7 +354,7 @@ class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
|
||||
|
||||
|
||||
class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
|
||||
def create_server_side_cursor(self) -> DBAPICursor:
|
||||
def create_server_side_cursor(self):
|
||||
return self._dbapi_connection.cursor(server_side=True)
|
||||
|
||||
|
||||
@@ -408,25 +369,19 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
|
||||
execution_ctx_cls = SQLiteExecutionContext_aiosqlite
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls) -> AsyncAdapt_aiosqlite_dbapi:
|
||||
def import_dbapi(cls):
|
||||
return AsyncAdapt_aiosqlite_dbapi(
|
||||
__import__("aiosqlite"), __import__("sqlite3")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url: URL) -> type[pool.Pool]:
|
||||
def get_pool_class(cls, url):
|
||||
if cls._is_url_file_db(url):
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
return pool.NullPool
|
||||
else:
|
||||
return pool.StaticPool
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
self.dbapi = cast("DBAPIModule", self.dbapi)
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(
|
||||
e, self.dbapi.OperationalError
|
||||
) and "no active connection" in str(e):
|
||||
@@ -434,10 +389,8 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
|
||||
|
||||
return super().is_disconnect(e, connection, cursor)
|
||||
|
||||
def get_driver_connection(
|
||||
self, connection: DBAPIConnection
|
||||
) -> AsyncIODBAPIConnection:
|
||||
return connection._connection # type: ignore[no-any-return]
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = SQLiteDialect_aiosqlite
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
# dialects/sqlite/dml.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# sqlite/dml.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -7,10 +7,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from .._typing import _OnConflictIndexElementsT
|
||||
from .._typing import _OnConflictIndexWhereT
|
||||
@@ -19,7 +15,6 @@ from .._typing import _OnConflictWhereT
|
||||
from ... import util
|
||||
from ...sql import coercions
|
||||
from ...sql import roles
|
||||
from ...sql import schema
|
||||
from ...sql._typing import _DMLTableArgument
|
||||
from ...sql.base import _exclusive_against
|
||||
from ...sql.base import _generative
|
||||
@@ -27,9 +22,7 @@ from ...sql.base import ColumnCollection
|
||||
from ...sql.base import ReadOnlyColumnCollection
|
||||
from ...sql.dml import Insert as StandardInsert
|
||||
from ...sql.elements import ClauseElement
|
||||
from ...sql.elements import ColumnElement
|
||||
from ...sql.elements import KeyedColumnElement
|
||||
from ...sql.elements import TextClause
|
||||
from ...sql.expression import alias
|
||||
from ...util.typing import Self
|
||||
|
||||
@@ -148,10 +141,11 @@ class Insert(StandardInsert):
|
||||
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
|
||||
|
||||
:param where:
|
||||
Optional argument. An expression object representing a ``WHERE``
|
||||
clause that restricts the rows affected by ``DO UPDATE SET``. Rows not
|
||||
meeting the ``WHERE`` condition will not be updated (effectively a
|
||||
``DO NOTHING`` for those rows).
|
||||
Optional argument. If present, can be a literal SQL
|
||||
string or an acceptable expression for a ``WHERE`` clause
|
||||
that restricts the rows affected by ``DO UPDATE SET``. Rows
|
||||
not meeting the ``WHERE`` condition will not be updated
|
||||
(effectively a ``DO NOTHING`` for those rows).
|
||||
|
||||
"""
|
||||
|
||||
@@ -190,10 +184,9 @@ class Insert(StandardInsert):
|
||||
class OnConflictClause(ClauseElement):
|
||||
stringify_dialect = "sqlite"
|
||||
|
||||
inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]]
|
||||
inferred_target_whereclause: Optional[
|
||||
Union[ColumnElement[Any], TextClause]
|
||||
]
|
||||
constraint_target: None
|
||||
inferred_target_elements: _OnConflictIndexElementsT
|
||||
inferred_target_whereclause: _OnConflictIndexWhereT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -201,22 +194,13 @@ class OnConflictClause(ClauseElement):
|
||||
index_where: _OnConflictIndexWhereT = None,
|
||||
):
|
||||
if index_elements is not None:
|
||||
self.inferred_target_elements = [
|
||||
coercions.expect(roles.DDLConstraintColumnRole, column)
|
||||
for column in index_elements
|
||||
]
|
||||
self.inferred_target_whereclause = (
|
||||
coercions.expect(
|
||||
roles.WhereHavingRole,
|
||||
index_where,
|
||||
)
|
||||
if index_where is not None
|
||||
else None
|
||||
)
|
||||
self.constraint_target = None
|
||||
self.inferred_target_elements = index_elements
|
||||
self.inferred_target_whereclause = index_where
|
||||
else:
|
||||
self.inferred_target_elements = (
|
||||
self.inferred_target_whereclause
|
||||
) = None
|
||||
self.constraint_target = (
|
||||
self.inferred_target_elements
|
||||
) = self.inferred_target_whereclause = None
|
||||
|
||||
|
||||
class OnConflictDoNothing(OnConflictClause):
|
||||
@@ -226,9 +210,6 @@ class OnConflictDoNothing(OnConflictClause):
|
||||
class OnConflictDoUpdate(OnConflictClause):
|
||||
__visit_name__ = "on_conflict_do_update"
|
||||
|
||||
update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]]
|
||||
update_whereclause: Optional[ColumnElement[Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_elements: _OnConflictIndexElementsT = None,
|
||||
@@ -256,8 +237,4 @@ class OnConflictDoUpdate(OnConflictClause):
|
||||
(coercions.expect(roles.DMLColumnRole, key), value)
|
||||
for key, value in set_.items()
|
||||
]
|
||||
self.update_whereclause = (
|
||||
coercions.expect(roles.WhereHavingRole, where)
|
||||
if where is not None
|
||||
else None
|
||||
)
|
||||
self.update_whereclause = where
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
# dialects/sqlite/json.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import types as sqltypes
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
# dialects/sqlite/provision.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import os
|
||||
@@ -52,6 +46,8 @@ def _format_url(url, driver, ident):
|
||||
assert "test_schema" not in filename
|
||||
tokens = re.split(r"[_\.]", filename)
|
||||
|
||||
new_filename = f"{driver}"
|
||||
|
||||
for token in tokens:
|
||||
if token in _drivernames:
|
||||
if driver is None:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/sqlite/pysqlcipher.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# sqlite/pysqlcipher.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -39,7 +39,7 @@ Current dialect selection logic is:
|
||||
|
||||
e = create_engine(
|
||||
"sqlite+pysqlcipher://:password@/dbname.db",
|
||||
module=sqlcipher_compatible_driver,
|
||||
module=sqlcipher_compatible_driver
|
||||
)
|
||||
|
||||
These drivers make use of the SQLCipher engine. This system essentially
|
||||
@@ -55,12 +55,12 @@ The format of the connect string is in every way the same as that
|
||||
of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
|
||||
"password" field is now accepted, which should contain a passphrase::
|
||||
|
||||
e = create_engine("sqlite+pysqlcipher://:testing@/foo.db")
|
||||
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
|
||||
|
||||
For an absolute file path, two leading slashes should be used for the
|
||||
database name::
|
||||
|
||||
e = create_engine("sqlite+pysqlcipher://:testing@//path/to/foo.db")
|
||||
e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
|
||||
|
||||
A selection of additional encryption-related pragmas supported by SQLCipher
|
||||
as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
|
||||
@@ -68,9 +68,7 @@ in the query string, and will result in that PRAGMA being called for each
|
||||
new connection. Currently, ``cipher``, ``kdf_iter``
|
||||
``cipher_page_size`` and ``cipher_use_hmac`` are supported::
|
||||
|
||||
e = create_engine(
|
||||
"sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000"
|
||||
)
|
||||
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
|
||||
|
||||
.. warning:: Previous versions of sqlalchemy did not take into consideration
|
||||
the encryption-related pragmas passed in the url string, that were silently
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# dialects/sqlite/pysqlite.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
|
||||
# sqlite/pysqlite.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -28,9 +28,7 @@ Connect Strings
|
||||
---------------
|
||||
|
||||
The file specification for the SQLite database is taken as the "database"
|
||||
portion of the URL. Note that the format of a SQLAlchemy url is:
|
||||
|
||||
.. sourcecode:: text
|
||||
portion of the URL. Note that the format of a SQLAlchemy url is::
|
||||
|
||||
driver://user:pass@host/database
|
||||
|
||||
@@ -39,28 +37,25 @@ the **right** of the third slash. So connecting to a relative filepath
|
||||
looks like::
|
||||
|
||||
# relative path
|
||||
e = create_engine("sqlite:///path/to/database.db")
|
||||
e = create_engine('sqlite:///path/to/database.db')
|
||||
|
||||
An absolute path, which is denoted by starting with a slash, means you
|
||||
need **four** slashes::
|
||||
|
||||
# absolute path
|
||||
e = create_engine("sqlite:////path/to/database.db")
|
||||
e = create_engine('sqlite:////path/to/database.db')
|
||||
|
||||
To use a Windows path, regular drive specifications and backslashes can be
|
||||
used. Double backslashes are probably needed::
|
||||
|
||||
# absolute path on Windows
|
||||
e = create_engine("sqlite:///C:\\path\\to\\database.db")
|
||||
e = create_engine('sqlite:///C:\\path\\to\\database.db')
|
||||
|
||||
To use sqlite ``:memory:`` database specify it as the filename using
|
||||
``sqlite:///:memory:``. It's also the default if no filepath is
|
||||
present, specifying only ``sqlite://`` and nothing else::
|
||||
The sqlite ``:memory:`` identifier is the default if no filepath is
|
||||
present. Specify ``sqlite://`` and nothing else::
|
||||
|
||||
# in-memory database (note three slashes)
|
||||
e = create_engine("sqlite:///:memory:")
|
||||
# also in-memory database
|
||||
e2 = create_engine("sqlite://")
|
||||
# in-memory database
|
||||
e = create_engine('sqlite://')
|
||||
|
||||
.. _pysqlite_uri_connections:
|
||||
|
||||
@@ -100,9 +95,7 @@ Above, the pysqlite / sqlite3 DBAPI would be passed arguments as::
|
||||
|
||||
sqlite3.connect(
|
||||
"file:path/to/database?mode=ro&nolock=1",
|
||||
check_same_thread=True,
|
||||
timeout=10,
|
||||
uri=True,
|
||||
check_same_thread=True, timeout=10, uri=True
|
||||
)
|
||||
|
||||
Regarding future parameters added to either the Python or native drivers. new
|
||||
@@ -148,11 +141,8 @@ as follows::
|
||||
def regexp(a, b):
|
||||
return re.search(a, b) is not None
|
||||
|
||||
|
||||
sqlite_connection.create_function(
|
||||
"regexp",
|
||||
2,
|
||||
regexp,
|
||||
"regexp", 2, regexp,
|
||||
)
|
||||
|
||||
There is currently no support for regular expression flags as a separate
|
||||
@@ -193,12 +183,10 @@ Keeping in mind that pysqlite's parsing option is not recommended,
|
||||
nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES
|
||||
can be forced if one configures "native_datetime=True" on create_engine()::
|
||||
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={
|
||||
"detect_types": sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
|
||||
},
|
||||
native_datetime=True,
|
||||
engine = create_engine('sqlite://',
|
||||
connect_args={'detect_types':
|
||||
sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
|
||||
native_datetime=True
|
||||
)
|
||||
|
||||
With this flag enabled, the DATE and TIMESTAMP types (but note - not the
|
||||
@@ -253,7 +241,6 @@ Pooling may be disabled for a file based database by specifying the
|
||||
parameter::
|
||||
|
||||
from sqlalchemy import NullPool
|
||||
|
||||
engine = create_engine("sqlite:///myfile.db", poolclass=NullPool)
|
||||
|
||||
It's been observed that the :class:`.NullPool` implementation incurs an
|
||||
@@ -273,12 +260,9 @@ globally, and the ``check_same_thread`` flag can be passed to Pysqlite
|
||||
as ``False``::
|
||||
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
engine = create_engine('sqlite://',
|
||||
connect_args={'check_same_thread':False},
|
||||
poolclass=StaticPool)
|
||||
|
||||
Note that using a ``:memory:`` database in multiple threads requires a recent
|
||||
version of SQLite.
|
||||
@@ -297,14 +281,14 @@ needed within multiple threads for this case::
|
||||
|
||||
# maintain the same connection per thread
|
||||
from sqlalchemy.pool import SingletonThreadPool
|
||||
|
||||
engine = create_engine("sqlite:///mydb.db", poolclass=SingletonThreadPool)
|
||||
engine = create_engine('sqlite:///mydb.db',
|
||||
poolclass=SingletonThreadPool)
|
||||
|
||||
|
||||
# maintain the same connection across all threads
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
engine = create_engine("sqlite:///mydb.db", poolclass=StaticPool)
|
||||
engine = create_engine('sqlite:///mydb.db',
|
||||
poolclass=StaticPool)
|
||||
|
||||
Note that :class:`.SingletonThreadPool` should be configured for the number
|
||||
of threads that are to be used; beyond that number, connections will be
|
||||
@@ -333,14 +317,13 @@ same column, use a custom type that will check each row individually::
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import TypeDecorator
|
||||
|
||||
|
||||
class MixedBinary(TypeDecorator):
|
||||
impl = String
|
||||
cache_ok = True
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if isinstance(value, str):
|
||||
value = bytes(value, "utf-8")
|
||||
value = bytes(value, 'utf-8')
|
||||
elif value is not None:
|
||||
value = bytes(value)
|
||||
|
||||
@@ -354,10 +337,74 @@ Then use the above ``MixedBinary`` datatype in the place where
|
||||
Serializable isolation / Savepoints / Transactional DDL
|
||||
-------------------------------------------------------
|
||||
|
||||
A newly revised version of this important section is now available
|
||||
at the top level of the SQLAlchemy SQLite documentation, in the section
|
||||
:ref:`sqlite_transactions`.
|
||||
In the section :ref:`sqlite_concurrency`, we refer to the pysqlite
|
||||
driver's assortment of issues that prevent several features of SQLite
|
||||
from working correctly. The pysqlite DBAPI driver has several
|
||||
long-standing bugs which impact the correctness of its transactional
|
||||
behavior. In its default mode of operation, SQLite features such as
|
||||
SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are
|
||||
non-functional, and in order to use these features, workarounds must
|
||||
be taken.
|
||||
|
||||
The issue is essentially that the driver attempts to second-guess the user's
|
||||
intent, failing to start transactions and sometimes ending them prematurely, in
|
||||
an effort to minimize the SQLite databases's file locking behavior, even
|
||||
though SQLite itself uses "shared" locks for read-only activities.
|
||||
|
||||
SQLAlchemy chooses to not alter this behavior by default, as it is the
|
||||
long-expected behavior of the pysqlite driver; if and when the pysqlite
|
||||
driver attempts to repair these issues, that will be more of a driver towards
|
||||
defaults for SQLAlchemy.
|
||||
|
||||
The good news is that with a few events, we can implement transactional
|
||||
support fully, by disabling pysqlite's feature entirely and emitting BEGIN
|
||||
ourselves. This is achieved using two event listeners::
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
|
||||
engine = create_engine("sqlite:///myfile.db")
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def do_connect(dbapi_connection, connection_record):
|
||||
# disable pysqlite's emitting of the BEGIN statement entirely.
|
||||
# also stops it from emitting COMMIT before any DDL.
|
||||
dbapi_connection.isolation_level = None
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
# emit our own BEGIN
|
||||
conn.exec_driver_sql("BEGIN")
|
||||
|
||||
.. warning:: When using the above recipe, it is advised to not use the
|
||||
:paramref:`.Connection.execution_options.isolation_level` setting on
|
||||
:class:`_engine.Connection` and :func:`_sa.create_engine`
|
||||
with the SQLite driver,
|
||||
as this function necessarily will also alter the ".isolation_level" setting.
|
||||
|
||||
|
||||
Above, we intercept a new pysqlite connection and disable any transactional
|
||||
integration. Then, at the point at which SQLAlchemy knows that transaction
|
||||
scope is to begin, we emit ``"BEGIN"`` ourselves.
|
||||
|
||||
When we take control of ``"BEGIN"``, we can also control directly SQLite's
|
||||
locking modes, introduced at
|
||||
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_,
|
||||
by adding the desired locking mode to our ``"BEGIN"``::
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
conn.exec_driver_sql("BEGIN EXCLUSIVE")
|
||||
|
||||
.. seealso::
|
||||
|
||||
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_ -
|
||||
on the SQLite site
|
||||
|
||||
`sqlite3 SELECT does not BEGIN a transaction <https://bugs.python.org/issue9924>`_ -
|
||||
on the Python bug tracker
|
||||
|
||||
`sqlite3 module breaks transactions and potentially corrupts data <https://bugs.python.org/issue10740>`_ -
|
||||
on the Python bug tracker
|
||||
|
||||
.. _pysqlite_udfs:
|
||||
|
||||
@@ -392,16 +439,12 @@ connection when it is created. That is accomplished with an event listener::
|
||||
with engine.connect() as conn:
|
||||
print(conn.scalar(text("SELECT UDF()")))
|
||||
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from .base import DATE
|
||||
from .base import DATETIME
|
||||
@@ -411,13 +454,6 @@ from ... import pool
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import DBAPIConnection
|
||||
from ...engine.interfaces import DBAPICursor
|
||||
from ...engine.interfaces import DBAPIModule
|
||||
from ...engine.url import URL
|
||||
from ...pool.base import PoolProxiedConnection
|
||||
|
||||
|
||||
class _SQLite_pysqliteTimeStamp(DATETIME):
|
||||
def bind_processor(self, dialect):
|
||||
@@ -471,7 +507,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
|
||||
return sqlite
|
||||
|
||||
@classmethod
|
||||
def _is_url_file_db(cls, url: URL):
|
||||
def _is_url_file_db(cls, url):
|
||||
if (url.database and url.database != ":memory:") and (
|
||||
url.query.get("mode", None) != "memory"
|
||||
):
|
||||
@@ -502,9 +538,6 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
|
||||
dbapi_connection.isolation_level = ""
|
||||
return super().set_isolation_level(dbapi_connection, level)
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_connection):
|
||||
return dbapi_connection.isolation_level is None
|
||||
|
||||
def on_connect(self):
|
||||
def regexp(a, b):
|
||||
if b is None:
|
||||
@@ -604,13 +637,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
|
||||
|
||||
return ([filename], pysqlite_opts)
|
||||
|
||||
def is_disconnect(
|
||||
self,
|
||||
e: DBAPIModule.Error,
|
||||
connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]],
|
||||
cursor: Optional[DBAPICursor],
|
||||
) -> bool:
|
||||
self.dbapi = cast("DBAPIModule", self.dbapi)
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
return isinstance(
|
||||
e, self.dbapi.ProgrammingError
|
||||
) and "Cannot operate on a closed database." in str(e)
|
||||
|
||||
Reference in New Issue
Block a user