This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .connection import connect, Connection # NOQA
|
||||
from .exceptions import * # NOQA
|
||||
@@ -15,10 +14,6 @@ from .types import * # NOQA
|
||||
|
||||
from ._version import __version__ # NOQA
|
||||
|
||||
from . import exceptions
|
||||
|
||||
|
||||
__all__: tuple[str, ...] = (
|
||||
'connect', 'create_pool', 'Pool', 'Record', 'Connection'
|
||||
)
|
||||
__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection')
|
||||
__all__ += exceptions.__all__ # NOQA
|
||||
|
||||
@@ -4,25 +4,18 @@
|
||||
#
|
||||
# SPDX-License-Identifier: PSF-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from . import compat
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from async_timeout import timeout as timeout_ctx
|
||||
else:
|
||||
from asyncio import timeout as timeout_ctx
|
||||
|
||||
_T = typing.TypeVar('_T')
|
||||
|
||||
|
||||
async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
|
||||
async def wait_for(fut, timeout):
|
||||
"""Wait for the single Future or coroutine to complete, with timeout.
|
||||
|
||||
Coroutine will be wrapped in Task.
|
||||
@@ -72,7 +65,7 @@ async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
|
||||
return await fut
|
||||
|
||||
|
||||
async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
|
||||
async def _cancel_and_wait(fut):
|
||||
"""Cancel the *fut* future or task and wait until it completes."""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -89,6 +82,6 @@ async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
|
||||
fut.remove_done_callback(cb)
|
||||
|
||||
|
||||
def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None:
|
||||
def _release_waiter(waiter, *args):
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
@@ -117,22 +117,10 @@ class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
|
||||
self.__unhandled_exceptions = []
|
||||
|
||||
def tearDown(self):
|
||||
excs = []
|
||||
for exc in self.__unhandled_exceptions:
|
||||
if isinstance(exc, ConnectionResetError):
|
||||
texc = traceback.TracebackException.from_exception(
|
||||
exc, lookup_lines=False)
|
||||
if texc.stack[-1].name == "_call_connection_lost":
|
||||
# On Windows calling socket.shutdown may raise
|
||||
# ConnectionResetError, which happens in the
|
||||
# finally block of _call_connection_lost.
|
||||
continue
|
||||
excs.append(exc)
|
||||
|
||||
if excs:
|
||||
if self.__unhandled_exceptions:
|
||||
formatted = []
|
||||
|
||||
for i, context in enumerate(excs):
|
||||
for i, context in enumerate(self.__unhandled_exceptions):
|
||||
formatted.append(self._format_loop_exception(context, i + 1))
|
||||
|
||||
self.fail(
|
||||
@@ -226,6 +214,13 @@ def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
|
||||
return cluster
|
||||
|
||||
|
||||
def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
|
||||
initdb_options=None):
|
||||
cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
|
||||
cluster.start(port='dynamic', server_settings=server_settings)
|
||||
return cluster
|
||||
|
||||
|
||||
def _get_initdb_options(initdb_options=None):
|
||||
if not initdb_options:
|
||||
initdb_options = {}
|
||||
@@ -249,12 +244,8 @@ def _init_default_cluster(initdb_options=None):
|
||||
_default_cluster = pg_cluster.RunningCluster()
|
||||
else:
|
||||
_default_cluster = _init_cluster(
|
||||
pg_cluster.TempCluster,
|
||||
cluster_kwargs={
|
||||
"data_dir_suffix": ".apgtest",
|
||||
},
|
||||
initdb_options=_get_initdb_options(initdb_options),
|
||||
)
|
||||
pg_cluster.TempCluster, cluster_kwargs={},
|
||||
initdb_options=_get_initdb_options(initdb_options))
|
||||
|
||||
return _default_cluster
|
||||
|
||||
@@ -271,7 +262,6 @@ def create_pool(dsn=None, *,
|
||||
max_size=10,
|
||||
max_queries=50000,
|
||||
max_inactive_connection_lifetime=60.0,
|
||||
connect=None,
|
||||
setup=None,
|
||||
init=None,
|
||||
loop=None,
|
||||
@@ -281,18 +271,12 @@ def create_pool(dsn=None, *,
|
||||
**connect_kwargs):
|
||||
return pool_class(
|
||||
dsn,
|
||||
min_size=min_size,
|
||||
max_size=max_size,
|
||||
max_queries=max_queries,
|
||||
loop=loop,
|
||||
connect=connect,
|
||||
setup=setup,
|
||||
init=init,
|
||||
min_size=min_size, max_size=max_size,
|
||||
max_queries=max_queries, loop=loop, setup=setup, init=init,
|
||||
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
|
||||
connection_class=connection_class,
|
||||
record_class=record_class,
|
||||
**connect_kwargs,
|
||||
)
|
||||
**connect_kwargs)
|
||||
|
||||
|
||||
class ClusterTestCase(TestCase):
|
||||
|
||||
@@ -10,8 +10,4 @@
|
||||
# supported platforms, publish the packages on PyPI, merge the PR
|
||||
# to the target branch, create a Git tag pointing to the commit.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
__version__: typing.Final = '0.30.0'
|
||||
__version__ = '0.29.0'
|
||||
|
||||
@@ -9,11 +9,9 @@ import asyncio
|
||||
import os
|
||||
import os.path
|
||||
import platform
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import socket
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -47,29 +45,6 @@ def find_available_port():
|
||||
sock.close()
|
||||
|
||||
|
||||
def _world_readable_mkdtemp(suffix=None, prefix=None, dir=None):
|
||||
name = "".join(random.choices(string.ascii_lowercase, k=8))
|
||||
if dir is None:
|
||||
dir = tempfile.gettempdir()
|
||||
if prefix is None:
|
||||
prefix = tempfile.gettempprefix()
|
||||
if suffix is None:
|
||||
suffix = ""
|
||||
fn = os.path.join(dir, prefix + name + suffix)
|
||||
os.mkdir(fn, 0o755)
|
||||
return fn
|
||||
|
||||
|
||||
def _mkdtemp(suffix=None, prefix=None, dir=None):
|
||||
if _system == 'Windows' and os.environ.get("GITHUB_ACTIONS"):
|
||||
# Due to mitigations introduced in python/cpython#118486
|
||||
# when Python runs in a session created via an SSH connection
|
||||
# tempfile.mkdtemp creates directories that are not accessible.
|
||||
return _world_readable_mkdtemp(suffix, prefix, dir)
|
||||
else:
|
||||
return tempfile.mkdtemp(suffix, prefix, dir)
|
||||
|
||||
|
||||
class ClusterError(Exception):
|
||||
pass
|
||||
|
||||
@@ -147,13 +122,9 @@ class Cluster:
|
||||
else:
|
||||
extra_args = []
|
||||
|
||||
os.makedirs(self._data_dir, exist_ok=True)
|
||||
process = subprocess.run(
|
||||
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=self._data_dir,
|
||||
)
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
output = process.stdout
|
||||
|
||||
@@ -228,10 +199,7 @@ class Cluster:
|
||||
process = subprocess.run(
|
||||
[self._pg_ctl, 'start', '-D', self._data_dir,
|
||||
'-o', ' '.join(extra_args)],
|
||||
stdout=stdout,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=self._data_dir,
|
||||
)
|
||||
stdout=stdout, stderr=subprocess.STDOUT)
|
||||
|
||||
if process.returncode != 0:
|
||||
if process.stderr:
|
||||
@@ -250,10 +218,7 @@ class Cluster:
|
||||
self._daemon_process = \
|
||||
subprocess.Popen(
|
||||
[self._postgres, '-D', self._data_dir, *extra_args],
|
||||
stdout=stdout,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=self._data_dir,
|
||||
)
|
||||
stdout=stdout, stderr=subprocess.STDOUT)
|
||||
|
||||
self._daemon_pid = self._daemon_process.pid
|
||||
|
||||
@@ -267,10 +232,7 @@ class Cluster:
|
||||
|
||||
process = subprocess.run(
|
||||
[self._pg_ctl, 'reload', '-D', self._data_dir],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=self._data_dir,
|
||||
)
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
stderr = process.stderr
|
||||
|
||||
@@ -283,10 +245,7 @@ class Cluster:
|
||||
process = subprocess.run(
|
||||
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
|
||||
'-m', 'fast'],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=self._data_dir,
|
||||
)
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
stderr = process.stderr
|
||||
|
||||
@@ -624,9 +583,9 @@ class TempCluster(Cluster):
|
||||
def __init__(self, *,
|
||||
data_dir_suffix=None, data_dir_prefix=None,
|
||||
data_dir_parent=None, pg_config_path=None):
|
||||
self._data_dir = _mkdtemp(suffix=data_dir_suffix,
|
||||
prefix=data_dir_prefix,
|
||||
dir=data_dir_parent)
|
||||
self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
|
||||
prefix=data_dir_prefix,
|
||||
dir=data_dir_parent)
|
||||
super().__init__(self._data_dir, pg_config_path=pg_config_path)
|
||||
|
||||
|
||||
|
||||
@@ -4,26 +4,22 @@
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import pathlib
|
||||
import platform
|
||||
import typing
|
||||
import sys
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
SYSTEM: typing.Final = platform.uname().system
|
||||
SYSTEM = platform.uname().system
|
||||
|
||||
|
||||
if sys.platform == 'win32':
|
||||
if SYSTEM == 'Windows':
|
||||
import ctypes.wintypes
|
||||
|
||||
CSIDL_APPDATA: typing.Final = 0x001a
|
||||
CSIDL_APPDATA = 0x001a
|
||||
|
||||
def get_pg_home_directory() -> pathlib.Path | None:
|
||||
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
|
||||
# We cannot simply use expanduser() as that returns the user's
|
||||
# home directory, whereas Postgres stores its config in
|
||||
# %AppData% on Windows.
|
||||
@@ -35,14 +31,14 @@ if sys.platform == 'win32':
|
||||
return pathlib.Path(buf.value) / 'postgresql'
|
||||
|
||||
else:
|
||||
def get_pg_home_directory() -> pathlib.Path | None:
|
||||
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
|
||||
try:
|
||||
return pathlib.Path.home()
|
||||
except (RuntimeError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
async def wait_closed(stream: asyncio.StreamWriter) -> None:
|
||||
async def wait_closed(stream):
|
||||
# Not all asyncio versions have StreamWriter.wait_closed().
|
||||
if hasattr(stream, 'wait_closed'):
|
||||
try:
|
||||
@@ -53,13 +49,6 @@ async def wait_closed(stream: asyncio.StreamWriter) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
def markcoroutinefunction(c): # type: ignore
|
||||
pass
|
||||
else:
|
||||
from inspect import markcoroutinefunction # noqa: F401
|
||||
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
from ._asyncio_compat import wait_for as wait_for # noqa: F401
|
||||
else:
|
||||
@@ -70,19 +59,3 @@ if sys.version_info < (3, 11):
|
||||
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
|
||||
else:
|
||||
from asyncio import timeout as timeout # noqa: F401
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import ( # noqa: F401
|
||||
Awaitable as Awaitable,
|
||||
)
|
||||
else:
|
||||
from collections.abc import ( # noqa: F401
|
||||
Awaitable as Awaitable,
|
||||
)
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
class StrEnum(str, enum.Enum):
|
||||
__str__ = str.__str__
|
||||
__repr__ = enum.Enum.__repr__
|
||||
else:
|
||||
from enum import StrEnum as StrEnum # noqa: F401
|
||||
|
||||
@@ -45,11 +45,6 @@ class SSLMode(enum.IntEnum):
|
||||
return getattr(cls, sslmode.replace('-', '_'))
|
||||
|
||||
|
||||
class SSLNegotiation(compat.StrEnum):
|
||||
postgres = "postgres"
|
||||
direct = "direct"
|
||||
|
||||
|
||||
_ConnectionParameters = collections.namedtuple(
|
||||
'ConnectionParameters',
|
||||
[
|
||||
@@ -58,11 +53,9 @@ _ConnectionParameters = collections.namedtuple(
|
||||
'database',
|
||||
'ssl',
|
||||
'sslmode',
|
||||
'ssl_negotiation',
|
||||
'direct_tls',
|
||||
'server_settings',
|
||||
'target_session_attrs',
|
||||
'krbsrvname',
|
||||
'gsslib',
|
||||
])
|
||||
|
||||
|
||||
@@ -268,13 +261,12 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
|
||||
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
password, passfile, database, ssl,
|
||||
direct_tls, server_settings,
|
||||
target_session_attrs, krbsrvname, gsslib):
|
||||
target_session_attrs):
|
||||
# `auth_hosts` is the version of host information for the purposes
|
||||
# of reading the pgpass file.
|
||||
auth_hosts = None
|
||||
sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
|
||||
ssl_min_protocol_version = ssl_max_protocol_version = None
|
||||
sslnegotiation = None
|
||||
|
||||
if dsn:
|
||||
parsed = urllib.parse.urlparse(dsn)
|
||||
@@ -368,9 +360,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
if 'sslrootcert' in query:
|
||||
sslrootcert = query.pop('sslrootcert')
|
||||
|
||||
if 'sslnegotiation' in query:
|
||||
sslnegotiation = query.pop('sslnegotiation')
|
||||
|
||||
if 'sslcrl' in query:
|
||||
sslcrl = query.pop('sslcrl')
|
||||
|
||||
@@ -394,16 +383,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
if target_session_attrs is None:
|
||||
target_session_attrs = dsn_target_session_attrs
|
||||
|
||||
if 'krbsrvname' in query:
|
||||
val = query.pop('krbsrvname')
|
||||
if krbsrvname is None:
|
||||
krbsrvname = val
|
||||
|
||||
if 'gsslib' in query:
|
||||
val = query.pop('gsslib')
|
||||
if gsslib is None:
|
||||
gsslib = val
|
||||
|
||||
if query:
|
||||
if server_settings is None:
|
||||
server_settings = query
|
||||
@@ -512,36 +491,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
if ssl is None and have_tcp_addrs:
|
||||
ssl = 'prefer'
|
||||
|
||||
if direct_tls is not None:
|
||||
sslneg = (
|
||||
SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
|
||||
)
|
||||
else:
|
||||
if sslnegotiation is None:
|
||||
sslnegotiation = os.environ.get("PGSSLNEGOTIATION")
|
||||
|
||||
if sslnegotiation is not None:
|
||||
try:
|
||||
sslneg = SSLNegotiation(sslnegotiation)
|
||||
except ValueError:
|
||||
modes = ', '.join(
|
||||
m.name.replace('_', '-')
|
||||
for m in SSLNegotiation
|
||||
)
|
||||
raise exceptions.ClientConfigurationError(
|
||||
f'`sslnegotiation` parameter must be one of: {modes}'
|
||||
) from None
|
||||
else:
|
||||
sslneg = SSLNegotiation.postgres
|
||||
|
||||
if isinstance(ssl, (str, SSLMode)):
|
||||
try:
|
||||
sslmode = SSLMode.parse(ssl)
|
||||
except AttributeError:
|
||||
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
|
||||
raise exceptions.ClientConfigurationError(
|
||||
'`sslmode` parameter must be one of: {}'.format(modes)
|
||||
) from None
|
||||
'`sslmode` parameter must be one of: {}'.format(modes))
|
||||
|
||||
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
|
||||
if sslmode < SSLMode.allow:
|
||||
@@ -694,24 +650,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
||||
)
|
||||
) from None
|
||||
|
||||
if krbsrvname is None:
|
||||
krbsrvname = os.getenv('PGKRBSRVNAME')
|
||||
|
||||
if gsslib is None:
|
||||
gsslib = os.getenv('PGGSSLIB')
|
||||
if gsslib is None:
|
||||
gsslib = 'sspi' if _system == 'Windows' else 'gssapi'
|
||||
if gsslib not in {'gssapi', 'sspi'}:
|
||||
raise exceptions.ClientConfigurationError(
|
||||
"gsslib parameter must be either 'gssapi' or 'sspi'"
|
||||
", got {!r}".format(gsslib))
|
||||
|
||||
params = _ConnectionParameters(
|
||||
user=user, password=password, database=database, ssl=ssl,
|
||||
sslmode=sslmode, ssl_negotiation=sslneg,
|
||||
sslmode=sslmode, direct_tls=direct_tls,
|
||||
server_settings=server_settings,
|
||||
target_session_attrs=target_session_attrs,
|
||||
krbsrvname=krbsrvname, gsslib=gsslib)
|
||||
target_session_attrs=target_session_attrs)
|
||||
|
||||
return addrs, params
|
||||
|
||||
@@ -722,7 +665,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
|
||||
max_cached_statement_lifetime,
|
||||
max_cacheable_statement_size,
|
||||
ssl, direct_tls, server_settings,
|
||||
target_session_attrs, krbsrvname, gsslib):
|
||||
target_session_attrs):
|
||||
local_vars = locals()
|
||||
for var_name in {'max_cacheable_statement_size',
|
||||
'max_cached_statement_lifetime',
|
||||
@@ -751,8 +694,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
|
||||
password=password, passfile=passfile, ssl=ssl,
|
||||
direct_tls=direct_tls, database=database,
|
||||
server_settings=server_settings,
|
||||
target_session_attrs=target_session_attrs,
|
||||
krbsrvname=krbsrvname, gsslib=gsslib)
|
||||
target_session_attrs=target_session_attrs)
|
||||
|
||||
config = _ClientConfiguration(
|
||||
command_timeout=command_timeout,
|
||||
@@ -914,9 +856,9 @@ async def __connect_addr(
|
||||
# UNIX socket
|
||||
connector = loop.create_unix_connection(proto_factory, addr)
|
||||
|
||||
elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
|
||||
# if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
|
||||
# direct SSL connection
|
||||
elif params.ssl and params.direct_tls:
|
||||
# if ssl and direct_tls are given, skip STARTTLS and perform direct
|
||||
# SSL connection
|
||||
connector = loop.create_connection(
|
||||
proto_factory, *addr, ssl=params.ssl
|
||||
)
|
||||
|
||||
@@ -231,8 +231,9 @@ class Connection(metaclass=ConnectionMeta):
|
||||
|
||||
:param callable callback:
|
||||
A callable or a coroutine function receiving one argument:
|
||||
**record**, a LoggedQuery containing `query`, `args`, `timeout`,
|
||||
`elapsed`, `exception`, `conn_addr`, and `conn_params`.
|
||||
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
|
||||
`elapsed`, `exception`, `conn_addr`, and
|
||||
`conn_params`.
|
||||
|
||||
.. versionadded:: 0.29.0
|
||||
"""
|
||||
@@ -756,44 +757,6 @@ class Connection(metaclass=ConnectionMeta):
|
||||
return None
|
||||
return data[0]
|
||||
|
||||
async def fetchmany(
|
||||
self, query, args, *, timeout: float=None, record_class=None
|
||||
):
|
||||
"""Run a query for each sequence of arguments in *args*
|
||||
and return the results as a list of :class:`Record`.
|
||||
|
||||
:param query:
|
||||
Query to execute.
|
||||
:param args:
|
||||
An iterable containing sequences of arguments for the query.
|
||||
:param float timeout:
|
||||
Optional timeout value in seconds.
|
||||
:param type record_class:
|
||||
If specified, the class to use for records returned by this method.
|
||||
Must be a subclass of :class:`~asyncpg.Record`. If not specified,
|
||||
a per-connection *record_class* is used.
|
||||
|
||||
:return list:
|
||||
A list of :class:`~asyncpg.Record` instances. If specified, the
|
||||
actual type of list elements would be *record_class*.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: pycon
|
||||
|
||||
>>> rows = await con.fetchmany('''
|
||||
... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a;
|
||||
... ''', [('x', 1), ('y', 2), ('z', 3)])
|
||||
>>> rows
|
||||
[<Record row=('x',)>, <Record row=('y',)>, <Record row=('z',)>]
|
||||
|
||||
.. versionadded:: 0.30.0
|
||||
"""
|
||||
self._check_open()
|
||||
return await self._executemany(
|
||||
query, args, timeout, return_rows=True, record_class=record_class
|
||||
)
|
||||
|
||||
async def copy_from_table(self, table_name, *, output,
|
||||
columns=None, schema_name=None, timeout=None,
|
||||
format=None, oids=None, delimiter=None,
|
||||
@@ -837,7 +800,7 @@ class Connection(metaclass=ConnectionMeta):
|
||||
... output='file.csv', format='csv')
|
||||
... print(result)
|
||||
...
|
||||
>>> asyncio.run(run())
|
||||
>>> asyncio.get_event_loop().run_until_complete(run())
|
||||
'COPY 100'
|
||||
|
||||
.. _`COPY statement documentation`:
|
||||
@@ -906,7 +869,7 @@ class Connection(metaclass=ConnectionMeta):
|
||||
... output='file.csv', format='csv')
|
||||
... print(result)
|
||||
...
|
||||
>>> asyncio.run(run())
|
||||
>>> asyncio.get_event_loop().run_until_complete(run())
|
||||
'COPY 10'
|
||||
|
||||
.. _`COPY statement documentation`:
|
||||
@@ -982,7 +945,7 @@ class Connection(metaclass=ConnectionMeta):
|
||||
... 'mytable', source='datafile.tbl')
|
||||
... print(result)
|
||||
...
|
||||
>>> asyncio.run(run())
|
||||
>>> asyncio.get_event_loop().run_until_complete(run())
|
||||
'COPY 140000'
|
||||
|
||||
.. _`COPY statement documentation`:
|
||||
@@ -1064,7 +1027,7 @@ class Connection(metaclass=ConnectionMeta):
|
||||
... (2, 'ham', 'spam')])
|
||||
... print(result)
|
||||
...
|
||||
>>> asyncio.run(run())
|
||||
>>> asyncio.get_event_loop().run_until_complete(run())
|
||||
'COPY 2'
|
||||
|
||||
Asynchronous record iterables are also supported:
|
||||
@@ -1082,7 +1045,7 @@ class Connection(metaclass=ConnectionMeta):
|
||||
... 'mytable', records=record_gen(100))
|
||||
... print(result)
|
||||
...
|
||||
>>> asyncio.run(run())
|
||||
>>> asyncio.get_event_loop().run_until_complete(run())
|
||||
'COPY 100'
|
||||
|
||||
.. versionadded:: 0.11.0
|
||||
@@ -1342,7 +1305,7 @@ class Connection(metaclass=ConnectionMeta):
|
||||
... print(result)
|
||||
... print(datetime.datetime(2002, 1, 1) + result)
|
||||
...
|
||||
>>> asyncio.run(run())
|
||||
>>> asyncio.get_event_loop().run_until_complete(run())
|
||||
relativedelta(years=+2, months=+3, days=+1)
|
||||
2004-04-02 00:00:00
|
||||
|
||||
@@ -1515,10 +1478,11 @@ class Connection(metaclass=ConnectionMeta):
|
||||
self._abort()
|
||||
self._cleanup()
|
||||
|
||||
async def _reset(self):
|
||||
async def reset(self, *, timeout=None):
|
||||
self._check_open()
|
||||
self._listeners.clear()
|
||||
self._log_listeners.clear()
|
||||
reset_query = self._get_reset_query()
|
||||
|
||||
if self._protocol.is_in_transaction() or self._top_xact is not None:
|
||||
if self._top_xact is None or not self._top_xact._managed:
|
||||
@@ -1530,36 +1494,10 @@ class Connection(metaclass=ConnectionMeta):
|
||||
})
|
||||
|
||||
self._top_xact = None
|
||||
await self.execute("ROLLBACK")
|
||||
reset_query = 'ROLLBACK;\n' + reset_query
|
||||
|
||||
async def reset(self, *, timeout=None):
|
||||
"""Reset the connection state.
|
||||
|
||||
Calling this will reset the connection session state to a state
|
||||
resembling that of a newly obtained connection. Namely, an open
|
||||
transaction (if any) is rolled back, open cursors are closed,
|
||||
all `LISTEN <https://www.postgresql.org/docs/current/sql-listen.html>`_
|
||||
registrations are removed, all session configuration
|
||||
variables are reset to their default values, and all advisory locks
|
||||
are released.
|
||||
|
||||
Note that the above describes the default query returned by
|
||||
:meth:`Connection.get_reset_query`. If one overloads the method
|
||||
by subclassing ``Connection``, then this method will do whatever
|
||||
the overloaded method returns, except open transactions are always
|
||||
terminated and any callbacks registered by
|
||||
:meth:`Connection.add_listener` or :meth:`Connection.add_log_listener`
|
||||
are removed.
|
||||
|
||||
:param float timeout:
|
||||
A timeout for resetting the connection. If not specified, defaults
|
||||
to no timeout.
|
||||
"""
|
||||
async with compat.timeout(timeout):
|
||||
await self._reset()
|
||||
reset_query = self.get_reset_query()
|
||||
if reset_query:
|
||||
await self.execute(reset_query)
|
||||
if reset_query:
|
||||
await self.execute(reset_query, timeout=timeout)
|
||||
|
||||
def _abort(self):
|
||||
# Put the connection into the aborted state.
|
||||
@@ -1720,15 +1658,7 @@ class Connection(metaclass=ConnectionMeta):
|
||||
con_ref = self._proxy
|
||||
return con_ref
|
||||
|
||||
def get_reset_query(self):
|
||||
"""Return the query sent to server on connection release.
|
||||
|
||||
The query returned by this method is used by :meth:`Connection.reset`,
|
||||
which is, in turn, used by :class:`~asyncpg.pool.Pool` before making
|
||||
the connection available to another acquirer.
|
||||
|
||||
.. versionadded:: 0.30.0
|
||||
"""
|
||||
def _get_reset_query(self):
|
||||
if self._reset_query is not None:
|
||||
return self._reset_query
|
||||
|
||||
@@ -1842,7 +1772,7 @@ class Connection(metaclass=ConnectionMeta):
|
||||
... await con.execute('LOCK TABLE tbl')
|
||||
... await change_type(con)
|
||||
...
|
||||
>>> asyncio.run(run())
|
||||
>>> asyncio.get_event_loop().run_until_complete(run())
|
||||
|
||||
.. versionadded:: 0.14.0
|
||||
"""
|
||||
@@ -1879,8 +1809,9 @@ class Connection(metaclass=ConnectionMeta):
|
||||
|
||||
:param callable callback:
|
||||
A callable or a coroutine function receiving one argument:
|
||||
**record**, a LoggedQuery containing `query`, `args`, `timeout`,
|
||||
`elapsed`, `exception`, `conn_addr`, and `conn_params`.
|
||||
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
|
||||
`elapsed`, `exception`, `conn_addr`, and
|
||||
`conn_params`.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -1967,27 +1898,17 @@ class Connection(metaclass=ConnectionMeta):
|
||||
)
|
||||
return result, stmt
|
||||
|
||||
async def _executemany(
|
||||
self,
|
||||
query,
|
||||
args,
|
||||
timeout,
|
||||
return_rows=False,
|
||||
record_class=None,
|
||||
):
|
||||
async def _executemany(self, query, args, timeout):
|
||||
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
|
||||
state=stmt,
|
||||
args=args,
|
||||
portal_name='',
|
||||
timeout=timeout,
|
||||
return_rows=return_rows,
|
||||
)
|
||||
timeout = self._protocol._get_timeout(timeout)
|
||||
with self._stmt_exclusive_section:
|
||||
with self._time_and_log(query, args, timeout):
|
||||
result, _ = await self._do_execute(
|
||||
query, executor, timeout, record_class=record_class
|
||||
)
|
||||
result, _ = await self._do_execute(query, executor, timeout)
|
||||
return result
|
||||
|
||||
async def _do_execute(
|
||||
@@ -2082,13 +2003,11 @@ async def connect(dsn=None, *,
|
||||
max_cacheable_statement_size=1024 * 15,
|
||||
command_timeout=None,
|
||||
ssl=None,
|
||||
direct_tls=None,
|
||||
direct_tls=False,
|
||||
connection_class=Connection,
|
||||
record_class=protocol.Record,
|
||||
server_settings=None,
|
||||
target_session_attrs=None,
|
||||
krbsrvname=None,
|
||||
gsslib=None):
|
||||
target_session_attrs=None):
|
||||
r"""A coroutine to establish a connection to a PostgreSQL server.
|
||||
|
||||
The connection parameters may be specified either as a connection
|
||||
@@ -2113,7 +2032,7 @@ async def connect(dsn=None, *,
|
||||
.. note::
|
||||
|
||||
The URI must be *valid*, which means that all components must
|
||||
be properly quoted with :py:func:`urllib.parse.quote_plus`, and
|
||||
be properly quoted with :py:func:`urllib.parse.quote`, and
|
||||
any literal IPv6 addresses must be enclosed in square brackets.
|
||||
For example:
|
||||
|
||||
@@ -2316,14 +2235,6 @@ async def connect(dsn=None, *,
|
||||
or the value of the ``PGTARGETSESSIONATTRS`` environment variable,
|
||||
or ``"any"`` if neither is specified.
|
||||
|
||||
:param str krbsrvname:
|
||||
Kerberos service name to use when authenticating with GSSAPI. This
|
||||
must match the server configuration. Defaults to 'postgres'.
|
||||
|
||||
:param str gsslib:
|
||||
GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi'
|
||||
or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise.
|
||||
|
||||
:return: A :class:`~asyncpg.connection.Connection` instance.
|
||||
|
||||
Example:
|
||||
@@ -2337,7 +2248,7 @@ async def connect(dsn=None, *,
|
||||
... types = await con.fetch('SELECT * FROM pg_type')
|
||||
... print(types)
|
||||
...
|
||||
>>> asyncio.run(run())
|
||||
>>> asyncio.get_event_loop().run_until_complete(run())
|
||||
[<Record typname='bool' typnamespace=11 ...
|
||||
|
||||
.. versionadded:: 0.10.0
|
||||
@@ -2392,9 +2303,6 @@ async def connect(dsn=None, *,
|
||||
.. versionchanged:: 0.28.0
|
||||
Added the *target_session_attrs* parameter.
|
||||
|
||||
.. versionchanged:: 0.30.0
|
||||
Added the *krbsrvname* and *gsslib* parameters.
|
||||
|
||||
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
|
||||
.. _create_default_context:
|
||||
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
|
||||
@@ -2436,9 +2344,7 @@ async def connect(dsn=None, *,
|
||||
statement_cache_size=statement_cache_size,
|
||||
max_cached_statement_lifetime=max_cached_statement_lifetime,
|
||||
max_cacheable_statement_size=max_cacheable_statement_size,
|
||||
target_session_attrs=target_session_attrs,
|
||||
krbsrvname=krbsrvname,
|
||||
gsslib=gsslib,
|
||||
target_session_attrs=target_session_attrs
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -121,10 +121,6 @@ class StackedDiagnosticsAccessedWithoutActiveHandlerError(DiagnosticsError):
|
||||
sqlstate = '0Z002'
|
||||
|
||||
|
||||
class InvalidArgumentForXqueryError(_base.PostgresError):
|
||||
sqlstate = '10608'
|
||||
|
||||
|
||||
class CaseNotFoundError(_base.PostgresError):
|
||||
sqlstate = '20000'
|
||||
|
||||
@@ -489,10 +485,6 @@ class IdleInTransactionSessionTimeoutError(InvalidTransactionStateError):
|
||||
sqlstate = '25P03'
|
||||
|
||||
|
||||
class TransactionTimeoutError(InvalidTransactionStateError):
|
||||
sqlstate = '25P04'
|
||||
|
||||
|
||||
class InvalidSQLStatementNameError(_base.PostgresError):
|
||||
sqlstate = '26000'
|
||||
|
||||
@@ -908,10 +900,6 @@ class DuplicateFileError(PostgresSystemError):
|
||||
sqlstate = '58P02'
|
||||
|
||||
|
||||
class FileNameTooLongError(PostgresSystemError):
|
||||
sqlstate = '58P03'
|
||||
|
||||
|
||||
class SnapshotTooOldError(_base.PostgresError):
|
||||
sqlstate = '72000'
|
||||
|
||||
@@ -1107,9 +1095,9 @@ __all__ = (
|
||||
'FDWTableNotFoundError', 'FDWTooManyHandlesError',
|
||||
'FDWUnableToCreateExecutionError', 'FDWUnableToCreateReplyError',
|
||||
'FDWUnableToEstablishConnectionError', 'FeatureNotSupportedError',
|
||||
'FileNameTooLongError', 'ForeignKeyViolationError',
|
||||
'FunctionExecutedNoReturnStatementError', 'GeneratedAlwaysError',
|
||||
'GroupingError', 'HeldCursorRequiresSameIsolationLevelError',
|
||||
'ForeignKeyViolationError', 'FunctionExecutedNoReturnStatementError',
|
||||
'GeneratedAlwaysError', 'GroupingError',
|
||||
'HeldCursorRequiresSameIsolationLevelError',
|
||||
'IdleInTransactionSessionTimeoutError', 'IdleSessionTimeoutError',
|
||||
'ImplicitZeroBitPadding', 'InFailedSQLTransactionError',
|
||||
'InappropriateAccessModeForBranchTransactionError',
|
||||
@@ -1124,7 +1112,6 @@ __all__ = (
|
||||
'InvalidArgumentForPowerFunctionError',
|
||||
'InvalidArgumentForSQLJsonDatetimeFunctionError',
|
||||
'InvalidArgumentForWidthBucketFunctionError',
|
||||
'InvalidArgumentForXqueryError',
|
||||
'InvalidAuthorizationSpecificationError',
|
||||
'InvalidBinaryRepresentationError', 'InvalidCachedStatementError',
|
||||
'InvalidCatalogNameError', 'InvalidCharacterValueForCastError',
|
||||
@@ -1197,9 +1184,9 @@ __all__ = (
|
||||
'TooManyJsonObjectMembersError', 'TooManyRowsError',
|
||||
'TransactionIntegrityConstraintViolationError',
|
||||
'TransactionResolutionUnknownError', 'TransactionRollbackError',
|
||||
'TransactionTimeoutError', 'TriggerProtocolViolatedError',
|
||||
'TriggeredActionError', 'TriggeredDataChangeViolationError',
|
||||
'TrimError', 'UndefinedColumnError', 'UndefinedFileError',
|
||||
'TriggerProtocolViolatedError', 'TriggeredActionError',
|
||||
'TriggeredDataChangeViolationError', 'TrimError',
|
||||
'UndefinedColumnError', 'UndefinedFileError',
|
||||
'UndefinedFunctionError', 'UndefinedObjectError',
|
||||
'UndefinedParameterError', 'UndefinedTableError',
|
||||
'UniqueViolationError', 'UnsafeNewEnumValueUsageError',
|
||||
|
||||
@@ -4,14 +4,8 @@
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from . import protocol
|
||||
|
||||
_TYPEINFO_13: typing.Final = '''\
|
||||
_TYPEINFO_13 = '''\
|
||||
(
|
||||
SELECT
|
||||
t.oid AS oid,
|
||||
@@ -130,7 +124,7 @@ ORDER BY
|
||||
'''.format(typeinfo=_TYPEINFO_13)
|
||||
|
||||
|
||||
_TYPEINFO: typing.Final = '''\
|
||||
_TYPEINFO = '''\
|
||||
(
|
||||
SELECT
|
||||
t.oid AS oid,
|
||||
@@ -254,7 +248,7 @@ ORDER BY
|
||||
'''.format(typeinfo=_TYPEINFO)
|
||||
|
||||
|
||||
TYPE_BY_NAME: typing.Final = '''\
|
||||
TYPE_BY_NAME = '''\
|
||||
SELECT
|
||||
t.oid,
|
||||
t.typelem AS elemtype,
|
||||
@@ -283,16 +277,16 @@ WHERE
|
||||
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
|
||||
|
||||
|
||||
def is_scalar_type(typeinfo: protocol.Record) -> bool:
|
||||
def is_scalar_type(typeinfo) -> bool:
|
||||
return (
|
||||
typeinfo['kind'] in SCALAR_TYPE_KINDS and
|
||||
not typeinfo['elemtype']
|
||||
)
|
||||
|
||||
|
||||
def is_domain_type(typeinfo: protocol.Record) -> bool:
|
||||
return typeinfo['kind'] == b'd' # type: ignore[no-any-return]
|
||||
def is_domain_type(typeinfo) -> bool:
|
||||
return typeinfo['kind'] == b'd'
|
||||
|
||||
|
||||
def is_composite_type(typeinfo: protocol.Record) -> bool:
|
||||
return typeinfo['kind'] == b'c' # type: ignore[no-any-return]
|
||||
def is_composite_type(typeinfo) -> bool:
|
||||
return typeinfo['kind'] == b'c'
|
||||
|
||||
Binary file not shown.
@@ -1,13 +0,0 @@
|
||||
import codecs
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
class CodecContext:
|
||||
def get_text_codec(self) -> codecs.CodecInfo: ...
|
||||
|
||||
class ReadBuffer: ...
|
||||
class WriteBuffer: ...
|
||||
class BufferError(Exception): ...
|
||||
|
||||
class UUID(uuid.UUID):
|
||||
def __init__(self, inp: typing.AnyStr) -> None: ...
|
||||
@@ -33,8 +33,7 @@ class PoolConnectionProxyMeta(type):
|
||||
if not inspect.isfunction(meth):
|
||||
continue
|
||||
|
||||
iscoroutine = inspect.iscoroutinefunction(meth)
|
||||
wrapper = mcls._wrap_connection_method(attrname, iscoroutine)
|
||||
wrapper = mcls._wrap_connection_method(attrname)
|
||||
wrapper = functools.update_wrapper(wrapper, meth)
|
||||
dct[attrname] = wrapper
|
||||
|
||||
@@ -44,7 +43,7 @@ class PoolConnectionProxyMeta(type):
|
||||
return super().__new__(mcls, name, bases, dct)
|
||||
|
||||
@staticmethod
|
||||
def _wrap_connection_method(meth_name, iscoroutine):
|
||||
def _wrap_connection_method(meth_name):
|
||||
def call_con_method(self, *args, **kwargs):
|
||||
# This method will be owned by PoolConnectionProxy class.
|
||||
if self._con is None:
|
||||
@@ -56,9 +55,6 @@ class PoolConnectionProxyMeta(type):
|
||||
meth = getattr(self._con.__class__, meth_name)
|
||||
return meth(self._con, *args, **kwargs)
|
||||
|
||||
if iscoroutine:
|
||||
compat.markcoroutinefunction(call_con_method)
|
||||
|
||||
return call_con_method
|
||||
|
||||
|
||||
@@ -210,12 +206,7 @@ class PoolConnectionHolder:
|
||||
if budget is not None:
|
||||
budget -= time.monotonic() - started
|
||||
|
||||
if self._pool._reset is not None:
|
||||
async with compat.timeout(budget):
|
||||
await self._con._reset()
|
||||
await self._pool._reset(self._con)
|
||||
else:
|
||||
await self._con.reset(timeout=budget)
|
||||
await self._con.reset(timeout=budget)
|
||||
except (Exception, asyncio.CancelledError) as ex:
|
||||
# If the `reset` call failed, terminate the connection.
|
||||
# A new one will be created when `acquire` is called
|
||||
@@ -318,7 +309,7 @@ class Pool:
|
||||
|
||||
__slots__ = (
|
||||
'_queue', '_loop', '_minsize', '_maxsize',
|
||||
'_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
|
||||
'_init', '_connect_args', '_connect_kwargs',
|
||||
'_holders', '_initialized', '_initializing', '_closing',
|
||||
'_closed', '_connection_class', '_record_class', '_generation',
|
||||
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
|
||||
@@ -329,10 +320,8 @@ class Pool:
|
||||
max_size,
|
||||
max_queries,
|
||||
max_inactive_connection_lifetime,
|
||||
connect=None,
|
||||
setup=None,
|
||||
init=None,
|
||||
reset=None,
|
||||
setup,
|
||||
init,
|
||||
loop,
|
||||
connection_class,
|
||||
record_class,
|
||||
@@ -392,22 +381,18 @@ class Pool:
|
||||
self._closing = False
|
||||
self._closed = False
|
||||
self._generation = 0
|
||||
|
||||
self._connect = connect if connect is not None else connection.connect
|
||||
self._init = init
|
||||
self._connect_args = connect_args
|
||||
self._connect_kwargs = connect_kwargs
|
||||
|
||||
self._setup = setup
|
||||
self._init = init
|
||||
self._reset = reset
|
||||
|
||||
self._max_queries = max_queries
|
||||
self._max_inactive_connection_lifetime = \
|
||||
max_inactive_connection_lifetime
|
||||
|
||||
async def _async__init__(self):
|
||||
if self._initialized:
|
||||
return self
|
||||
return
|
||||
if self._initializing:
|
||||
raise exceptions.InterfaceError(
|
||||
'pool is being initialized in another task')
|
||||
@@ -514,25 +499,13 @@ class Pool:
|
||||
self._connect_kwargs = connect_kwargs
|
||||
|
||||
async def _get_new_connection(self):
|
||||
con = await self._connect(
|
||||
con = await connection.connect(
|
||||
*self._connect_args,
|
||||
loop=self._loop,
|
||||
connection_class=self._connection_class,
|
||||
record_class=self._record_class,
|
||||
**self._connect_kwargs,
|
||||
)
|
||||
if not isinstance(con, self._connection_class):
|
||||
good = self._connection_class
|
||||
good_n = f'{good.__module__}.{good.__name__}'
|
||||
bad = type(con)
|
||||
if bad.__module__ == "builtins":
|
||||
bad_n = bad.__name__
|
||||
else:
|
||||
bad_n = f'{bad.__module__}.{bad.__name__}'
|
||||
raise exceptions.InterfaceError(
|
||||
"expected pool connect callback to return an instance of "
|
||||
f"'{good_n}', got " f"'{bad_n}'"
|
||||
)
|
||||
|
||||
if self._init is not None:
|
||||
try:
|
||||
@@ -632,22 +605,6 @@ class Pool:
|
||||
record_class=record_class
|
||||
)
|
||||
|
||||
async def fetchmany(self, query, args, *, timeout=None, record_class=None):
|
||||
"""Run a query for each sequence of arguments in *args*
|
||||
and return the results as a list of :class:`Record`.
|
||||
|
||||
Pool performs this operation using one of its connections. Other than
|
||||
that, it behaves identically to
|
||||
:meth:`Connection.fetchmany()
|
||||
<asyncpg.connection.Connection.fetchmany>`.
|
||||
|
||||
.. versionadded:: 0.30.0
|
||||
"""
|
||||
async with self.acquire() as con:
|
||||
return await con.fetchmany(
|
||||
query, args, timeout=timeout, record_class=record_class
|
||||
)
|
||||
|
||||
async def copy_from_table(
|
||||
self,
|
||||
table_name,
|
||||
@@ -1040,10 +997,8 @@ def create_pool(dsn=None, *,
|
||||
max_size=10,
|
||||
max_queries=50000,
|
||||
max_inactive_connection_lifetime=300.0,
|
||||
connect=None,
|
||||
setup=None,
|
||||
init=None,
|
||||
reset=None,
|
||||
loop=None,
|
||||
connection_class=connection.Connection,
|
||||
record_class=protocol.Record,
|
||||
@@ -1124,16 +1079,9 @@ def create_pool(dsn=None, *,
|
||||
Number of seconds after which inactive connections in the
|
||||
pool will be closed. Pass ``0`` to disable this mechanism.
|
||||
|
||||
:param coroutine connect:
|
||||
A coroutine that is called instead of
|
||||
:func:`~asyncpg.connection.connect` whenever the pool needs to make a
|
||||
new connection. Must return an instance of type specified by
|
||||
*connection_class* or :class:`~asyncpg.connection.Connection` if
|
||||
*connection_class* was not specified.
|
||||
|
||||
:param coroutine setup:
|
||||
A coroutine to prepare a connection right before it is returned
|
||||
from :meth:`Pool.acquire()`. An example use
|
||||
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
|
||||
case would be to automatically set up notifications listeners for
|
||||
all connections of a pool.
|
||||
|
||||
@@ -1145,25 +1093,6 @@ def create_pool(dsn=None, *,
|
||||
or :meth:`Connection.set_type_codec() <\
|
||||
asyncpg.connection.Connection.set_type_codec>`.
|
||||
|
||||
:param coroutine reset:
|
||||
A coroutine to reset a connection before it is returned to the pool by
|
||||
:meth:`Pool.release()`. The function is supposed
|
||||
to reset any changes made to the database session so that the next
|
||||
acquirer gets the connection in a well-defined state.
|
||||
|
||||
The default implementation calls :meth:`Connection.reset() <\
|
||||
asyncpg.connection.Connection.reset>`, which runs the following::
|
||||
|
||||
SELECT pg_advisory_unlock_all();
|
||||
CLOSE ALL;
|
||||
UNLISTEN *;
|
||||
RESET ALL;
|
||||
|
||||
The exact reset query is determined by detected server capabilities,
|
||||
and a custom *reset* implementation can obtain the default query
|
||||
by calling :meth:`Connection.get_reset_query() <\
|
||||
asyncpg.connection.Connection.get_reset_query>`.
|
||||
|
||||
:param loop:
|
||||
An asyncio event loop instance. If ``None``, the default
|
||||
event loop will be used.
|
||||
@@ -1190,22 +1119,12 @@ def create_pool(dsn=None, *,
|
||||
|
||||
.. versionchanged:: 0.22.0
|
||||
Added the *record_class* parameter.
|
||||
|
||||
.. versionchanged:: 0.30.0
|
||||
Added the *connect* and *reset* parameters.
|
||||
"""
|
||||
return Pool(
|
||||
dsn,
|
||||
connection_class=connection_class,
|
||||
record_class=record_class,
|
||||
min_size=min_size,
|
||||
max_size=max_size,
|
||||
max_queries=max_queries,
|
||||
loop=loop,
|
||||
connect=connect,
|
||||
setup=setup,
|
||||
init=init,
|
||||
reset=reset,
|
||||
min_size=min_size, max_size=max_size,
|
||||
max_queries=max_queries, loop=loop, setup=setup, init=init,
|
||||
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
|
||||
**connect_kwargs,
|
||||
)
|
||||
**connect_kwargs)
|
||||
|
||||
@@ -147,8 +147,8 @@ class PreparedStatement(connresource.ConnectionResource):
|
||||
# will discard any output that a SELECT would return, other
|
||||
# side effects of the statement will happen as usual. If you
|
||||
# wish to use EXPLAIN ANALYZE on an INSERT, UPDATE, DELETE,
|
||||
# MERGE, CREATE TABLE AS, or EXECUTE statement without letting
|
||||
# the command affect your data, use this approach:
|
||||
# CREATE TABLE AS, or EXECUTE statement without letting the
|
||||
# command affect your data, use this approach:
|
||||
# BEGIN;
|
||||
# EXPLAIN ANALYZE ...;
|
||||
# ROLLBACK;
|
||||
@@ -210,27 +210,6 @@ class PreparedStatement(connresource.ConnectionResource):
|
||||
return None
|
||||
return data[0]
|
||||
|
||||
@connresource.guarded
|
||||
async def fetchmany(self, args, *, timeout=None):
|
||||
"""Execute the statement and return a list of :class:`Record` objects.
|
||||
|
||||
:param args: Query arguments.
|
||||
:param float timeout: Optional timeout value in seconds.
|
||||
|
||||
:return: A list of :class:`Record` instances.
|
||||
|
||||
.. versionadded:: 0.30.0
|
||||
"""
|
||||
return await self.__do_execute(
|
||||
lambda protocol: protocol.bind_execute_many(
|
||||
self._state,
|
||||
args,
|
||||
portal_name='',
|
||||
timeout=timeout,
|
||||
return_rows=True,
|
||||
)
|
||||
)
|
||||
|
||||
@connresource.guarded
|
||||
async def executemany(self, args, *, timeout: float=None):
|
||||
"""Execute the statement for each sequence of arguments in *args*.
|
||||
@@ -243,12 +222,7 @@ class PreparedStatement(connresource.ConnectionResource):
|
||||
"""
|
||||
return await self.__do_execute(
|
||||
lambda protocol: protocol.bind_execute_many(
|
||||
self._state,
|
||||
args,
|
||||
portal_name='',
|
||||
timeout=timeout,
|
||||
return_rows=False,
|
||||
))
|
||||
self._state, args, '', timeout))
|
||||
|
||||
async def __do_execute(self, executor):
|
||||
protocol = self._connection._protocol
|
||||
|
||||
@@ -6,6 +6,4 @@
|
||||
|
||||
# flake8: NOQA
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP
|
||||
|
||||
@@ -483,7 +483,7 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
|
||||
|
||||
|
||||
cdef class DataCodecConfig:
|
||||
def __init__(self):
|
||||
def __init__(self, cache_key):
|
||||
# Codec instance cache for derived types:
|
||||
# composites, arrays, ranges, domains and their combinations.
|
||||
self._derived_type_codecs = {}
|
||||
|
||||
@@ -51,6 +51,16 @@ cdef enum AuthenticationMessage:
|
||||
AUTH_SASL_FINAL = 12
|
||||
|
||||
|
||||
AUTH_METHOD_NAME = {
|
||||
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
|
||||
AUTH_REQUIRED_PASSWORD: 'password',
|
||||
AUTH_REQUIRED_PASSWORDMD5: 'md5',
|
||||
AUTH_REQUIRED_GSS: 'gss',
|
||||
AUTH_REQUIRED_SASL: 'scram-sha-256',
|
||||
AUTH_REQUIRED_SSPI: 'sspi',
|
||||
}
|
||||
|
||||
|
||||
cdef enum ResultType:
|
||||
RESULT_OK = 1
|
||||
RESULT_FAILED = 2
|
||||
@@ -86,13 +96,10 @@ cdef class CoreProtocol:
|
||||
|
||||
object transport
|
||||
|
||||
object address
|
||||
# Instance of _ConnectionParameters
|
||||
object con_params
|
||||
# Instance of SCRAMAuthentication
|
||||
SCRAMAuthentication scram
|
||||
# Instance of gssapi.SecurityContext or sspilib.SecurityContext
|
||||
object gss_ctx
|
||||
|
||||
readonly int32_t backend_pid
|
||||
readonly int32_t backend_secret
|
||||
@@ -138,10 +145,6 @@ cdef class CoreProtocol:
|
||||
cdef _auth_password_message_md5(self, bytes salt)
|
||||
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
|
||||
cdef _auth_password_message_sasl_continue(self, bytes server_response)
|
||||
cdef _auth_gss_init_gssapi(self)
|
||||
cdef _auth_gss_init_sspi(self, bint negotiate)
|
||||
cdef _auth_gss_get_service(self)
|
||||
cdef _auth_gss_step(self, bytes server_response)
|
||||
|
||||
cdef _write(self, buf)
|
||||
cdef _writelines(self, list buffers)
|
||||
@@ -171,7 +174,7 @@ cdef class CoreProtocol:
|
||||
cdef _bind_execute(self, str portal_name, str stmt_name,
|
||||
WriteBuffer bind_data, int32_t limit)
|
||||
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
|
||||
object bind_data, bint return_rows)
|
||||
object bind_data)
|
||||
cdef bint _bind_execute_many_more(self, bint first=*)
|
||||
cdef _bind_execute_many_fail(self, object error, bint first=*)
|
||||
cdef _bind(self, str portal_name, str stmt_name,
|
||||
|
||||
@@ -11,20 +11,9 @@ import hashlib
|
||||
include "scram.pyx"
|
||||
|
||||
|
||||
cdef dict AUTH_METHOD_NAME = {
|
||||
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
|
||||
AUTH_REQUIRED_PASSWORD: 'password',
|
||||
AUTH_REQUIRED_PASSWORDMD5: 'md5',
|
||||
AUTH_REQUIRED_GSS: 'gss',
|
||||
AUTH_REQUIRED_SASL: 'scram-sha-256',
|
||||
AUTH_REQUIRED_SSPI: 'sspi',
|
||||
}
|
||||
|
||||
|
||||
cdef class CoreProtocol:
|
||||
|
||||
def __init__(self, addr, con_params):
|
||||
self.address = addr
|
||||
def __init__(self, con_params):
|
||||
# type of `con_params` is `_ConnectionParameters`
|
||||
self.buffer = ReadBuffer()
|
||||
self.user = con_params.user
|
||||
@@ -37,9 +26,6 @@ cdef class CoreProtocol:
|
||||
self.encoding = 'utf-8'
|
||||
# type of `scram` is `SCRAMAuthentcation`
|
||||
self.scram = None
|
||||
# type of `gss_ctx` is `gssapi.SecurityContext` or
|
||||
# `sspilib.SecurityContext`
|
||||
self.gss_ctx = None
|
||||
|
||||
self._reset_result()
|
||||
|
||||
@@ -633,35 +619,22 @@ cdef class CoreProtocol:
|
||||
'could not verify server signature for '
|
||||
'SCRAM authentciation: scram-sha-256',
|
||||
)
|
||||
self.scram = None
|
||||
|
||||
elif status in (AUTH_REQUIRED_GSS, AUTH_REQUIRED_SSPI):
|
||||
# AUTH_REQUIRED_SSPI is the same as AUTH_REQUIRED_GSS, except that
|
||||
# it uses protocol negotiation with SSPI clients. Both methods use
|
||||
# AUTH_REQUIRED_GSS_CONTINUE for subsequent authentication steps.
|
||||
if self.gss_ctx is not None:
|
||||
self.result_type = RESULT_FAILED
|
||||
self.result = apg_exc.InterfaceError(
|
||||
'duplicate GSSAPI/SSPI authentication request')
|
||||
else:
|
||||
if self.con_params.gsslib == 'gssapi':
|
||||
self._auth_gss_init_gssapi()
|
||||
else:
|
||||
self._auth_gss_init_sspi(status == AUTH_REQUIRED_SSPI)
|
||||
self.auth_msg = self._auth_gss_step(None)
|
||||
|
||||
elif status == AUTH_REQUIRED_GSS_CONTINUE:
|
||||
server_response = self.buffer.consume_message()
|
||||
self.auth_msg = self._auth_gss_step(server_response)
|
||||
elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
|
||||
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
|
||||
AUTH_REQUIRED_SSPI):
|
||||
self.result_type = RESULT_FAILED
|
||||
self.result = apg_exc.InterfaceError(
|
||||
'unsupported authentication method requested by the '
|
||||
'server: {!r}'.format(AUTH_METHOD_NAME[status]))
|
||||
|
||||
else:
|
||||
self.result_type = RESULT_FAILED
|
||||
self.result = apg_exc.InterfaceError(
|
||||
'unsupported authentication method requested by the '
|
||||
'server: {!r}'.format(AUTH_METHOD_NAME.get(status, status)))
|
||||
'server: {}'.format(status))
|
||||
|
||||
if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
|
||||
AUTH_REQUIRED_GSS_CONTINUE):
|
||||
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
|
||||
self.buffer.discard_message()
|
||||
|
||||
cdef _auth_password_message_cleartext(self):
|
||||
@@ -718,59 +691,6 @@ cdef class CoreProtocol:
|
||||
|
||||
return msg
|
||||
|
||||
cdef _auth_gss_init_gssapi(self):
|
||||
try:
|
||||
import gssapi
|
||||
except ModuleNotFoundError:
|
||||
raise apg_exc.InterfaceError(
|
||||
'gssapi module not found; please install asyncpg[gssauth] to '
|
||||
'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
|
||||
) from None
|
||||
|
||||
service_name, host = self._auth_gss_get_service()
|
||||
self.gss_ctx = gssapi.SecurityContext(
|
||||
name=gssapi.Name(
|
||||
f'{service_name}@{host}', gssapi.NameType.hostbased_service),
|
||||
usage='initiate')
|
||||
|
||||
cdef _auth_gss_init_sspi(self, bint negotiate):
|
||||
try:
|
||||
import sspilib
|
||||
except ModuleNotFoundError:
|
||||
raise apg_exc.InterfaceError(
|
||||
'sspilib module not found; please install asyncpg[gssauth] to '
|
||||
'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
|
||||
) from None
|
||||
|
||||
service_name, host = self._auth_gss_get_service()
|
||||
self.gss_ctx = sspilib.ClientSecurityContext(
|
||||
target_name=f'{service_name}/{host}',
|
||||
credential=sspilib.UserCredential(
|
||||
protocol='Negotiate' if negotiate else 'Kerberos'))
|
||||
|
||||
cdef _auth_gss_get_service(self):
|
||||
service_name = self.con_params.krbsrvname or 'postgres'
|
||||
if isinstance(self.address, str):
|
||||
raise apg_exc.InternalClientError(
|
||||
'GSSAPI/SSPI authentication is only supported for TCP/IP '
|
||||
'connections')
|
||||
|
||||
return service_name, self.address[0]
|
||||
|
||||
cdef _auth_gss_step(self, bytes server_response):
|
||||
cdef:
|
||||
WriteBuffer msg
|
||||
|
||||
token = self.gss_ctx.step(server_response)
|
||||
if not token:
|
||||
self.gss_ctx = None
|
||||
return None
|
||||
msg = WriteBuffer.new_message(b'p')
|
||||
msg.write_bytes(token)
|
||||
msg.end_message()
|
||||
|
||||
return msg
|
||||
|
||||
cdef _parse_msg_ready_for_query(self):
|
||||
cdef char status = self.buffer.read_byte()
|
||||
|
||||
@@ -1020,12 +940,12 @@ cdef class CoreProtocol:
|
||||
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
|
||||
|
||||
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
|
||||
object bind_data, bint return_rows):
|
||||
object bind_data):
|
||||
self._ensure_connected()
|
||||
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
|
||||
|
||||
self.result = [] if return_rows else None
|
||||
self._discard_data = not return_rows
|
||||
self.result = None
|
||||
self._discard_data = True
|
||||
self._execute_iter = bind_data
|
||||
self._execute_portal_name = portal_name
|
||||
self._execute_stmt_name = stmt_name
|
||||
|
||||
@@ -142,7 +142,7 @@ cdef class PreparedStatementState:
|
||||
# that the user tried to parametrize a statement that does
|
||||
# not support parameters.
|
||||
hint += (r' Note that parameters are supported only in'
|
||||
r' SELECT, INSERT, UPDATE, DELETE, MERGE and VALUES'
|
||||
r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
|
||||
r' statements, and will *not* work in statements '
|
||||
r' like CREATE VIEW or DECLARE CURSOR.')
|
||||
|
||||
|
||||
Binary file not shown.
@@ -31,6 +31,7 @@ cdef class BaseProtocol(CoreProtocol):
|
||||
|
||||
cdef:
|
||||
object loop
|
||||
object address
|
||||
ConnectionSettings settings
|
||||
object cancel_sent_waiter
|
||||
object cancel_waiter
|
||||
|
||||
@@ -1,300 +0,0 @@
|
||||
import asyncio
|
||||
import asyncio.protocols
|
||||
import hmac
|
||||
from codecs import CodecInfo
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
from hashlib import md5, sha256
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Final,
|
||||
Generic,
|
||||
Literal,
|
||||
NewType,
|
||||
TypeVar,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import asyncpg.pgproto.pgproto
|
||||
|
||||
from ..connect_utils import _ConnectionParameters
|
||||
from ..pgproto.pgproto import WriteBuffer
|
||||
from ..types import Attribute, Type
|
||||
|
||||
_T = TypeVar('_T')
|
||||
_Record = TypeVar('_Record', bound=Record)
|
||||
_OtherRecord = TypeVar('_OtherRecord', bound=Record)
|
||||
_PreparedStatementState = TypeVar(
|
||||
'_PreparedStatementState', bound=PreparedStatementState[Any]
|
||||
)
|
||||
|
||||
_NoTimeoutType = NewType('_NoTimeoutType', object)
|
||||
_TimeoutType: TypeAlias = float | None | _NoTimeoutType
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]]
|
||||
BUILTIN_TYPE_OID_MAP: Final[dict[int, str]]
|
||||
NO_TIMEOUT: Final[_NoTimeoutType]
|
||||
|
||||
hashlib_md5 = md5
|
||||
|
||||
@final
|
||||
class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext):
|
||||
__pyx_vtable__: Any
|
||||
def __init__(self, conn_key: object) -> None: ...
|
||||
def add_python_codec(
|
||||
self,
|
||||
typeoid: int,
|
||||
typename: str,
|
||||
typeschema: str,
|
||||
typeinfos: Iterable[object],
|
||||
typekind: str,
|
||||
encoder: Callable[[Any], Any],
|
||||
decoder: Callable[[Any], Any],
|
||||
format: object,
|
||||
) -> Any: ...
|
||||
def clear_type_cache(self) -> None: ...
|
||||
def get_data_codec(
|
||||
self, oid: int, format: object = ..., ignore_custom_codec: bool = ...
|
||||
) -> Any: ...
|
||||
def get_text_codec(self) -> CodecInfo: ...
|
||||
def register_data_types(self, types: Iterable[object]) -> None: ...
|
||||
def remove_python_codec(
|
||||
self, typeoid: int, typename: str, typeschema: str
|
||||
) -> None: ...
|
||||
def set_builtin_type_codec(
|
||||
self,
|
||||
typeoid: int,
|
||||
typename: str,
|
||||
typeschema: str,
|
||||
typekind: str,
|
||||
alias_to: str,
|
||||
format: object = ...,
|
||||
) -> Any: ...
|
||||
def __getattr__(self, name: str) -> Any: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
@final
|
||||
class PreparedStatementState(Generic[_Record]):
|
||||
closed: bool
|
||||
prepared: bool
|
||||
name: str
|
||||
query: str
|
||||
refs: int
|
||||
record_class: type[_Record]
|
||||
ignore_custom_codec: bool
|
||||
__pyx_vtable__: Any
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
query: str,
|
||||
protocol: BaseProtocol[Any],
|
||||
record_class: type[_Record],
|
||||
ignore_custom_codec: bool,
|
||||
) -> None: ...
|
||||
def _get_parameters(self) -> tuple[Type, ...]: ...
|
||||
def _get_attributes(self) -> tuple[Attribute, ...]: ...
|
||||
def _init_types(self) -> set[int]: ...
|
||||
def _init_codecs(self) -> None: ...
|
||||
def attach(self) -> None: ...
|
||||
def detach(self) -> None: ...
|
||||
def mark_closed(self) -> None: ...
|
||||
def mark_unprepared(self) -> None: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
class CoreProtocol:
|
||||
backend_pid: Any
|
||||
backend_secret: Any
|
||||
__pyx_vtable__: Any
|
||||
def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ...
|
||||
def is_in_transaction(self) -> bool: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
class BaseProtocol(CoreProtocol, Generic[_Record]):
|
||||
queries_count: Any
|
||||
is_ssl: bool
|
||||
__pyx_vtable__: Any
|
||||
def __init__(
|
||||
self,
|
||||
addr: object,
|
||||
connected_fut: object,
|
||||
con_params: _ConnectionParameters,
|
||||
record_class: type[_Record],
|
||||
loop: object,
|
||||
) -> None: ...
|
||||
def set_connection(self, connection: object) -> None: ...
|
||||
def get_server_pid(self, *args: object, **kwargs: object) -> int: ...
|
||||
def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ...
|
||||
def get_record_class(self) -> type[_Record]: ...
|
||||
def abort(self) -> None: ...
|
||||
async def bind(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Sequence[object],
|
||||
portal_name: str,
|
||||
timeout: _TimeoutType,
|
||||
) -> Any: ...
|
||||
@overload
|
||||
async def bind_execute(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Sequence[object],
|
||||
portal_name: str,
|
||||
limit: int,
|
||||
return_extra: Literal[False],
|
||||
timeout: _TimeoutType,
|
||||
) -> list[_OtherRecord]: ...
|
||||
@overload
|
||||
async def bind_execute(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Sequence[object],
|
||||
portal_name: str,
|
||||
limit: int,
|
||||
return_extra: Literal[True],
|
||||
timeout: _TimeoutType,
|
||||
) -> tuple[list[_OtherRecord], bytes, bool]: ...
|
||||
@overload
|
||||
async def bind_execute(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Sequence[object],
|
||||
portal_name: str,
|
||||
limit: int,
|
||||
return_extra: bool,
|
||||
timeout: _TimeoutType,
|
||||
) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ...
|
||||
async def bind_execute_many(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Iterable[Sequence[object]],
|
||||
portal_name: str,
|
||||
timeout: _TimeoutType,
|
||||
) -> None: ...
|
||||
async def close(self, timeout: _TimeoutType) -> None: ...
|
||||
def _get_timeout(self, timeout: _TimeoutType) -> float | None: ...
|
||||
def _is_cancelling(self) -> bool: ...
|
||||
async def _wait_for_cancellation(self) -> None: ...
|
||||
async def close_statement(
|
||||
self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType
|
||||
) -> Any: ...
|
||||
async def copy_in(self, *args: object, **kwargs: object) -> str: ...
|
||||
async def copy_out(self, *args: object, **kwargs: object) -> str: ...
|
||||
async def execute(self, *args: object, **kwargs: object) -> Any: ...
|
||||
def is_closed(self, *args: object, **kwargs: object) -> Any: ...
|
||||
def is_connected(self, *args: object, **kwargs: object) -> Any: ...
|
||||
def data_received(self, data: object) -> None: ...
|
||||
def connection_made(self, transport: object) -> None: ...
|
||||
def connection_lost(self, exc: Exception | None) -> None: ...
|
||||
def pause_writing(self, *args: object, **kwargs: object) -> Any: ...
|
||||
@overload
|
||||
async def prepare(
|
||||
self,
|
||||
stmt_name: str,
|
||||
query: str,
|
||||
timeout: float | None = ...,
|
||||
*,
|
||||
state: _PreparedStatementState,
|
||||
ignore_custom_codec: bool = ...,
|
||||
record_class: None,
|
||||
) -> _PreparedStatementState: ...
|
||||
@overload
|
||||
async def prepare(
|
||||
self,
|
||||
stmt_name: str,
|
||||
query: str,
|
||||
timeout: float | None = ...,
|
||||
*,
|
||||
state: None = ...,
|
||||
ignore_custom_codec: bool = ...,
|
||||
record_class: type[_OtherRecord],
|
||||
) -> PreparedStatementState[_OtherRecord]: ...
|
||||
async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ...
|
||||
async def query(self, *args: object, **kwargs: object) -> str: ...
|
||||
def resume_writing(self, *args: object, **kwargs: object) -> Any: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
@final
|
||||
class Codec:
|
||||
__pyx_vtable__: Any
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
class DataCodecConfig:
|
||||
__pyx_vtable__: Any
|
||||
def __init__(self) -> None: ...
|
||||
def add_python_codec(
|
||||
self,
|
||||
typeoid: int,
|
||||
typename: str,
|
||||
typeschema: str,
|
||||
typekind: str,
|
||||
typeinfos: Iterable[object],
|
||||
encoder: Callable[[ConnectionSettings, WriteBuffer, object], object],
|
||||
decoder: Callable[..., object],
|
||||
format: object,
|
||||
xformat: object,
|
||||
) -> Any: ...
|
||||
def add_types(self, types: Iterable[object]) -> Any: ...
|
||||
def clear_type_cache(self) -> None: ...
|
||||
def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ...
|
||||
def remove_python_codec(
|
||||
self, typeoid: int, typename: str, typeschema: str
|
||||
) -> Any: ...
|
||||
def set_builtin_type_codec(
|
||||
self,
|
||||
typeoid: int,
|
||||
typename: str,
|
||||
typeschema: str,
|
||||
typekind: str,
|
||||
alias_to: str,
|
||||
format: object = ...,
|
||||
) -> Any: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ...
|
||||
|
||||
class Record:
|
||||
@overload
|
||||
def get(self, key: str) -> Any | None: ...
|
||||
@overload
|
||||
def get(self, key: str, default: _T) -> Any | _T: ...
|
||||
def items(self) -> Iterator[tuple[str, Any]]: ...
|
||||
def keys(self) -> Iterator[str]: ...
|
||||
def values(self) -> Iterator[Any]: ...
|
||||
@overload
|
||||
def __getitem__(self, index: str) -> Any: ...
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> Any: ...
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> tuple[Any, ...]: ...
|
||||
def __iter__(self) -> Iterator[Any]: ...
|
||||
def __contains__(self, x: object) -> bool: ...
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
class Timer:
|
||||
def __init__(self, budget: float | None) -> None: ...
|
||||
def __enter__(self) -> None: ...
|
||||
def __exit__(self, et: object, e: object, tb: object) -> None: ...
|
||||
def get_remaining_budget(self) -> float: ...
|
||||
def has_budget_greater_than(self, amount: float) -> bool: ...
|
||||
|
||||
@final
|
||||
class SCRAMAuthentication:
|
||||
AUTHENTICATION_METHODS: ClassVar[list[str]]
|
||||
DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int]
|
||||
DIGEST = sha256
|
||||
REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]]
|
||||
REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]]
|
||||
SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]]
|
||||
authentication_method: bytes
|
||||
authorization_message: bytes | None
|
||||
client_channel_binding: bytes
|
||||
client_first_message_bare: bytes | None
|
||||
client_nonce: bytes | None
|
||||
client_proof: bytes | None
|
||||
password_salt: bytes | None
|
||||
password_iterations: int
|
||||
server_first_message: bytes | None
|
||||
server_key: hmac.HMAC | None
|
||||
server_nonce: bytes | None
|
||||
@@ -75,7 +75,7 @@ NO_TIMEOUT = object()
|
||||
cdef class BaseProtocol(CoreProtocol):
|
||||
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
|
||||
# type of `con_params` is `_ConnectionParameters`
|
||||
CoreProtocol.__init__(self, addr, con_params)
|
||||
CoreProtocol.__init__(self, con_params)
|
||||
|
||||
self.loop = loop
|
||||
self.transport = None
|
||||
@@ -83,7 +83,8 @@ cdef class BaseProtocol(CoreProtocol):
|
||||
self.cancel_waiter = None
|
||||
self.cancel_sent_waiter = None
|
||||
|
||||
self.settings = ConnectionSettings((addr, con_params.database))
|
||||
self.address = addr
|
||||
self.settings = ConnectionSettings((self.address, con_params.database))
|
||||
self.record_class = record_class
|
||||
|
||||
self.statement = None
|
||||
@@ -212,7 +213,6 @@ cdef class BaseProtocol(CoreProtocol):
|
||||
args,
|
||||
portal_name: str,
|
||||
timeout,
|
||||
return_rows: bool,
|
||||
):
|
||||
if self.cancel_waiter is not None:
|
||||
await self.cancel_waiter
|
||||
@@ -238,8 +238,7 @@ cdef class BaseProtocol(CoreProtocol):
|
||||
more = self._bind_execute_many(
|
||||
portal_name,
|
||||
state.name,
|
||||
arg_bufs,
|
||||
return_rows) # network op
|
||||
arg_bufs) # network op
|
||||
|
||||
self.last_query = state.query
|
||||
self.statement = state
|
||||
|
||||
@@ -11,12 +11,12 @@ from asyncpg import exceptions
|
||||
@cython.final
|
||||
cdef class ConnectionSettings(pgproto.CodecContext):
|
||||
|
||||
def __cinit__(self):
|
||||
def __cinit__(self, conn_key):
|
||||
self._encoding = 'utf-8'
|
||||
self._is_utf8 = True
|
||||
self._settings = {}
|
||||
self._codec = codecs.lookup('utf-8')
|
||||
self._data_codecs = DataCodecConfig()
|
||||
self._data_codecs = DataCodecConfig(conn_key)
|
||||
|
||||
cdef add_setting(self, str name, str val):
|
||||
self._settings[name] = val
|
||||
|
||||
@@ -4,14 +4,12 @@
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import typing
|
||||
|
||||
from .types import ServerVersion
|
||||
|
||||
version_regex: typing.Final = re.compile(
|
||||
version_regex = re.compile(
|
||||
r"(Postgre[^\s]*)?\s*"
|
||||
r"(?P<major>[0-9]+)\.?"
|
||||
r"((?P<minor>[0-9]+)\.?)?"
|
||||
@@ -21,15 +19,7 @@ version_regex: typing.Final = re.compile(
|
||||
)
|
||||
|
||||
|
||||
class _VersionDict(typing.TypedDict):
|
||||
major: int
|
||||
minor: int | None
|
||||
micro: int | None
|
||||
releaselevel: str | None
|
||||
serial: int | None
|
||||
|
||||
|
||||
def split_server_version_string(version_string: str) -> ServerVersion:
|
||||
def split_server_version_string(version_string):
|
||||
version_match = version_regex.search(version_string)
|
||||
|
||||
if version_match is None:
|
||||
@@ -38,17 +28,17 @@ def split_server_version_string(version_string: str) -> ServerVersion:
|
||||
f'version from "{version_string}"'
|
||||
)
|
||||
|
||||
version: _VersionDict = version_match.groupdict() # type: ignore[assignment] # noqa: E501
|
||||
version = version_match.groupdict()
|
||||
for ver_key, ver_value in version.items():
|
||||
# Cast all possible versions parts to int
|
||||
try:
|
||||
version[ver_key] = int(ver_value) # type: ignore[literal-required, call-overload] # noqa: E501
|
||||
version[ver_key] = int(ver_value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if version["major"] < 10:
|
||||
if version.get("major") < 10:
|
||||
return ServerVersion(
|
||||
version["major"],
|
||||
version.get("major"),
|
||||
version.get("minor") or 0,
|
||||
version.get("micro") or 0,
|
||||
version.get("releaselevel") or "final",
|
||||
@@ -62,7 +52,7 @@ def split_server_version_string(version_string: str) -> ServerVersion:
|
||||
# want to keep that behaviour consistent, i.e not fail
|
||||
# a major version check due to a bugfix release.
|
||||
return ServerVersion(
|
||||
version["major"],
|
||||
version.get("major"),
|
||||
0,
|
||||
version.get("minor") or 0,
|
||||
version.get("releaselevel") or "final",
|
||||
|
||||
@@ -4,18 +4,14 @@
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import collections
|
||||
|
||||
from asyncpg.pgproto.types import (
|
||||
BitString, Point, Path, Polygon,
|
||||
Box, Line, LineSegment, Circle,
|
||||
)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
|
||||
@@ -23,13 +19,7 @@ __all__ = (
|
||||
)
|
||||
|
||||
|
||||
class Type(typing.NamedTuple):
|
||||
oid: int
|
||||
name: str
|
||||
kind: str
|
||||
schema: str
|
||||
|
||||
|
||||
Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema'])
|
||||
Type.__doc__ = 'Database data type.'
|
||||
Type.oid.__doc__ = 'OID of the type.'
|
||||
Type.name.__doc__ = 'Type name. For example "int2".'
|
||||
@@ -38,61 +28,25 @@ Type.kind.__doc__ = \
|
||||
Type.schema.__doc__ = 'Name of the database schema that defines the type.'
|
||||
|
||||
|
||||
class Attribute(typing.NamedTuple):
|
||||
name: str
|
||||
type: Type
|
||||
|
||||
|
||||
Attribute = collections.namedtuple('Attribute', ['name', 'type'])
|
||||
Attribute.__doc__ = 'Database relation attribute.'
|
||||
Attribute.name.__doc__ = 'Attribute name.'
|
||||
Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.'
|
||||
|
||||
|
||||
class ServerVersion(typing.NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
micro: int
|
||||
releaselevel: str
|
||||
serial: int
|
||||
|
||||
|
||||
ServerVersion = collections.namedtuple(
|
||||
'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial'])
|
||||
ServerVersion.__doc__ = 'PostgreSQL server version tuple.'
|
||||
|
||||
|
||||
class _RangeValue(typing.Protocol):
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
...
|
||||
|
||||
def __lt__(self, __other: _RangeValue) -> bool:
|
||||
...
|
||||
|
||||
def __gt__(self, __other: _RangeValue) -> bool:
|
||||
...
|
||||
|
||||
|
||||
_RV = typing.TypeVar('_RV', bound=_RangeValue)
|
||||
|
||||
|
||||
class Range(typing.Generic[_RV]):
|
||||
class Range:
|
||||
"""Immutable representation of PostgreSQL `range` type."""
|
||||
|
||||
__slots__ = ('_lower', '_upper', '_lower_inc', '_upper_inc', '_empty')
|
||||
__slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty'
|
||||
|
||||
_lower: _RV | None
|
||||
_upper: _RV | None
|
||||
_lower_inc: bool
|
||||
_upper_inc: bool
|
||||
_empty: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower: _RV | None = None,
|
||||
upper: _RV | None = None,
|
||||
*,
|
||||
lower_inc: bool = True,
|
||||
upper_inc: bool = False,
|
||||
empty: bool = False
|
||||
) -> None:
|
||||
def __init__(self, lower=None, upper=None, *,
|
||||
lower_inc=True, upper_inc=False,
|
||||
empty=False):
|
||||
self._empty = empty
|
||||
if empty:
|
||||
self._lower = self._upper = None
|
||||
@@ -104,34 +58,34 @@ class Range(typing.Generic[_RV]):
|
||||
self._upper_inc = upper is not None and upper_inc
|
||||
|
||||
@property
|
||||
def lower(self) -> _RV | None:
|
||||
def lower(self):
|
||||
return self._lower
|
||||
|
||||
@property
|
||||
def lower_inc(self) -> bool:
|
||||
def lower_inc(self):
|
||||
return self._lower_inc
|
||||
|
||||
@property
|
||||
def lower_inf(self) -> bool:
|
||||
def lower_inf(self):
|
||||
return self._lower is None and not self._empty
|
||||
|
||||
@property
|
||||
def upper(self) -> _RV | None:
|
||||
def upper(self):
|
||||
return self._upper
|
||||
|
||||
@property
|
||||
def upper_inc(self) -> bool:
|
||||
def upper_inc(self):
|
||||
return self._upper_inc
|
||||
|
||||
@property
|
||||
def upper_inf(self) -> bool:
|
||||
def upper_inf(self):
|
||||
return self._upper is None and not self._empty
|
||||
|
||||
@property
|
||||
def isempty(self) -> bool:
|
||||
def isempty(self):
|
||||
return self._empty
|
||||
|
||||
def _issubset_lower(self, other: Self) -> bool:
|
||||
def _issubset_lower(self, other):
|
||||
if other._lower is None:
|
||||
return True
|
||||
if self._lower is None:
|
||||
@@ -142,7 +96,7 @@ class Range(typing.Generic[_RV]):
|
||||
and (other._lower_inc or not self._lower_inc)
|
||||
)
|
||||
|
||||
def _issubset_upper(self, other: Self) -> bool:
|
||||
def _issubset_upper(self, other):
|
||||
if other._upper is None:
|
||||
return True
|
||||
if self._upper is None:
|
||||
@@ -153,7 +107,7 @@ class Range(typing.Generic[_RV]):
|
||||
and (other._upper_inc or not self._upper_inc)
|
||||
)
|
||||
|
||||
def issubset(self, other: Self) -> bool:
|
||||
def issubset(self, other):
|
||||
if self._empty:
|
||||
return True
|
||||
if other._empty:
|
||||
@@ -161,13 +115,13 @@ class Range(typing.Generic[_RV]):
|
||||
|
||||
return self._issubset_lower(other) and self._issubset_upper(other)
|
||||
|
||||
def issuperset(self, other: Self) -> bool:
|
||||
def issuperset(self, other):
|
||||
return other.issubset(self)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
def __bool__(self):
|
||||
return not self._empty
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Range):
|
||||
return NotImplemented
|
||||
|
||||
@@ -178,14 +132,14 @@ class Range(typing.Generic[_RV]):
|
||||
self._upper_inc,
|
||||
self._empty
|
||||
) == (
|
||||
other._lower, # pyright: ignore [reportUnknownMemberType]
|
||||
other._upper, # pyright: ignore [reportUnknownMemberType]
|
||||
other._lower,
|
||||
other._upper,
|
||||
other._lower_inc,
|
||||
other._upper_inc,
|
||||
other._empty
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self._lower,
|
||||
self._upper,
|
||||
@@ -194,7 +148,7 @@ class Range(typing.Generic[_RV]):
|
||||
self._empty
|
||||
))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self):
|
||||
if self._empty:
|
||||
desc = 'empty'
|
||||
else:
|
||||
|
||||
@@ -42,11 +42,4 @@ async def _mogrify(conn, query, args):
|
||||
|
||||
# Finally, replace $n references with text values.
|
||||
return re.sub(
|
||||
r"\$(\d+)\b",
|
||||
lambda m: (
|
||||
textified[int(m.group(1)) - 1]
|
||||
if textified[int(m.group(1)) - 1] is not None
|
||||
else "NULL"
|
||||
),
|
||||
query,
|
||||
)
|
||||
r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query)
|
||||
|
||||
Reference in New Issue
Block a user