Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -0,0 +1,395 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from asyncpg import exceptions
|
||||
|
||||
|
||||
@cython.final
|
||||
cdef class PreparedStatementState:
|
||||
|
||||
def __cinit__(
|
||||
self,
|
||||
str name,
|
||||
str query,
|
||||
BaseProtocol protocol,
|
||||
type record_class,
|
||||
bint ignore_custom_codec
|
||||
):
|
||||
self.name = name
|
||||
self.query = query
|
||||
self.settings = protocol.settings
|
||||
self.row_desc = self.parameters_desc = None
|
||||
self.args_codecs = self.rows_codecs = None
|
||||
self.args_num = self.cols_num = 0
|
||||
self.cols_desc = None
|
||||
self.closed = False
|
||||
self.prepared = True
|
||||
self.refs = 0
|
||||
self.record_class = record_class
|
||||
self.ignore_custom_codec = ignore_custom_codec
|
||||
|
||||
def _get_parameters(self):
|
||||
cdef Codec codec
|
||||
|
||||
result = []
|
||||
for oid in self.parameters_desc:
|
||||
codec = self.settings.get_data_codec(oid)
|
||||
if codec is None:
|
||||
raise exceptions.InternalClientError(
|
||||
'missing codec information for OID {}'.format(oid))
|
||||
result.append(apg_types.Type(
|
||||
oid, codec.name, codec.kind, codec.schema))
|
||||
|
||||
return tuple(result)
|
||||
|
||||
def _get_attributes(self):
|
||||
cdef Codec codec
|
||||
|
||||
if not self.row_desc:
|
||||
return ()
|
||||
|
||||
result = []
|
||||
for d in self.row_desc:
|
||||
name = d[0]
|
||||
oid = d[3]
|
||||
|
||||
codec = self.settings.get_data_codec(oid)
|
||||
if codec is None:
|
||||
raise exceptions.InternalClientError(
|
||||
'missing codec information for OID {}'.format(oid))
|
||||
|
||||
name = name.decode(self.settings._encoding)
|
||||
|
||||
result.append(
|
||||
apg_types.Attribute(name,
|
||||
apg_types.Type(oid, codec.name, codec.kind, codec.schema)))
|
||||
|
||||
return tuple(result)
|
||||
|
||||
def _init_types(self):
|
||||
cdef:
|
||||
Codec codec
|
||||
set missing = set()
|
||||
|
||||
if self.parameters_desc:
|
||||
for p_oid in self.parameters_desc:
|
||||
codec = self.settings.get_data_codec(<uint32_t>p_oid)
|
||||
if codec is None or not codec.has_encoder():
|
||||
missing.add(p_oid)
|
||||
|
||||
if self.row_desc:
|
||||
for rdesc in self.row_desc:
|
||||
codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
|
||||
if codec is None or not codec.has_decoder():
|
||||
missing.add(rdesc[3])
|
||||
|
||||
return missing
|
||||
|
||||
cpdef _init_codecs(self):
|
||||
self._ensure_args_encoder()
|
||||
self._ensure_rows_decoder()
|
||||
|
||||
def attach(self):
|
||||
self.refs += 1
|
||||
|
||||
def detach(self):
|
||||
self.refs -= 1
|
||||
|
||||
def mark_closed(self):
|
||||
self.closed = True
|
||||
|
||||
def mark_unprepared(self):
|
||||
if self.name:
|
||||
raise exceptions.InternalClientError(
|
||||
"named prepared statements cannot be marked unprepared")
|
||||
self.prepared = False
|
||||
|
||||
cdef _encode_bind_msg(self, args, int seqno = -1):
|
||||
cdef:
|
||||
int idx
|
||||
WriteBuffer writer
|
||||
Codec codec
|
||||
|
||||
if not cpython.PySequence_Check(args):
|
||||
if seqno >= 0:
|
||||
raise exceptions.DataError(
|
||||
f'invalid input in executemany() argument sequence '
|
||||
f'element #{seqno}: expected a sequence, got '
|
||||
f'{type(args).__name__}'
|
||||
)
|
||||
else:
|
||||
# Non executemany() callers do not pass user input directly,
|
||||
# so bad input is a bug.
|
||||
raise exceptions.InternalClientError(
|
||||
f'Bind: expected a sequence, got {type(args).__name__}')
|
||||
|
||||
if len(args) > 32767:
|
||||
raise exceptions.InterfaceError(
|
||||
'the number of query arguments cannot exceed 32767')
|
||||
|
||||
writer = WriteBuffer.new()
|
||||
|
||||
num_args_passed = len(args)
|
||||
if self.args_num != num_args_passed:
|
||||
hint = 'Check the query against the passed list of arguments.'
|
||||
|
||||
if self.args_num == 0:
|
||||
# If the server was expecting zero arguments, it is likely
|
||||
# that the user tried to parametrize a statement that does
|
||||
# not support parameters.
|
||||
hint += (r' Note that parameters are supported only in'
|
||||
r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
|
||||
r' statements, and will *not* work in statements '
|
||||
r' like CREATE VIEW or DECLARE CURSOR.')
|
||||
|
||||
raise exceptions.InterfaceError(
|
||||
'the server expects {x} argument{s} for this query, '
|
||||
'{y} {w} passed'.format(
|
||||
x=self.args_num, s='s' if self.args_num != 1 else '',
|
||||
y=num_args_passed,
|
||||
w='was' if num_args_passed == 1 else 'were'),
|
||||
hint=hint)
|
||||
|
||||
if self.have_text_args:
|
||||
writer.write_int16(self.args_num)
|
||||
for idx in range(self.args_num):
|
||||
codec = <Codec>(self.args_codecs[idx])
|
||||
writer.write_int16(<int16_t>codec.format)
|
||||
else:
|
||||
# All arguments are in binary format
|
||||
writer.write_int32(0x00010001)
|
||||
|
||||
writer.write_int16(self.args_num)
|
||||
|
||||
for idx in range(self.args_num):
|
||||
arg = args[idx]
|
||||
if arg is None:
|
||||
writer.write_int32(-1)
|
||||
else:
|
||||
codec = <Codec>(self.args_codecs[idx])
|
||||
try:
|
||||
codec.encode(self.settings, writer, arg)
|
||||
except (AssertionError, exceptions.InternalClientError):
|
||||
# These are internal errors and should raise as-is.
|
||||
raise
|
||||
except exceptions.InterfaceError as e:
|
||||
# This is already a descriptive error, but annotate
|
||||
# with argument name for clarity.
|
||||
pos = f'${idx + 1}'
|
||||
if seqno >= 0:
|
||||
pos = (
|
||||
f'{pos} in element #{seqno} of'
|
||||
f' executemany() sequence'
|
||||
)
|
||||
raise e.with_msg(
|
||||
f'query argument {pos}: {e.args[0]}'
|
||||
) from None
|
||||
except Exception as e:
|
||||
# Everything else is assumed to be an encoding error
|
||||
# due to invalid input.
|
||||
pos = f'${idx + 1}'
|
||||
if seqno >= 0:
|
||||
pos = (
|
||||
f'{pos} in element #{seqno} of'
|
||||
f' executemany() sequence'
|
||||
)
|
||||
value_repr = repr(arg)
|
||||
if len(value_repr) > 40:
|
||||
value_repr = value_repr[:40] + '...'
|
||||
|
||||
raise exceptions.DataError(
|
||||
f'invalid input for query argument'
|
||||
f' {pos}: {value_repr} ({e})'
|
||||
) from e
|
||||
|
||||
if self.have_text_cols:
|
||||
writer.write_int16(self.cols_num)
|
||||
for idx in range(self.cols_num):
|
||||
codec = <Codec>(self.rows_codecs[idx])
|
||||
writer.write_int16(<int16_t>codec.format)
|
||||
else:
|
||||
# All columns are in binary format
|
||||
writer.write_int32(0x00010001)
|
||||
|
||||
return writer
|
||||
|
||||
cdef _ensure_rows_decoder(self):
|
||||
cdef:
|
||||
list cols_names
|
||||
object cols_mapping
|
||||
tuple row
|
||||
uint32_t oid
|
||||
Codec codec
|
||||
list codecs
|
||||
|
||||
if self.cols_desc is not None:
|
||||
return
|
||||
|
||||
if self.cols_num == 0:
|
||||
self.cols_desc = record.ApgRecordDesc_New({}, ())
|
||||
return
|
||||
|
||||
cols_mapping = collections.OrderedDict()
|
||||
cols_names = []
|
||||
codecs = []
|
||||
for i from 0 <= i < self.cols_num:
|
||||
row = self.row_desc[i]
|
||||
col_name = row[0].decode(self.settings._encoding)
|
||||
cols_mapping[col_name] = i
|
||||
cols_names.append(col_name)
|
||||
oid = row[3]
|
||||
codec = self.settings.get_data_codec(
|
||||
oid, ignore_custom_codec=self.ignore_custom_codec)
|
||||
if codec is None or not codec.has_decoder():
|
||||
raise exceptions.InternalClientError(
|
||||
'no decoder for OID {}'.format(oid))
|
||||
if not codec.is_binary():
|
||||
self.have_text_cols = True
|
||||
|
||||
codecs.append(codec)
|
||||
|
||||
self.cols_desc = record.ApgRecordDesc_New(
|
||||
cols_mapping, tuple(cols_names))
|
||||
|
||||
self.rows_codecs = tuple(codecs)
|
||||
|
||||
cdef _ensure_args_encoder(self):
|
||||
cdef:
|
||||
uint32_t p_oid
|
||||
Codec codec
|
||||
list codecs = []
|
||||
|
||||
if self.args_num == 0 or self.args_codecs is not None:
|
||||
return
|
||||
|
||||
for i from 0 <= i < self.args_num:
|
||||
p_oid = self.parameters_desc[i]
|
||||
codec = self.settings.get_data_codec(
|
||||
p_oid, ignore_custom_codec=self.ignore_custom_codec)
|
||||
if codec is None or not codec.has_encoder():
|
||||
raise exceptions.InternalClientError(
|
||||
'no encoder for OID {}'.format(p_oid))
|
||||
if codec.type not in {}:
|
||||
self.have_text_args = True
|
||||
|
||||
codecs.append(codec)
|
||||
|
||||
self.args_codecs = tuple(codecs)
|
||||
|
||||
cdef _set_row_desc(self, object desc):
|
||||
self.row_desc = _decode_row_desc(desc)
|
||||
self.cols_num = <int16_t>(len(self.row_desc))
|
||||
|
||||
cdef _set_args_desc(self, object desc):
|
||||
self.parameters_desc = _decode_parameters_desc(desc)
|
||||
self.args_num = <int16_t>(len(self.parameters_desc))
|
||||
|
||||
cdef _decode_row(self, const char* cbuf, ssize_t buf_len):
|
||||
cdef:
|
||||
Codec codec
|
||||
int16_t fnum
|
||||
int32_t flen
|
||||
object dec_row
|
||||
tuple rows_codecs = self.rows_codecs
|
||||
ConnectionSettings settings = self.settings
|
||||
int32_t i
|
||||
FRBuffer rbuf
|
||||
ssize_t bl
|
||||
|
||||
frb_init(&rbuf, cbuf, buf_len)
|
||||
|
||||
fnum = hton.unpack_int16(frb_read(&rbuf, 2))
|
||||
|
||||
if fnum != self.cols_num:
|
||||
raise exceptions.ProtocolError(
|
||||
'the number of columns in the result row ({}) is '
|
||||
'different from what was described ({})'.format(
|
||||
fnum, self.cols_num))
|
||||
|
||||
dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum)
|
||||
for i in range(fnum):
|
||||
flen = hton.unpack_int32(frb_read(&rbuf, 4))
|
||||
|
||||
if flen == -1:
|
||||
val = None
|
||||
else:
|
||||
# Clamp buffer size to that of the reported field length
|
||||
# to make sure that codecs can rely on read_all() working
|
||||
# properly.
|
||||
bl = frb_get_len(&rbuf)
|
||||
if flen > bl:
|
||||
frb_check(&rbuf, flen)
|
||||
frb_set_len(&rbuf, flen)
|
||||
codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i)
|
||||
val = codec.decode(settings, &rbuf)
|
||||
if frb_get_len(&rbuf) != 0:
|
||||
raise BufferError(
|
||||
'unexpected trailing {} bytes in buffer'.format(
|
||||
frb_get_len(&rbuf)))
|
||||
frb_set_len(&rbuf, bl - flen)
|
||||
|
||||
cpython.Py_INCREF(val)
|
||||
record.ApgRecord_SET_ITEM(dec_row, i, val)
|
||||
|
||||
if frb_get_len(&rbuf) != 0:
|
||||
raise BufferError('unexpected trailing {} bytes in buffer'.format(
|
||||
frb_get_len(&rbuf)))
|
||||
|
||||
return dec_row
|
||||
|
||||
|
||||
cdef _decode_parameters_desc(object desc):
|
||||
cdef:
|
||||
ReadBuffer reader
|
||||
int16_t nparams
|
||||
uint32_t p_oid
|
||||
list result = []
|
||||
|
||||
reader = ReadBuffer.new_message_parser(desc)
|
||||
nparams = reader.read_int16()
|
||||
|
||||
for i from 0 <= i < nparams:
|
||||
p_oid = <uint32_t>reader.read_int32()
|
||||
result.append(p_oid)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
cdef _decode_row_desc(object desc):
|
||||
cdef:
|
||||
ReadBuffer reader
|
||||
|
||||
int16_t nfields
|
||||
|
||||
bytes f_name
|
||||
uint32_t f_table_oid
|
||||
int16_t f_column_num
|
||||
uint32_t f_dt_oid
|
||||
int16_t f_dt_size
|
||||
int32_t f_dt_mod
|
||||
int16_t f_format
|
||||
|
||||
list result
|
||||
|
||||
reader = ReadBuffer.new_message_parser(desc)
|
||||
nfields = reader.read_int16()
|
||||
result = []
|
||||
|
||||
for i from 0 <= i < nfields:
|
||||
f_name = reader.read_null_str()
|
||||
f_table_oid = <uint32_t>reader.read_int32()
|
||||
f_column_num = reader.read_int16()
|
||||
f_dt_oid = <uint32_t>reader.read_int32()
|
||||
f_dt_size = reader.read_int16()
|
||||
f_dt_mod = reader.read_int32()
|
||||
f_format = reader.read_int16()
|
||||
|
||||
result.append(
|
||||
(f_name, f_table_oid, f_column_num, f_dt_oid,
|
||||
f_dt_size, f_dt_mod, f_format))
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user