This commit is contained in:
@@ -6,6 +6,4 @@
|
||||
|
||||
# flake8: NOQA
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP
|
||||
|
||||
@@ -483,7 +483,7 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
|
||||
|
||||
|
||||
cdef class DataCodecConfig:
|
||||
def __init__(self):
|
||||
def __init__(self, cache_key):
|
||||
# Codec instance cache for derived types:
|
||||
# composites, arrays, ranges, domains and their combinations.
|
||||
self._derived_type_codecs = {}
|
||||
|
||||
@@ -51,6 +51,16 @@ cdef enum AuthenticationMessage:
|
||||
AUTH_SASL_FINAL = 12
|
||||
|
||||
|
||||
AUTH_METHOD_NAME = {
|
||||
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
|
||||
AUTH_REQUIRED_PASSWORD: 'password',
|
||||
AUTH_REQUIRED_PASSWORDMD5: 'md5',
|
||||
AUTH_REQUIRED_GSS: 'gss',
|
||||
AUTH_REQUIRED_SASL: 'scram-sha-256',
|
||||
AUTH_REQUIRED_SSPI: 'sspi',
|
||||
}
|
||||
|
||||
|
||||
cdef enum ResultType:
|
||||
RESULT_OK = 1
|
||||
RESULT_FAILED = 2
|
||||
@@ -86,13 +96,10 @@ cdef class CoreProtocol:
|
||||
|
||||
object transport
|
||||
|
||||
object address
|
||||
# Instance of _ConnectionParameters
|
||||
object con_params
|
||||
# Instance of SCRAMAuthentication
|
||||
SCRAMAuthentication scram
|
||||
# Instance of gssapi.SecurityContext or sspilib.SecurityContext
|
||||
object gss_ctx
|
||||
|
||||
readonly int32_t backend_pid
|
||||
readonly int32_t backend_secret
|
||||
@@ -138,10 +145,6 @@ cdef class CoreProtocol:
|
||||
cdef _auth_password_message_md5(self, bytes salt)
|
||||
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
|
||||
cdef _auth_password_message_sasl_continue(self, bytes server_response)
|
||||
cdef _auth_gss_init_gssapi(self)
|
||||
cdef _auth_gss_init_sspi(self, bint negotiate)
|
||||
cdef _auth_gss_get_service(self)
|
||||
cdef _auth_gss_step(self, bytes server_response)
|
||||
|
||||
cdef _write(self, buf)
|
||||
cdef _writelines(self, list buffers)
|
||||
@@ -171,7 +174,7 @@ cdef class CoreProtocol:
|
||||
cdef _bind_execute(self, str portal_name, str stmt_name,
|
||||
WriteBuffer bind_data, int32_t limit)
|
||||
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
|
||||
object bind_data, bint return_rows)
|
||||
object bind_data)
|
||||
cdef bint _bind_execute_many_more(self, bint first=*)
|
||||
cdef _bind_execute_many_fail(self, object error, bint first=*)
|
||||
cdef _bind(self, str portal_name, str stmt_name,
|
||||
|
||||
@@ -11,20 +11,9 @@ import hashlib
|
||||
include "scram.pyx"
|
||||
|
||||
|
||||
cdef dict AUTH_METHOD_NAME = {
|
||||
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
|
||||
AUTH_REQUIRED_PASSWORD: 'password',
|
||||
AUTH_REQUIRED_PASSWORDMD5: 'md5',
|
||||
AUTH_REQUIRED_GSS: 'gss',
|
||||
AUTH_REQUIRED_SASL: 'scram-sha-256',
|
||||
AUTH_REQUIRED_SSPI: 'sspi',
|
||||
}
|
||||
|
||||
|
||||
cdef class CoreProtocol:
|
||||
|
||||
def __init__(self, addr, con_params):
|
||||
self.address = addr
|
||||
def __init__(self, con_params):
|
||||
# type of `con_params` is `_ConnectionParameters`
|
||||
self.buffer = ReadBuffer()
|
||||
self.user = con_params.user
|
||||
@@ -37,9 +26,6 @@ cdef class CoreProtocol:
|
||||
self.encoding = 'utf-8'
|
||||
# type of `scram` is `SCRAMAuthentcation`
|
||||
self.scram = None
|
||||
# type of `gss_ctx` is `gssapi.SecurityContext` or
|
||||
# `sspilib.SecurityContext`
|
||||
self.gss_ctx = None
|
||||
|
||||
self._reset_result()
|
||||
|
||||
@@ -633,35 +619,22 @@ cdef class CoreProtocol:
|
||||
'could not verify server signature for '
|
||||
'SCRAM authentciation: scram-sha-256',
|
||||
)
|
||||
self.scram = None
|
||||
|
||||
elif status in (AUTH_REQUIRED_GSS, AUTH_REQUIRED_SSPI):
|
||||
# AUTH_REQUIRED_SSPI is the same as AUTH_REQUIRED_GSS, except that
|
||||
# it uses protocol negotiation with SSPI clients. Both methods use
|
||||
# AUTH_REQUIRED_GSS_CONTINUE for subsequent authentication steps.
|
||||
if self.gss_ctx is not None:
|
||||
self.result_type = RESULT_FAILED
|
||||
self.result = apg_exc.InterfaceError(
|
||||
'duplicate GSSAPI/SSPI authentication request')
|
||||
else:
|
||||
if self.con_params.gsslib == 'gssapi':
|
||||
self._auth_gss_init_gssapi()
|
||||
else:
|
||||
self._auth_gss_init_sspi(status == AUTH_REQUIRED_SSPI)
|
||||
self.auth_msg = self._auth_gss_step(None)
|
||||
|
||||
elif status == AUTH_REQUIRED_GSS_CONTINUE:
|
||||
server_response = self.buffer.consume_message()
|
||||
self.auth_msg = self._auth_gss_step(server_response)
|
||||
elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
|
||||
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
|
||||
AUTH_REQUIRED_SSPI):
|
||||
self.result_type = RESULT_FAILED
|
||||
self.result = apg_exc.InterfaceError(
|
||||
'unsupported authentication method requested by the '
|
||||
'server: {!r}'.format(AUTH_METHOD_NAME[status]))
|
||||
|
||||
else:
|
||||
self.result_type = RESULT_FAILED
|
||||
self.result = apg_exc.InterfaceError(
|
||||
'unsupported authentication method requested by the '
|
||||
'server: {!r}'.format(AUTH_METHOD_NAME.get(status, status)))
|
||||
'server: {}'.format(status))
|
||||
|
||||
if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
|
||||
AUTH_REQUIRED_GSS_CONTINUE):
|
||||
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
|
||||
self.buffer.discard_message()
|
||||
|
||||
cdef _auth_password_message_cleartext(self):
|
||||
@@ -718,59 +691,6 @@ cdef class CoreProtocol:
|
||||
|
||||
return msg
|
||||
|
||||
cdef _auth_gss_init_gssapi(self):
|
||||
try:
|
||||
import gssapi
|
||||
except ModuleNotFoundError:
|
||||
raise apg_exc.InterfaceError(
|
||||
'gssapi module not found; please install asyncpg[gssauth] to '
|
||||
'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
|
||||
) from None
|
||||
|
||||
service_name, host = self._auth_gss_get_service()
|
||||
self.gss_ctx = gssapi.SecurityContext(
|
||||
name=gssapi.Name(
|
||||
f'{service_name}@{host}', gssapi.NameType.hostbased_service),
|
||||
usage='initiate')
|
||||
|
||||
cdef _auth_gss_init_sspi(self, bint negotiate):
|
||||
try:
|
||||
import sspilib
|
||||
except ModuleNotFoundError:
|
||||
raise apg_exc.InterfaceError(
|
||||
'sspilib module not found; please install asyncpg[gssauth] to '
|
||||
'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
|
||||
) from None
|
||||
|
||||
service_name, host = self._auth_gss_get_service()
|
||||
self.gss_ctx = sspilib.ClientSecurityContext(
|
||||
target_name=f'{service_name}/{host}',
|
||||
credential=sspilib.UserCredential(
|
||||
protocol='Negotiate' if negotiate else 'Kerberos'))
|
||||
|
||||
cdef _auth_gss_get_service(self):
|
||||
service_name = self.con_params.krbsrvname or 'postgres'
|
||||
if isinstance(self.address, str):
|
||||
raise apg_exc.InternalClientError(
|
||||
'GSSAPI/SSPI authentication is only supported for TCP/IP '
|
||||
'connections')
|
||||
|
||||
return service_name, self.address[0]
|
||||
|
||||
cdef _auth_gss_step(self, bytes server_response):
|
||||
cdef:
|
||||
WriteBuffer msg
|
||||
|
||||
token = self.gss_ctx.step(server_response)
|
||||
if not token:
|
||||
self.gss_ctx = None
|
||||
return None
|
||||
msg = WriteBuffer.new_message(b'p')
|
||||
msg.write_bytes(token)
|
||||
msg.end_message()
|
||||
|
||||
return msg
|
||||
|
||||
cdef _parse_msg_ready_for_query(self):
|
||||
cdef char status = self.buffer.read_byte()
|
||||
|
||||
@@ -1020,12 +940,12 @@ cdef class CoreProtocol:
|
||||
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
|
||||
|
||||
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
|
||||
object bind_data, bint return_rows):
|
||||
object bind_data):
|
||||
self._ensure_connected()
|
||||
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
|
||||
|
||||
self.result = [] if return_rows else None
|
||||
self._discard_data = not return_rows
|
||||
self.result = None
|
||||
self._discard_data = True
|
||||
self._execute_iter = bind_data
|
||||
self._execute_portal_name = portal_name
|
||||
self._execute_stmt_name = stmt_name
|
||||
|
||||
@@ -142,7 +142,7 @@ cdef class PreparedStatementState:
|
||||
# that the user tried to parametrize a statement that does
|
||||
# not support parameters.
|
||||
hint += (r' Note that parameters are supported only in'
|
||||
r' SELECT, INSERT, UPDATE, DELETE, MERGE and VALUES'
|
||||
r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
|
||||
r' statements, and will *not* work in statements '
|
||||
r' like CREATE VIEW or DECLARE CURSOR.')
|
||||
|
||||
|
||||
Binary file not shown.
@@ -31,6 +31,7 @@ cdef class BaseProtocol(CoreProtocol):
|
||||
|
||||
cdef:
|
||||
object loop
|
||||
object address
|
||||
ConnectionSettings settings
|
||||
object cancel_sent_waiter
|
||||
object cancel_waiter
|
||||
|
||||
@@ -1,300 +0,0 @@
|
||||
import asyncio
|
||||
import asyncio.protocols
|
||||
import hmac
|
||||
from codecs import CodecInfo
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
from hashlib import md5, sha256
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Final,
|
||||
Generic,
|
||||
Literal,
|
||||
NewType,
|
||||
TypeVar,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import asyncpg.pgproto.pgproto
|
||||
|
||||
from ..connect_utils import _ConnectionParameters
|
||||
from ..pgproto.pgproto import WriteBuffer
|
||||
from ..types import Attribute, Type
|
||||
|
||||
_T = TypeVar('_T')
|
||||
_Record = TypeVar('_Record', bound=Record)
|
||||
_OtherRecord = TypeVar('_OtherRecord', bound=Record)
|
||||
_PreparedStatementState = TypeVar(
|
||||
'_PreparedStatementState', bound=PreparedStatementState[Any]
|
||||
)
|
||||
|
||||
_NoTimeoutType = NewType('_NoTimeoutType', object)
|
||||
_TimeoutType: TypeAlias = float | None | _NoTimeoutType
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]]
|
||||
BUILTIN_TYPE_OID_MAP: Final[dict[int, str]]
|
||||
NO_TIMEOUT: Final[_NoTimeoutType]
|
||||
|
||||
hashlib_md5 = md5
|
||||
|
||||
@final
|
||||
class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext):
|
||||
__pyx_vtable__: Any
|
||||
def __init__(self, conn_key: object) -> None: ...
|
||||
def add_python_codec(
|
||||
self,
|
||||
typeoid: int,
|
||||
typename: str,
|
||||
typeschema: str,
|
||||
typeinfos: Iterable[object],
|
||||
typekind: str,
|
||||
encoder: Callable[[Any], Any],
|
||||
decoder: Callable[[Any], Any],
|
||||
format: object,
|
||||
) -> Any: ...
|
||||
def clear_type_cache(self) -> None: ...
|
||||
def get_data_codec(
|
||||
self, oid: int, format: object = ..., ignore_custom_codec: bool = ...
|
||||
) -> Any: ...
|
||||
def get_text_codec(self) -> CodecInfo: ...
|
||||
def register_data_types(self, types: Iterable[object]) -> None: ...
|
||||
def remove_python_codec(
|
||||
self, typeoid: int, typename: str, typeschema: str
|
||||
) -> None: ...
|
||||
def set_builtin_type_codec(
|
||||
self,
|
||||
typeoid: int,
|
||||
typename: str,
|
||||
typeschema: str,
|
||||
typekind: str,
|
||||
alias_to: str,
|
||||
format: object = ...,
|
||||
) -> Any: ...
|
||||
def __getattr__(self, name: str) -> Any: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
@final
|
||||
class PreparedStatementState(Generic[_Record]):
|
||||
closed: bool
|
||||
prepared: bool
|
||||
name: str
|
||||
query: str
|
||||
refs: int
|
||||
record_class: type[_Record]
|
||||
ignore_custom_codec: bool
|
||||
__pyx_vtable__: Any
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
query: str,
|
||||
protocol: BaseProtocol[Any],
|
||||
record_class: type[_Record],
|
||||
ignore_custom_codec: bool,
|
||||
) -> None: ...
|
||||
def _get_parameters(self) -> tuple[Type, ...]: ...
|
||||
def _get_attributes(self) -> tuple[Attribute, ...]: ...
|
||||
def _init_types(self) -> set[int]: ...
|
||||
def _init_codecs(self) -> None: ...
|
||||
def attach(self) -> None: ...
|
||||
def detach(self) -> None: ...
|
||||
def mark_closed(self) -> None: ...
|
||||
def mark_unprepared(self) -> None: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
class CoreProtocol:
|
||||
backend_pid: Any
|
||||
backend_secret: Any
|
||||
__pyx_vtable__: Any
|
||||
def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ...
|
||||
def is_in_transaction(self) -> bool: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
class BaseProtocol(CoreProtocol, Generic[_Record]):
|
||||
queries_count: Any
|
||||
is_ssl: bool
|
||||
__pyx_vtable__: Any
|
||||
def __init__(
|
||||
self,
|
||||
addr: object,
|
||||
connected_fut: object,
|
||||
con_params: _ConnectionParameters,
|
||||
record_class: type[_Record],
|
||||
loop: object,
|
||||
) -> None: ...
|
||||
def set_connection(self, connection: object) -> None: ...
|
||||
def get_server_pid(self, *args: object, **kwargs: object) -> int: ...
|
||||
def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ...
|
||||
def get_record_class(self) -> type[_Record]: ...
|
||||
def abort(self) -> None: ...
|
||||
async def bind(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Sequence[object],
|
||||
portal_name: str,
|
||||
timeout: _TimeoutType,
|
||||
) -> Any: ...
|
||||
@overload
|
||||
async def bind_execute(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Sequence[object],
|
||||
portal_name: str,
|
||||
limit: int,
|
||||
return_extra: Literal[False],
|
||||
timeout: _TimeoutType,
|
||||
) -> list[_OtherRecord]: ...
|
||||
@overload
|
||||
async def bind_execute(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Sequence[object],
|
||||
portal_name: str,
|
||||
limit: int,
|
||||
return_extra: Literal[True],
|
||||
timeout: _TimeoutType,
|
||||
) -> tuple[list[_OtherRecord], bytes, bool]: ...
|
||||
@overload
|
||||
async def bind_execute(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Sequence[object],
|
||||
portal_name: str,
|
||||
limit: int,
|
||||
return_extra: bool,
|
||||
timeout: _TimeoutType,
|
||||
) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ...
|
||||
async def bind_execute_many(
|
||||
self,
|
||||
state: PreparedStatementState[_OtherRecord],
|
||||
args: Iterable[Sequence[object]],
|
||||
portal_name: str,
|
||||
timeout: _TimeoutType,
|
||||
) -> None: ...
|
||||
async def close(self, timeout: _TimeoutType) -> None: ...
|
||||
def _get_timeout(self, timeout: _TimeoutType) -> float | None: ...
|
||||
def _is_cancelling(self) -> bool: ...
|
||||
async def _wait_for_cancellation(self) -> None: ...
|
||||
async def close_statement(
|
||||
self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType
|
||||
) -> Any: ...
|
||||
async def copy_in(self, *args: object, **kwargs: object) -> str: ...
|
||||
async def copy_out(self, *args: object, **kwargs: object) -> str: ...
|
||||
async def execute(self, *args: object, **kwargs: object) -> Any: ...
|
||||
def is_closed(self, *args: object, **kwargs: object) -> Any: ...
|
||||
def is_connected(self, *args: object, **kwargs: object) -> Any: ...
|
||||
def data_received(self, data: object) -> None: ...
|
||||
def connection_made(self, transport: object) -> None: ...
|
||||
def connection_lost(self, exc: Exception | None) -> None: ...
|
||||
def pause_writing(self, *args: object, **kwargs: object) -> Any: ...
|
||||
@overload
|
||||
async def prepare(
|
||||
self,
|
||||
stmt_name: str,
|
||||
query: str,
|
||||
timeout: float | None = ...,
|
||||
*,
|
||||
state: _PreparedStatementState,
|
||||
ignore_custom_codec: bool = ...,
|
||||
record_class: None,
|
||||
) -> _PreparedStatementState: ...
|
||||
@overload
|
||||
async def prepare(
|
||||
self,
|
||||
stmt_name: str,
|
||||
query: str,
|
||||
timeout: float | None = ...,
|
||||
*,
|
||||
state: None = ...,
|
||||
ignore_custom_codec: bool = ...,
|
||||
record_class: type[_OtherRecord],
|
||||
) -> PreparedStatementState[_OtherRecord]: ...
|
||||
async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ...
|
||||
async def query(self, *args: object, **kwargs: object) -> str: ...
|
||||
def resume_writing(self, *args: object, **kwargs: object) -> Any: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
@final
|
||||
class Codec:
|
||||
__pyx_vtable__: Any
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
class DataCodecConfig:
|
||||
__pyx_vtable__: Any
|
||||
def __init__(self) -> None: ...
|
||||
def add_python_codec(
|
||||
self,
|
||||
typeoid: int,
|
||||
typename: str,
|
||||
typeschema: str,
|
||||
typekind: str,
|
||||
typeinfos: Iterable[object],
|
||||
encoder: Callable[[ConnectionSettings, WriteBuffer, object], object],
|
||||
decoder: Callable[..., object],
|
||||
format: object,
|
||||
xformat: object,
|
||||
) -> Any: ...
|
||||
def add_types(self, types: Iterable[object]) -> Any: ...
|
||||
def clear_type_cache(self) -> None: ...
|
||||
def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ...
|
||||
def remove_python_codec(
|
||||
self, typeoid: int, typename: str, typeschema: str
|
||||
) -> Any: ...
|
||||
def set_builtin_type_codec(
|
||||
self,
|
||||
typeoid: int,
|
||||
typename: str,
|
||||
typeschema: str,
|
||||
typekind: str,
|
||||
alias_to: str,
|
||||
format: object = ...,
|
||||
) -> Any: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
|
||||
class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ...
|
||||
|
||||
class Record:
|
||||
@overload
|
||||
def get(self, key: str) -> Any | None: ...
|
||||
@overload
|
||||
def get(self, key: str, default: _T) -> Any | _T: ...
|
||||
def items(self) -> Iterator[tuple[str, Any]]: ...
|
||||
def keys(self) -> Iterator[str]: ...
|
||||
def values(self) -> Iterator[Any]: ...
|
||||
@overload
|
||||
def __getitem__(self, index: str) -> Any: ...
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> Any: ...
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> tuple[Any, ...]: ...
|
||||
def __iter__(self) -> Iterator[Any]: ...
|
||||
def __contains__(self, x: object) -> bool: ...
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
class Timer:
|
||||
def __init__(self, budget: float | None) -> None: ...
|
||||
def __enter__(self) -> None: ...
|
||||
def __exit__(self, et: object, e: object, tb: object) -> None: ...
|
||||
def get_remaining_budget(self) -> float: ...
|
||||
def has_budget_greater_than(self, amount: float) -> bool: ...
|
||||
|
||||
@final
|
||||
class SCRAMAuthentication:
|
||||
AUTHENTICATION_METHODS: ClassVar[list[str]]
|
||||
DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int]
|
||||
DIGEST = sha256
|
||||
REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]]
|
||||
REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]]
|
||||
SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]]
|
||||
authentication_method: bytes
|
||||
authorization_message: bytes | None
|
||||
client_channel_binding: bytes
|
||||
client_first_message_bare: bytes | None
|
||||
client_nonce: bytes | None
|
||||
client_proof: bytes | None
|
||||
password_salt: bytes | None
|
||||
password_iterations: int
|
||||
server_first_message: bytes | None
|
||||
server_key: hmac.HMAC | None
|
||||
server_nonce: bytes | None
|
||||
@@ -75,7 +75,7 @@ NO_TIMEOUT = object()
|
||||
cdef class BaseProtocol(CoreProtocol):
|
||||
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
|
||||
# type of `con_params` is `_ConnectionParameters`
|
||||
CoreProtocol.__init__(self, addr, con_params)
|
||||
CoreProtocol.__init__(self, con_params)
|
||||
|
||||
self.loop = loop
|
||||
self.transport = None
|
||||
@@ -83,7 +83,8 @@ cdef class BaseProtocol(CoreProtocol):
|
||||
self.cancel_waiter = None
|
||||
self.cancel_sent_waiter = None
|
||||
|
||||
self.settings = ConnectionSettings((addr, con_params.database))
|
||||
self.address = addr
|
||||
self.settings = ConnectionSettings((self.address, con_params.database))
|
||||
self.record_class = record_class
|
||||
|
||||
self.statement = None
|
||||
@@ -212,7 +213,6 @@ cdef class BaseProtocol(CoreProtocol):
|
||||
args,
|
||||
portal_name: str,
|
||||
timeout,
|
||||
return_rows: bool,
|
||||
):
|
||||
if self.cancel_waiter is not None:
|
||||
await self.cancel_waiter
|
||||
@@ -238,8 +238,7 @@ cdef class BaseProtocol(CoreProtocol):
|
||||
more = self._bind_execute_many(
|
||||
portal_name,
|
||||
state.name,
|
||||
arg_bufs,
|
||||
return_rows) # network op
|
||||
arg_bufs) # network op
|
||||
|
||||
self.last_query = state.query
|
||||
self.statement = state
|
||||
|
||||
@@ -11,12 +11,12 @@ from asyncpg import exceptions
|
||||
@cython.final
|
||||
cdef class ConnectionSettings(pgproto.CodecContext):
|
||||
|
||||
def __cinit__(self):
|
||||
def __cinit__(self, conn_key):
|
||||
self._encoding = 'utf-8'
|
||||
self._is_utf8 = True
|
||||
self._settings = {}
|
||||
self._codec = codecs.lookup('utf-8')
|
||||
self._data_codecs = DataCodecConfig()
|
||||
self._data_codecs = DataCodecConfig(conn_key)
|
||||
|
||||
cdef add_setting(self, str name, str val):
|
||||
self._settings[name] = val
|
||||
|
||||
Reference in New Issue
Block a user