main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -6,6 +6,4 @@
# flake8: NOQA
from __future__ import annotations
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP

View File

@@ -483,7 +483,7 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
cdef class DataCodecConfig:
def __init__(self):
def __init__(self, cache_key):
# Codec instance cache for derived types:
# composites, arrays, ranges, domains and their combinations.
self._derived_type_codecs = {}

View File

@@ -51,6 +51,16 @@ cdef enum AuthenticationMessage:
AUTH_SASL_FINAL = 12
AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}
cdef enum ResultType:
RESULT_OK = 1
RESULT_FAILED = 2
@@ -86,13 +96,10 @@ cdef class CoreProtocol:
object transport
object address
# Instance of _ConnectionParameters
object con_params
# Instance of SCRAMAuthentication
SCRAMAuthentication scram
# Instance of gssapi.SecurityContext or sspilib.SecurityContext
object gss_ctx
readonly int32_t backend_pid
readonly int32_t backend_secret
@@ -138,10 +145,6 @@ cdef class CoreProtocol:
cdef _auth_password_message_md5(self, bytes salt)
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
cdef _auth_password_message_sasl_continue(self, bytes server_response)
cdef _auth_gss_init_gssapi(self)
cdef _auth_gss_init_sspi(self, bint negotiate)
cdef _auth_gss_get_service(self)
cdef _auth_gss_step(self, bytes server_response)
cdef _write(self, buf)
cdef _writelines(self, list buffers)
@@ -171,7 +174,7 @@ cdef class CoreProtocol:
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data, bint return_rows)
object bind_data)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,

View File

@@ -11,20 +11,9 @@ import hashlib
include "scram.pyx"
cdef dict AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}
cdef class CoreProtocol:
def __init__(self, addr, con_params):
self.address = addr
def __init__(self, con_params):
# type of `con_params` is `_ConnectionParameters`
self.buffer = ReadBuffer()
self.user = con_params.user
@@ -37,9 +26,6 @@ cdef class CoreProtocol:
self.encoding = 'utf-8'
# type of `scram` is `SCRAMAuthentcation`
self.scram = None
# type of `gss_ctx` is `gssapi.SecurityContext` or
# `sspilib.SecurityContext`
self.gss_ctx = None
self._reset_result()
@@ -633,35 +619,22 @@ cdef class CoreProtocol:
'could not verify server signature for '
'SCRAM authentciation: scram-sha-256',
)
self.scram = None
elif status in (AUTH_REQUIRED_GSS, AUTH_REQUIRED_SSPI):
# AUTH_REQUIRED_SSPI is the same as AUTH_REQUIRED_GSS, except that
# it uses protocol negotiation with SSPI clients. Both methods use
# AUTH_REQUIRED_GSS_CONTINUE for subsequent authentication steps.
if self.gss_ctx is not None:
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
'duplicate GSSAPI/SSPI authentication request')
else:
if self.con_params.gsslib == 'gssapi':
self._auth_gss_init_gssapi()
else:
self._auth_gss_init_sspi(status == AUTH_REQUIRED_SSPI)
self.auth_msg = self._auth_gss_step(None)
elif status == AUTH_REQUIRED_GSS_CONTINUE:
server_response = self.buffer.consume_message()
self.auth_msg = self._auth_gss_step(server_response)
elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
AUTH_REQUIRED_SSPI):
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
'unsupported authentication method requested by the '
'server: {!r}'.format(AUTH_METHOD_NAME[status]))
else:
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
'unsupported authentication method requested by the '
'server: {!r}'.format(AUTH_METHOD_NAME.get(status, status)))
'server: {}'.format(status))
if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
AUTH_REQUIRED_GSS_CONTINUE):
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
self.buffer.discard_message()
cdef _auth_password_message_cleartext(self):
@@ -718,59 +691,6 @@ cdef class CoreProtocol:
return msg
cdef _auth_gss_init_gssapi(self):
try:
import gssapi
except ModuleNotFoundError:
raise apg_exc.InterfaceError(
'gssapi module not found; please install asyncpg[gssauth] to '
'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
) from None
service_name, host = self._auth_gss_get_service()
self.gss_ctx = gssapi.SecurityContext(
name=gssapi.Name(
f'{service_name}@{host}', gssapi.NameType.hostbased_service),
usage='initiate')
cdef _auth_gss_init_sspi(self, bint negotiate):
try:
import sspilib
except ModuleNotFoundError:
raise apg_exc.InterfaceError(
'sspilib module not found; please install asyncpg[gssauth] to '
'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
) from None
service_name, host = self._auth_gss_get_service()
self.gss_ctx = sspilib.ClientSecurityContext(
target_name=f'{service_name}/{host}',
credential=sspilib.UserCredential(
protocol='Negotiate' if negotiate else 'Kerberos'))
cdef _auth_gss_get_service(self):
service_name = self.con_params.krbsrvname or 'postgres'
if isinstance(self.address, str):
raise apg_exc.InternalClientError(
'GSSAPI/SSPI authentication is only supported for TCP/IP '
'connections')
return service_name, self.address[0]
cdef _auth_gss_step(self, bytes server_response):
cdef:
WriteBuffer msg
token = self.gss_ctx.step(server_response)
if not token:
self.gss_ctx = None
return None
msg = WriteBuffer.new_message(b'p')
msg.write_bytes(token)
msg.end_message()
return msg
cdef _parse_msg_ready_for_query(self):
cdef char status = self.buffer.read_byte()
@@ -1020,12 +940,12 @@ cdef class CoreProtocol:
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data, bint return_rows):
object bind_data):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
self.result = [] if return_rows else None
self._discard_data = not return_rows
self.result = None
self._discard_data = True
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name

View File

@@ -142,7 +142,7 @@ cdef class PreparedStatementState:
# that the user tried to parametrize a statement that does
# not support parameters.
hint += (r' Note that parameters are supported only in'
r' SELECT, INSERT, UPDATE, DELETE, MERGE and VALUES'
r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
r' statements, and will *not* work in statements '
r' like CREATE VIEW or DECLARE CURSOR.')

View File

@@ -31,6 +31,7 @@ cdef class BaseProtocol(CoreProtocol):
cdef:
object loop
object address
ConnectionSettings settings
object cancel_sent_waiter
object cancel_waiter

View File

@@ -1,300 +0,0 @@
import asyncio
import asyncio.protocols
import hmac
from codecs import CodecInfo
from collections.abc import Callable, Iterable, Iterator, Sequence
from hashlib import md5, sha256
from typing import (
Any,
ClassVar,
Final,
Generic,
Literal,
NewType,
TypeVar,
final,
overload,
)
from typing_extensions import TypeAlias
import asyncpg.pgproto.pgproto
from ..connect_utils import _ConnectionParameters
from ..pgproto.pgproto import WriteBuffer
from ..types import Attribute, Type
_T = TypeVar('_T')
_Record = TypeVar('_Record', bound=Record)
_OtherRecord = TypeVar('_OtherRecord', bound=Record)
_PreparedStatementState = TypeVar(
'_PreparedStatementState', bound=PreparedStatementState[Any]
)
_NoTimeoutType = NewType('_NoTimeoutType', object)
_TimeoutType: TypeAlias = float | None | _NoTimeoutType
BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]]
BUILTIN_TYPE_OID_MAP: Final[dict[int, str]]
NO_TIMEOUT: Final[_NoTimeoutType]
hashlib_md5 = md5
@final
class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext):
__pyx_vtable__: Any
def __init__(self, conn_key: object) -> None: ...
def add_python_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typeinfos: Iterable[object],
typekind: str,
encoder: Callable[[Any], Any],
decoder: Callable[[Any], Any],
format: object,
) -> Any: ...
def clear_type_cache(self) -> None: ...
def get_data_codec(
self, oid: int, format: object = ..., ignore_custom_codec: bool = ...
) -> Any: ...
def get_text_codec(self) -> CodecInfo: ...
def register_data_types(self, types: Iterable[object]) -> None: ...
def remove_python_codec(
self, typeoid: int, typename: str, typeschema: str
) -> None: ...
def set_builtin_type_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typekind: str,
alias_to: str,
format: object = ...,
) -> Any: ...
def __getattr__(self, name: str) -> Any: ...
def __reduce__(self) -> Any: ...
@final
class PreparedStatementState(Generic[_Record]):
closed: bool
prepared: bool
name: str
query: str
refs: int
record_class: type[_Record]
ignore_custom_codec: bool
__pyx_vtable__: Any
def __init__(
self,
name: str,
query: str,
protocol: BaseProtocol[Any],
record_class: type[_Record],
ignore_custom_codec: bool,
) -> None: ...
def _get_parameters(self) -> tuple[Type, ...]: ...
def _get_attributes(self) -> tuple[Attribute, ...]: ...
def _init_types(self) -> set[int]: ...
def _init_codecs(self) -> None: ...
def attach(self) -> None: ...
def detach(self) -> None: ...
def mark_closed(self) -> None: ...
def mark_unprepared(self) -> None: ...
def __reduce__(self) -> Any: ...
class CoreProtocol:
backend_pid: Any
backend_secret: Any
__pyx_vtable__: Any
def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ...
def is_in_transaction(self) -> bool: ...
def __reduce__(self) -> Any: ...
class BaseProtocol(CoreProtocol, Generic[_Record]):
queries_count: Any
is_ssl: bool
__pyx_vtable__: Any
def __init__(
self,
addr: object,
connected_fut: object,
con_params: _ConnectionParameters,
record_class: type[_Record],
loop: object,
) -> None: ...
def set_connection(self, connection: object) -> None: ...
def get_server_pid(self, *args: object, **kwargs: object) -> int: ...
def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ...
def get_record_class(self) -> type[_Record]: ...
def abort(self) -> None: ...
async def bind(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
timeout: _TimeoutType,
) -> Any: ...
@overload
async def bind_execute(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
limit: int,
return_extra: Literal[False],
timeout: _TimeoutType,
) -> list[_OtherRecord]: ...
@overload
async def bind_execute(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
limit: int,
return_extra: Literal[True],
timeout: _TimeoutType,
) -> tuple[list[_OtherRecord], bytes, bool]: ...
@overload
async def bind_execute(
self,
state: PreparedStatementState[_OtherRecord],
args: Sequence[object],
portal_name: str,
limit: int,
return_extra: bool,
timeout: _TimeoutType,
) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ...
async def bind_execute_many(
self,
state: PreparedStatementState[_OtherRecord],
args: Iterable[Sequence[object]],
portal_name: str,
timeout: _TimeoutType,
) -> None: ...
async def close(self, timeout: _TimeoutType) -> None: ...
def _get_timeout(self, timeout: _TimeoutType) -> float | None: ...
def _is_cancelling(self) -> bool: ...
async def _wait_for_cancellation(self) -> None: ...
async def close_statement(
self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType
) -> Any: ...
async def copy_in(self, *args: object, **kwargs: object) -> str: ...
async def copy_out(self, *args: object, **kwargs: object) -> str: ...
async def execute(self, *args: object, **kwargs: object) -> Any: ...
def is_closed(self, *args: object, **kwargs: object) -> Any: ...
def is_connected(self, *args: object, **kwargs: object) -> Any: ...
def data_received(self, data: object) -> None: ...
def connection_made(self, transport: object) -> None: ...
def connection_lost(self, exc: Exception | None) -> None: ...
def pause_writing(self, *args: object, **kwargs: object) -> Any: ...
@overload
async def prepare(
self,
stmt_name: str,
query: str,
timeout: float | None = ...,
*,
state: _PreparedStatementState,
ignore_custom_codec: bool = ...,
record_class: None,
) -> _PreparedStatementState: ...
@overload
async def prepare(
self,
stmt_name: str,
query: str,
timeout: float | None = ...,
*,
state: None = ...,
ignore_custom_codec: bool = ...,
record_class: type[_OtherRecord],
) -> PreparedStatementState[_OtherRecord]: ...
async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ...
async def query(self, *args: object, **kwargs: object) -> str: ...
def resume_writing(self, *args: object, **kwargs: object) -> Any: ...
def __reduce__(self) -> Any: ...
@final
class Codec:
__pyx_vtable__: Any
def __reduce__(self) -> Any: ...
class DataCodecConfig:
__pyx_vtable__: Any
def __init__(self) -> None: ...
def add_python_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typekind: str,
typeinfos: Iterable[object],
encoder: Callable[[ConnectionSettings, WriteBuffer, object], object],
decoder: Callable[..., object],
format: object,
xformat: object,
) -> Any: ...
def add_types(self, types: Iterable[object]) -> Any: ...
def clear_type_cache(self) -> None: ...
def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ...
def remove_python_codec(
self, typeoid: int, typename: str, typeschema: str
) -> Any: ...
def set_builtin_type_codec(
self,
typeoid: int,
typename: str,
typeschema: str,
typekind: str,
alias_to: str,
format: object = ...,
) -> Any: ...
def __reduce__(self) -> Any: ...
class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ...
class Record:
@overload
def get(self, key: str) -> Any | None: ...
@overload
def get(self, key: str, default: _T) -> Any | _T: ...
def items(self) -> Iterator[tuple[str, Any]]: ...
def keys(self) -> Iterator[str]: ...
def values(self) -> Iterator[Any]: ...
@overload
def __getitem__(self, index: str) -> Any: ...
@overload
def __getitem__(self, index: int) -> Any: ...
@overload
def __getitem__(self, index: slice) -> tuple[Any, ...]: ...
def __iter__(self) -> Iterator[Any]: ...
def __contains__(self, x: object) -> bool: ...
def __len__(self) -> int: ...
class Timer:
def __init__(self, budget: float | None) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, et: object, e: object, tb: object) -> None: ...
def get_remaining_budget(self) -> float: ...
def has_budget_greater_than(self, amount: float) -> bool: ...
@final
class SCRAMAuthentication:
AUTHENTICATION_METHODS: ClassVar[list[str]]
DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int]
DIGEST = sha256
REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]]
REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]]
SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]]
authentication_method: bytes
authorization_message: bytes | None
client_channel_binding: bytes
client_first_message_bare: bytes | None
client_nonce: bytes | None
client_proof: bytes | None
password_salt: bytes | None
password_iterations: int
server_first_message: bytes | None
server_key: hmac.HMAC | None
server_nonce: bytes | None

View File

@@ -75,7 +75,7 @@ NO_TIMEOUT = object()
cdef class BaseProtocol(CoreProtocol):
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
# type of `con_params` is `_ConnectionParameters`
CoreProtocol.__init__(self, addr, con_params)
CoreProtocol.__init__(self, con_params)
self.loop = loop
self.transport = None
@@ -83,7 +83,8 @@ cdef class BaseProtocol(CoreProtocol):
self.cancel_waiter = None
self.cancel_sent_waiter = None
self.settings = ConnectionSettings((addr, con_params.database))
self.address = addr
self.settings = ConnectionSettings((self.address, con_params.database))
self.record_class = record_class
self.statement = None
@@ -212,7 +213,6 @@ cdef class BaseProtocol(CoreProtocol):
args,
portal_name: str,
timeout,
return_rows: bool,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
@@ -238,8 +238,7 @@ cdef class BaseProtocol(CoreProtocol):
more = self._bind_execute_many(
portal_name,
state.name,
arg_bufs,
return_rows) # network op
arg_bufs) # network op
self.last_query = state.query
self.statement = state

View File

@@ -11,12 +11,12 @@ from asyncpg import exceptions
@cython.final
cdef class ConnectionSettings(pgproto.CodecContext):
def __cinit__(self):
def __cinit__(self, conn_key):
self._encoding = 'utf-8'
self._is_utf8 = True
self._settings = {}
self._codec = codecs.lookup('utf-8')
self._data_codecs = DataCodecConfig()
self._data_codecs = DataCodecConfig(conn_key)
cdef add_setting(self, str name, str val):
self._settings[name] = val