main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -4,7 +4,6 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
from .connection import connect, Connection # NOQA
from .exceptions import * # NOQA
@@ -15,10 +14,6 @@ from .types import * # NOQA
from ._version import __version__ # NOQA
from . import exceptions
__all__: tuple[str, ...] = (
'connect', 'create_pool', 'Pool', 'Record', 'Connection'
)
__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection')
__all__ += exceptions.__all__ # NOQA

View File

@@ -4,25 +4,18 @@
#
# SPDX-License-Identifier: PSF-2.0
from __future__ import annotations
import asyncio
import functools
import sys
import typing
if typing.TYPE_CHECKING:
from . import compat
if sys.version_info < (3, 11):
from async_timeout import timeout as timeout_ctx
else:
from asyncio import timeout as timeout_ctx
_T = typing.TypeVar('_T')
async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
async def wait_for(fut, timeout):
"""Wait for the single Future or coroutine to complete, with timeout.
Coroutine will be wrapped in Task.
@@ -72,7 +65,7 @@ async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
return await fut
async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
async def _cancel_and_wait(fut):
"""Cancel the *fut* future or task and wait until it completes."""
loop = asyncio.get_running_loop()
@@ -89,6 +82,6 @@ async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
fut.remove_done_callback(cb)
def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None:
def _release_waiter(waiter, *args):
if not waiter.done():
waiter.set_result(None)

View File

@@ -117,22 +117,10 @@ class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
self.__unhandled_exceptions = []
def tearDown(self):
excs = []
for exc in self.__unhandled_exceptions:
if isinstance(exc, ConnectionResetError):
texc = traceback.TracebackException.from_exception(
exc, lookup_lines=False)
if texc.stack[-1].name == "_call_connection_lost":
# On Windows calling socket.shutdown may raise
# ConnectionResetError, which happens in the
# finally block of _call_connection_lost.
continue
excs.append(exc)
if excs:
if self.__unhandled_exceptions:
formatted = []
for i, context in enumerate(excs):
for i, context in enumerate(self.__unhandled_exceptions):
formatted.append(self._format_loop_exception(context, i + 1))
self.fail(
@@ -226,6 +214,13 @@ def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
return cluster
def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
initdb_options=None):
cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
cluster.start(port='dynamic', server_settings=server_settings)
return cluster
def _get_initdb_options(initdb_options=None):
if not initdb_options:
initdb_options = {}
@@ -249,12 +244,8 @@ def _init_default_cluster(initdb_options=None):
_default_cluster = pg_cluster.RunningCluster()
else:
_default_cluster = _init_cluster(
pg_cluster.TempCluster,
cluster_kwargs={
"data_dir_suffix": ".apgtest",
},
initdb_options=_get_initdb_options(initdb_options),
)
pg_cluster.TempCluster, cluster_kwargs={},
initdb_options=_get_initdb_options(initdb_options))
return _default_cluster
@@ -271,7 +262,6 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=60.0,
connect=None,
setup=None,
init=None,
loop=None,
@@ -281,18 +271,12 @@ def create_pool(dsn=None, *,
**connect_kwargs):
return pool_class(
dsn,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
loop=loop,
connect=connect,
setup=setup,
init=init,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
record_class=record_class,
**connect_kwargs,
)
**connect_kwargs)
class ClusterTestCase(TestCase):

View File

@@ -10,8 +10,4 @@
# supported platforms, publish the packages on PyPI, merge the PR
# to the target branch, create a Git tag pointing to the commit.
from __future__ import annotations
import typing
__version__: typing.Final = '0.30.0'
__version__ = '0.29.0'

View File

@@ -9,11 +9,9 @@ import asyncio
import os
import os.path
import platform
import random
import re
import shutil
import socket
import string
import subprocess
import sys
import tempfile
@@ -47,29 +45,6 @@ def find_available_port():
sock.close()
def _world_readable_mkdtemp(suffix=None, prefix=None, dir=None):
name = "".join(random.choices(string.ascii_lowercase, k=8))
if dir is None:
dir = tempfile.gettempdir()
if prefix is None:
prefix = tempfile.gettempprefix()
if suffix is None:
suffix = ""
fn = os.path.join(dir, prefix + name + suffix)
os.mkdir(fn, 0o755)
return fn
def _mkdtemp(suffix=None, prefix=None, dir=None):
if _system == 'Windows' and os.environ.get("GITHUB_ACTIONS"):
# Due to mitigations introduced in python/cpython#118486
# when Python runs in a session created via an SSH connection
# tempfile.mkdtemp creates directories that are not accessible.
return _world_readable_mkdtemp(suffix, prefix, dir)
else:
return tempfile.mkdtemp(suffix, prefix, dir)
class ClusterError(Exception):
pass
@@ -147,13 +122,9 @@ class Cluster:
else:
extra_args = []
os.makedirs(self._data_dir, exist_ok=True)
process = subprocess.run(
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=self._data_dir,
)
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
output = process.stdout
@@ -228,10 +199,7 @@ class Cluster:
process = subprocess.run(
[self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)],
stdout=stdout,
stderr=subprocess.STDOUT,
cwd=self._data_dir,
)
stdout=stdout, stderr=subprocess.STDOUT)
if process.returncode != 0:
if process.stderr:
@@ -250,10 +218,7 @@ class Cluster:
self._daemon_process = \
subprocess.Popen(
[self._postgres, '-D', self._data_dir, *extra_args],
stdout=stdout,
stderr=subprocess.STDOUT,
cwd=self._data_dir,
)
stdout=stdout, stderr=subprocess.STDOUT)
self._daemon_pid = self._daemon_process.pid
@@ -267,10 +232,7 @@ class Cluster:
process = subprocess.run(
[self._pg_ctl, 'reload', '-D', self._data_dir],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=self._data_dir,
)
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
@@ -283,10 +245,7 @@ class Cluster:
process = subprocess.run(
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
'-m', 'fast'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=self._data_dir,
)
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
@@ -624,9 +583,9 @@ class TempCluster(Cluster):
def __init__(self, *,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._data_dir = _mkdtemp(suffix=data_dir_suffix,
prefix=data_dir_prefix,
dir=data_dir_parent)
self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
prefix=data_dir_prefix,
dir=data_dir_parent)
super().__init__(self._data_dir, pg_config_path=pg_config_path)

View File

@@ -4,26 +4,22 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import enum
import pathlib
import platform
import typing
import sys
if typing.TYPE_CHECKING:
import asyncio
SYSTEM: typing.Final = platform.uname().system
SYSTEM = platform.uname().system
if sys.platform == 'win32':
if SYSTEM == 'Windows':
import ctypes.wintypes
CSIDL_APPDATA: typing.Final = 0x001a
CSIDL_APPDATA = 0x001a
def get_pg_home_directory() -> pathlib.Path | None:
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
@@ -35,14 +31,14 @@ if sys.platform == 'win32':
return pathlib.Path(buf.value) / 'postgresql'
else:
def get_pg_home_directory() -> pathlib.Path | None:
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
try:
return pathlib.Path.home()
except (RuntimeError, KeyError):
return None
async def wait_closed(stream: asyncio.StreamWriter) -> None:
async def wait_closed(stream):
# Not all asyncio versions have StreamWriter.wait_closed().
if hasattr(stream, 'wait_closed'):
try:
@@ -53,13 +49,6 @@ async def wait_closed(stream: asyncio.StreamWriter) -> None:
pass
if sys.version_info < (3, 12):
def markcoroutinefunction(c): # type: ignore
pass
else:
from inspect import markcoroutinefunction # noqa: F401
if sys.version_info < (3, 12):
from ._asyncio_compat import wait_for as wait_for # noqa: F401
else:
@@ -70,19 +59,3 @@ if sys.version_info < (3, 11):
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
else:
from asyncio import timeout as timeout # noqa: F401
if sys.version_info < (3, 9):
from typing import ( # noqa: F401
Awaitable as Awaitable,
)
else:
from collections.abc import ( # noqa: F401
Awaitable as Awaitable,
)
if sys.version_info < (3, 11):
class StrEnum(str, enum.Enum):
__str__ = str.__str__
__repr__ = enum.Enum.__repr__
else:
from enum import StrEnum as StrEnum # noqa: F401

View File

@@ -45,11 +45,6 @@ class SSLMode(enum.IntEnum):
return getattr(cls, sslmode.replace('-', '_'))
class SSLNegotiation(compat.StrEnum):
postgres = "postgres"
direct = "direct"
_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
@@ -58,11 +53,9 @@ _ConnectionParameters = collections.namedtuple(
'database',
'ssl',
'sslmode',
'ssl_negotiation',
'direct_tls',
'server_settings',
'target_session_attrs',
'krbsrvname',
'gsslib',
])
@@ -268,13 +261,12 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
target_session_attrs):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
ssl_min_protocol_version = ssl_max_protocol_version = None
sslnegotiation = None
if dsn:
parsed = urllib.parse.urlparse(dsn)
@@ -368,9 +360,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if 'sslrootcert' in query:
sslrootcert = query.pop('sslrootcert')
if 'sslnegotiation' in query:
sslnegotiation = query.pop('sslnegotiation')
if 'sslcrl' in query:
sslcrl = query.pop('sslcrl')
@@ -394,16 +383,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs
if 'krbsrvname' in query:
val = query.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val
if 'gsslib' in query:
val = query.pop('gsslib')
if gsslib is None:
gsslib = val
if query:
if server_settings is None:
server_settings = query
@@ -512,36 +491,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None and have_tcp_addrs:
ssl = 'prefer'
if direct_tls is not None:
sslneg = (
SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
)
else:
if sslnegotiation is None:
sslnegotiation = os.environ.get("PGSSLNEGOTIATION")
if sslnegotiation is not None:
try:
sslneg = SSLNegotiation(sslnegotiation)
except ValueError:
modes = ', '.join(
m.name.replace('_', '-')
for m in SSLNegotiation
)
raise exceptions.ClientConfigurationError(
f'`sslnegotiation` parameter must be one of: {modes}'
) from None
else:
sslneg = SSLNegotiation.postgres
if isinstance(ssl, (str, SSLMode)):
try:
sslmode = SSLMode.parse(ssl)
except AttributeError:
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
raise exceptions.ClientConfigurationError(
'`sslmode` parameter must be one of: {}'.format(modes)
) from None
'`sslmode` parameter must be one of: {}'.format(modes))
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
if sslmode < SSLMode.allow:
@@ -694,24 +650,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
)
) from None
if krbsrvname is None:
krbsrvname = os.getenv('PGKRBSRVNAME')
if gsslib is None:
gsslib = os.getenv('PGGSSLIB')
if gsslib is None:
gsslib = 'sspi' if _system == 'Windows' else 'gssapi'
if gsslib not in {'gssapi', 'sspi'}:
raise exceptions.ClientConfigurationError(
"gsslib parameter must be either 'gssapi' or 'sspi'"
", got {!r}".format(gsslib))
params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, ssl_negotiation=sslneg,
sslmode=sslmode, direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
target_session_attrs=target_session_attrs)
return addrs, params
@@ -722,7 +665,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
target_session_attrs):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
@@ -751,8 +694,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
target_session_attrs=target_session_attrs)
config = _ClientConfiguration(
command_timeout=command_timeout,
@@ -914,9 +856,9 @@ async def __connect_addr(
# UNIX socket
connector = loop.create_unix_connection(proto_factory, addr)
elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
# if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
# direct SSL connection
elif params.ssl and params.direct_tls:
# if ssl and direct_tls are given, skip STARTTLS and perform direct
# SSL connection
connector = loop.create_connection(
proto_factory, *addr, ssl=params.ssl
)

View File

@@ -231,8 +231,9 @@ class Connection(metaclass=ConnectionMeta):
:param callable callback:
A callable or a coroutine function receiving one argument:
**record**, a LoggedQuery containing `query`, `args`, `timeout`,
`elapsed`, `exception`, `conn_addr`, and `conn_params`.
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
`elapsed`, `exception`, `conn_addr`, and
`conn_params`.
.. versionadded:: 0.29.0
"""
@@ -756,44 +757,6 @@ class Connection(metaclass=ConnectionMeta):
return None
return data[0]
async def fetchmany(
self, query, args, *, timeout: float=None, record_class=None
):
"""Run a query for each sequence of arguments in *args*
and return the results as a list of :class:`Record`.
:param query:
Query to execute.
:param args:
An iterable containing sequences of arguments for the query.
:param float timeout:
Optional timeout value in seconds.
:param type record_class:
If specified, the class to use for records returned by this method.
Must be a subclass of :class:`~asyncpg.Record`. If not specified,
a per-connection *record_class* is used.
:return list:
A list of :class:`~asyncpg.Record` instances. If specified, the
actual type of list elements would be *record_class*.
Example:
.. code-block:: pycon
>>> rows = await con.fetchmany('''
... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a;
... ''', [('x', 1), ('y', 2), ('z', 3)])
>>> rows
[<Record row=('x',)>, <Record row=('y',)>, <Record row=('z',)>]
.. versionadded:: 0.30.0
"""
self._check_open()
return await self._executemany(
query, args, timeout, return_rows=True, record_class=record_class
)
async def copy_from_table(self, table_name, *, output,
columns=None, schema_name=None, timeout=None,
format=None, oids=None, delimiter=None,
@@ -837,7 +800,7 @@ class Connection(metaclass=ConnectionMeta):
... output='file.csv', format='csv')
... print(result)
...
>>> asyncio.run(run())
>>> asyncio.get_event_loop().run_until_complete(run())
'COPY 100'
.. _`COPY statement documentation`:
@@ -906,7 +869,7 @@ class Connection(metaclass=ConnectionMeta):
... output='file.csv', format='csv')
... print(result)
...
>>> asyncio.run(run())
>>> asyncio.get_event_loop().run_until_complete(run())
'COPY 10'
.. _`COPY statement documentation`:
@@ -982,7 +945,7 @@ class Connection(metaclass=ConnectionMeta):
... 'mytable', source='datafile.tbl')
... print(result)
...
>>> asyncio.run(run())
>>> asyncio.get_event_loop().run_until_complete(run())
'COPY 140000'
.. _`COPY statement documentation`:
@@ -1064,7 +1027,7 @@ class Connection(metaclass=ConnectionMeta):
... (2, 'ham', 'spam')])
... print(result)
...
>>> asyncio.run(run())
>>> asyncio.get_event_loop().run_until_complete(run())
'COPY 2'
Asynchronous record iterables are also supported:
@@ -1082,7 +1045,7 @@ class Connection(metaclass=ConnectionMeta):
... 'mytable', records=record_gen(100))
... print(result)
...
>>> asyncio.run(run())
>>> asyncio.get_event_loop().run_until_complete(run())
'COPY 100'
.. versionadded:: 0.11.0
@@ -1342,7 +1305,7 @@ class Connection(metaclass=ConnectionMeta):
... print(result)
... print(datetime.datetime(2002, 1, 1) + result)
...
>>> asyncio.run(run())
>>> asyncio.get_event_loop().run_until_complete(run())
relativedelta(years=+2, months=+3, days=+1)
2004-04-02 00:00:00
@@ -1515,10 +1478,11 @@ class Connection(metaclass=ConnectionMeta):
self._abort()
self._cleanup()
async def _reset(self):
async def reset(self, *, timeout=None):
self._check_open()
self._listeners.clear()
self._log_listeners.clear()
reset_query = self._get_reset_query()
if self._protocol.is_in_transaction() or self._top_xact is not None:
if self._top_xact is None or not self._top_xact._managed:
@@ -1530,36 +1494,10 @@ class Connection(metaclass=ConnectionMeta):
})
self._top_xact = None
await self.execute("ROLLBACK")
reset_query = 'ROLLBACK;\n' + reset_query
async def reset(self, *, timeout=None):
"""Reset the connection state.
Calling this will reset the connection session state to a state
resembling that of a newly obtained connection. Namely, an open
transaction (if any) is rolled back, open cursors are closed,
all `LISTEN <https://www.postgresql.org/docs/current/sql-listen.html>`_
registrations are removed, all session configuration
variables are reset to their default values, and all advisory locks
are released.
Note that the above describes the default query returned by
:meth:`Connection.get_reset_query`. If one overloads the method
by subclassing ``Connection``, then this method will do whatever
the overloaded method returns, except open transactions are always
terminated and any callbacks registered by
:meth:`Connection.add_listener` or :meth:`Connection.add_log_listener`
are removed.
:param float timeout:
A timeout for resetting the connection. If not specified, defaults
to no timeout.
"""
async with compat.timeout(timeout):
await self._reset()
reset_query = self.get_reset_query()
if reset_query:
await self.execute(reset_query)
if reset_query:
await self.execute(reset_query, timeout=timeout)
def _abort(self):
# Put the connection into the aborted state.
@@ -1720,15 +1658,7 @@ class Connection(metaclass=ConnectionMeta):
con_ref = self._proxy
return con_ref
def get_reset_query(self):
"""Return the query sent to server on connection release.
The query returned by this method is used by :meth:`Connection.reset`,
which is, in turn, used by :class:`~asyncpg.pool.Pool` before making
the connection available to another acquirer.
.. versionadded:: 0.30.0
"""
def _get_reset_query(self):
if self._reset_query is not None:
return self._reset_query
@@ -1842,7 +1772,7 @@ class Connection(metaclass=ConnectionMeta):
... await con.execute('LOCK TABLE tbl')
... await change_type(con)
...
>>> asyncio.run(run())
>>> asyncio.get_event_loop().run_until_complete(run())
.. versionadded:: 0.14.0
"""
@@ -1879,8 +1809,9 @@ class Connection(metaclass=ConnectionMeta):
:param callable callback:
A callable or a coroutine function receiving one argument:
**record**, a LoggedQuery containing `query`, `args`, `timeout`,
`elapsed`, `exception`, `conn_addr`, and `conn_params`.
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
`elapsed`, `exception`, `conn_addr`, and
`conn_params`.
Example:
@@ -1967,27 +1898,17 @@ class Connection(metaclass=ConnectionMeta):
)
return result, stmt
async def _executemany(
self,
query,
args,
timeout,
return_rows=False,
record_class=None,
):
async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
state=stmt,
args=args,
portal_name='',
timeout=timeout,
return_rows=return_rows,
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
with self._time_and_log(query, args, timeout):
result, _ = await self._do_execute(
query, executor, timeout, record_class=record_class
)
result, _ = await self._do_execute(query, executor, timeout)
return result
async def _do_execute(
@@ -2082,13 +2003,11 @@ async def connect(dsn=None, *,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
direct_tls=None,
direct_tls=False,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None,
target_session_attrs=None,
krbsrvname=None,
gsslib=None):
target_session_attrs=None):
r"""A coroutine to establish a connection to a PostgreSQL server.
The connection parameters may be specified either as a connection
@@ -2113,7 +2032,7 @@ async def connect(dsn=None, *,
.. note::
The URI must be *valid*, which means that all components must
be properly quoted with :py:func:`urllib.parse.quote_plus`, and
be properly quoted with :py:func:`urllib.parse.quote`, and
any literal IPv6 addresses must be enclosed in square brackets.
For example:
@@ -2316,14 +2235,6 @@ async def connect(dsn=None, *,
or the value of the ``PGTARGETSESSIONATTRS`` environment variable,
or ``"any"`` if neither is specified.
:param str krbsrvname:
Kerberos service name to use when authenticating with GSSAPI. This
must match the server configuration. Defaults to 'postgres'.
:param str gsslib:
GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi'
or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise.
:return: A :class:`~asyncpg.connection.Connection` instance.
Example:
@@ -2337,7 +2248,7 @@ async def connect(dsn=None, *,
... types = await con.fetch('SELECT * FROM pg_type')
... print(types)
...
>>> asyncio.run(run())
>>> asyncio.get_event_loop().run_until_complete(run())
[<Record typname='bool' typnamespace=11 ...
.. versionadded:: 0.10.0
@@ -2392,9 +2303,6 @@ async def connect(dsn=None, *,
.. versionchanged:: 0.28.0
Added the *target_session_attrs* parameter.
.. versionchanged:: 0.30.0
Added the *krbsrvname* and *gsslib* parameters.
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
.. _create_default_context:
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
@@ -2436,9 +2344,7 @@ async def connect(dsn=None, *,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname,
gsslib=gsslib,
target_session_attrs=target_session_attrs
)

View File

@@ -121,10 +121,6 @@ class StackedDiagnosticsAccessedWithoutActiveHandlerError(DiagnosticsError):
sqlstate = '0Z002'
class InvalidArgumentForXqueryError(_base.PostgresError):
sqlstate = '10608'
class CaseNotFoundError(_base.PostgresError):
sqlstate = '20000'
@@ -489,10 +485,6 @@ class IdleInTransactionSessionTimeoutError(InvalidTransactionStateError):
sqlstate = '25P03'
class TransactionTimeoutError(InvalidTransactionStateError):
sqlstate = '25P04'
class InvalidSQLStatementNameError(_base.PostgresError):
sqlstate = '26000'
@@ -908,10 +900,6 @@ class DuplicateFileError(PostgresSystemError):
sqlstate = '58P02'
class FileNameTooLongError(PostgresSystemError):
sqlstate = '58P03'
class SnapshotTooOldError(_base.PostgresError):
sqlstate = '72000'
@@ -1107,9 +1095,9 @@ __all__ = (
'FDWTableNotFoundError', 'FDWTooManyHandlesError',
'FDWUnableToCreateExecutionError', 'FDWUnableToCreateReplyError',
'FDWUnableToEstablishConnectionError', 'FeatureNotSupportedError',
'FileNameTooLongError', 'ForeignKeyViolationError',
'FunctionExecutedNoReturnStatementError', 'GeneratedAlwaysError',
'GroupingError', 'HeldCursorRequiresSameIsolationLevelError',
'ForeignKeyViolationError', 'FunctionExecutedNoReturnStatementError',
'GeneratedAlwaysError', 'GroupingError',
'HeldCursorRequiresSameIsolationLevelError',
'IdleInTransactionSessionTimeoutError', 'IdleSessionTimeoutError',
'ImplicitZeroBitPadding', 'InFailedSQLTransactionError',
'InappropriateAccessModeForBranchTransactionError',
@@ -1124,7 +1112,6 @@ __all__ = (
'InvalidArgumentForPowerFunctionError',
'InvalidArgumentForSQLJsonDatetimeFunctionError',
'InvalidArgumentForWidthBucketFunctionError',
'InvalidArgumentForXqueryError',
'InvalidAuthorizationSpecificationError',
'InvalidBinaryRepresentationError', 'InvalidCachedStatementError',
'InvalidCatalogNameError', 'InvalidCharacterValueForCastError',
@@ -1197,9 +1184,9 @@ __all__ = (
'TooManyJsonObjectMembersError', 'TooManyRowsError',
'TransactionIntegrityConstraintViolationError',
'TransactionResolutionUnknownError', 'TransactionRollbackError',
'TransactionTimeoutError', 'TriggerProtocolViolatedError',
'TriggeredActionError', 'TriggeredDataChangeViolationError',
'TrimError', 'UndefinedColumnError', 'UndefinedFileError',
'TriggerProtocolViolatedError', 'TriggeredActionError',
'TriggeredDataChangeViolationError', 'TrimError',
'UndefinedColumnError', 'UndefinedFileError',
'UndefinedFunctionError', 'UndefinedObjectError',
'UndefinedParameterError', 'UndefinedTableError',
'UniqueViolationError', 'UnsafeNewEnumValueUsageError',

View File

@@ -4,14 +4,8 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import typing
if typing.TYPE_CHECKING:
from . import protocol
_TYPEINFO_13: typing.Final = '''\
_TYPEINFO_13 = '''\
(
SELECT
t.oid AS oid,
@@ -130,7 +124,7 @@ ORDER BY
'''.format(typeinfo=_TYPEINFO_13)
_TYPEINFO: typing.Final = '''\
_TYPEINFO = '''\
(
SELECT
t.oid AS oid,
@@ -254,7 +248,7 @@ ORDER BY
'''.format(typeinfo=_TYPEINFO)
TYPE_BY_NAME: typing.Final = '''\
TYPE_BY_NAME = '''\
SELECT
t.oid,
t.typelem AS elemtype,
@@ -283,16 +277,16 @@ WHERE
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
def is_scalar_type(typeinfo: protocol.Record) -> bool:
def is_scalar_type(typeinfo) -> bool:
return (
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)
def is_domain_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'd' # type: ignore[no-any-return]
def is_domain_type(typeinfo) -> bool:
return typeinfo['kind'] == b'd'
def is_composite_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'c' # type: ignore[no-any-return]
def is_composite_type(typeinfo) -> bool:
return typeinfo['kind'] == b'c'

View File

@@ -1,13 +0,0 @@
import codecs
import typing
import uuid
class CodecContext:
def get_text_codec(self) -> codecs.CodecInfo: ...
class ReadBuffer: ...
class WriteBuffer: ...
class BufferError(Exception): ...
class UUID(uuid.UUID):
def __init__(self, inp: typing.AnyStr) -> None: ...

View File

@@ -33,8 +33,7 @@ class PoolConnectionProxyMeta(type):
if not inspect.isfunction(meth):
continue
iscoroutine = inspect.iscoroutinefunction(meth)
wrapper = mcls._wrap_connection_method(attrname, iscoroutine)
wrapper = mcls._wrap_connection_method(attrname)
wrapper = functools.update_wrapper(wrapper, meth)
dct[attrname] = wrapper
@@ -44,7 +43,7 @@ class PoolConnectionProxyMeta(type):
return super().__new__(mcls, name, bases, dct)
@staticmethod
def _wrap_connection_method(meth_name, iscoroutine):
def _wrap_connection_method(meth_name):
def call_con_method(self, *args, **kwargs):
# This method will be owned by PoolConnectionProxy class.
if self._con is None:
@@ -56,9 +55,6 @@ class PoolConnectionProxyMeta(type):
meth = getattr(self._con.__class__, meth_name)
return meth(self._con, *args, **kwargs)
if iscoroutine:
compat.markcoroutinefunction(call_con_method)
return call_con_method
@@ -210,12 +206,7 @@ class PoolConnectionHolder:
if budget is not None:
budget -= time.monotonic() - started
if self._pool._reset is not None:
async with compat.timeout(budget):
await self._con._reset()
await self._pool._reset(self._con)
else:
await self._con.reset(timeout=budget)
await self._con.reset(timeout=budget)
except (Exception, asyncio.CancelledError) as ex:
# If the `reset` call failed, terminate the connection.
# A new one will be created when `acquire` is called
@@ -318,7 +309,7 @@ class Pool:
__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
'_init', '_connect_args', '_connect_kwargs',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
@@ -329,10 +320,8 @@ class Pool:
max_size,
max_queries,
max_inactive_connection_lifetime,
connect=None,
setup=None,
init=None,
reset=None,
setup,
init,
loop,
connection_class,
record_class,
@@ -392,22 +381,18 @@ class Pool:
self._closing = False
self._closed = False
self._generation = 0
self._connect = connect if connect is not None else connection.connect
self._init = init
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs
self._setup = setup
self._init = init
self._reset = reset
self._max_queries = max_queries
self._max_inactive_connection_lifetime = \
max_inactive_connection_lifetime
async def _async__init__(self):
if self._initialized:
return self
return
if self._initializing:
raise exceptions.InterfaceError(
'pool is being initialized in another task')
@@ -514,25 +499,13 @@ class Pool:
self._connect_kwargs = connect_kwargs
async def _get_new_connection(self):
con = await self._connect(
con = await connection.connect(
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
record_class=self._record_class,
**self._connect_kwargs,
)
if not isinstance(con, self._connection_class):
good = self._connection_class
good_n = f'{good.__module__}.{good.__name__}'
bad = type(con)
if bad.__module__ == "builtins":
bad_n = bad.__name__
else:
bad_n = f'{bad.__module__}.{bad.__name__}'
raise exceptions.InterfaceError(
"expected pool connect callback to return an instance of "
f"'{good_n}', got " f"'{bad_n}'"
)
if self._init is not None:
try:
@@ -632,22 +605,6 @@ class Pool:
record_class=record_class
)
async def fetchmany(self, query, args, *, timeout=None, record_class=None):
"""Run a query for each sequence of arguments in *args*
and return the results as a list of :class:`Record`.
Pool performs this operation using one of its connections. Other than
that, it behaves identically to
:meth:`Connection.fetchmany()
<asyncpg.connection.Connection.fetchmany>`.
.. versionadded:: 0.30.0
"""
async with self.acquire() as con:
return await con.fetchmany(
query, args, timeout=timeout, record_class=record_class
)
async def copy_from_table(
self,
table_name,
@@ -1040,10 +997,8 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
connect=None,
setup=None,
init=None,
reset=None,
loop=None,
connection_class=connection.Connection,
record_class=protocol.Record,
@@ -1124,16 +1079,9 @@ def create_pool(dsn=None, *,
Number of seconds after which inactive connections in the
pool will be closed. Pass ``0`` to disable this mechanism.
:param coroutine connect:
A coroutine that is called instead of
:func:`~asyncpg.connection.connect` whenever the pool needs to make a
new connection. Must return an instance of type specified by
*connection_class* or :class:`~asyncpg.connection.Connection` if
*connection_class* was not specified.
:param coroutine setup:
A coroutine to prepare a connection right before it is returned
from :meth:`Pool.acquire()`. An example use
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
case would be to automatically set up notifications listeners for
all connections of a pool.
@@ -1145,25 +1093,6 @@ def create_pool(dsn=None, *,
or :meth:`Connection.set_type_codec() <\
asyncpg.connection.Connection.set_type_codec>`.
:param coroutine reset:
A coroutine to reset a connection before it is returned to the pool by
:meth:`Pool.release()`. The function is supposed
to reset any changes made to the database session so that the next
acquirer gets the connection in a well-defined state.
The default implementation calls :meth:`Connection.reset() <\
asyncpg.connection.Connection.reset>`, which runs the following::
SELECT pg_advisory_unlock_all();
CLOSE ALL;
UNLISTEN *;
RESET ALL;
The exact reset query is determined by detected server capabilities,
and a custom *reset* implementation can obtain the default query
by calling :meth:`Connection.get_reset_query() <\
asyncpg.connection.Connection.get_reset_query>`.
:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
@@ -1190,22 +1119,12 @@ def create_pool(dsn=None, *,
.. versionchanged:: 0.22.0
Added the *record_class* parameter.
.. versionchanged:: 0.30.0
Added the *connect* and *reset* parameters.
"""
return Pool(
dsn,
connection_class=connection_class,
record_class=record_class,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
loop=loop,
connect=connect,
setup=setup,
init=init,
reset=reset,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
**connect_kwargs,
)
**connect_kwargs)

View File

@@ -147,8 +147,8 @@ class PreparedStatement(connresource.ConnectionResource):
# will discard any output that a SELECT would return, other
# side effects of the statement will happen as usual. If you
# wish to use EXPLAIN ANALYZE on an INSERT, UPDATE, DELETE,
# MERGE, CREATE TABLE AS, or EXECUTE statement without letting
# the command affect your data, use this approach:
# CREATE TABLE AS, or EXECUTE statement without letting the
# command affect your data, use this approach:
# BEGIN;
# EXPLAIN ANALYZE ...;
# ROLLBACK;
@@ -210,27 +210,6 @@ class PreparedStatement(connresource.ConnectionResource):
return None
return data[0]
@connresource.guarded
async def fetchmany(self, args, *, timeout=None):
"""Execute the statement and return a list of :class:`Record` objects.
:param args: Query arguments.
:param float timeout: Optional timeout value in seconds.
:return: A list of :class:`Record` instances.
.. versionadded:: 0.30.0
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state,
args,
portal_name='',
timeout=timeout,
return_rows=True,
)
)
@connresource.guarded
async def executemany(self, args, *, timeout: float=None):
"""Execute the statement for each sequence of arguments in *args*.
@@ -243,12 +222,7 @@ class PreparedStatement(connresource.ConnectionResource):
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state,
args,
portal_name='',
timeout=timeout,
return_rows=False,
))
self._state, args, '', timeout))
async def __do_execute(self, executor):
protocol = self._connection._protocol

View File

@@ -6,6 +6,4 @@
# flake8: NOQA
from __future__ import annotations
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP

View File

@@ -483,7 +483,7 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
cdef class DataCodecConfig:
def __init__(self):
def __init__(self, cache_key):
# Codec instance cache for derived types:
# composites, arrays, ranges, domains and their combinations.
self._derived_type_codecs = {}

View File

@@ -51,6 +51,16 @@ cdef enum AuthenticationMessage:
AUTH_SASL_FINAL = 12
AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}
cdef enum ResultType:
RESULT_OK = 1
RESULT_FAILED = 2
@@ -86,13 +96,10 @@ cdef class CoreProtocol:
object transport
object address
# Instance of _ConnectionParameters
object con_params
# Instance of SCRAMAuthentication
SCRAMAuthentication scram
# Instance of gssapi.SecurityContext or sspilib.SecurityContext
object gss_ctx
readonly int32_t backend_pid
readonly int32_t backend_secret
@@ -138,10 +145,6 @@ cdef class CoreProtocol:
cdef _auth_password_message_md5(self, bytes salt)
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
cdef _auth_password_message_sasl_continue(self, bytes server_response)
cdef _auth_gss_init_gssapi(self)
cdef _auth_gss_init_sspi(self, bint negotiate)
cdef _auth_gss_get_service(self)
cdef _auth_gss_step(self, bytes server_response)
cdef _write(self, buf)
cdef _writelines(self, list buffers)
@@ -171,7 +174,7 @@ cdef class CoreProtocol:
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data, bint return_rows)
object bind_data)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,

View File

@@ -11,20 +11,9 @@ import hashlib
include "scram.pyx"
cdef dict AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}
cdef class CoreProtocol:
def __init__(self, addr, con_params):
self.address = addr
def __init__(self, con_params):
# type of `con_params` is `_ConnectionParameters`
self.buffer = ReadBuffer()
self.user = con_params.user
@@ -37,9 +26,6 @@ cdef class CoreProtocol:
self.encoding = 'utf-8'
# type of `scram` is `SCRAMAuthentcation`
self.scram = None
# type of `gss_ctx` is `gssapi.SecurityContext` or
# `sspilib.SecurityContext`
self.gss_ctx = None
self._reset_result()
@@ -633,35 +619,22 @@ cdef class CoreProtocol:
'could not verify server signature for '
'SCRAM authentciation: scram-sha-256',
)
self.scram = None
elif status in (AUTH_REQUIRED_GSS, AUTH_REQUIRED_SSPI):
# AUTH_REQUIRED_SSPI is the same as AUTH_REQUIRED_GSS, except that
# it uses protocol negotiation with SSPI clients. Both methods use
# AUTH_REQUIRED_GSS_CONTINUE for subsequent authentication steps.
if self.gss_ctx is not None:
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
'duplicate GSSAPI/SSPI authentication request')
else:
if self.con_params.gsslib == 'gssapi':
self._auth_gss_init_gssapi()
else:
self._auth_gss_init_sspi(status == AUTH_REQUIRED_SSPI)
self.auth_msg = self._auth_gss_step(None)
elif status == AUTH_REQUIRED_GSS_CONTINUE:
server_response = self.buffer.consume_message()
self.auth_msg = self._auth_gss_step(server_response)
elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
AUTH_REQUIRED_SSPI):
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
'unsupported authentication method requested by the '
'server: {!r}'.format(AUTH_METHOD_NAME[status]))
else:
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
'unsupported authentication method requested by the '
'server: {!r}'.format(AUTH_METHOD_NAME.get(status, status)))
'server: {}'.format(status))
if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
AUTH_REQUIRED_GSS_CONTINUE):
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
self.buffer.discard_message()
cdef _auth_password_message_cleartext(self):
@@ -718,59 +691,6 @@ cdef class CoreProtocol:
return msg
cdef _auth_gss_init_gssapi(self):
try:
import gssapi
except ModuleNotFoundError:
raise apg_exc.InterfaceError(
'gssapi module not found; please install asyncpg[gssauth] to '
'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
) from None
service_name, host = self._auth_gss_get_service()
self.gss_ctx = gssapi.SecurityContext(
name=gssapi.Name(
f'{service_name}@{host}', gssapi.NameType.hostbased_service),
usage='initiate')
cdef _auth_gss_init_sspi(self, bint negotiate):
try:
import sspilib
except ModuleNotFoundError:
raise apg_exc.InterfaceError(
'sspilib module not found; please install asyncpg[gssauth] to '
'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
) from None
service_name, host = self._auth_gss_get_service()
self.gss_ctx = sspilib.ClientSecurityContext(
target_name=f'{service_name}/{host}',
credential=sspilib.UserCredential(
protocol='Negotiate' if negotiate else 'Kerberos'))
cdef _auth_gss_get_service(self):
service_name = self.con_params.krbsrvname or 'postgres'
if isinstance(self.address, str):
raise apg_exc.InternalClientError(
'GSSAPI/SSPI authentication is only supported for TCP/IP '
'connections')
return service_name, self.address[0]
cdef _auth_gss_step(self, bytes server_response):
cdef:
WriteBuffer msg
token = self.gss_ctx.step(server_response)
if not token:
self.gss_ctx = None
return None
msg = WriteBuffer.new_message(b'p')
msg.write_bytes(token)
msg.end_message()
return msg
cdef _parse_msg_ready_for_query(self):
cdef char status = self.buffer.read_byte()
@@ -1020,12 +940,12 @@ cdef class CoreProtocol:
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data, bint return_rows):
object bind_data):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
self.result = [] if return_rows else None
self._discard_data = not return_rows
self.result = None
self._discard_data = True
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name

View File

@@ -142,7 +142,7 @@ cdef class PreparedStatementState:
# that the user tried to parametrize a statement that does
# not support parameters.
hint += (r' Note that parameters are supported only in'
r' SELECT, INSERT, UPDATE, DELETE, MERGE and VALUES'
r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
r' statements, and will *not* work in statements '
r' like CREATE VIEW or DECLARE CURSOR.')

View File

@@ -31,6 +31,7 @@ cdef class BaseProtocol(CoreProtocol):
cdef:
object loop
object address
ConnectionSettings settings
object cancel_sent_waiter
object cancel_waiter

View File

@@ -1,300 +0,0 @@
import asyncio
import asyncio.protocols
import hmac
from codecs import CodecInfo
from collections.abc import Callable, Iterable, Iterator, Sequence
from hashlib import md5, sha256
from typing import (
Any,
ClassVar,
Final,
Generic,
Literal,
NewType,
TypeVar,
final,
overload,
)
from typing_extensions import TypeAlias
import asyncpg.pgproto.pgproto
from ..connect_utils import _ConnectionParameters
from ..pgproto.pgproto import WriteBuffer
from ..types import Attribute, Type
_T = TypeVar('_T')
_Record = TypeVar('_Record', bound=Record)
_OtherRecord = TypeVar('_OtherRecord', bound=Record)
_PreparedStatementState = TypeVar(
'_PreparedStatementState', bound=PreparedStatementState[Any]
)
_NoTimeoutType = NewType('_NoTimeoutType', object)
_TimeoutType: TypeAlias = float | None | _NoTimeoutType
BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]]
BUILTIN_TYPE_OID_MAP: Final[dict[int, str]]
NO_TIMEOUT: Final[_NoTimeoutType]
hashlib_md5 = md5
@final
class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext):
__pyx_vtable__: Any
def __init__(self, conn_key: object) -> None: ...
def add_python_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typeinfos: Iterable[object],
typekind: str,
encoder: Callable[[Any], Any],
decoder: Callable[[Any], Any],
format: object,
) -> Any: ...
def clear_type_cache(self) -> None: ...
def get_data_codec(
self, oid: int, format: object = ..., ignore_custom_codec: bool = ...
) -> Any: ...
def get_text_codec(self) -> CodecInfo: ...
def register_data_types(self, types: Iterable[object]) -> None: ...
def remove_python_codec(
self, typeoid: int, typename: str, typeschema: str
) -> None: ...
def set_builtin_type_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typekind: str,
alias_to: str,
format: object = ...,
) -> Any: ...
def __getattr__(self, name: str) -> Any: ...
def __reduce__(self) -> Any: ...
@final
class PreparedStatementState(Generic[_Record]):
closed: bool
prepared: bool
name: str
query: str
refs: int
record_class: type[_Record]
ignore_custom_codec: bool
__pyx_vtable__: Any
def __init__(
self,
name: str,
query: str,
protocol: BaseProtocol[Any],
record_class: type[_Record],
ignore_custom_codec: bool,
) -> None: ...
def _get_parameters(self) -> tuple[Type, ...]: ...
def _get_attributes(self) -> tuple[Attribute, ...]: ...
def _init_types(self) -> set[int]: ...
def _init_codecs(self) -> None: ...
def attach(self) -> None: ...
def detach(self) -> None: ...
def mark_closed(self) -> None: ...
def mark_unprepared(self) -> None: ...
def __reduce__(self) -> Any: ...
class CoreProtocol:
backend_pid: Any
backend_secret: Any
__pyx_vtable__: Any
def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ...
def is_in_transaction(self) -> bool: ...
def __reduce__(self) -> Any: ...
class BaseProtocol(CoreProtocol, Generic[_Record]):
queries_count: Any
is_ssl: bool
__pyx_vtable__: Any
def __init__(
self,
addr: object,
connected_fut: object,
con_params: _ConnectionParameters,
record_class: type[_Record],
loop: object,
) -> None: ...
def set_connection(self, connection: object) -> None: ...
def get_server_pid(self, *args: object, **kwargs: object) -> int: ...
def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ...
def get_record_class(self) -> type[_Record]: ...
def abort(self) -> None: ...
async def bind(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
timeout: _TimeoutType,
) -> Any: ...
@overload
async def bind_execute(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
limit: int,
return_extra: Literal[False],
timeout: _TimeoutType,
) -> list[_OtherRecord]: ...
@overload
async def bind_execute(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
limit: int,
return_extra: Literal[True],
timeout: _TimeoutType,
) -> tuple[list[_OtherRecord], bytes, bool]: ...
@overload
async def bind_execute(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
limit: int,
return_extra: bool,
timeout: _TimeoutType,
) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ...
async def bind_execute_many(
self,
state: PreparedStatementState[_OtherRecord],
args: Iterable[Sequence[object]],
portal_name: str,
timeout: _TimeoutType,
) -> None: ...
async def close(self, timeout: _TimeoutType) -> None: ...
def _get_timeout(self, timeout: _TimeoutType) -> float | None: ...
def _is_cancelling(self) -> bool: ...
async def _wait_for_cancellation(self) -> None: ...
async def close_statement(
self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType
) -> Any: ...
async def copy_in(self, *args: object, **kwargs: object) -> str: ...
async def copy_out(self, *args: object, **kwargs: object) -> str: ...
async def execute(self, *args: object, **kwargs: object) -> Any: ...
def is_closed(self, *args: object, **kwargs: object) -> Any: ...
def is_connected(self, *args: object, **kwargs: object) -> Any: ...
def data_received(self, data: object) -> None: ...
def connection_made(self, transport: object) -> None: ...
def connection_lost(self, exc: Exception | None) -> None: ...
def pause_writing(self, *args: object, **kwargs: object) -> Any: ...
@overload
async def prepare(
self,
stmt_name: str,
query: str,
timeout: float | None = ...,
*,
state: _PreparedStatementState,
ignore_custom_codec: bool = ...,
record_class: None,
) -> _PreparedStatementState: ...
@overload
async def prepare(
self,
stmt_name: str,
query: str,
timeout: float | None = ...,
*,
state: None = ...,
ignore_custom_codec: bool = ...,
record_class: type[_OtherRecord],
) -> PreparedStatementState[_OtherRecord]: ...
async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ...
async def query(self, *args: object, **kwargs: object) -> str: ...
def resume_writing(self, *args: object, **kwargs: object) -> Any: ...
def __reduce__(self) -> Any: ...
@final
class Codec:
__pyx_vtable__: Any
def __reduce__(self) -> Any: ...
class DataCodecConfig:
__pyx_vtable__: Any
def __init__(self) -> None: ...
def add_python_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typekind: str,
typeinfos: Iterable[object],
encoder: Callable[[ConnectionSettings, WriteBuffer, object], object],
decoder: Callable[..., object],
format: object,
xformat: object,
) -> Any: ...
def add_types(self, types: Iterable[object]) -> Any: ...
def clear_type_cache(self) -> None: ...
def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ...
def remove_python_codec(
self, typeoid: int, typename: str, typeschema: str
) -> Any: ...
def set_builtin_type_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typekind: str,
alias_to: str,
format: object = ...,
) -> Any: ...
def __reduce__(self) -> Any: ...
class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ...
class Record:
@overload
def get(self, key: str) -> Any | None: ...
@overload
def get(self, key: str, default: _T) -> Any | _T: ...
def items(self) -> Iterator[tuple[str, Any]]: ...
def keys(self) -> Iterator[str]: ...
def values(self) -> Iterator[Any]: ...
@overload
def __getitem__(self, index: str) -> Any: ...
@overload
def __getitem__(self, index: int) -> Any: ...
@overload
def __getitem__(self, index: slice) -> tuple[Any, ...]: ...
def __iter__(self) -> Iterator[Any]: ...
def __contains__(self, x: object) -> bool: ...
def __len__(self) -> int: ...
class Timer:
def __init__(self, budget: float | None) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, et: object, e: object, tb: object) -> None: ...
def get_remaining_budget(self) -> float: ...
def has_budget_greater_than(self, amount: float) -> bool: ...
@final
class SCRAMAuthentication:
AUTHENTICATION_METHODS: ClassVar[list[str]]
DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int]
DIGEST = sha256
REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]]
REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]]
SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]]
authentication_method: bytes
authorization_message: bytes | None
client_channel_binding: bytes
client_first_message_bare: bytes | None
client_nonce: bytes | None
client_proof: bytes | None
password_salt: bytes | None
password_iterations: int
server_first_message: bytes | None
server_key: hmac.HMAC | None
server_nonce: bytes | None

View File

@@ -75,7 +75,7 @@ NO_TIMEOUT = object()
cdef class BaseProtocol(CoreProtocol):
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
# type of `con_params` is `_ConnectionParameters`
CoreProtocol.__init__(self, addr, con_params)
CoreProtocol.__init__(self, con_params)
self.loop = loop
self.transport = None
@@ -83,7 +83,8 @@ cdef class BaseProtocol(CoreProtocol):
self.cancel_waiter = None
self.cancel_sent_waiter = None
self.settings = ConnectionSettings((addr, con_params.database))
self.address = addr
self.settings = ConnectionSettings((self.address, con_params.database))
self.record_class = record_class
self.statement = None
@@ -212,7 +213,6 @@ cdef class BaseProtocol(CoreProtocol):
args,
portal_name: str,
timeout,
return_rows: bool,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
@@ -238,8 +238,7 @@ cdef class BaseProtocol(CoreProtocol):
more = self._bind_execute_many(
portal_name,
state.name,
arg_bufs,
return_rows) # network op
arg_bufs) # network op
self.last_query = state.query
self.statement = state

View File

@@ -11,12 +11,12 @@ from asyncpg import exceptions
@cython.final
cdef class ConnectionSettings(pgproto.CodecContext):
def __cinit__(self):
def __cinit__(self, conn_key):
self._encoding = 'utf-8'
self._is_utf8 = True
self._settings = {}
self._codec = codecs.lookup('utf-8')
self._data_codecs = DataCodecConfig()
self._data_codecs = DataCodecConfig(conn_key)
cdef add_setting(self, str name, str val):
self._settings[name] = val

View File

@@ -4,14 +4,12 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import re
import typing
from .types import ServerVersion
version_regex: typing.Final = re.compile(
version_regex = re.compile(
r"(Postgre[^\s]*)?\s*"
r"(?P<major>[0-9]+)\.?"
r"((?P<minor>[0-9]+)\.?)?"
@@ -21,15 +19,7 @@ version_regex: typing.Final = re.compile(
)
class _VersionDict(typing.TypedDict):
major: int
minor: int | None
micro: int | None
releaselevel: str | None
serial: int | None
def split_server_version_string(version_string: str) -> ServerVersion:
def split_server_version_string(version_string):
version_match = version_regex.search(version_string)
if version_match is None:
@@ -38,17 +28,17 @@ def split_server_version_string(version_string: str) -> ServerVersion:
f'version from "{version_string}"'
)
version: _VersionDict = version_match.groupdict() # type: ignore[assignment] # noqa: E501
version = version_match.groupdict()
for ver_key, ver_value in version.items():
# Cast all possible versions parts to int
try:
version[ver_key] = int(ver_value) # type: ignore[literal-required, call-overload] # noqa: E501
version[ver_key] = int(ver_value)
except (TypeError, ValueError):
pass
if version["major"] < 10:
if version.get("major") < 10:
return ServerVersion(
version["major"],
version.get("major"),
version.get("minor") or 0,
version.get("micro") or 0,
version.get("releaselevel") or "final",
@@ -62,7 +52,7 @@ def split_server_version_string(version_string: str) -> ServerVersion:
# want to keep that behaviour consistent, i.e not fail
# a major version check due to a bugfix release.
return ServerVersion(
version["major"],
version.get("major"),
0,
version.get("minor") or 0,
version.get("releaselevel") or "final",

View File

@@ -4,18 +4,14 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import typing
import collections
from asyncpg.pgproto.types import (
BitString, Point, Path, Polygon,
Box, Line, LineSegment, Circle,
)
if typing.TYPE_CHECKING:
from typing_extensions import Self
__all__ = (
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
@@ -23,13 +19,7 @@ __all__ = (
)
class Type(typing.NamedTuple):
oid: int
name: str
kind: str
schema: str
Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema'])
Type.__doc__ = 'Database data type.'
Type.oid.__doc__ = 'OID of the type.'
Type.name.__doc__ = 'Type name. For example "int2".'
@@ -38,61 +28,25 @@ Type.kind.__doc__ = \
Type.schema.__doc__ = 'Name of the database schema that defines the type.'
class Attribute(typing.NamedTuple):
name: str
type: Type
Attribute = collections.namedtuple('Attribute', ['name', 'type'])
Attribute.__doc__ = 'Database relation attribute.'
Attribute.name.__doc__ = 'Attribute name.'
Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.'
class ServerVersion(typing.NamedTuple):
major: int
minor: int
micro: int
releaselevel: str
serial: int
ServerVersion = collections.namedtuple(
'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial'])
ServerVersion.__doc__ = 'PostgreSQL server version tuple.'
class _RangeValue(typing.Protocol):
def __eq__(self, __value: object) -> bool:
...
def __lt__(self, __other: _RangeValue) -> bool:
...
def __gt__(self, __other: _RangeValue) -> bool:
...
_RV = typing.TypeVar('_RV', bound=_RangeValue)
class Range(typing.Generic[_RV]):
class Range:
"""Immutable representation of PostgreSQL `range` type."""
__slots__ = ('_lower', '_upper', '_lower_inc', '_upper_inc', '_empty')
__slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty'
_lower: _RV | None
_upper: _RV | None
_lower_inc: bool
_upper_inc: bool
_empty: bool
def __init__(
self,
lower: _RV | None = None,
upper: _RV | None = None,
*,
lower_inc: bool = True,
upper_inc: bool = False,
empty: bool = False
) -> None:
def __init__(self, lower=None, upper=None, *,
lower_inc=True, upper_inc=False,
empty=False):
self._empty = empty
if empty:
self._lower = self._upper = None
@@ -104,34 +58,34 @@ class Range(typing.Generic[_RV]):
self._upper_inc = upper is not None and upper_inc
@property
def lower(self) -> _RV | None:
def lower(self):
return self._lower
@property
def lower_inc(self) -> bool:
def lower_inc(self):
return self._lower_inc
@property
def lower_inf(self) -> bool:
def lower_inf(self):
return self._lower is None and not self._empty
@property
def upper(self) -> _RV | None:
def upper(self):
return self._upper
@property
def upper_inc(self) -> bool:
def upper_inc(self):
return self._upper_inc
@property
def upper_inf(self) -> bool:
def upper_inf(self):
return self._upper is None and not self._empty
@property
def isempty(self) -> bool:
def isempty(self):
return self._empty
def _issubset_lower(self, other: Self) -> bool:
def _issubset_lower(self, other):
if other._lower is None:
return True
if self._lower is None:
@@ -142,7 +96,7 @@ class Range(typing.Generic[_RV]):
and (other._lower_inc or not self._lower_inc)
)
def _issubset_upper(self, other: Self) -> bool:
def _issubset_upper(self, other):
if other._upper is None:
return True
if self._upper is None:
@@ -153,7 +107,7 @@ class Range(typing.Generic[_RV]):
and (other._upper_inc or not self._upper_inc)
)
def issubset(self, other: Self) -> bool:
def issubset(self, other):
if self._empty:
return True
if other._empty:
@@ -161,13 +115,13 @@ class Range(typing.Generic[_RV]):
return self._issubset_lower(other) and self._issubset_upper(other)
def issuperset(self, other: Self) -> bool:
def issuperset(self, other):
return other.issubset(self)
def __bool__(self) -> bool:
def __bool__(self):
return not self._empty
def __eq__(self, other: object) -> bool:
def __eq__(self, other):
if not isinstance(other, Range):
return NotImplemented
@@ -178,14 +132,14 @@ class Range(typing.Generic[_RV]):
self._upper_inc,
self._empty
) == (
other._lower, # pyright: ignore [reportUnknownMemberType]
other._upper, # pyright: ignore [reportUnknownMemberType]
other._lower,
other._upper,
other._lower_inc,
other._upper_inc,
other._empty
)
def __hash__(self) -> int:
def __hash__(self):
return hash((
self._lower,
self._upper,
@@ -194,7 +148,7 @@ class Range(typing.Generic[_RV]):
self._empty
))
def __repr__(self) -> str:
def __repr__(self):
if self._empty:
desc = 'empty'
else:

View File

@@ -42,11 +42,4 @@ async def _mogrify(conn, query, args):
# Finally, replace $n references with text values.
return re.sub(
r"\$(\d+)\b",
lambda m: (
textified[int(m.group(1)) - 1]
if textified[int(m.group(1)) - 1] is not None
else "NULL"
),
query,
)
r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query)