Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-09-25 15:51:48 +09:00
parent dd7349bb4c
commit ddce9f5125
5586 changed files with 1470941 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
# flake8: NOQA
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP

View File

@@ -0,0 +1,875 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from collections.abc import (Iterable as IterableABC,
Mapping as MappingABC,
Sized as SizedABC)
from asyncpg import exceptions
DEF ARRAY_MAXDIM = 6 # defined in postgresql/src/includes/c.h
# "NULL"
cdef Py_UCS4 *APG_NULL = [0x004E, 0x0055, 0x004C, 0x004C, 0x0000]
ctypedef object (*encode_func_ex)(ConnectionSettings settings,
WriteBuffer buf,
object obj,
const void *arg)
ctypedef object (*decode_func_ex)(ConnectionSettings settings,
FRBuffer *buf,
const void *arg)
cdef inline bint _is_trivial_container(object obj):
return cpython.PyUnicode_Check(obj) or cpython.PyBytes_Check(obj) or \
cpythonx.PyByteArray_Check(obj) or cpythonx.PyMemoryView_Check(obj)
cdef inline _is_array_iterable(object obj):
return (
isinstance(obj, IterableABC) and
isinstance(obj, SizedABC) and
not _is_trivial_container(obj) and
not isinstance(obj, MappingABC)
)
cdef inline _is_sub_array_iterable(object obj):
# Sub-arrays have a specialized check, because we treat
# nested tuples as records.
return _is_array_iterable(obj) and not cpython.PyTuple_Check(obj)
cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims):
cdef:
ssize_t mylen = len(obj)
ssize_t elemlen = -2
object it
if mylen > _MAXINT32:
raise ValueError('too many elements in array value')
if ndims[0] > ARRAY_MAXDIM:
raise ValueError(
'number of array dimensions ({}) exceed the maximum expected ({})'.
format(ndims[0], ARRAY_MAXDIM))
dims[ndims[0] - 1] = <int32_t>mylen
for elem in obj:
if _is_sub_array_iterable(elem):
if elemlen == -2:
elemlen = len(elem)
if elemlen > _MAXINT32:
raise ValueError('too many elements in array value')
ndims[0] += 1
_get_array_shape(elem, dims, ndims)
else:
if len(elem) != elemlen:
raise ValueError('non-homogeneous array')
else:
if elemlen >= 0:
raise ValueError('non-homogeneous array')
else:
elemlen = -1
cdef _write_array_data(ConnectionSettings settings, object obj, int32_t ndims,
int32_t dim, WriteBuffer elem_data,
encode_func_ex encoder, const void *encoder_arg):
if dim < ndims - 1:
for item in obj:
_write_array_data(settings, item, ndims, dim + 1, elem_data,
encoder, encoder_arg)
else:
for item in obj:
if item is None:
elem_data.write_int32(-1)
else:
try:
encoder(settings, elem_data, item, encoder_arg)
except TypeError as e:
raise ValueError(
'invalid array element: {}'.format(e.args[0])) from None
cdef inline array_encode(ConnectionSettings settings, WriteBuffer buf,
object obj, uint32_t elem_oid,
encode_func_ex encoder, const void *encoder_arg):
cdef:
WriteBuffer elem_data
int32_t dims[ARRAY_MAXDIM]
int32_t ndims = 1
int32_t i
if not _is_array_iterable(obj):
raise TypeError(
'a sized iterable container expected (got type {!r})'.format(
type(obj).__name__))
_get_array_shape(obj, dims, &ndims)
elem_data = WriteBuffer.new()
if ndims > 1:
_write_array_data(settings, obj, ndims, 0, elem_data,
encoder, encoder_arg)
else:
for i, item in enumerate(obj):
if item is None:
elem_data.write_int32(-1)
else:
try:
encoder(settings, elem_data, item, encoder_arg)
except TypeError as e:
raise ValueError(
'invalid array element at index {}: {}'.format(
i, e.args[0])) from None
buf.write_int32(12 + 8 * ndims + elem_data.len())
# Number of dimensions
buf.write_int32(ndims)
# flags
buf.write_int32(0)
# element type
buf.write_int32(<int32_t>elem_oid)
# upper / lower bounds
for i in range(ndims):
buf.write_int32(dims[i])
buf.write_int32(1)
# element data
buf.write_buffer(elem_data)
cdef _write_textarray_data(ConnectionSettings settings, object obj,
int32_t ndims, int32_t dim, WriteBuffer array_data,
encode_func_ex encoder, const void *encoder_arg,
Py_UCS4 typdelim):
cdef:
ssize_t i = 0
int8_t delim = <int8_t>typdelim
WriteBuffer elem_data
Py_buffer pybuf
const char *elem_str
char ch
ssize_t elem_len
ssize_t quoted_elem_len
bint need_quoting
array_data.write_byte(b'{')
if dim < ndims - 1:
for item in obj:
if i > 0:
array_data.write_byte(delim)
array_data.write_byte(b' ')
_write_textarray_data(settings, item, ndims, dim + 1, array_data,
encoder, encoder_arg, typdelim)
i += 1
else:
for item in obj:
elem_data = WriteBuffer.new()
if i > 0:
array_data.write_byte(delim)
array_data.write_byte(b' ')
if item is None:
array_data.write_bytes(b'NULL')
i += 1
continue
else:
try:
encoder(settings, elem_data, item, encoder_arg)
except TypeError as e:
raise ValueError(
'invalid array element: {}'.format(
e.args[0])) from None
# element string length (first four bytes are the encoded length.)
elem_len = elem_data.len() - 4
if elem_len == 0:
# Empty string
array_data.write_bytes(b'""')
else:
cpython.PyObject_GetBuffer(
elem_data, &pybuf, cpython.PyBUF_SIMPLE)
elem_str = <const char*>(pybuf.buf) + 4
try:
if not apg_strcasecmp_char(elem_str, b'NULL'):
array_data.write_byte(b'"')
array_data.write_cstr(elem_str, 4)
array_data.write_byte(b'"')
else:
quoted_elem_len = elem_len
need_quoting = False
for i in range(elem_len):
ch = elem_str[i]
if ch == b'"' or ch == b'\\':
# Quotes and backslashes need escaping.
quoted_elem_len += 1
need_quoting = True
elif (ch == b'{' or ch == b'}' or ch == delim or
apg_ascii_isspace(<uint32_t>ch)):
need_quoting = True
if need_quoting:
array_data.write_byte(b'"')
if quoted_elem_len == elem_len:
array_data.write_cstr(elem_str, elem_len)
else:
# Escaping required.
for i in range(elem_len):
ch = elem_str[i]
if ch == b'"' or ch == b'\\':
array_data.write_byte(b'\\')
array_data.write_byte(ch)
array_data.write_byte(b'"')
else:
array_data.write_cstr(elem_str, elem_len)
finally:
cpython.PyBuffer_Release(&pybuf)
i += 1
array_data.write_byte(b'}')
cdef inline textarray_encode(ConnectionSettings settings, WriteBuffer buf,
object obj, encode_func_ex encoder,
const void *encoder_arg, Py_UCS4 typdelim):
cdef:
WriteBuffer array_data
int32_t dims[ARRAY_MAXDIM]
int32_t ndims = 1
int32_t i
if not _is_array_iterable(obj):
raise TypeError(
'a sized iterable container expected (got type {!r})'.format(
type(obj).__name__))
_get_array_shape(obj, dims, &ndims)
array_data = WriteBuffer.new()
_write_textarray_data(settings, obj, ndims, 0, array_data,
encoder, encoder_arg, typdelim)
buf.write_int32(array_data.len())
buf.write_buffer(array_data)
cdef inline array_decode(ConnectionSettings settings, FRBuffer *buf,
decode_func_ex decoder, const void *decoder_arg):
cdef:
int32_t ndims = hton.unpack_int32(frb_read(buf, 4))
int32_t flags = hton.unpack_int32(frb_read(buf, 4))
uint32_t elem_oid = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
list result
int i
int32_t elem_len
int32_t elem_count = 1
FRBuffer elem_buf
int32_t dims[ARRAY_MAXDIM]
Codec elem_codec
if ndims == 0:
return []
if ndims > ARRAY_MAXDIM:
raise exceptions.ProtocolError(
'number of array dimensions ({}) exceed the maximum expected ({})'.
format(ndims, ARRAY_MAXDIM))
elif ndims < 0:
raise exceptions.ProtocolError(
'unexpected array dimensions value: {}'.format(ndims))
for i in range(ndims):
dims[i] = hton.unpack_int32(frb_read(buf, 4))
if dims[i] < 0:
raise exceptions.ProtocolError(
'unexpected array dimension size: {}'.format(dims[i]))
# Ignore the lower bound information
frb_read(buf, 4)
if ndims == 1:
# Fast path for flat arrays
elem_count = dims[0]
result = cpython.PyList_New(elem_count)
for i in range(elem_count):
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
elem = None
else:
frb_slice_from(&elem_buf, buf, elem_len)
elem = decoder(settings, &elem_buf, decoder_arg)
cpython.Py_INCREF(elem)
cpython.PyList_SET_ITEM(result, i, elem)
else:
result = _nested_array_decode(settings, buf,
decoder, decoder_arg, ndims, dims,
&elem_buf)
return result
cdef _nested_array_decode(ConnectionSettings settings,
FRBuffer *buf,
decode_func_ex decoder,
const void *decoder_arg,
int32_t ndims, int32_t *dims,
FRBuffer *elem_buf):
cdef:
int32_t elem_len
int64_t i, j
int64_t array_len = 1
object elem, stride
# An array of pointers to lists for each current array level.
void *strides[ARRAY_MAXDIM]
# An array of current positions at each array level.
int32_t indexes[ARRAY_MAXDIM]
for i in range(ndims):
array_len *= dims[i]
indexes[i] = 0
strides[i] = NULL
if array_len == 0:
# A multidimensional array with a zero-sized dimension?
return []
elif array_len < 0:
# Array length overflow
raise exceptions.ProtocolError('array length overflow')
for i in range(array_len):
# Decode the element.
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
elem = None
else:
elem = decoder(settings,
frb_slice_from(elem_buf, buf, elem_len),
decoder_arg)
# Take an explicit reference for PyList_SET_ITEM in the below
# loop expects this.
cpython.Py_INCREF(elem)
# Iterate over array dimentions and put the element in
# the correctly nested sublist.
for j in reversed(range(ndims)):
if indexes[j] == 0:
# Allocate the list for this array level.
stride = cpython.PyList_New(dims[j])
strides[j] = <void*><cpython.PyObject>stride
# Take an explicit reference for PyList_SET_ITEM below
# expects this.
cpython.Py_INCREF(stride)
stride = <object><cpython.PyObject*>strides[j]
cpython.PyList_SET_ITEM(stride, indexes[j], elem)
indexes[j] += 1
if indexes[j] == dims[j] and j != 0:
# This array level is full, continue the
# ascent in the dimensions so that this level
# sublist will be appened to the parent list.
elem = stride
# Reset the index, this will cause the
# new list to be allocated on the next
# iteration on this array axis.
indexes[j] = 0
else:
break
stride = <object><cpython.PyObject*>strides[0]
# Since each element in strides has a refcount of 1,
# returning strides[0] will increment it to 2, so
# balance that.
cpython.Py_DECREF(stride)
return stride
cdef textarray_decode(ConnectionSettings settings, FRBuffer *buf,
decode_func_ex decoder, const void *decoder_arg,
Py_UCS4 typdelim):
cdef:
Py_UCS4 *array_text
str s
# Make a copy of array data since we will be mutating it for
# the purposes of element decoding.
s = pgproto.text_decode(settings, buf)
array_text = cpythonx.PyUnicode_AsUCS4Copy(s)
try:
return _textarray_decode(
settings, array_text, decoder, decoder_arg, typdelim)
except ValueError as e:
raise exceptions.ProtocolError(
'malformed array literal {!r}: {}'.format(s, e.args[0]))
finally:
cpython.PyMem_Free(array_text)
cdef _textarray_decode(ConnectionSettings settings,
Py_UCS4 *array_text,
decode_func_ex decoder,
const void *decoder_arg,
Py_UCS4 typdelim):
cdef:
bytearray array_bytes
list result
list new_stride
Py_UCS4 *ptr
int32_t ndims = 0
int32_t ubound = 0
int32_t lbound = 0
int32_t dims[ARRAY_MAXDIM]
int32_t inferred_dims[ARRAY_MAXDIM]
int32_t inferred_ndims = 0
void *strides[ARRAY_MAXDIM]
int32_t indexes[ARRAY_MAXDIM]
int32_t nest_level = 0
int32_t item_level = 0
bint end_of_array = False
bint end_of_item = False
bint has_quoting = False
bint strip_spaces = False
bint in_quotes = False
Py_UCS4 *item_start
Py_UCS4 *item_ptr
Py_UCS4 *item_end
int i
object item
str item_text
FRBuffer item_buf
char *pg_item_str
ssize_t pg_item_len
ptr = array_text
while True:
while apg_ascii_isspace(ptr[0]):
ptr += 1
if ptr[0] != '[':
# Finished parsing dimensions spec.
break
ptr += 1 # '['
if ndims > ARRAY_MAXDIM:
raise ValueError(
'number of array dimensions ({}) exceed the '
'maximum expected ({})'.format(ndims, ARRAY_MAXDIM))
ptr = apg_parse_int32(ptr, &ubound)
if ptr == NULL:
raise ValueError('missing array dimension value')
if ptr[0] == ':':
ptr += 1
lbound = ubound
# [lower:upper] spec. We disregard the lbound for decoding.
ptr = apg_parse_int32(ptr, &ubound)
if ptr == NULL:
raise ValueError('missing array dimension value')
else:
lbound = 1
if ptr[0] != ']':
raise ValueError('missing \']\' after array dimensions')
ptr += 1 # ']'
dims[ndims] = ubound - lbound + 1
ndims += 1
if ndims != 0:
# If dimensions were given, the '=' token is expected.
if ptr[0] != '=':
raise ValueError('missing \'=\' after array dimensions')
ptr += 1 # '='
# Skip any whitespace after the '=', whitespace
# before was consumed in the above loop.
while apg_ascii_isspace(ptr[0]):
ptr += 1
# Infer the dimensions from the brace structure in the
# array literal body, and check that it matches the explicit
# spec. This also validates that the array literal is sane.
_infer_array_dims(ptr, typdelim, inferred_dims, &inferred_ndims)
if inferred_ndims != ndims:
raise ValueError(
'specified array dimensions do not match array content')
for i in range(ndims):
if inferred_dims[i] != dims[i]:
raise ValueError(
'specified array dimensions do not match array content')
else:
# Infer the dimensions from the brace structure in the array literal
# body. This also validates that the array literal is sane.
_infer_array_dims(ptr, typdelim, dims, &ndims)
while not end_of_array:
# We iterate over the literal character by character
# and modify the string in-place removing the array-specific
# quoting and determining the boundaries of each element.
end_of_item = has_quoting = in_quotes = False
strip_spaces = True
# Pointers to array element start, end, and the current pointer
# tracking the position where characters are written when
# escaping is folded.
item_start = item_end = item_ptr = ptr
item_level = 0
while not end_of_item:
if ptr[0] == '"':
in_quotes = not in_quotes
if in_quotes:
strip_spaces = False
else:
item_end = item_ptr
has_quoting = True
elif ptr[0] == '\\':
# Quoted character, collapse the backslash.
ptr += 1
has_quoting = True
item_ptr[0] = ptr[0]
item_ptr += 1
strip_spaces = False
item_end = item_ptr
elif in_quotes:
# Consume the string until we see the closing quote.
item_ptr[0] = ptr[0]
item_ptr += 1
elif ptr[0] == '{':
# Nesting level increase.
nest_level += 1
indexes[nest_level - 1] = 0
new_stride = cpython.PyList_New(dims[nest_level - 1])
strides[nest_level - 1] = \
<void*>(<cpython.PyObject>new_stride)
if nest_level > 1:
cpython.Py_INCREF(new_stride)
cpython.PyList_SET_ITEM(
<object><cpython.PyObject*>strides[nest_level - 2],
indexes[nest_level - 2],
new_stride)
else:
result = new_stride
elif ptr[0] == '}':
if item_level == 0:
# Make sure we keep track of which nesting
# level the item belongs to, as the loop
# will continue to consume closing braces
# until the delimiter or the end of input.
item_level = nest_level
nest_level -= 1
if nest_level == 0:
end_of_array = end_of_item = True
elif ptr[0] == typdelim:
# Array element delimiter,
end_of_item = True
if item_level == 0:
item_level = nest_level
elif apg_ascii_isspace(ptr[0]):
if not strip_spaces:
item_ptr[0] = ptr[0]
item_ptr += 1
# Ignore the leading literal whitespace.
else:
item_ptr[0] = ptr[0]
item_ptr += 1
strip_spaces = False
item_end = item_ptr
ptr += 1
# end while not end_of_item
if item_end == item_start:
# Empty array
continue
item_end[0] = '\0'
if not has_quoting and apg_strcasecmp(item_start, APG_NULL) == 0:
# NULL element.
item = None
else:
# XXX: find a way to avoid the redundant encode/decode
# cycle here.
item_text = cpythonx.PyUnicode_FromKindAndData(
cpythonx.PyUnicode_4BYTE_KIND,
<void *>item_start,
item_end - item_start)
# Prepare the element buffer and call the text decoder
# for the element type.
pgproto.as_pg_string_and_size(
settings, item_text, &pg_item_str, &pg_item_len)
frb_init(&item_buf, pg_item_str, pg_item_len)
item = decoder(settings, &item_buf, decoder_arg)
# Place the decoded element in the array.
cpython.Py_INCREF(item)
cpython.PyList_SET_ITEM(
<object><cpython.PyObject*>strides[item_level - 1],
indexes[item_level - 1],
item)
if nest_level > 0:
indexes[nest_level - 1] += 1
return result
cdef enum _ArrayParseState:
APS_START = 1
APS_STRIDE_STARTED = 2
APS_STRIDE_DONE = 3
APS_STRIDE_DELIMITED = 4
APS_ELEM_STARTED = 5
APS_ELEM_DELIMITED = 6
cdef _UnexpectedCharacter(const Py_UCS4 *array_text, const Py_UCS4 *ptr):
return ValueError('unexpected character {!r} at position {}'.format(
cpython.PyUnicode_FromOrdinal(<int>ptr[0]), ptr - array_text + 1))
cdef _infer_array_dims(const Py_UCS4 *array_text,
Py_UCS4 typdelim,
int32_t *dims,
int32_t *ndims):
cdef:
const Py_UCS4 *ptr = array_text
int i
int nest_level = 0
bint end_of_array = False
bint end_of_item = False
bint in_quotes = False
bint array_is_empty = True
int stride_len[ARRAY_MAXDIM]
int prev_stride_len[ARRAY_MAXDIM]
_ArrayParseState parse_state = APS_START
for i in range(ARRAY_MAXDIM):
dims[i] = prev_stride_len[i] = 0
stride_len[i] = 1
while not end_of_array:
end_of_item = False
while not end_of_item:
if ptr[0] == '\0':
raise ValueError('unexpected end of string')
elif ptr[0] == '"':
if (parse_state not in (APS_STRIDE_STARTED,
APS_ELEM_DELIMITED) and
not (parse_state == APS_ELEM_STARTED and in_quotes)):
raise _UnexpectedCharacter(array_text, ptr)
in_quotes = not in_quotes
if in_quotes:
parse_state = APS_ELEM_STARTED
array_is_empty = False
elif ptr[0] == '\\':
if parse_state not in (APS_STRIDE_STARTED,
APS_ELEM_STARTED,
APS_ELEM_DELIMITED):
raise _UnexpectedCharacter(array_text, ptr)
parse_state = APS_ELEM_STARTED
array_is_empty = False
if ptr[1] != '\0':
ptr += 1
else:
raise ValueError('unexpected end of string')
elif in_quotes:
# Ignore everything inside the quotes.
pass
elif ptr[0] == '{':
if parse_state not in (APS_START,
APS_STRIDE_STARTED,
APS_STRIDE_DELIMITED):
raise _UnexpectedCharacter(array_text, ptr)
parse_state = APS_STRIDE_STARTED
if nest_level >= ARRAY_MAXDIM:
raise ValueError(
'number of array dimensions ({}) exceed the '
'maximum expected ({})'.format(
nest_level, ARRAY_MAXDIM))
dims[nest_level] = 0
nest_level += 1
if ndims[0] < nest_level:
ndims[0] = nest_level
elif ptr[0] == '}':
if (parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE) and
not (nest_level == 1 and
parse_state == APS_STRIDE_STARTED)):
raise _UnexpectedCharacter(array_text, ptr)
parse_state = APS_STRIDE_DONE
if nest_level == 0:
raise _UnexpectedCharacter(array_text, ptr)
nest_level -= 1
if (prev_stride_len[nest_level] != 0 and
stride_len[nest_level] != prev_stride_len[nest_level]):
raise ValueError(
'inconsistent sub-array dimensions'
' at position {}'.format(
ptr - array_text + 1))
prev_stride_len[nest_level] = stride_len[nest_level]
stride_len[nest_level] = 1
if nest_level == 0:
end_of_array = end_of_item = True
else:
dims[nest_level - 1] += 1
elif ptr[0] == typdelim:
if parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE):
raise _UnexpectedCharacter(array_text, ptr)
if parse_state == APS_STRIDE_DONE:
parse_state = APS_STRIDE_DELIMITED
else:
parse_state = APS_ELEM_DELIMITED
end_of_item = True
stride_len[nest_level - 1] += 1
elif not apg_ascii_isspace(ptr[0]):
if parse_state not in (APS_STRIDE_STARTED,
APS_ELEM_STARTED,
APS_ELEM_DELIMITED):
raise _UnexpectedCharacter(array_text, ptr)
parse_state = APS_ELEM_STARTED
array_is_empty = False
if not end_of_item:
ptr += 1
if not array_is_empty:
dims[ndims[0] - 1] += 1
ptr += 1
# only whitespace is allowed after the closing brace
while ptr[0] != '\0':
if not apg_ascii_isspace(ptr[0]):
raise _UnexpectedCharacter(array_text, ptr)
ptr += 1
if array_is_empty:
ndims[0] = 0
cdef uint4_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj,
const void *arg):
return pgproto.uint4_encode(settings, buf, obj)
cdef uint4_decode_ex(ConnectionSettings settings, FRBuffer *buf,
const void *arg):
return pgproto.uint4_decode(settings, buf)
cdef arrayoid_encode(ConnectionSettings settings, WriteBuffer buf, items):
array_encode(settings, buf, items, OIDOID,
<encode_func_ex>&uint4_encode_ex, NULL)
cdef arrayoid_decode(ConnectionSettings settings, FRBuffer *buf):
return array_decode(settings, buf, <decode_func_ex>&uint4_decode_ex, NULL)
cdef text_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj,
const void *arg):
return pgproto.text_encode(settings, buf, obj)
cdef text_decode_ex(ConnectionSettings settings, FRBuffer *buf,
const void *arg):
return pgproto.text_decode(settings, buf)
cdef arraytext_encode(ConnectionSettings settings, WriteBuffer buf, items):
array_encode(settings, buf, items, TEXTOID,
<encode_func_ex>&text_encode_ex, NULL)
cdef arraytext_decode(ConnectionSettings settings, FRBuffer *buf):
return array_decode(settings, buf, <decode_func_ex>&text_decode_ex, NULL)
cdef init_array_codecs():
# oid[] and text[] are registered as core codecs
# to make type introspection query work
#
register_core_codec(_OIDOID,
<encode_func>&arrayoid_encode,
<decode_func>&arrayoid_decode,
PG_FORMAT_BINARY)
register_core_codec(_TEXTOID,
<encode_func>&arraytext_encode,
<decode_func>&arraytext_decode,
PG_FORMAT_BINARY)
init_array_codecs()

View File

@@ -0,0 +1,187 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
ctypedef object (*encode_func)(ConnectionSettings settings,
WriteBuffer buf,
object obj)
ctypedef object (*decode_func)(ConnectionSettings settings,
FRBuffer *buf)
ctypedef object (*codec_encode_func)(Codec codec,
ConnectionSettings settings,
WriteBuffer buf,
object obj)
ctypedef object (*codec_decode_func)(Codec codec,
ConnectionSettings settings,
FRBuffer *buf)
cdef enum CodecType:
CODEC_UNDEFINED = 0
CODEC_C = 1
CODEC_PY = 2
CODEC_ARRAY = 3
CODEC_COMPOSITE = 4
CODEC_RANGE = 5
CODEC_MULTIRANGE = 6
cdef enum ServerDataFormat:
PG_FORMAT_ANY = -1
PG_FORMAT_TEXT = 0
PG_FORMAT_BINARY = 1
cdef enum ClientExchangeFormat:
PG_XFORMAT_OBJECT = 1
PG_XFORMAT_TUPLE = 2
cdef class Codec:
cdef:
uint32_t oid
str name
str schema
str kind
CodecType type
ServerDataFormat format
ClientExchangeFormat xformat
encode_func c_encoder
decode_func c_decoder
Codec base_codec
object py_encoder
object py_decoder
# arrays
Codec element_codec
Py_UCS4 element_delimiter
# composite types
tuple element_type_oids
object element_names
object record_desc
list element_codecs
# Pointers to actual encoder/decoder functions for this codec
codec_encode_func encoder
codec_decode_func decoder
cdef init(self, str name, str schema, str kind,
CodecType type, ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder, decode_func c_decoder,
Codec base_codec,
object py_encoder, object py_decoder,
Codec element_codec, tuple element_type_oids,
object element_names, list element_codecs,
Py_UCS4 element_delimiter)
cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_array_text(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_composite(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_in_python(self, ConnectionSettings settings, FRBuffer *buf)
cdef inline encode(self,
ConnectionSettings settings,
WriteBuffer buf,
object obj)
cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf)
cdef has_encoder(self)
cdef has_decoder(self)
cdef is_binary(self)
cdef inline Codec copy(self)
@staticmethod
cdef Codec new_array_codec(uint32_t oid,
str name,
str schema,
Codec element_codec,
Py_UCS4 element_delimiter)
@staticmethod
cdef Codec new_range_codec(uint32_t oid,
str name,
str schema,
Codec element_codec)
@staticmethod
cdef Codec new_multirange_codec(uint32_t oid,
str name,
str schema,
Codec element_codec)
@staticmethod
cdef Codec new_composite_codec(uint32_t oid,
str name,
str schema,
ServerDataFormat format,
list element_codecs,
tuple element_type_oids,
object element_names)
@staticmethod
cdef Codec new_python_codec(uint32_t oid,
str name,
str schema,
str kind,
object encoder,
object decoder,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat)
cdef class DataCodecConfig:
cdef:
dict _derived_type_codecs
dict _custom_type_codecs
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint ignore_custom_codec=*)
cdef inline Codec get_custom_codec(self, uint32_t oid,
ServerDataFormat format)

View File

@@ -0,0 +1,895 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from collections.abc import Mapping as MappingABC
import asyncpg
from asyncpg import exceptions
cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2]
cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2]
cdef dict EXTRA_CODECS = {}
@cython.final
cdef class Codec:
def __cinit__(self, uint32_t oid):
self.oid = oid
self.type = CODEC_UNDEFINED
cdef init(
self,
str name,
str schema,
str kind,
CodecType type,
ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
object py_encoder,
object py_decoder,
Codec element_codec,
tuple element_type_oids,
object element_names,
list element_codecs,
Py_UCS4 element_delimiter,
):
self.name = name
self.schema = schema
self.kind = kind
self.type = type
self.format = format
self.xformat = xformat
self.c_encoder = c_encoder
self.c_decoder = c_decoder
self.base_codec = base_codec
self.py_encoder = py_encoder
self.py_decoder = py_decoder
self.element_codec = element_codec
self.element_type_oids = element_type_oids
self.element_codecs = element_codecs
self.element_delimiter = element_delimiter
self.element_names = element_names
if base_codec is not None:
if c_encoder != NULL or c_decoder != NULL:
raise exceptions.InternalClientError(
'base_codec is mutually exclusive with c_encoder/c_decoder'
)
if element_names is not None:
self.record_desc = record.ApgRecordDesc_New(
element_names, tuple(element_names))
else:
self.record_desc = None
if type == CODEC_C:
self.encoder = <codec_encode_func>&self.encode_scalar
self.decoder = <codec_decode_func>&self.decode_scalar
elif type == CODEC_ARRAY:
if format == PG_FORMAT_BINARY:
self.encoder = <codec_encode_func>&self.encode_array
self.decoder = <codec_decode_func>&self.decode_array
else:
self.encoder = <codec_encode_func>&self.encode_array_text
self.decoder = <codec_decode_func>&self.decode_array_text
elif type == CODEC_RANGE:
if format != PG_FORMAT_BINARY:
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'range types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_range
self.decoder = <codec_decode_func>&self.decode_range
elif type == CODEC_MULTIRANGE:
if format != PG_FORMAT_BINARY:
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'range types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_multirange
self.decoder = <codec_decode_func>&self.decode_multirange
elif type == CODEC_COMPOSITE:
if format != PG_FORMAT_BINARY:
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'composite types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_composite
self.decoder = <codec_decode_func>&self.decode_composite
elif type == CODEC_PY:
self.encoder = <codec_encode_func>&self.encode_in_python
self.decoder = <codec_decode_func>&self.decode_in_python
else:
raise exceptions.InternalClientError(
'unexpected codec type: {}'.format(type))
cdef Codec copy(self):
cdef Codec codec
codec = Codec(self.oid)
codec.init(self.name, self.schema, self.kind,
self.type, self.format, self.xformat,
self.c_encoder, self.c_decoder, self.base_codec,
self.py_encoder, self.py_decoder,
self.element_codec,
self.element_type_oids, self.element_names,
self.element_codecs, self.element_delimiter)
return codec
cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
self.c_encoder(settings, buf, obj)
cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
array_encode(settings, buf, obj, self.element_codec.oid,
codec_encode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
return textarray_encode(settings, buf, obj,
codec_encode_func_ex,
<void*>(<cpython.PyObject>self.element_codec),
self.element_delimiter)
cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
range_encode(settings, buf, obj, self.element_codec.oid,
codec_encode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
multirange_encode(settings, buf, obj, self.element_codec.oid,
codec_encode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
cdef:
WriteBuffer elem_data
int i
list elem_codecs = self.element_codecs
ssize_t count
ssize_t composite_size
tuple rec
if isinstance(obj, MappingABC):
# Input is dict-like, form a tuple
composite_size = len(self.element_type_oids)
rec = cpython.PyTuple_New(composite_size)
for i in range(composite_size):
cpython.Py_INCREF(None)
cpython.PyTuple_SET_ITEM(rec, i, None)
for field in obj:
try:
i = self.element_names[field]
except KeyError:
raise ValueError(
'{!r} is not a valid element of composite '
'type {}'.format(field, self.name)) from None
item = obj[field]
cpython.Py_INCREF(item)
cpython.PyTuple_SET_ITEM(rec, i, item)
obj = rec
count = len(obj)
if count > _MAXINT32:
raise ValueError('too many elements in composite type record')
elem_data = WriteBuffer.new()
i = 0
for item in obj:
elem_data.write_int32(<int32_t>self.element_type_oids[i])
if item is None:
elem_data.write_int32(-1)
else:
(<Codec>elem_codecs[i]).encode(settings, elem_data, item)
i += 1
record_encode_frame(settings, buf, elem_data, <int32_t>count)
cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
data = self.py_encoder(obj)
if self.xformat == PG_XFORMAT_OBJECT:
if self.format == PG_FORMAT_BINARY:
pgproto.bytea_encode(settings, buf, data)
elif self.format == PG_FORMAT_TEXT:
pgproto.text_encode(settings, buf, data)
else:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
if self.base_codec is not None:
self.base_codec.encode(settings, buf, data)
else:
self.c_encoder(settings, buf, data)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
cdef encode(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
return self.encoder(self, settings, buf, obj)
cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf):
return self.c_decoder(settings, buf)
cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf):
return array_decode(settings, buf, codec_decode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef decode_array_text(self, ConnectionSettings settings,
FRBuffer *buf):
return textarray_decode(settings, buf, codec_decode_func_ex,
<void*>(<cpython.PyObject>self.element_codec),
self.element_delimiter)
cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf):
return range_decode(settings, buf, codec_decode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf):
return multirange_decode(settings, buf, codec_decode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef decode_composite(self, ConnectionSettings settings,
FRBuffer *buf):
cdef:
object result
ssize_t elem_count
ssize_t i
int32_t elem_len
uint32_t elem_typ
uint32_t received_elem_typ
Codec elem_codec
FRBuffer elem_buf
elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4))
if elem_count != len(self.element_type_oids):
raise exceptions.OutdatedSchemaCacheError(
'unexpected number of attributes of composite type: '
'{}, expected {}'
.format(
elem_count,
len(self.element_type_oids),
),
schema=self.schema,
data_type=self.name,
)
result = record.ApgRecord_New(asyncpg.Record, self.record_desc, elem_count)
for i in range(elem_count):
elem_typ = self.element_type_oids[i]
received_elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
if received_elem_typ != elem_typ:
raise exceptions.OutdatedSchemaCacheError(
'unexpected data type of composite type attribute {}: '
'{!r}, expected {!r}'
.format(
i,
BUILTIN_TYPE_OID_MAP.get(
received_elem_typ, received_elem_typ),
BUILTIN_TYPE_OID_MAP.get(
elem_typ, elem_typ)
),
schema=self.schema,
data_type=self.name,
position=i,
)
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
elem = None
else:
elem_codec = self.element_codecs[i]
elem = elem_codec.decode(
settings, frb_slice_from(&elem_buf, buf, elem_len))
cpython.Py_INCREF(elem)
record.ApgRecord_SET_ITEM(result, i, elem)
return result
cdef decode_in_python(self, ConnectionSettings settings,
FRBuffer *buf):
if self.xformat == PG_XFORMAT_OBJECT:
if self.format == PG_FORMAT_BINARY:
data = pgproto.bytea_decode(settings, buf)
elif self.format == PG_FORMAT_TEXT:
data = pgproto.text_decode(settings, buf)
else:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
if self.base_codec is not None:
data = self.base_codec.decode(settings, buf)
else:
data = self.c_decoder(settings, buf)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
return self.py_decoder(data)
cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf):
return self.decoder(self, settings, buf)
cdef inline has_encoder(self):
cdef Codec elem_codec
if self.c_encoder is not NULL or self.py_encoder is not None:
return True
elif (
self.type == CODEC_ARRAY
or self.type == CODEC_RANGE
or self.type == CODEC_MULTIRANGE
):
return self.element_codec.has_encoder()
elif self.type == CODEC_COMPOSITE:
for elem_codec in self.element_codecs:
if not elem_codec.has_encoder():
return False
return True
else:
return False
cdef has_decoder(self):
cdef Codec elem_codec
if self.c_decoder is not NULL or self.py_decoder is not None:
return True
elif (
self.type == CODEC_ARRAY
or self.type == CODEC_RANGE
or self.type == CODEC_MULTIRANGE
):
return self.element_codec.has_decoder()
elif self.type == CODEC_COMPOSITE:
for elem_codec in self.element_codecs:
if not elem_codec.has_decoder():
return False
return True
else:
return False
cdef is_binary(self):
return self.format == PG_FORMAT_BINARY
def __repr__(self):
return '<Codec oid={} elem_oid={} core={}>'.format(
self.oid,
'NA' if self.element_codec is None else self.element_codec.oid,
has_core_codec(self.oid))
@staticmethod
cdef Codec new_array_codec(uint32_t oid,
str name,
str schema,
Codec element_codec,
Py_UCS4 element_delimiter):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format,
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_codec, None, None, None, element_delimiter)
return codec
@staticmethod
cdef Codec new_range_codec(uint32_t oid,
str name,
str schema,
Codec element_codec):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format,
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_codec, None, None, None, 0)
return codec
@staticmethod
cdef Codec new_multirange_codec(uint32_t oid,
str name,
str schema,
Codec element_codec):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'multirange', CODEC_MULTIRANGE,
element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None,
None, None, element_codec, None, None, None, 0)
return codec
@staticmethod
cdef Codec new_composite_codec(uint32_t oid,
str name,
str schema,
ServerDataFormat format,
list element_codecs,
tuple element_type_oids,
object element_names):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'composite', CODEC_COMPOSITE,
format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
None, element_type_oids, element_names, element_codecs, 0)
return codec
@staticmethod
cdef Codec new_python_codec(uint32_t oid,
str name,
str schema,
str kind,
object encoder,
object decoder,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, kind, CODEC_PY, format, xformat,
c_encoder, c_decoder, base_codec, encoder, decoder,
None, None, None, None, 0)
return codec
# Encode callback for arrays
cdef codec_encode_func_ex(ConnectionSettings settings, WriteBuffer buf,
object obj, const void *arg):
return (<Codec>arg).encode(settings, buf, obj)
# Decode callback for arrays
cdef codec_decode_func_ex(ConnectionSettings settings, FRBuffer *buf,
const void *arg):
return (<Codec>arg).decode(settings, buf)
cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
cdef:
int64_t oid = 0
bint overflow = False
try:
oid = cpython.PyLong_AsLongLong(val)
except OverflowError:
overflow = True
if overflow or (oid < 0 or oid > UINT32_MAX):
raise OverflowError('OID value too large: {!r}'.format(val))
return <uint32_t>val
cdef class DataCodecConfig:
def __init__(self, cache_key):
# Codec instance cache for derived types:
# composites, arrays, ranges, domains and their combinations.
self._derived_type_codecs = {}
# Codec instances set up by the user for the connection.
self._custom_type_codecs = {}
def add_types(self, types):
cdef:
Codec elem_codec
list comp_elem_codecs
ServerDataFormat format
ServerDataFormat elem_format
bint has_text_elements
Py_UCS4 elem_delim
for ti in types:
oid = ti['oid']
if self.get_codec(oid, PG_FORMAT_ANY) is not None:
continue
name = ti['name']
schema = ti['ns']
array_element_oid = ti['elemtype']
range_subtype_oid = ti['range_subtype']
if ti['attrtypoids']:
comp_type_attrs = tuple(ti['attrtypoids'])
else:
comp_type_attrs = None
base_type = ti['basetype']
if array_element_oid:
# Array type (note, there is no separate 'kind' for arrays)
# Canonicalize type name to "elemtype[]"
if name.startswith('_'):
name = name[1:]
name = '{}[]'.format(name)
elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_codec = self.declare_fallback_codec(
array_element_oid, ti['elemtype_name'], schema)
elem_delim = <Py_UCS4>ti['elemdelim'][0]
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_array_codec(
oid, name, schema, elem_codec, elem_delim)
elif ti['kind'] == b'c':
# Composite type
if not comp_type_attrs:
raise exceptions.InternalClientError(
f'type record missing field types for composite {oid}')
comp_elem_codecs = []
has_text_elements = False
for typoid in comp_type_attrs:
elem_codec = self.get_codec(typoid, PG_FORMAT_ANY)
if elem_codec is None:
raise exceptions.InternalClientError(
f'no codec for composite attribute type {typoid}')
if elem_codec.format is PG_FORMAT_TEXT:
has_text_elements = True
comp_elem_codecs.append(elem_codec)
element_names = collections.OrderedDict()
for i, attrname in enumerate(ti['attrnames']):
element_names[attrname] = i
# If at least one element is text-encoded, we must
# encode the whole composite as text.
if has_text_elements:
elem_format = PG_FORMAT_TEXT
else:
elem_format = PG_FORMAT_BINARY
self._derived_type_codecs[oid, elem_format] = \
Codec.new_composite_codec(
oid, name, schema, elem_format, comp_elem_codecs,
comp_type_attrs, element_names)
elif ti['kind'] == b'd':
# Domain type
if not base_type:
raise exceptions.InternalClientError(
f'type record missing base type for domain {oid}')
elem_codec = self.get_codec(base_type, PG_FORMAT_ANY)
if elem_codec is None:
elem_codec = self.declare_fallback_codec(
base_type, ti['basetype_name'], schema)
self._derived_type_codecs[oid, elem_codec.format] = elem_codec
elif ti['kind'] == b'r':
# Range type
if not range_subtype_oid:
raise exceptions.InternalClientError(
f'type record missing base type for range {oid}')
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_codec = self.declare_fallback_codec(
range_subtype_oid, ti['range_subtype_name'], schema)
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_range_codec(oid, name, schema, elem_codec)
elif ti['kind'] == b'm':
# Multirange type
if not range_subtype_oid:
raise exceptions.InternalClientError(
f'type record missing base type for multirange {oid}')
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_codec = self.declare_fallback_codec(
range_subtype_oid, ti['range_subtype_name'], schema)
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_multirange_codec(oid, name, schema, elem_codec)
elif ti['kind'] == b'e':
# Enum types are essentially text
self._set_builtin_type_codec(oid, name, schema, 'scalar',
TEXTOID, PG_FORMAT_ANY)
else:
self.declare_fallback_codec(oid, name, schema)
def add_python_codec(self, typeoid, typename, typeschema, typekind,
typeinfos, encoder, decoder, format, xformat):
cdef:
Codec core_codec = None
encode_func c_encoder = NULL
decode_func c_decoder = NULL
Codec base_codec = None
uint32_t oid = pylong_as_oid(typeoid)
bint codec_set = False
# Clear all previous overrides (this also clears type cache).
self.remove_python_codec(typeoid, typename, typeschema)
if typeinfos:
self.add_types(typeinfos)
if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY)
else:
formats = (format,)
for fmt in formats:
if xformat == PG_XFORMAT_TUPLE:
if typekind == "scalar":
core_codec = get_core_codec(oid, fmt, xformat)
if core_codec is None:
continue
c_encoder = core_codec.c_encoder
c_decoder = core_codec.c_decoder
elif typekind == "composite":
base_codec = self.get_codec(oid, fmt)
if base_codec is None:
continue
self._custom_type_codecs[typeoid, fmt] = \
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
base_codec, fmt, xformat)
codec_set = True
if not codec_set:
raise exceptions.InterfaceError(
"{} type does not support the 'tuple' exchange format".format(
typename))
def remove_python_codec(self, typeoid, typename, typeschema):
for fmt in (PG_FORMAT_BINARY, PG_FORMAT_TEXT):
self._custom_type_codecs.pop((typeoid, fmt), None)
self.clear_type_cache()
def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
alias_to, format=PG_FORMAT_ANY):
cdef:
Codec codec
Codec target_codec
uint32_t oid = pylong_as_oid(typeoid)
uint32_t alias_oid = 0
bint codec_set = False
if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT)
else:
formats = (format,)
if isinstance(alias_to, int):
alias_oid = pylong_as_oid(alias_to)
else:
alias_oid = BUILTIN_TYPE_NAME_MAP.get(alias_to, 0)
for format in formats:
if alias_oid != 0:
target_codec = self.get_codec(alias_oid, format)
else:
target_codec = get_extra_codec(alias_to, format)
if target_codec is None:
continue
codec = target_codec.copy()
codec.oid = typeoid
codec.name = typename
codec.schema = typeschema
codec.kind = typekind
self._custom_type_codecs[typeoid, format] = codec
codec_set = True
if not codec_set:
if format == PG_FORMAT_BINARY:
codec_str = 'binary'
elif format == PG_FORMAT_TEXT:
codec_str = 'text'
else:
codec_str = 'text or binary'
raise exceptions.InterfaceError(
f'cannot alias {typename} to {alias_to}: '
f'there is no {codec_str} codec for {alias_to}')
def set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
alias_to, format=PG_FORMAT_ANY):
self._set_builtin_type_codec(typeoid, typename, typeschema, typekind,
alias_to, format)
self.clear_type_cache()
def clear_type_cache(self):
self._derived_type_codecs.clear()
def declare_fallback_codec(self, uint32_t oid, str name, str schema):
cdef Codec codec
if oid <= MAXBUILTINOID:
# This is a BKI type, for which asyncpg has no
# defined codec. This should only happen for newly
# added builtin types, for which this version of
# asyncpg is lacking support.
#
raise exceptions.UnsupportedClientFeatureError(
f'unhandled standard data type {name!r} (OID {oid})')
else:
# This is a non-BKI type, and as such, has no
# stable OID, so no possibility of a builtin codec.
# In this case, fallback to text format. Applications
# can avoid this by specifying a codec for this type
# using Connection.set_type_codec().
#
self._set_builtin_type_codec(oid, name, schema, 'scalar',
TEXTOID, PG_FORMAT_TEXT)
codec = self.get_codec(oid, PG_FORMAT_TEXT)
return codec
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint ignore_custom_codec=False):
cdef Codec codec
if format == PG_FORMAT_ANY:
codec = self.get_codec(
oid, PG_FORMAT_BINARY, ignore_custom_codec)
if codec is None:
codec = self.get_codec(
oid, PG_FORMAT_TEXT, ignore_custom_codec)
return codec
else:
if not ignore_custom_codec:
codec = self.get_custom_codec(oid, PG_FORMAT_ANY)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec
codec = get_core_codec(oid, format)
if codec is not None:
return codec
else:
try:
return self._derived_type_codecs[oid, format]
except KeyError:
return None
cdef inline Codec get_custom_codec(
self,
uint32_t oid,
ServerDataFormat format
):
cdef Codec codec
if format == PG_FORMAT_ANY:
codec = self.get_custom_codec(oid, PG_FORMAT_BINARY)
if codec is None:
codec = self.get_custom_codec(oid, PG_FORMAT_TEXT)
else:
codec = self._custom_type_codecs.get((oid, format))
return codec
cdef inline Codec get_core_codec(
uint32_t oid, ServerDataFormat format,
ClientExchangeFormat xformat=PG_XFORMAT_OBJECT):
cdef:
void *ptr = NULL
if oid > MAXSUPPORTEDOID:
return None
if format == PG_FORMAT_BINARY:
ptr = binary_codec_map[oid * xformat]
elif format == PG_FORMAT_TEXT:
ptr = text_codec_map[oid * xformat]
if ptr is NULL:
return None
else:
return <Codec>ptr
cdef inline Codec get_any_core_codec(
uint32_t oid, ServerDataFormat format,
ClientExchangeFormat xformat=PG_XFORMAT_OBJECT):
"""A version of get_core_codec that accepts PG_FORMAT_ANY."""
cdef:
Codec codec
if format == PG_FORMAT_ANY:
codec = get_core_codec(oid, PG_FORMAT_BINARY, xformat)
if codec is None:
codec = get_core_codec(oid, PG_FORMAT_TEXT, xformat)
else:
codec = get_core_codec(oid, format, xformat)
return codec
cdef inline int has_core_codec(uint32_t oid):
return binary_codec_map[oid] != NULL or text_codec_map[oid] != NULL
cdef register_core_codec(uint32_t oid,
encode_func encode,
decode_func decode,
ServerDataFormat format,
ClientExchangeFormat xformat=PG_XFORMAT_OBJECT):
if oid > MAXSUPPORTEDOID:
raise exceptions.InternalClientError(
'cannot register core codec for OID {}: it is greater '
'than MAXSUPPORTEDOID ({})'.format(oid, MAXSUPPORTEDOID))
cdef:
Codec codec
str name
str kind
name = BUILTIN_TYPE_OID_MAP[oid]
kind = 'array' if oid in ARRAY_TYPES else 'scalar'
codec = Codec(oid)
codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat,
encode, decode, None, None, None, None, None, None, None, 0)
cpython.Py_INCREF(codec) # immortalize
if format == PG_FORMAT_BINARY:
binary_codec_map[oid * xformat] = <void*>codec
elif format == PG_FORMAT_TEXT:
text_codec_map[oid * xformat] = <void*>codec
else:
raise exceptions.InternalClientError(
'invalid data format: {}'.format(format))
cdef register_extra_codec(str name,
encode_func encode,
decode_func decode,
ServerDataFormat format):
cdef:
Codec codec
str kind
kind = 'scalar'
codec = Codec(INVALIDOID)
codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT,
encode, decode, None, None, None, None, None, None, None, 0)
EXTRA_CODECS[name, format] = codec
cdef inline Codec get_extra_codec(str name, ServerDataFormat format):
return EXTRA_CODECS.get((name, format))

View File

@@ -0,0 +1,484 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef init_bits_codecs():
register_core_codec(BITOID,
<encode_func>pgproto.bits_encode,
<decode_func>pgproto.bits_decode,
PG_FORMAT_BINARY)
register_core_codec(VARBITOID,
<encode_func>pgproto.bits_encode,
<decode_func>pgproto.bits_decode,
PG_FORMAT_BINARY)
cdef init_bytea_codecs():
register_core_codec(BYTEAOID,
<encode_func>pgproto.bytea_encode,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
register_core_codec(CHAROID,
<encode_func>pgproto.bytea_encode,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
cdef init_datetime_codecs():
register_core_codec(DATEOID,
<encode_func>pgproto.date_encode,
<decode_func>pgproto.date_decode,
PG_FORMAT_BINARY)
register_core_codec(DATEOID,
<encode_func>pgproto.date_encode_tuple,
<decode_func>pgproto.date_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(TIMEOID,
<encode_func>pgproto.time_encode,
<decode_func>pgproto.time_decode,
PG_FORMAT_BINARY)
register_core_codec(TIMEOID,
<encode_func>pgproto.time_encode_tuple,
<decode_func>pgproto.time_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(TIMETZOID,
<encode_func>pgproto.timetz_encode,
<decode_func>pgproto.timetz_decode,
PG_FORMAT_BINARY)
register_core_codec(TIMETZOID,
<encode_func>pgproto.timetz_encode_tuple,
<decode_func>pgproto.timetz_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(TIMESTAMPOID,
<encode_func>pgproto.timestamp_encode,
<decode_func>pgproto.timestamp_decode,
PG_FORMAT_BINARY)
register_core_codec(TIMESTAMPOID,
<encode_func>pgproto.timestamp_encode_tuple,
<decode_func>pgproto.timestamp_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(TIMESTAMPTZOID,
<encode_func>pgproto.timestamptz_encode,
<decode_func>pgproto.timestamptz_decode,
PG_FORMAT_BINARY)
register_core_codec(TIMESTAMPTZOID,
<encode_func>pgproto.timestamp_encode_tuple,
<decode_func>pgproto.timestamp_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(INTERVALOID,
<encode_func>pgproto.interval_encode,
<decode_func>pgproto.interval_decode,
PG_FORMAT_BINARY)
register_core_codec(INTERVALOID,
<encode_func>pgproto.interval_encode_tuple,
<decode_func>pgproto.interval_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
# For obsolete abstime/reltime/tinterval, we do not bother to
# interpret the value, and simply return and pass it as text.
#
register_core_codec(ABSTIMEOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(RELTIMEOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(TINTERVALOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_float_codecs():
register_core_codec(FLOAT4OID,
<encode_func>pgproto.float4_encode,
<decode_func>pgproto.float4_decode,
PG_FORMAT_BINARY)
register_core_codec(FLOAT8OID,
<encode_func>pgproto.float8_encode,
<decode_func>pgproto.float8_decode,
PG_FORMAT_BINARY)
cdef init_geometry_codecs():
register_core_codec(BOXOID,
<encode_func>pgproto.box_encode,
<decode_func>pgproto.box_decode,
PG_FORMAT_BINARY)
register_core_codec(LINEOID,
<encode_func>pgproto.line_encode,
<decode_func>pgproto.line_decode,
PG_FORMAT_BINARY)
register_core_codec(LSEGOID,
<encode_func>pgproto.lseg_encode,
<decode_func>pgproto.lseg_decode,
PG_FORMAT_BINARY)
register_core_codec(POINTOID,
<encode_func>pgproto.point_encode,
<decode_func>pgproto.point_decode,
PG_FORMAT_BINARY)
register_core_codec(PATHOID,
<encode_func>pgproto.path_encode,
<decode_func>pgproto.path_decode,
PG_FORMAT_BINARY)
register_core_codec(POLYGONOID,
<encode_func>pgproto.poly_encode,
<decode_func>pgproto.poly_decode,
PG_FORMAT_BINARY)
register_core_codec(CIRCLEOID,
<encode_func>pgproto.circle_encode,
<decode_func>pgproto.circle_decode,
PG_FORMAT_BINARY)
cdef init_hstore_codecs():
register_extra_codec('pg_contrib.hstore',
<encode_func>pgproto.hstore_encode,
<decode_func>pgproto.hstore_decode,
PG_FORMAT_BINARY)
cdef init_json_codecs():
register_core_codec(JSONOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_BINARY)
register_core_codec(JSONBOID,
<encode_func>pgproto.jsonb_encode,
<decode_func>pgproto.jsonb_decode,
PG_FORMAT_BINARY)
register_core_codec(JSONPATHOID,
<encode_func>pgproto.jsonpath_encode,
<decode_func>pgproto.jsonpath_decode,
PG_FORMAT_BINARY)
cdef init_int_codecs():
register_core_codec(BOOLOID,
<encode_func>pgproto.bool_encode,
<decode_func>pgproto.bool_decode,
PG_FORMAT_BINARY)
register_core_codec(INT2OID,
<encode_func>pgproto.int2_encode,
<decode_func>pgproto.int2_decode,
PG_FORMAT_BINARY)
register_core_codec(INT4OID,
<encode_func>pgproto.int4_encode,
<decode_func>pgproto.int4_decode,
PG_FORMAT_BINARY)
register_core_codec(INT8OID,
<encode_func>pgproto.int8_encode,
<decode_func>pgproto.int8_decode,
PG_FORMAT_BINARY)
cdef init_pseudo_codecs():
# Void type is returned by SELECT void_returning_function()
register_core_codec(VOIDOID,
<encode_func>pgproto.void_encode,
<decode_func>pgproto.void_decode,
PG_FORMAT_BINARY)
# Unknown type, always decoded as text
register_core_codec(UNKNOWNOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# OID and friends
oid_types = [
OIDOID, XIDOID, CIDOID
]
for oid_type in oid_types:
register_core_codec(oid_type,
<encode_func>pgproto.uint4_encode,
<decode_func>pgproto.uint4_decode,
PG_FORMAT_BINARY)
# 64-bit OID types
oid8_types = [
XID8OID,
]
for oid_type in oid8_types:
register_core_codec(oid_type,
<encode_func>pgproto.uint8_encode,
<decode_func>pgproto.uint8_decode,
PG_FORMAT_BINARY)
# reg* types -- these are really system catalog OIDs, but
# allow the catalog object name as an input. We could just
# decode these as OIDs, but handling them as text seems more
# useful.
#
reg_types = [
REGPROCOID, REGPROCEDUREOID, REGOPEROID, REGOPERATOROID,
REGCLASSOID, REGTYPEOID, REGCONFIGOID, REGDICTIONARYOID,
REGNAMESPACEOID, REGROLEOID, REFCURSOROID, REGCOLLATIONOID,
]
for reg_type in reg_types:
register_core_codec(reg_type,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# cstring type is used by Postgres' I/O functions
register_core_codec(CSTRINGOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_BINARY)
# various system pseudotypes with no I/O
no_io_types = [
ANYOID, TRIGGEROID, EVENT_TRIGGEROID, LANGUAGE_HANDLEROID,
FDW_HANDLEROID, TSM_HANDLEROID, INTERNALOID, OPAQUEOID,
ANYELEMENTOID, ANYNONARRAYOID, ANYCOMPATIBLEOID,
ANYCOMPATIBLEARRAYOID, ANYCOMPATIBLENONARRAYOID,
ANYCOMPATIBLERANGEOID, ANYCOMPATIBLEMULTIRANGEOID,
ANYRANGEOID, ANYMULTIRANGEOID, ANYARRAYOID,
PG_DDL_COMMANDOID, INDEX_AM_HANDLEROID, TABLE_AM_HANDLEROID,
]
register_core_codec(ANYENUMOID,
NULL,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
for no_io_type in no_io_types:
register_core_codec(no_io_type,
NULL,
NULL,
PG_FORMAT_BINARY)
# ACL specification string
register_core_codec(ACLITEMOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# Postgres' serialized expression tree type
register_core_codec(PG_NODE_TREEOID,
NULL,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# pg_lsn type -- a pointer to a location in the XLOG.
register_core_codec(PG_LSNOID,
<encode_func>pgproto.int8_encode,
<decode_func>pgproto.int8_decode,
PG_FORMAT_BINARY)
register_core_codec(SMGROID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# pg_dependencies and pg_ndistinct are special types
# used in pg_statistic_ext columns.
register_core_codec(PG_DEPENDENCIESOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(PG_NDISTINCTOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# pg_mcv_list is a special type used in pg_statistic_ext_data
# system catalog
register_core_codec(PG_MCV_LISTOID,
<encode_func>pgproto.bytea_encode,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
# These two are internal to BRIN index support and are unlikely
# to be sent, but since I/O functions for these exist, add decoders
# nonetheless.
register_core_codec(PG_BRIN_BLOOM_SUMMARYOID,
NULL,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
register_core_codec(PG_BRIN_MINMAX_MULTI_SUMMARYOID,
NULL,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
cdef init_text_codecs():
textoids = [
NAMEOID,
BPCHAROID,
VARCHAROID,
TEXTOID,
XMLOID
]
for oid in textoids:
register_core_codec(oid,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_BINARY)
register_core_codec(oid,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_tid_codecs():
register_core_codec(TIDOID,
<encode_func>pgproto.tid_encode,
<decode_func>pgproto.tid_decode,
PG_FORMAT_BINARY)
cdef init_txid_codecs():
register_core_codec(TXID_SNAPSHOTOID,
<encode_func>pgproto.pg_snapshot_encode,
<decode_func>pgproto.pg_snapshot_decode,
PG_FORMAT_BINARY)
register_core_codec(PG_SNAPSHOTOID,
<encode_func>pgproto.pg_snapshot_encode,
<decode_func>pgproto.pg_snapshot_decode,
PG_FORMAT_BINARY)
cdef init_tsearch_codecs():
ts_oids = [
TSQUERYOID,
TSVECTOROID,
]
for oid in ts_oids:
register_core_codec(oid,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(GTSVECTOROID,
NULL,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_uuid_codecs():
register_core_codec(UUIDOID,
<encode_func>pgproto.uuid_encode,
<decode_func>pgproto.uuid_decode,
PG_FORMAT_BINARY)
cdef init_numeric_codecs():
register_core_codec(NUMERICOID,
<encode_func>pgproto.numeric_encode_text,
<decode_func>pgproto.numeric_decode_text,
PG_FORMAT_TEXT)
register_core_codec(NUMERICOID,
<encode_func>pgproto.numeric_encode_binary,
<decode_func>pgproto.numeric_decode_binary,
PG_FORMAT_BINARY)
cdef init_network_codecs():
register_core_codec(CIDROID,
<encode_func>pgproto.cidr_encode,
<decode_func>pgproto.cidr_decode,
PG_FORMAT_BINARY)
register_core_codec(INETOID,
<encode_func>pgproto.inet_encode,
<decode_func>pgproto.inet_decode,
PG_FORMAT_BINARY)
register_core_codec(MACADDROID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(MACADDR8OID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_monetary_codecs():
moneyoids = [
MONEYOID,
]
for oid in moneyoids:
register_core_codec(oid,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_all_pgproto_codecs():
# Builtin types, in lexicographical order.
init_bits_codecs()
init_bytea_codecs()
init_datetime_codecs()
init_float_codecs()
init_geometry_codecs()
init_int_codecs()
init_json_codecs()
init_monetary_codecs()
init_network_codecs()
init_numeric_codecs()
init_text_codecs()
init_tid_codecs()
init_tsearch_codecs()
init_txid_codecs()
init_uuid_codecs()
# Various pseudotypes and system types
init_pseudo_codecs()
# contrib
init_hstore_codecs()
init_all_pgproto_codecs()

View File

@@ -0,0 +1,207 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from asyncpg import types as apg_types
from collections.abc import Sequence as SequenceABC
# defined in postgresql/src/include/utils/rangetypes.h
DEF RANGE_EMPTY = 0x01 # range is empty
DEF RANGE_LB_INC = 0x02 # lower bound is inclusive
DEF RANGE_UB_INC = 0x04 # upper bound is inclusive
DEF RANGE_LB_INF = 0x08 # lower bound is -infinity
DEF RANGE_UB_INF = 0x10 # upper bound is +infinity
cdef enum _RangeArgumentType:
_RANGE_ARGUMENT_INVALID = 0
_RANGE_ARGUMENT_TUPLE = 1
_RANGE_ARGUMENT_RANGE = 2
cdef inline bint _range_has_lbound(uint8_t flags):
return not (flags & (RANGE_EMPTY | RANGE_LB_INF))
cdef inline bint _range_has_ubound(uint8_t flags):
return not (flags & (RANGE_EMPTY | RANGE_UB_INF))
cdef inline _RangeArgumentType _range_type(object obj):
if cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj):
return _RANGE_ARGUMENT_TUPLE
elif isinstance(obj, apg_types.Range):
return _RANGE_ARGUMENT_RANGE
else:
return _RANGE_ARGUMENT_INVALID
cdef range_encode(ConnectionSettings settings, WriteBuffer buf,
object obj, uint32_t elem_oid,
encode_func_ex encoder, const void *encoder_arg):
cdef:
ssize_t obj_len
uint8_t flags = 0
object lower = None
object upper = None
WriteBuffer bounds_data = WriteBuffer.new()
_RangeArgumentType arg_type = _range_type(obj)
if arg_type == _RANGE_ARGUMENT_INVALID:
raise TypeError(
'list, tuple or Range object expected (got type {})'.format(
type(obj)))
elif arg_type == _RANGE_ARGUMENT_TUPLE:
obj_len = len(obj)
if obj_len == 2:
lower = obj[0]
upper = obj[1]
if lower is None:
flags |= RANGE_LB_INF
if upper is None:
flags |= RANGE_UB_INF
flags |= RANGE_LB_INC | RANGE_UB_INC
elif obj_len == 1:
lower = obj[0]
flags |= RANGE_LB_INC | RANGE_UB_INF
elif obj_len == 0:
flags |= RANGE_EMPTY
else:
raise ValueError(
'expected 0, 1 or 2 elements in range (got {})'.format(
obj_len))
else:
if obj.isempty:
flags |= RANGE_EMPTY
else:
lower = obj.lower
upper = obj.upper
if obj.lower_inc:
flags |= RANGE_LB_INC
elif lower is None:
flags |= RANGE_LB_INF
if obj.upper_inc:
flags |= RANGE_UB_INC
elif upper is None:
flags |= RANGE_UB_INF
if _range_has_lbound(flags):
encoder(settings, bounds_data, lower, encoder_arg)
if _range_has_ubound(flags):
encoder(settings, bounds_data, upper, encoder_arg)
buf.write_int32(1 + bounds_data.len())
buf.write_byte(<int8_t>flags)
buf.write_buffer(bounds_data)
cdef range_decode(ConnectionSettings settings, FRBuffer *buf,
decode_func_ex decoder, const void *decoder_arg):
cdef:
uint8_t flags = <uint8_t>frb_read(buf, 1)[0]
int32_t bound_len
object lower = None
object upper = None
FRBuffer bound_buf
if _range_has_lbound(flags):
bound_len = hton.unpack_int32(frb_read(buf, 4))
if bound_len == -1:
lower = None
else:
frb_slice_from(&bound_buf, buf, bound_len)
lower = decoder(settings, &bound_buf, decoder_arg)
if _range_has_ubound(flags):
bound_len = hton.unpack_int32(frb_read(buf, 4))
if bound_len == -1:
upper = None
else:
frb_slice_from(&bound_buf, buf, bound_len)
upper = decoder(settings, &bound_buf, decoder_arg)
return apg_types.Range(lower=lower, upper=upper,
lower_inc=(flags & RANGE_LB_INC) != 0,
upper_inc=(flags & RANGE_UB_INC) != 0,
empty=(flags & RANGE_EMPTY) != 0)
cdef multirange_encode(ConnectionSettings settings, WriteBuffer buf,
object obj, uint32_t elem_oid,
encode_func_ex encoder, const void *encoder_arg):
cdef:
WriteBuffer elem_data
ssize_t elem_data_len
ssize_t elem_count
if not isinstance(obj, SequenceABC):
raise TypeError(
'expected a sequence (got type {!r})'.format(type(obj).__name__)
)
elem_data = WriteBuffer.new()
for elem in obj:
range_encode(settings, elem_data, elem, elem_oid, encoder, encoder_arg)
elem_count = len(obj)
if elem_count > INT32_MAX:
raise OverflowError(f'too many elements in multirange value')
elem_data_len = elem_data.len()
if elem_data_len > INT32_MAX - 4:
raise OverflowError(
f'size of encoded multirange datum exceeds the maximum allowed'
f' {INT32_MAX - 4} bytes')
# Datum length
buf.write_int32(4 + <int32_t>elem_data_len)
# Number of elements in multirange
buf.write_int32(<int32_t>elem_count)
buf.write_buffer(elem_data)
cdef multirange_decode(ConnectionSettings settings, FRBuffer *buf,
decode_func_ex decoder, const void *decoder_arg):
cdef:
int32_t nelems = hton.unpack_int32(frb_read(buf, 4))
FRBuffer elem_buf
int32_t elem_len
int i
list result
if nelems == 0:
return []
if nelems < 0:
raise exceptions.ProtocolError(
'unexpected multirange size value: {}'.format(nelems))
result = cpython.PyList_New(nelems)
for i in range(nelems):
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
raise exceptions.ProtocolError(
'unexpected NULL element in multirange value')
else:
frb_slice_from(&elem_buf, buf, elem_len)
elem = range_decode(settings, &elem_buf, decoder, decoder_arg)
cpython.Py_INCREF(elem)
cpython.PyList_SET_ITEM(result, i, elem)
return result

View File

@@ -0,0 +1,71 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from asyncpg import exceptions
cdef inline record_encode_frame(ConnectionSettings settings, WriteBuffer buf,
WriteBuffer elem_data, int32_t elem_count):
buf.write_int32(4 + elem_data.len())
# attribute count
buf.write_int32(elem_count)
# encoded attribute data
buf.write_buffer(elem_data)
cdef anonymous_record_decode(ConnectionSettings settings, FRBuffer *buf):
cdef:
tuple result
ssize_t elem_count
ssize_t i
int32_t elem_len
uint32_t elem_typ
Codec elem_codec
FRBuffer elem_buf
elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4))
result = cpython.PyTuple_New(elem_count)
for i in range(elem_count):
elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
elem = None
else:
elem_codec = settings.get_data_codec(elem_typ)
if elem_codec is None or not elem_codec.has_decoder():
raise exceptions.InternalClientError(
'no decoder for composite type element in '
'position {} of type OID {}'.format(i, elem_typ))
elem = elem_codec.decode(settings,
frb_slice_from(&elem_buf, buf, elem_len))
cpython.Py_INCREF(elem)
cpython.PyTuple_SET_ITEM(result, i, elem)
return result
cdef anonymous_record_encode(ConnectionSettings settings, WriteBuffer buf, obj):
raise exceptions.UnsupportedClientFeatureError(
'input of anonymous composite types is not supported',
hint=(
'Consider declaring an explicit composite type and '
'using it to cast the argument.'
),
detail='PostgreSQL does not implement anonymous composite type input.'
)
cdef init_record_codecs():
register_core_codec(RECORDOID,
<encode_func>anonymous_record_encode,
<decode_func>anonymous_record_decode,
PG_FORMAT_BINARY)
init_record_codecs()

View File

@@ -0,0 +1,99 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef inline uint32_t _apg_tolower(uint32_t c):
if c >= <uint32_t><Py_UCS4>'A' and c <= <uint32_t><Py_UCS4>'Z':
return c + <uint32_t><Py_UCS4>'a' - <uint32_t><Py_UCS4>'A'
else:
return c
cdef int apg_strcasecmp(const Py_UCS4 *s1, const Py_UCS4 *s2):
cdef:
uint32_t c1
uint32_t c2
int i = 0
while True:
c1 = s1[i]
c2 = s2[i]
if c1 != c2:
c1 = _apg_tolower(c1)
c2 = _apg_tolower(c2)
if c1 != c2:
return <int32_t>c1 - <int32_t>c2
if c1 == 0 or c2 == 0:
break
i += 1
return 0
cdef int apg_strcasecmp_char(const char *s1, const char *s2):
cdef:
uint8_t c1
uint8_t c2
int i = 0
while True:
c1 = <uint8_t>s1[i]
c2 = <uint8_t>s2[i]
if c1 != c2:
c1 = <uint8_t>_apg_tolower(c1)
c2 = <uint8_t>_apg_tolower(c2)
if c1 != c2:
return <int8_t>c1 - <int8_t>c2
if c1 == 0 or c2 == 0:
break
i += 1
return 0
cdef inline bint apg_ascii_isspace(Py_UCS4 ch):
return (
ch == ' ' or
ch == '\n' or
ch == '\r' or
ch == '\t' or
ch == '\v' or
ch == '\f'
)
cdef Py_UCS4 *apg_parse_int32(Py_UCS4 *buf, int32_t *num):
cdef:
Py_UCS4 *p
int32_t n = 0
int32_t neg = 0
if buf[0] == '-':
neg = 1
buf += 1
elif buf[0] == '+':
buf += 1
p = buf
while <int>p[0] >= <int><Py_UCS4>'0' and <int>p[0] <= <int><Py_UCS4>'9':
n = 10 * n - (<int>p[0] - <int32_t><Py_UCS4>'0')
p += 1
if p == buf:
return NULL
if not neg:
n = -n
num[0] = n
return p

View File

@@ -0,0 +1,12 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
DEF _MAXINT32 = 2**31 - 1
DEF _COPY_BUFFER_SIZE = 524288
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
DEF _EXECUTE_MANY_BUF_NUM = 4
DEF _EXECUTE_MANY_BUF_SIZE = 32768

View File

@@ -0,0 +1,195 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
include "scram.pxd"
cdef enum ConnectionStatus:
CONNECTION_OK = 1
CONNECTION_BAD = 2
CONNECTION_STARTED = 3 # Waiting for connection to be made.
cdef enum ProtocolState:
PROTOCOL_IDLE = 0
PROTOCOL_FAILED = 1
PROTOCOL_ERROR_CONSUME = 2
PROTOCOL_CANCELLED = 3
PROTOCOL_TERMINATING = 4
PROTOCOL_AUTH = 10
PROTOCOL_PREPARE = 11
PROTOCOL_BIND_EXECUTE = 12
PROTOCOL_BIND_EXECUTE_MANY = 13
PROTOCOL_CLOSE_STMT_PORTAL = 14
PROTOCOL_SIMPLE_QUERY = 15
PROTOCOL_EXECUTE = 16
PROTOCOL_BIND = 17
PROTOCOL_COPY_OUT = 18
PROTOCOL_COPY_OUT_DATA = 19
PROTOCOL_COPY_OUT_DONE = 20
PROTOCOL_COPY_IN = 21
PROTOCOL_COPY_IN_DATA = 22
cdef enum AuthenticationMessage:
AUTH_SUCCESSFUL = 0
AUTH_REQUIRED_KERBEROS = 2
AUTH_REQUIRED_PASSWORD = 3
AUTH_REQUIRED_PASSWORDMD5 = 5
AUTH_REQUIRED_SCMCRED = 6
AUTH_REQUIRED_GSS = 7
AUTH_REQUIRED_GSS_CONTINUE = 8
AUTH_REQUIRED_SSPI = 9
AUTH_REQUIRED_SASL = 10
AUTH_SASL_CONTINUE = 11
AUTH_SASL_FINAL = 12
AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}
cdef enum ResultType:
RESULT_OK = 1
RESULT_FAILED = 2
cdef enum TransactionStatus:
PQTRANS_IDLE = 0 # connection idle
PQTRANS_ACTIVE = 1 # command in progress
PQTRANS_INTRANS = 2 # idle, within transaction block
PQTRANS_INERROR = 3 # idle, within failed transaction
PQTRANS_UNKNOWN = 4 # cannot determine status
ctypedef object (*decode_row_method)(object, const char*, ssize_t)
cdef class CoreProtocol:
cdef:
ReadBuffer buffer
bint _skip_discard
bint _discard_data
# executemany support data
object _execute_iter
str _execute_portal_name
str _execute_stmt_name
ConnectionStatus con_status
ProtocolState state
TransactionStatus xact_status
str encoding
object transport
# Instance of _ConnectionParameters
object con_params
# Instance of SCRAMAuthentication
SCRAMAuthentication scram
readonly int32_t backend_pid
readonly int32_t backend_secret
## Result
ResultType result_type
object result
bytes result_param_desc
bytes result_row_desc
bytes result_status_msg
# True - completed, False - suspended
bint result_execute_completed
cpdef is_in_transaction(self)
cdef _process__auth(self, char mtype)
cdef _process__prepare(self, char mtype)
cdef _process__bind_execute(self, char mtype)
cdef _process__bind_execute_many(self, char mtype)
cdef _process__close_stmt_portal(self, char mtype)
cdef _process__simple_query(self, char mtype)
cdef _process__bind(self, char mtype)
cdef _process__copy_out(self, char mtype)
cdef _process__copy_out_data(self, char mtype)
cdef _process__copy_in(self, char mtype)
cdef _process__copy_in_data(self, char mtype)
cdef _parse_msg_authentication(self)
cdef _parse_msg_parameter_status(self)
cdef _parse_msg_notification(self)
cdef _parse_msg_backend_key_data(self)
cdef _parse_msg_ready_for_query(self)
cdef _parse_data_msgs(self)
cdef _parse_copy_data_msgs(self)
cdef _parse_msg_error_response(self, is_error)
cdef _parse_msg_command_complete(self)
cdef _write_copy_data_msg(self, object data)
cdef _write_copy_done_msg(self)
cdef _write_copy_fail_msg(self, str cause)
cdef _auth_password_message_cleartext(self)
cdef _auth_password_message_md5(self, bytes salt)
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
cdef _auth_password_message_sasl_continue(self, bytes server_response)
cdef _write(self, buf)
cdef _writelines(self, list buffers)
cdef _read_server_messages(self)
cdef _push_result(self)
cdef _reset_result(self)
cdef _set_state(self, ProtocolState new_state)
cdef _ensure_connected(self)
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
cdef WriteBuffer _build_bind_message(self, str portal_name,
str stmt_name,
WriteBuffer bind_data)
cdef WriteBuffer _build_empty_bind_data(self)
cdef WriteBuffer _build_execute_message(self, str portal_name,
int32_t limit)
cdef _connect(self)
cdef _prepare_and_describe(self, str stmt_name, str query)
cdef _send_parse_message(self, str stmt_name, str query)
cdef _send_bind_message(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,
WriteBuffer bind_data)
cdef _execute(self, str portal_name, int32_t limit)
cdef _close(self, str name, bint is_portal)
cdef _simple_query(self, str query)
cdef _copy_out(self, str copy_stmt)
cdef _copy_in(self, str copy_stmt)
cdef _terminate(self)
cdef _decode_row(self, const char* buf, ssize_t buf_len)
cdef _on_result(self)
cdef _on_notification(self, pid, channel, payload)
cdef _on_notice(self, parsed)
cdef _set_server_parameter(self, name, val)
cdef _on_connection_lost(self, exc)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef extern from "Python.h":
int PyByteArray_Check(object)
int PyMemoryView_Check(object)
Py_buffer *PyMemoryView_GET_BUFFER(object)
object PyMemoryView_GetContiguous(object, int buffertype, char order)
Py_UCS4* PyUnicode_AsUCS4Copy(object) except NULL
object PyUnicode_FromKindAndData(
int kind, const void *buffer, Py_ssize_t size)
int PyUnicode_4BYTE_KIND

View File

@@ -0,0 +1,63 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
'''Map PostgreSQL encoding names to Python encoding names
https://www.postgresql.org/docs/current/static/multibyte.html#CHARSET-TABLE
'''
cdef dict ENCODINGS_MAP = {
'abc': 'cp1258',
'alt': 'cp866',
'euc_cn': 'euccn',
'euc_jp': 'eucjp',
'euc_kr': 'euckr',
'koi8r': 'koi8_r',
'koi8u': 'koi8_u',
'shift_jis_2004': 'euc_jis_2004',
'sjis': 'shift_jis',
'sql_ascii': 'ascii',
'vscii': 'cp1258',
'tcvn': 'cp1258',
'tcvn5712': 'cp1258',
'unicode': 'utf_8',
'win': 'cp1521',
'win1250': 'cp1250',
'win1251': 'cp1251',
'win1252': 'cp1252',
'win1253': 'cp1253',
'win1254': 'cp1254',
'win1255': 'cp1255',
'win1256': 'cp1256',
'win1257': 'cp1257',
'win1258': 'cp1258',
'win866': 'cp866',
'win874': 'cp874',
'win932': 'cp932',
'win936': 'cp936',
'win949': 'cp949',
'win950': 'cp950',
'windows1250': 'cp1250',
'windows1251': 'cp1251',
'windows1252': 'cp1252',
'windows1253': 'cp1253',
'windows1254': 'cp1254',
'windows1255': 'cp1255',
'windows1256': 'cp1256',
'windows1257': 'cp1257',
'windows1258': 'cp1258',
'windows866': 'cp866',
'windows874': 'cp874',
'windows932': 'cp932',
'windows936': 'cp936',
'windows949': 'cp949',
'windows950': 'cp950',
}
cdef get_python_encoding(pg_encoding):
return ENCODINGS_MAP.get(pg_encoding.lower(), pg_encoding.lower())

View File

@@ -0,0 +1,266 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
# GENERATED FROM pg_catalog.pg_type
# DO NOT MODIFY, use tools/generate_type_map.py to update
DEF INVALIDOID = 0
DEF MAXBUILTINOID = 9999
DEF MAXSUPPORTEDOID = 5080
DEF BOOLOID = 16
DEF BYTEAOID = 17
DEF CHAROID = 18
DEF NAMEOID = 19
DEF INT8OID = 20
DEF INT2OID = 21
DEF INT4OID = 23
DEF REGPROCOID = 24
DEF TEXTOID = 25
DEF OIDOID = 26
DEF TIDOID = 27
DEF XIDOID = 28
DEF CIDOID = 29
DEF PG_DDL_COMMANDOID = 32
DEF JSONOID = 114
DEF XMLOID = 142
DEF PG_NODE_TREEOID = 194
DEF SMGROID = 210
DEF TABLE_AM_HANDLEROID = 269
DEF INDEX_AM_HANDLEROID = 325
DEF POINTOID = 600
DEF LSEGOID = 601
DEF PATHOID = 602
DEF BOXOID = 603
DEF POLYGONOID = 604
DEF LINEOID = 628
DEF CIDROID = 650
DEF FLOAT4OID = 700
DEF FLOAT8OID = 701
DEF ABSTIMEOID = 702
DEF RELTIMEOID = 703
DEF TINTERVALOID = 704
DEF UNKNOWNOID = 705
DEF CIRCLEOID = 718
DEF MACADDR8OID = 774
DEF MONEYOID = 790
DEF MACADDROID = 829
DEF INETOID = 869
DEF _TEXTOID = 1009
DEF _OIDOID = 1028
DEF ACLITEMOID = 1033
DEF BPCHAROID = 1042
DEF VARCHAROID = 1043
DEF DATEOID = 1082
DEF TIMEOID = 1083
DEF TIMESTAMPOID = 1114
DEF TIMESTAMPTZOID = 1184
DEF INTERVALOID = 1186
DEF TIMETZOID = 1266
DEF BITOID = 1560
DEF VARBITOID = 1562
DEF NUMERICOID = 1700
DEF REFCURSOROID = 1790
DEF REGPROCEDUREOID = 2202
DEF REGOPEROID = 2203
DEF REGOPERATOROID = 2204
DEF REGCLASSOID = 2205
DEF REGTYPEOID = 2206
DEF RECORDOID = 2249
DEF CSTRINGOID = 2275
DEF ANYOID = 2276
DEF ANYARRAYOID = 2277
DEF VOIDOID = 2278
DEF TRIGGEROID = 2279
DEF LANGUAGE_HANDLEROID = 2280
DEF INTERNALOID = 2281
DEF OPAQUEOID = 2282
DEF ANYELEMENTOID = 2283
DEF ANYNONARRAYOID = 2776
DEF UUIDOID = 2950
DEF TXID_SNAPSHOTOID = 2970
DEF FDW_HANDLEROID = 3115
DEF PG_LSNOID = 3220
DEF TSM_HANDLEROID = 3310
DEF PG_NDISTINCTOID = 3361
DEF PG_DEPENDENCIESOID = 3402
DEF ANYENUMOID = 3500
DEF TSVECTOROID = 3614
DEF TSQUERYOID = 3615
DEF GTSVECTOROID = 3642
DEF REGCONFIGOID = 3734
DEF REGDICTIONARYOID = 3769
DEF JSONBOID = 3802
DEF ANYRANGEOID = 3831
DEF EVENT_TRIGGEROID = 3838
DEF JSONPATHOID = 4072
DEF REGNAMESPACEOID = 4089
DEF REGROLEOID = 4096
DEF REGCOLLATIONOID = 4191
DEF ANYMULTIRANGEOID = 4537
DEF ANYCOMPATIBLEMULTIRANGEOID = 4538
DEF PG_BRIN_BLOOM_SUMMARYOID = 4600
DEF PG_BRIN_MINMAX_MULTI_SUMMARYOID = 4601
DEF PG_MCV_LISTOID = 5017
DEF PG_SNAPSHOTOID = 5038
DEF XID8OID = 5069
DEF ANYCOMPATIBLEOID = 5077
DEF ANYCOMPATIBLEARRAYOID = 5078
DEF ANYCOMPATIBLENONARRAYOID = 5079
DEF ANYCOMPATIBLERANGEOID = 5080
cdef ARRAY_TYPES = (_TEXTOID, _OIDOID,)
BUILTIN_TYPE_OID_MAP = {
ABSTIMEOID: 'abstime',
ACLITEMOID: 'aclitem',
ANYARRAYOID: 'anyarray',
ANYCOMPATIBLEARRAYOID: 'anycompatiblearray',
ANYCOMPATIBLEMULTIRANGEOID: 'anycompatiblemultirange',
ANYCOMPATIBLENONARRAYOID: 'anycompatiblenonarray',
ANYCOMPATIBLEOID: 'anycompatible',
ANYCOMPATIBLERANGEOID: 'anycompatiblerange',
ANYELEMENTOID: 'anyelement',
ANYENUMOID: 'anyenum',
ANYMULTIRANGEOID: 'anymultirange',
ANYNONARRAYOID: 'anynonarray',
ANYOID: 'any',
ANYRANGEOID: 'anyrange',
BITOID: 'bit',
BOOLOID: 'bool',
BOXOID: 'box',
BPCHAROID: 'bpchar',
BYTEAOID: 'bytea',
CHAROID: 'char',
CIDOID: 'cid',
CIDROID: 'cidr',
CIRCLEOID: 'circle',
CSTRINGOID: 'cstring',
DATEOID: 'date',
EVENT_TRIGGEROID: 'event_trigger',
FDW_HANDLEROID: 'fdw_handler',
FLOAT4OID: 'float4',
FLOAT8OID: 'float8',
GTSVECTOROID: 'gtsvector',
INDEX_AM_HANDLEROID: 'index_am_handler',
INETOID: 'inet',
INT2OID: 'int2',
INT4OID: 'int4',
INT8OID: 'int8',
INTERNALOID: 'internal',
INTERVALOID: 'interval',
JSONBOID: 'jsonb',
JSONOID: 'json',
JSONPATHOID: 'jsonpath',
LANGUAGE_HANDLEROID: 'language_handler',
LINEOID: 'line',
LSEGOID: 'lseg',
MACADDR8OID: 'macaddr8',
MACADDROID: 'macaddr',
MONEYOID: 'money',
NAMEOID: 'name',
NUMERICOID: 'numeric',
OIDOID: 'oid',
OPAQUEOID: 'opaque',
PATHOID: 'path',
PG_BRIN_BLOOM_SUMMARYOID: 'pg_brin_bloom_summary',
PG_BRIN_MINMAX_MULTI_SUMMARYOID: 'pg_brin_minmax_multi_summary',
PG_DDL_COMMANDOID: 'pg_ddl_command',
PG_DEPENDENCIESOID: 'pg_dependencies',
PG_LSNOID: 'pg_lsn',
PG_MCV_LISTOID: 'pg_mcv_list',
PG_NDISTINCTOID: 'pg_ndistinct',
PG_NODE_TREEOID: 'pg_node_tree',
PG_SNAPSHOTOID: 'pg_snapshot',
POINTOID: 'point',
POLYGONOID: 'polygon',
RECORDOID: 'record',
REFCURSOROID: 'refcursor',
REGCLASSOID: 'regclass',
REGCOLLATIONOID: 'regcollation',
REGCONFIGOID: 'regconfig',
REGDICTIONARYOID: 'regdictionary',
REGNAMESPACEOID: 'regnamespace',
REGOPERATOROID: 'regoperator',
REGOPEROID: 'regoper',
REGPROCEDUREOID: 'regprocedure',
REGPROCOID: 'regproc',
REGROLEOID: 'regrole',
REGTYPEOID: 'regtype',
RELTIMEOID: 'reltime',
SMGROID: 'smgr',
TABLE_AM_HANDLEROID: 'table_am_handler',
TEXTOID: 'text',
TIDOID: 'tid',
TIMEOID: 'time',
TIMESTAMPOID: 'timestamp',
TIMESTAMPTZOID: 'timestamptz',
TIMETZOID: 'timetz',
TINTERVALOID: 'tinterval',
TRIGGEROID: 'trigger',
TSM_HANDLEROID: 'tsm_handler',
TSQUERYOID: 'tsquery',
TSVECTOROID: 'tsvector',
TXID_SNAPSHOTOID: 'txid_snapshot',
UNKNOWNOID: 'unknown',
UUIDOID: 'uuid',
VARBITOID: 'varbit',
VARCHAROID: 'varchar',
VOIDOID: 'void',
XID8OID: 'xid8',
XIDOID: 'xid',
XMLOID: 'xml',
_OIDOID: 'oid[]',
_TEXTOID: 'text[]'
}
BUILTIN_TYPE_NAME_MAP = {v: k for k, v in BUILTIN_TYPE_OID_MAP.items()}
BUILTIN_TYPE_NAME_MAP['smallint'] = \
BUILTIN_TYPE_NAME_MAP['int2']
BUILTIN_TYPE_NAME_MAP['int'] = \
BUILTIN_TYPE_NAME_MAP['int4']
BUILTIN_TYPE_NAME_MAP['integer'] = \
BUILTIN_TYPE_NAME_MAP['int4']
BUILTIN_TYPE_NAME_MAP['bigint'] = \
BUILTIN_TYPE_NAME_MAP['int8']
BUILTIN_TYPE_NAME_MAP['decimal'] = \
BUILTIN_TYPE_NAME_MAP['numeric']
BUILTIN_TYPE_NAME_MAP['real'] = \
BUILTIN_TYPE_NAME_MAP['float4']
BUILTIN_TYPE_NAME_MAP['double precision'] = \
BUILTIN_TYPE_NAME_MAP['float8']
BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamptz']
BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamp']
BUILTIN_TYPE_NAME_MAP['time with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timetz']
BUILTIN_TYPE_NAME_MAP['time without timezone'] = \
BUILTIN_TYPE_NAME_MAP['time']
BUILTIN_TYPE_NAME_MAP['char'] = \
BUILTIN_TYPE_NAME_MAP['bpchar']
BUILTIN_TYPE_NAME_MAP['character'] = \
BUILTIN_TYPE_NAME_MAP['bpchar']
BUILTIN_TYPE_NAME_MAP['character varying'] = \
BUILTIN_TYPE_NAME_MAP['varchar']
BUILTIN_TYPE_NAME_MAP['bit varying'] = \
BUILTIN_TYPE_NAME_MAP['varbit']

View File

@@ -0,0 +1,39 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class PreparedStatementState:
cdef:
readonly str name
readonly str query
readonly bint closed
readonly bint prepared
readonly int refs
readonly type record_class
readonly bint ignore_custom_codec
list row_desc
list parameters_desc
ConnectionSettings settings
int16_t args_num
bint have_text_args
tuple args_codecs
int16_t cols_num
object cols_desc
bint have_text_cols
tuple rows_codecs
cdef _encode_bind_msg(self, args, int seqno = ?)
cpdef _init_codecs(self)
cdef _ensure_rows_decoder(self)
cdef _ensure_args_encoder(self)
cdef _set_row_desc(self, object desc)
cdef _set_args_desc(self, object desc)
cdef _decode_row(self, const char* cbuf, ssize_t buf_len)

View File

@@ -0,0 +1,395 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from asyncpg import exceptions
@cython.final
cdef class PreparedStatementState:
def __cinit__(
self,
str name,
str query,
BaseProtocol protocol,
type record_class,
bint ignore_custom_codec
):
self.name = name
self.query = query
self.settings = protocol.settings
self.row_desc = self.parameters_desc = None
self.args_codecs = self.rows_codecs = None
self.args_num = self.cols_num = 0
self.cols_desc = None
self.closed = False
self.prepared = True
self.refs = 0
self.record_class = record_class
self.ignore_custom_codec = ignore_custom_codec
def _get_parameters(self):
cdef Codec codec
result = []
for oid in self.parameters_desc:
codec = self.settings.get_data_codec(oid)
if codec is None:
raise exceptions.InternalClientError(
'missing codec information for OID {}'.format(oid))
result.append(apg_types.Type(
oid, codec.name, codec.kind, codec.schema))
return tuple(result)
def _get_attributes(self):
cdef Codec codec
if not self.row_desc:
return ()
result = []
for d in self.row_desc:
name = d[0]
oid = d[3]
codec = self.settings.get_data_codec(oid)
if codec is None:
raise exceptions.InternalClientError(
'missing codec information for OID {}'.format(oid))
name = name.decode(self.settings._encoding)
result.append(
apg_types.Attribute(name,
apg_types.Type(oid, codec.name, codec.kind, codec.schema)))
return tuple(result)
def _init_types(self):
cdef:
Codec codec
set missing = set()
if self.parameters_desc:
for p_oid in self.parameters_desc:
codec = self.settings.get_data_codec(<uint32_t>p_oid)
if codec is None or not codec.has_encoder():
missing.add(p_oid)
if self.row_desc:
for rdesc in self.row_desc:
codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
if codec is None or not codec.has_decoder():
missing.add(rdesc[3])
return missing
cpdef _init_codecs(self):
self._ensure_args_encoder()
self._ensure_rows_decoder()
def attach(self):
self.refs += 1
def detach(self):
self.refs -= 1
def mark_closed(self):
self.closed = True
def mark_unprepared(self):
if self.name:
raise exceptions.InternalClientError(
"named prepared statements cannot be marked unprepared")
self.prepared = False
cdef _encode_bind_msg(self, args, int seqno = -1):
cdef:
int idx
WriteBuffer writer
Codec codec
if not cpython.PySequence_Check(args):
if seqno >= 0:
raise exceptions.DataError(
f'invalid input in executemany() argument sequence '
f'element #{seqno}: expected a sequence, got '
f'{type(args).__name__}'
)
else:
# Non executemany() callers do not pass user input directly,
# so bad input is a bug.
raise exceptions.InternalClientError(
f'Bind: expected a sequence, got {type(args).__name__}')
if len(args) > 32767:
raise exceptions.InterfaceError(
'the number of query arguments cannot exceed 32767')
writer = WriteBuffer.new()
num_args_passed = len(args)
if self.args_num != num_args_passed:
hint = 'Check the query against the passed list of arguments.'
if self.args_num == 0:
# If the server was expecting zero arguments, it is likely
# that the user tried to parametrize a statement that does
# not support parameters.
hint += (r' Note that parameters are supported only in'
r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
r' statements, and will *not* work in statements '
r' like CREATE VIEW or DECLARE CURSOR.')
raise exceptions.InterfaceError(
'the server expects {x} argument{s} for this query, '
'{y} {w} passed'.format(
x=self.args_num, s='s' if self.args_num != 1 else '',
y=num_args_passed,
w='was' if num_args_passed == 1 else 'were'),
hint=hint)
if self.have_text_args:
writer.write_int16(self.args_num)
for idx in range(self.args_num):
codec = <Codec>(self.args_codecs[idx])
writer.write_int16(<int16_t>codec.format)
else:
# All arguments are in binary format
writer.write_int32(0x00010001)
writer.write_int16(self.args_num)
for idx in range(self.args_num):
arg = args[idx]
if arg is None:
writer.write_int32(-1)
else:
codec = <Codec>(self.args_codecs[idx])
try:
codec.encode(self.settings, writer, arg)
except (AssertionError, exceptions.InternalClientError):
# These are internal errors and should raise as-is.
raise
except exceptions.InterfaceError as e:
# This is already a descriptive error, but annotate
# with argument name for clarity.
pos = f'${idx + 1}'
if seqno >= 0:
pos = (
f'{pos} in element #{seqno} of'
f' executemany() sequence'
)
raise e.with_msg(
f'query argument {pos}: {e.args[0]}'
) from None
except Exception as e:
# Everything else is assumed to be an encoding error
# due to invalid input.
pos = f'${idx + 1}'
if seqno >= 0:
pos = (
f'{pos} in element #{seqno} of'
f' executemany() sequence'
)
value_repr = repr(arg)
if len(value_repr) > 40:
value_repr = value_repr[:40] + '...'
raise exceptions.DataError(
f'invalid input for query argument'
f' {pos}: {value_repr} ({e})'
) from e
if self.have_text_cols:
writer.write_int16(self.cols_num)
for idx in range(self.cols_num):
codec = <Codec>(self.rows_codecs[idx])
writer.write_int16(<int16_t>codec.format)
else:
# All columns are in binary format
writer.write_int32(0x00010001)
return writer
cdef _ensure_rows_decoder(self):
cdef:
list cols_names
object cols_mapping
tuple row
uint32_t oid
Codec codec
list codecs
if self.cols_desc is not None:
return
if self.cols_num == 0:
self.cols_desc = record.ApgRecordDesc_New({}, ())
return
cols_mapping = collections.OrderedDict()
cols_names = []
codecs = []
for i from 0 <= i < self.cols_num:
row = self.row_desc[i]
col_name = row[0].decode(self.settings._encoding)
cols_mapping[col_name] = i
cols_names.append(col_name)
oid = row[3]
codec = self.settings.get_data_codec(
oid, ignore_custom_codec=self.ignore_custom_codec)
if codec is None or not codec.has_decoder():
raise exceptions.InternalClientError(
'no decoder for OID {}'.format(oid))
if not codec.is_binary():
self.have_text_cols = True
codecs.append(codec)
self.cols_desc = record.ApgRecordDesc_New(
cols_mapping, tuple(cols_names))
self.rows_codecs = tuple(codecs)
cdef _ensure_args_encoder(self):
cdef:
uint32_t p_oid
Codec codec
list codecs = []
if self.args_num == 0 or self.args_codecs is not None:
return
for i from 0 <= i < self.args_num:
p_oid = self.parameters_desc[i]
codec = self.settings.get_data_codec(
p_oid, ignore_custom_codec=self.ignore_custom_codec)
if codec is None or not codec.has_encoder():
raise exceptions.InternalClientError(
'no encoder for OID {}'.format(p_oid))
if codec.type not in {}:
self.have_text_args = True
codecs.append(codec)
self.args_codecs = tuple(codecs)
cdef _set_row_desc(self, object desc):
self.row_desc = _decode_row_desc(desc)
self.cols_num = <int16_t>(len(self.row_desc))
cdef _set_args_desc(self, object desc):
self.parameters_desc = _decode_parameters_desc(desc)
self.args_num = <int16_t>(len(self.parameters_desc))
cdef _decode_row(self, const char* cbuf, ssize_t buf_len):
cdef:
Codec codec
int16_t fnum
int32_t flen
object dec_row
tuple rows_codecs = self.rows_codecs
ConnectionSettings settings = self.settings
int32_t i
FRBuffer rbuf
ssize_t bl
frb_init(&rbuf, cbuf, buf_len)
fnum = hton.unpack_int16(frb_read(&rbuf, 2))
if fnum != self.cols_num:
raise exceptions.ProtocolError(
'the number of columns in the result row ({}) is '
'different from what was described ({})'.format(
fnum, self.cols_num))
dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum)
for i in range(fnum):
flen = hton.unpack_int32(frb_read(&rbuf, 4))
if flen == -1:
val = None
else:
# Clamp buffer size to that of the reported field length
# to make sure that codecs can rely on read_all() working
# properly.
bl = frb_get_len(&rbuf)
if flen > bl:
frb_check(&rbuf, flen)
frb_set_len(&rbuf, flen)
codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i)
val = codec.decode(settings, &rbuf)
if frb_get_len(&rbuf) != 0:
raise BufferError(
'unexpected trailing {} bytes in buffer'.format(
frb_get_len(&rbuf)))
frb_set_len(&rbuf, bl - flen)
cpython.Py_INCREF(val)
record.ApgRecord_SET_ITEM(dec_row, i, val)
if frb_get_len(&rbuf) != 0:
raise BufferError('unexpected trailing {} bytes in buffer'.format(
frb_get_len(&rbuf)))
return dec_row
cdef _decode_parameters_desc(object desc):
cdef:
ReadBuffer reader
int16_t nparams
uint32_t p_oid
list result = []
reader = ReadBuffer.new_message_parser(desc)
nparams = reader.read_int16()
for i from 0 <= i < nparams:
p_oid = <uint32_t>reader.read_int32()
result.append(p_oid)
return result
cdef _decode_row_desc(object desc):
cdef:
ReadBuffer reader
int16_t nfields
bytes f_name
uint32_t f_table_oid
int16_t f_column_num
uint32_t f_dt_oid
int16_t f_dt_size
int32_t f_dt_mod
int16_t f_format
list result
reader = ReadBuffer.new_message_parser(desc)
nfields = reader.read_int16()
result = []
for i from 0 <= i < nfields:
f_name = reader.read_null_str()
f_table_oid = <uint32_t>reader.read_int32()
f_column_num = reader.read_int16()
f_dt_oid = <uint32_t>reader.read_int32()
f_dt_size = reader.read_int16()
f_dt_mod = reader.read_int32()
f_format = reader.read_int16()
result.append(
(f_name, f_table_oid, f_column_num, f_dt_oid,
f_dt_size, f_dt_mod, f_format))
return result

View File

@@ -0,0 +1,78 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc.stdint cimport int16_t, int32_t, uint16_t, \
uint32_t, int64_t, uint64_t
from asyncpg.pgproto.debug cimport PG_DEBUG
from asyncpg.pgproto.pgproto cimport (
WriteBuffer,
ReadBuffer,
FRBuffer,
)
from asyncpg.pgproto cimport pgproto
include "consts.pxi"
include "pgtypes.pxi"
include "codecs/base.pxd"
include "settings.pxd"
include "coreproto.pxd"
include "prepared_stmt.pxd"
cdef class BaseProtocol(CoreProtocol):
cdef:
object loop
object address
ConnectionSettings settings
object cancel_sent_waiter
object cancel_waiter
object waiter
bint return_extra
object create_future
object timeout_handle
object conref
type record_class
bint is_reading
str last_query
bint writing_paused
bint closing
readonly uint64_t queries_count
bint _is_ssl
PreparedStatementState statement
cdef get_connection(self)
cdef _get_timeout_impl(self, timeout)
cdef _check_state(self)
cdef _new_waiter(self, timeout)
cdef _coreproto_error(self)
cdef _on_result__connect(self, object waiter)
cdef _on_result__prepare(self, object waiter)
cdef _on_result__bind_and_exec(self, object waiter)
cdef _on_result__close_stmt_or_portal(self, object waiter)
cdef _on_result__simple_query(self, object waiter)
cdef _on_result__bind(self, object waiter)
cdef _on_result__copy_out(self, object waiter)
cdef _on_result__copy_in(self, object waiter)
cdef _handle_waiter_on_connection_lost(self, cause)
cdef _dispatch_result(self)
cdef inline resume_reading(self)
cdef inline pause_reading(self)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cimport cpython
cdef extern from "record/recordobj.h":
cpython.PyTypeObject *ApgRecord_InitTypes() except NULL
int ApgRecord_CheckExact(object)
object ApgRecord_New(type, object, int)
void ApgRecord_SET_ITEM(object, int, object)
object ApgRecordDesc_New(object, object)

View File

@@ -0,0 +1,31 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class SCRAMAuthentication:
cdef:
readonly bytes authentication_method
readonly bytes authorization_message
readonly bytes client_channel_binding
readonly bytes client_first_message_bare
readonly bytes client_nonce
readonly bytes client_proof
readonly bytes password_salt
readonly int password_iterations
readonly bytes server_first_message
# server_key is an instance of hmac.HAMC
readonly object server_key
readonly bytes server_nonce
cdef create_client_first_message(self, str username)
cdef create_client_final_message(self, str password)
cdef parse_server_first_message(self, bytes server_response)
cdef verify_server_final_message(self, bytes server_final_message)
cdef _bytes_xor(self, bytes a, bytes b)
cdef _generate_client_nonce(self, int num_bytes)
cdef _generate_client_proof(self, str password)
cdef _generate_salted_password(self, str password, bytes salt, int iterations)
cdef _normalize_password(self, str original_password)

View File

@@ -0,0 +1,341 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import base64
import hashlib
import hmac
import re
import secrets
import stringprep
import unicodedata
@cython.final
cdef class SCRAMAuthentication:
"""Contains the protocol for generating and a SCRAM hashed password.
Since PostgreSQL 10, the option to hash passwords using the SCRAM-SHA-256
method was added. This module follows the defined protocol, which can be
referenced from here:
https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256
libpq references the following RFCs that it uses for implementation:
* RFC 5802
* RFC 5803
* RFC 7677
The protocol works as such:
- A client connets to the server. The server requests the client to begin
SASL authentication using SCRAM and presents a client with the methods it
supports. At present, those are SCRAM-SHA-256, and, on servers that are
built with OpenSSL and
are PG11+, SCRAM-SHA-256-PLUS (which supports channel binding, more on that
below)
- The client sends a "first message" to the server, where it chooses which
method to authenticate with, and sends, along with the method, an indication
of channel binding (we disable for now), a nonce, and the username.
(Technically, PostgreSQL ignores the username as it already has it from the
initical connection, but we add it for completeness)
- The server responds with a "first message" in which it extends the nonce,
as well as a password salt and the number of iterations to hash the password
with. The client validates that the new nonce contains the first part of the
client's original nonce
- The client generates a salted password, but does not sent this up to the
server. Instead, the client follows the SCRAM algorithm (RFC5802) to
generate a proof. This proof is sent aspart of a client "final message" to
the server for it to validate.
- The server validates the proof. If it is valid, the server sends a
verification code for the client to verify that the server came to the same
proof the client did. PostgreSQL immediately sends an AuthenticationOK
response right after a valid negotiation. If the password the client
provided was invalid, then authentication fails.
(The beauty of this is that the salted password is never transmitted over
the wire!)
PostgreSQL 11 added support for the channel binding (i.e.
SCRAM-SHA-256-PLUS) but to do some ongoing discussion, there is a conscious
decision by several driver authors to not support it as of yet. As such, the
channel binding parameter is hard-coded to "n" for now, but can be updated
to support other channel binding methos in the future
"""
AUTHENTICATION_METHODS = [b"SCRAM-SHA-256"]
DEFAULT_CLIENT_NONCE_BYTES = 24
DIGEST = hashlib.sha256
REQUIREMENTS_CLIENT_FINAL_MESSAGE = ['client_channel_binding',
'server_nonce']
REQUIREMENTS_CLIENT_PROOF = ['password_iterations', 'password_salt',
'server_first_message', 'server_nonce']
SASLPREP_PROHIBITED = (
stringprep.in_table_a1, # PostgreSQL treats this as prohibited
stringprep.in_table_c12,
stringprep.in_table_c21_c22,
stringprep.in_table_c3,
stringprep.in_table_c4,
stringprep.in_table_c5,
stringprep.in_table_c6,
stringprep.in_table_c7,
stringprep.in_table_c8,
stringprep.in_table_c9,
)
def __cinit__(self, bytes authentication_method):
self.authentication_method = authentication_method
self.authorization_message = None
# channel binding is turned off for the time being
self.client_channel_binding = b"n,,"
self.client_first_message_bare = None
self.client_nonce = None
self.client_proof = None
self.password_salt = None
# self.password_iterations = None
self.server_first_message = None
self.server_key = None
self.server_nonce = None
cdef create_client_first_message(self, str username):
"""Create the initial client message for SCRAM authentication"""
cdef:
bytes msg
bytes client_first_message
self.client_nonce = \
self._generate_client_nonce(self.DEFAULT_CLIENT_NONCE_BYTES)
# set the client first message bare here, as it's used in a later step
self.client_first_message_bare = b"n=" + username.encode("utf-8") + \
b",r=" + self.client_nonce
# put together the full message here
msg = bytes()
msg += self.authentication_method + b"\0"
client_first_message = self.client_channel_binding + \
self.client_first_message_bare
msg += (len(client_first_message)).to_bytes(4, byteorder='big') + \
client_first_message
return msg
cdef create_client_final_message(self, str password):
"""Create the final client message as part of SCRAM authentication"""
cdef:
bytes msg
if any([getattr(self, val) is None for val in
self.REQUIREMENTS_CLIENT_FINAL_MESSAGE]):
raise Exception(
"you need values from server to generate a client proof")
# normalize the password using the SASLprep algorithm in RFC 4013
password = self._normalize_password(password)
# generate the client proof
self.client_proof = self._generate_client_proof(password=password)
msg = bytes()
msg += b"c=" + base64.b64encode(self.client_channel_binding) + \
b",r=" + self.server_nonce + \
b",p=" + base64.b64encode(self.client_proof)
return msg
cdef parse_server_first_message(self, bytes server_response):
"""Parse the response from the first message from the server"""
self.server_first_message = server_response
try:
self.server_nonce = re.search(b'r=([^,]+),',
self.server_first_message).group(1)
except IndexError:
raise Exception("could not get nonce")
if not self.server_nonce.startswith(self.client_nonce):
raise Exception("invalid nonce")
try:
self.password_salt = re.search(b',s=([^,]+),',
self.server_first_message).group(1)
except IndexError:
raise Exception("could not get salt")
try:
self.password_iterations = int(re.search(b',i=(\d+),?',
self.server_first_message).group(1))
except (IndexError, TypeError, ValueError):
raise Exception("could not get iterations")
cdef verify_server_final_message(self, bytes server_final_message):
"""Verify the final message from the server"""
cdef:
bytes server_signature
try:
server_signature = re.search(b'v=([^,]+)',
server_final_message).group(1)
except IndexError:
raise Exception("could not get server signature")
verify_server_signature = hmac.new(self.server_key.digest(),
self.authorization_message, self.DIGEST)
# validate the server signature against the verifier
return server_signature == base64.b64encode(
verify_server_signature.digest())
cdef _bytes_xor(self, bytes a, bytes b):
"""XOR two bytestrings together"""
return bytes(a_i ^ b_i for a_i, b_i in zip(a, b))
cdef _generate_client_nonce(self, int num_bytes):
cdef:
bytes token
token = secrets.token_bytes(num_bytes)
return base64.b64encode(token)
cdef _generate_client_proof(self, str password):
"""need to ensure a server response exists, i.e. """
cdef:
bytes salted_password
if any([getattr(self, val) is None for val in
self.REQUIREMENTS_CLIENT_PROOF]):
raise Exception(
"you need values from server to generate a client proof")
# generate a salt password
salted_password = self._generate_salted_password(password,
self.password_salt, self.password_iterations)
# client key is derived from the salted password
client_key = hmac.new(salted_password, b"Client Key", self.DIGEST)
# this allows us to compute the stored key that is residing on the server
stored_key = self.DIGEST(client_key.digest())
# as well as compute the server key
self.server_key = hmac.new(salted_password, b"Server Key", self.DIGEST)
# build the authorization message that will be used in the
# client signature
# the "c=" portion is for the channel binding, but this is not
# presently implemented
self.authorization_message = self.client_first_message_bare + b"," + \
self.server_first_message + b",c=" + \
base64.b64encode(self.client_channel_binding) + \
b",r=" + self.server_nonce
# sign!
client_signature = hmac.new(stored_key.digest(),
self.authorization_message, self.DIGEST)
# and the proof
return self._bytes_xor(client_key.digest(), client_signature.digest())
cdef _generate_salted_password(self, str password, bytes salt, int iterations):
"""This follows the "Hi" algorithm specified in RFC5802"""
cdef:
bytes p
bytes s
bytes u
# convert the password to a binary string - UTF8 is safe for SASL
# (though there are SASLPrep rules)
p = password.encode("utf8")
# the salt needs to be base64 decoded -- full binary must be used
s = base64.b64decode(salt)
# the initial signature is the salt with a terminator of a 32-bit string
# ending in 1
ui = hmac.new(p, s + b'\x00\x00\x00\x01', self.DIGEST)
# grab the initial digest
u = ui.digest()
# for X number of iterations, recompute the HMAC signature against the
# password and the latest iteration of the hash, and XOR it with the
# previous version
for x in range(iterations - 1):
ui = hmac.new(p, ui.digest(), hashlib.sha256)
# this is a fancy way of XORing two byte strings together
u = self._bytes_xor(u, ui.digest())
return u
cdef _normalize_password(self, str original_password):
"""Normalize the password using the SASLprep from RFC4013"""
cdef:
str normalized_password
# Note: Per the PostgreSQL documentation, PostgreSWL does not require
# UTF-8 to be used for the password, but will perform SASLprep on the
# password regardless.
# If the password is not valid UTF-8, PostgreSQL will then **not** use
# SASLprep processing.
# If the password fails SASLprep, the password should still be sent
# See: https://www.postgresql.org/docs/current/sasl-authentication.html
# and
# https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/common/saslprep.c
# using the `pg_saslprep` function
normalized_password = original_password
# if the original password is an ASCII string or fails to encode as a
# UTF-8 string, then no further action is needed
try:
original_password.encode("ascii")
except UnicodeEncodeError:
pass
else:
return original_password
# Step 1 of SASLPrep: Map. Per the algorithm, we map non-ascii space
# characters to ASCII spaces (\x20 or \u0020, but we will use ' ') and
# commonly mapped to nothing characters are removed
# Table C.1.2 -- non-ASCII spaces
# Table B.1 -- "Commonly mapped to nothing"
normalized_password = u"".join(
' ' if stringprep.in_table_c12(c) else c
for c in tuple(normalized_password) if not stringprep.in_table_b1(c)
)
# If at this point the password is empty, PostgreSQL uses the original
# password
if not normalized_password:
return original_password
# Step 2 of SASLPrep: Normalize. Normalize the password using the
# Unicode normalization algorithm to NFKC form
normalized_password = unicodedata.normalize('NFKC', normalized_password)
# If the password is not empty, PostgreSQL uses the original password
if not normalized_password:
return original_password
normalized_password_tuple = tuple(normalized_password)
# Step 3 of SASLPrep: Prohobited characters. If PostgreSQL detects any
# of the prohibited characters in SASLPrep, it will use the original
# password
# We also include "unassigned code points" in the prohibited character
# category as PostgreSQL does the same
for c in normalized_password_tuple:
if any(
in_prohibited_table(c)
for in_prohibited_table in self.SASLPREP_PROHIBITED
):
return original_password
# Step 4 of SASLPrep: Bi-directional characters. PostgreSQL follows the
# rules for bi-directional characters laid on in RFC3454 Sec. 6 which
# are:
# 1. Characters in RFC 3454 Sec 5.8 are prohibited (C.8)
# 2. If a string contains a RandALCat character, it cannot containy any
# LCat character
# 3. If the string contains any RandALCat character, an RandALCat
# character must be the first and last character of the string
# RandALCat characters are found in table D.1, whereas LCat are in D.2
if any(stringprep.in_table_d1(c) for c in normalized_password_tuple):
# if the first character or the last character are not in D.1,
# return the original password
if not (stringprep.in_table_d1(normalized_password_tuple[0]) and
stringprep.in_table_d1(normalized_password_tuple[-1])):
return original_password
# if any characters are in D.2, use the original password
if any(
stringprep.in_table_d2(c) for c in normalized_password_tuple
):
return original_password
# return the normalized password
return normalized_password

View File

@@ -0,0 +1,30 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class ConnectionSettings(pgproto.CodecContext):
cdef:
str _encoding
object _codec
dict _settings
bint _is_utf8
DataCodecConfig _data_codecs
cdef add_setting(self, str name, str val)
cdef is_encoding_utf8(self)
cpdef get_text_codec(self)
cpdef inline register_data_types(self, types)
cpdef inline add_python_codec(
self, typeoid, typename, typeschema, typeinfos, typekind, encoder,
decoder, format)
cpdef inline remove_python_codec(
self, typeoid, typename, typeschema)
cpdef inline clear_type_cache(self)
cpdef inline set_builtin_type_codec(
self, typeoid, typename, typeschema, typekind, alias_to, format)
cpdef inline Codec get_data_codec(
self, uint32_t oid, ServerDataFormat format=*,
bint ignore_custom_codec=*)

View File

@@ -0,0 +1,106 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from asyncpg import exceptions
@cython.final
cdef class ConnectionSettings(pgproto.CodecContext):
def __cinit__(self, conn_key):
self._encoding = 'utf-8'
self._is_utf8 = True
self._settings = {}
self._codec = codecs.lookup('utf-8')
self._data_codecs = DataCodecConfig(conn_key)
cdef add_setting(self, str name, str val):
self._settings[name] = val
if name == 'client_encoding':
py_enc = get_python_encoding(val)
self._codec = codecs.lookup(py_enc)
self._encoding = self._codec.name
self._is_utf8 = self._encoding == 'utf-8'
cdef is_encoding_utf8(self):
return self._is_utf8
cpdef get_text_codec(self):
return self._codec
cpdef inline register_data_types(self, types):
self._data_codecs.add_types(types)
cpdef inline add_python_codec(self, typeoid, typename, typeschema,
typeinfos, typekind, encoder, decoder,
format):
cdef:
ServerDataFormat _format
ClientExchangeFormat xformat
if format == 'binary':
_format = PG_FORMAT_BINARY
xformat = PG_XFORMAT_OBJECT
elif format == 'text':
_format = PG_FORMAT_TEXT
xformat = PG_XFORMAT_OBJECT
elif format == 'tuple':
_format = PG_FORMAT_ANY
xformat = PG_XFORMAT_TUPLE
else:
raise exceptions.InterfaceError(
'invalid `format` argument, expected {}, got {!r}'.format(
"'text', 'binary' or 'tuple'", format
))
self._data_codecs.add_python_codec(typeoid, typename, typeschema,
typekind, typeinfos,
encoder, decoder,
_format, xformat)
cpdef inline remove_python_codec(self, typeoid, typename, typeschema):
self._data_codecs.remove_python_codec(typeoid, typename, typeschema)
cpdef inline clear_type_cache(self):
self._data_codecs.clear_type_cache()
cpdef inline set_builtin_type_codec(self, typeoid, typename, typeschema,
typekind, alias_to, format):
cdef:
ServerDataFormat _format
if format is None:
_format = PG_FORMAT_ANY
elif format == 'binary':
_format = PG_FORMAT_BINARY
elif format == 'text':
_format = PG_FORMAT_TEXT
else:
raise exceptions.InterfaceError(
'invalid `format` argument, expected {}, got {!r}'.format(
"'text' or 'binary'", format
))
self._data_codecs.set_builtin_type_codec(typeoid, typename, typeschema,
typekind, alias_to, _format)
cpdef inline Codec get_data_codec(self, uint32_t oid,
ServerDataFormat format=PG_FORMAT_ANY,
bint ignore_custom_codec=False):
return self._data_codecs.get_codec(oid, format, ignore_custom_codec)
def __getattr__(self, name):
if not name.startswith('_'):
try:
return self._settings[name]
except KeyError:
raise AttributeError(name) from None
return object.__getattribute__(self, name)
def __repr__(self):
return '<ConnectionSettings {!r}>'.format(self._settings)