This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# dialects/postgresql/asyncpg.py
|
||||
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# postgresql/asyncpg.py
|
||||
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
@@ -23,10 +23,18 @@ This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
|
||||
|
||||
The dialect can also be run as a "synchronous" dialect within the
|
||||
:func:`_sa.create_engine` function, which will pass "await" calls into
|
||||
an ad-hoc event loop. This mode of operation is of **limited use**
|
||||
and is for special testing scenarios only. The mode can be enabled by
|
||||
adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
|
||||
in conjunction with :func:`_sa.create_engine`::
|
||||
|
||||
# for testing purposes only; do not use in production!
|
||||
engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname"
|
||||
)
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
@@ -81,15 +89,11 @@ asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
|
||||
argument)::
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500"
|
||||
)
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500")
|
||||
|
||||
To disable the prepared statement cache, use a value of zero::
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0"
|
||||
)
|
||||
engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0")
|
||||
|
||||
.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
|
||||
|
||||
@@ -119,8 +123,8 @@ To disable the prepared statement cache, use a value of zero::
|
||||
|
||||
.. _asyncpg_prepared_statement_name:
|
||||
|
||||
Prepared Statement Name with PGBouncer
|
||||
--------------------------------------
|
||||
Prepared Statement Name
|
||||
-----------------------
|
||||
|
||||
By default, asyncpg enumerates prepared statements in numeric order, which
|
||||
can lead to errors if a name has already been taken for another prepared
|
||||
@@ -135,10 +139,10 @@ a prepared statement is prepared::
|
||||
from uuid import uuid4
|
||||
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://user:pass@somepgbouncer/dbname",
|
||||
"postgresql+asyncpg://user:pass@hostname/dbname",
|
||||
poolclass=NullPool,
|
||||
connect_args={
|
||||
"prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__",
|
||||
'prepared_statement_name_func': lambda: f'__asyncpg_{uuid4()}__',
|
||||
},
|
||||
)
|
||||
|
||||
@@ -148,7 +152,7 @@ a prepared statement is prepared::
|
||||
|
||||
https://github.com/sqlalchemy/sqlalchemy/issues/6467
|
||||
|
||||
.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in
|
||||
.. warning:: To prevent a buildup of useless prepared statements in
|
||||
your application, it's important to use the :class:`.NullPool` pool
|
||||
class, and to configure PgBouncer to use `DISCARD <https://www.postgresql.org/docs/current/sql-discard.html>`_
|
||||
when returning connections. The DISCARD command is used to release resources held by the db connection,
|
||||
@@ -178,11 +182,13 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
import collections
|
||||
import decimal
|
||||
import json as _py_json
|
||||
import re
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import json
|
||||
from . import ranges
|
||||
@@ -212,6 +218,9 @@ from ...util.concurrency import asyncio
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class AsyncpgARRAY(PGARRAY):
|
||||
render_bind_cast = True
|
||||
@@ -265,20 +274,20 @@ class AsyncpgInteger(sqltypes.Integer):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class AsyncpgSmallInteger(sqltypes.SmallInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class AsyncpgBigInteger(sqltypes.BigInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class AsyncpgJSON(json.JSON):
|
||||
render_bind_cast = True
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
|
||||
class AsyncpgJSONB(json.JSONB):
|
||||
render_bind_cast = True
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
@@ -363,7 +372,7 @@ class AsyncpgCHAR(sqltypes.CHAR):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _AsyncpgRange(ranges.AbstractSingleRangeImpl):
|
||||
class _AsyncpgRange(ranges.AbstractRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
asyncpg_Range = dialect.dbapi.asyncpg.Range
|
||||
|
||||
@@ -417,7 +426,10 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
)
|
||||
return value
|
||||
|
||||
return [to_range(element) for element in value]
|
||||
return [
|
||||
to_range(element)
|
||||
for element in cast("Iterable[ranges.Range]", value)
|
||||
]
|
||||
|
||||
return to_range
|
||||
|
||||
@@ -436,7 +448,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
return rvalue
|
||||
|
||||
if value is not None:
|
||||
value = ranges.MultiRange(to_range(elem) for elem in value)
|
||||
value = [to_range(elem) for elem in value]
|
||||
|
||||
return value
|
||||
|
||||
@@ -494,7 +506,7 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self._rows = deque()
|
||||
self._rows = []
|
||||
self._cursor = None
|
||||
self.description = None
|
||||
self.arraysize = 1
|
||||
@@ -502,7 +514,7 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
self._invalidate_schema_cache_asof = 0
|
||||
|
||||
def close(self):
|
||||
self._rows.clear()
|
||||
self._rows[:] = []
|
||||
|
||||
def _handle_exception(self, error):
|
||||
self._adapt_connection._handle_exception(error)
|
||||
@@ -542,12 +554,11 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
self._cursor = await prepared_stmt.cursor(*parameters)
|
||||
self.rowcount = -1
|
||||
else:
|
||||
self._rows = deque(await prepared_stmt.fetch(*parameters))
|
||||
self._rows = await prepared_stmt.fetch(*parameters)
|
||||
status = prepared_stmt.get_statusmsg()
|
||||
|
||||
reg = re.match(
|
||||
r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)",
|
||||
status or "",
|
||||
r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status
|
||||
)
|
||||
if reg:
|
||||
self.rowcount = int(reg.group(1))
|
||||
@@ -591,11 +602,11 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.popleft()
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.popleft()
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -603,12 +614,13 @@ class AsyncAdapt_asyncpg_cursor:
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
rr = self._rows
|
||||
return [rr.popleft() for _ in range(min(size, len(rr)))]
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = list(self._rows)
|
||||
self._rows.clear()
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
@@ -618,21 +630,23 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
super().__init__(adapt_connection)
|
||||
self._rowbuffer = deque()
|
||||
self._rowbuffer = None
|
||||
|
||||
def close(self):
|
||||
self._cursor = None
|
||||
self._rowbuffer.clear()
|
||||
self._rowbuffer = None
|
||||
|
||||
def _buffer_rows(self):
|
||||
assert self._cursor is not None
|
||||
new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
|
||||
self._rowbuffer.extend(new_rows)
|
||||
self._rowbuffer = collections.deque(new_rows)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._rowbuffer:
|
||||
self._buffer_rows()
|
||||
|
||||
while True:
|
||||
while self._rowbuffer:
|
||||
yield self._rowbuffer.popleft()
|
||||
@@ -655,19 +669,21 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
|
||||
if not self._rowbuffer:
|
||||
self._buffer_rows()
|
||||
|
||||
assert self._cursor is not None
|
||||
rb = self._rowbuffer
|
||||
lb = len(rb)
|
||||
buf = list(self._rowbuffer)
|
||||
lb = len(buf)
|
||||
if size > lb:
|
||||
rb.extend(
|
||||
buf.extend(
|
||||
self._adapt_connection.await_(self._cursor.fetch(size - lb))
|
||||
)
|
||||
|
||||
return [rb.popleft() for _ in range(min(size, len(rb)))]
|
||||
result = buf[0:size]
|
||||
self._rowbuffer = collections.deque(buf[size:])
|
||||
return result
|
||||
|
||||
def fetchall(self):
|
||||
ret = list(self._rowbuffer)
|
||||
ret.extend(self._adapt_connection.await_(self._all()))
|
||||
ret = list(self._rowbuffer) + list(
|
||||
self._adapt_connection.await_(self._all())
|
||||
)
|
||||
self._rowbuffer.clear()
|
||||
return ret
|
||||
|
||||
@@ -717,7 +733,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
self.isolation_level = self._isolation_setting = None
|
||||
self.isolation_level = self._isolation_setting = "read_committed"
|
||||
self.readonly = False
|
||||
self.deferrable = False
|
||||
self._transaction = None
|
||||
@@ -786,9 +802,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
translated_error = exception_mapping[super_](
|
||||
"%s: %s" % (type(error), error)
|
||||
)
|
||||
translated_error.pgcode = translated_error.sqlstate = (
|
||||
getattr(error, "sqlstate", None)
|
||||
)
|
||||
translated_error.pgcode = (
|
||||
translated_error.sqlstate
|
||||
) = getattr(error, "sqlstate", None)
|
||||
raise translated_error from error
|
||||
else:
|
||||
raise error
|
||||
@@ -852,45 +868,25 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
else:
|
||||
return AsyncAdapt_asyncpg_cursor(self)
|
||||
|
||||
async def _rollback_and_discard(self):
|
||||
try:
|
||||
await self._transaction.rollback()
|
||||
finally:
|
||||
# if asyncpg .rollback() was actually called, then whether or
|
||||
# not it raised or succeeded, the transation is done, discard it
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
|
||||
async def _commit_and_discard(self):
|
||||
try:
|
||||
await self._transaction.commit()
|
||||
finally:
|
||||
# if asyncpg .commit() was actually called, then whether or
|
||||
# not it raised or succeeded, the transation is done, discard it
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
|
||||
def rollback(self):
|
||||
if self._started:
|
||||
try:
|
||||
self.await_(self._rollback_and_discard())
|
||||
self.await_(self._transaction.rollback())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
finally:
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
except Exception as error:
|
||||
# don't dereference asyncpg transaction if we didn't
|
||||
# actually try to call rollback() on it
|
||||
self._handle_exception(error)
|
||||
|
||||
def commit(self):
|
||||
if self._started:
|
||||
try:
|
||||
self.await_(self._commit_and_discard())
|
||||
self.await_(self._transaction.commit())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
finally:
|
||||
self._transaction = None
|
||||
self._started = False
|
||||
except Exception as error:
|
||||
# don't dereference asyncpg transaction if we didn't
|
||||
# actually try to call commit() on it
|
||||
self._handle_exception(error)
|
||||
|
||||
def close(self):
|
||||
self.rollback()
|
||||
@@ -898,31 +894,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
|
||||
self.await_(self._connection.close())
|
||||
|
||||
def terminate(self):
|
||||
if util.concurrency.in_greenlet():
|
||||
# in a greenlet; this is the connection was invalidated
|
||||
# case.
|
||||
try:
|
||||
# try to gracefully close; see #10717
|
||||
# timeout added in asyncpg 0.14.0 December 2017
|
||||
self.await_(asyncio.shield(self._connection.close(timeout=2)))
|
||||
except (
|
||||
asyncio.TimeoutError,
|
||||
asyncio.CancelledError,
|
||||
OSError,
|
||||
self.dbapi.asyncpg.PostgresError,
|
||||
) as e:
|
||||
# in the case where we are recycling an old connection
|
||||
# that may have already been disconnected, close() will
|
||||
# fail with the above timeout. in this case, terminate
|
||||
# the connection without any further waiting.
|
||||
# see issue #8419
|
||||
self._connection.terminate()
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
# re-raise CancelledError if we were cancelled
|
||||
raise
|
||||
else:
|
||||
# not in a greenlet; this is the gc cleanup case
|
||||
self._connection.terminate()
|
||||
self._connection.terminate()
|
||||
self._started = False
|
||||
|
||||
@staticmethod
|
||||
@@ -1059,7 +1031,6 @@ class PGDialect_asyncpg(PGDialect):
|
||||
INTERVAL: AsyncPgInterval,
|
||||
sqltypes.Boolean: AsyncpgBoolean,
|
||||
sqltypes.Integer: AsyncpgInteger,
|
||||
sqltypes.SmallInteger: AsyncpgSmallInteger,
|
||||
sqltypes.BigInteger: AsyncpgBigInteger,
|
||||
sqltypes.Numeric: AsyncpgNumeric,
|
||||
sqltypes.Float: AsyncpgFloat,
|
||||
@@ -1074,7 +1045,7 @@ class PGDialect_asyncpg(PGDialect):
|
||||
OID: AsyncpgOID,
|
||||
REGCLASS: AsyncpgREGCLASS,
|
||||
sqltypes.CHAR: AsyncpgCHAR,
|
||||
ranges.AbstractSingleRange: _AsyncpgRange,
|
||||
ranges.AbstractRange: _AsyncpgRange,
|
||||
ranges.AbstractMultiRange: _AsyncpgMultiRange,
|
||||
},
|
||||
)
|
||||
@@ -1117,9 +1088,6 @@ class PGDialect_asyncpg(PGDialect):
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
dbapi_connection.set_isolation_level(self._isolation_lookup[level])
|
||||
|
||||
def detect_autocommit_setting(self, dbapi_conn) -> bool:
|
||||
return bool(dbapi_conn.autocommit)
|
||||
|
||||
def set_readonly(self, connection, value):
|
||||
connection.readonly = value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user