Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-09-25 15:51:48 +09:00
parent dd7349bb4c
commit ddce9f5125
5586 changed files with 1470941 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from .connection import connect, Connection # NOQA
from .exceptions import * # NOQA
from .pool import create_pool, Pool # NOQA
from .protocol import Record # NOQA
from .types import * # NOQA
from ._version import __version__ # NOQA
__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection')
__all__ += exceptions.__all__ # NOQA

View File

@@ -0,0 +1,87 @@
# Backports from Python/Lib/asyncio for older Pythons
#
# Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved
#
# SPDX-License-Identifier: PSF-2.0
import asyncio
import functools
import sys
if sys.version_info < (3, 11):
from async_timeout import timeout as timeout_ctx
else:
from asyncio import timeout as timeout_ctx
async def wait_for(fut, timeout):
"""Wait for the single Future or coroutine to complete, with timeout.
Coroutine will be wrapped in Task.
Returns result of the Future or coroutine. When a timeout occurs,
it cancels the task and raises TimeoutError. To avoid the task
cancellation, wrap it in shield().
If the wait is cancelled, the task is also cancelled.
If the task supresses the cancellation and returns a value instead,
that value is returned.
This function is a coroutine.
"""
# The special case for timeout <= 0 is for the following case:
#
# async def test_waitfor():
# func_started = False
#
# async def func():
# nonlocal func_started
# func_started = True
#
# try:
# await asyncio.wait_for(func(), 0)
# except asyncio.TimeoutError:
# assert not func_started
# else:
# assert False
#
# asyncio.run(test_waitfor())
if timeout is not None and timeout <= 0:
fut = asyncio.ensure_future(fut)
if fut.done():
return fut.result()
await _cancel_and_wait(fut)
try:
return fut.result()
except asyncio.CancelledError as exc:
raise TimeoutError from exc
async with timeout_ctx(timeout):
return await fut
async def _cancel_and_wait(fut):
"""Cancel the *fut* future or task and wait until it completes."""
loop = asyncio.get_running_loop()
waiter = loop.create_future()
cb = functools.partial(_release_waiter, waiter)
fut.add_done_callback(cb)
try:
fut.cancel()
# We cannot wait on *fut* directly to make
# sure _cancel_and_wait itself is reliably cancellable.
await waiter
finally:
fut.remove_done_callback(cb)
def _release_waiter(waiter, *args):
if not waiter.done():
waiter.set_result(None)

View File

@@ -0,0 +1,527 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import atexit
import contextlib
import functools
import inspect
import logging
import os
import re
import textwrap
import time
import traceback
import unittest
import asyncpg
from asyncpg import cluster as pg_cluster
from asyncpg import connection as pg_connection
from asyncpg import pool as pg_pool
from . import fuzzer
@contextlib.contextmanager
def silence_asyncio_long_exec_warning():
def flt(log_record):
msg = log_record.getMessage()
return not msg.startswith('Executing ')
logger = logging.getLogger('asyncio')
logger.addFilter(flt)
try:
yield
finally:
logger.removeFilter(flt)
def with_timeout(timeout):
def wrap(func):
func.__timeout__ = timeout
return func
return wrap
class TestCaseMeta(type(unittest.TestCase)):
TEST_TIMEOUT = None
@staticmethod
def _iter_methods(bases, ns):
for base in bases:
for methname in dir(base):
if not methname.startswith('test_'):
continue
meth = getattr(base, methname)
if not inspect.iscoroutinefunction(meth):
continue
yield methname, meth
for methname, meth in ns.items():
if not methname.startswith('test_'):
continue
if not inspect.iscoroutinefunction(meth):
continue
yield methname, meth
def __new__(mcls, name, bases, ns):
for methname, meth in mcls._iter_methods(bases, ns):
@functools.wraps(meth)
def wrapper(self, *args, __meth__=meth, **kwargs):
coro = __meth__(self, *args, **kwargs)
timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT)
if timeout:
coro = asyncio.wait_for(coro, timeout)
try:
self.loop.run_until_complete(coro)
except asyncio.TimeoutError:
raise self.failureException(
'test timed out after {} seconds'.format(
timeout)) from None
else:
self.loop.run_until_complete(coro)
ns[methname] = wrapper
return super().__new__(mcls, name, bases, ns)
class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
@classmethod
def setUpClass(cls):
if os.environ.get('USE_UVLOOP'):
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
cls.loop = loop
@classmethod
def tearDownClass(cls):
cls.loop.close()
asyncio.set_event_loop(None)
def setUp(self):
self.loop.set_exception_handler(self.loop_exception_handler)
self.__unhandled_exceptions = []
def tearDown(self):
if self.__unhandled_exceptions:
formatted = []
for i, context in enumerate(self.__unhandled_exceptions):
formatted.append(self._format_loop_exception(context, i + 1))
self.fail(
'unexpected exceptions in asynchronous code:\n' +
'\n'.join(formatted))
@contextlib.contextmanager
def assertRunUnder(self, delta):
st = time.monotonic()
try:
yield
finally:
elapsed = time.monotonic() - st
if elapsed > delta:
raise AssertionError(
'running block took {:0.3f}s which is longer '
'than the expected maximum of {:0.3f}s'.format(
elapsed, delta))
@contextlib.contextmanager
def assertLoopErrorHandlerCalled(self, msg_re: str):
contexts = []
def handler(loop, ctx):
contexts.append(ctx)
old_handler = self.loop.get_exception_handler()
self.loop.set_exception_handler(handler)
try:
yield
for ctx in contexts:
msg = ctx.get('message')
if msg and re.search(msg_re, msg):
return
raise AssertionError(
'no message matching {!r} was logged with '
'loop.call_exception_handler()'.format(msg_re))
finally:
self.loop.set_exception_handler(old_handler)
def loop_exception_handler(self, loop, context):
self.__unhandled_exceptions.append(context)
loop.default_exception_handler(context)
def _format_loop_exception(self, context, n):
message = context.get('message', 'Unhandled exception in event loop')
exception = context.get('exception')
if exception is not None:
exc_info = (type(exception), exception, exception.__traceback__)
else:
exc_info = None
lines = []
for key in sorted(context):
if key in {'message', 'exception'}:
continue
value = context[key]
if key == 'source_traceback':
tb = ''.join(traceback.format_list(value))
value = 'Object created at (most recent call last):\n'
value += tb.rstrip()
else:
try:
value = repr(value)
except Exception as ex:
value = ('Exception in __repr__ {!r}; '
'value type: {!r}'.format(ex, type(value)))
lines.append('[{}]: {}\n\n'.format(key, value))
if exc_info is not None:
lines.append('[exception]:\n')
formatted_exc = textwrap.indent(
''.join(traceback.format_exception(*exc_info)), ' ')
lines.append(formatted_exc)
details = textwrap.indent(''.join(lines), ' ')
return '{:02d}. {}:\n{}\n'.format(n, message, details)
_default_cluster = None
def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
cluster = ClusterCls(**cluster_kwargs)
cluster.init(**(initdb_options or {}))
cluster.trust_local_connections()
atexit.register(_shutdown_cluster, cluster)
return cluster
def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
initdb_options=None):
cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
cluster.start(port='dynamic', server_settings=server_settings)
return cluster
def _get_initdb_options(initdb_options=None):
if not initdb_options:
initdb_options = {}
else:
initdb_options = dict(initdb_options)
# Make the default superuser name stable.
if 'username' not in initdb_options:
initdb_options['username'] = 'postgres'
return initdb_options
def _init_default_cluster(initdb_options=None):
global _default_cluster
if _default_cluster is None:
pg_host = os.environ.get('PGHOST')
if pg_host:
# Using existing cluster, assuming it is initialized and running
_default_cluster = pg_cluster.RunningCluster()
else:
_default_cluster = _init_cluster(
pg_cluster.TempCluster, cluster_kwargs={},
initdb_options=_get_initdb_options(initdb_options))
return _default_cluster
def _shutdown_cluster(cluster):
if cluster.get_status() == 'running':
cluster.stop()
if cluster.get_status() != 'not-initialized':
cluster.destroy()
def create_pool(dsn=None, *,
min_size=10,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=60.0,
setup=None,
init=None,
loop=None,
pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection,
record_class=asyncpg.Record,
**connect_kwargs):
return pool_class(
dsn,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
record_class=record_class,
**connect_kwargs)
class ClusterTestCase(TestCase):
@classmethod
def get_server_settings(cls):
settings = {
'log_connections': 'on'
}
if cls.cluster.get_pg_version() >= (11, 0):
# JITting messes up timing tests, and
# is not essential for testing.
settings['jit'] = 'off'
return settings
@classmethod
def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
cluster = _init_cluster(ClusterCls, cluster_kwargs,
_get_initdb_options(initdb_options))
cls._clusters.append(cluster)
return cluster
@classmethod
def start_cluster(cls, cluster, *, server_settings={}):
cluster.start(port='dynamic', server_settings=server_settings)
@classmethod
def setup_cluster(cls):
cls.cluster = _init_default_cluster()
if cls.cluster.get_status() != 'running':
cls.cluster.start(
port='dynamic', server_settings=cls.get_server_settings())
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._clusters = []
cls.setup_cluster()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
for cluster in cls._clusters:
if cluster is not _default_cluster:
cluster.stop()
cluster.destroy()
cls._clusters = []
@classmethod
def get_connection_spec(cls, kwargs={}):
conn_spec = cls.cluster.get_connection_spec()
if kwargs.get('dsn'):
conn_spec.pop('host')
conn_spec.update(kwargs)
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
conn_spec['user'] = 'postgres'
return conn_spec
@classmethod
def connect(cls, **kwargs):
conn_spec = cls.get_connection_spec(kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)
def setUp(self):
super().setUp()
self._pools = []
def tearDown(self):
super().tearDown()
for pool in self._pools:
pool.terminate()
self._pools = []
def create_pool(self, pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection, **kwargs):
conn_spec = self.get_connection_spec(kwargs)
pool = create_pool(loop=self.loop, pool_class=pool_class,
connection_class=connection_class, **conn_spec)
self._pools.append(pool)
return pool
class ProxiedClusterTestCase(ClusterTestCase):
@classmethod
def get_server_settings(cls):
settings = dict(super().get_server_settings())
settings['listen_addresses'] = '127.0.0.1'
return settings
@classmethod
def get_proxy_settings(cls):
return {'fuzzing-mode': None}
@classmethod
def setUpClass(cls):
super().setUpClass()
conn_spec = cls.cluster.get_connection_spec()
host = conn_spec.get('host')
if not host:
host = '127.0.0.1'
elif host.startswith('/'):
host = '127.0.0.1'
cls.proxy = fuzzer.TCPFuzzingProxy(
backend_host=host,
backend_port=conn_spec['port'],
)
cls.proxy.start()
@classmethod
def tearDownClass(cls):
cls.proxy.stop()
super().tearDownClass()
@classmethod
def get_connection_spec(cls, kwargs):
conn_spec = super().get_connection_spec(kwargs)
conn_spec['host'] = cls.proxy.listening_addr
conn_spec['port'] = cls.proxy.listening_port
return conn_spec
def tearDown(self):
self.proxy.reset()
super().tearDown()
def with_connection_options(**options):
if not options:
raise ValueError('no connection options were specified')
def wrap(func):
func.__connect_options__ = options
return func
return wrap
class ConnectedTestCase(ClusterTestCase):
def setUp(self):
super().setUp()
# Extract options set up with `with_connection_options`.
test_func = getattr(self, self._testMethodName).__func__
opts = getattr(test_func, '__connect_options__', {})
self.con = self.loop.run_until_complete(self.connect(**opts))
self.server_version = self.con.get_server_version()
def tearDown(self):
try:
self.loop.run_until_complete(self.con.close())
self.con = None
finally:
super().tearDown()
class HotStandbyTestCase(ClusterTestCase):
@classmethod
def setup_cluster(cls):
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
cls.start_cluster(
cls.master_cluster,
server_settings={
'max_wal_senders': 10,
'wal_level': 'hot_standby'
}
)
con = None
try:
con = cls.loop.run_until_complete(
cls.master_cluster.connect(
database='postgres', user='postgres', loop=cls.loop))
cls.loop.run_until_complete(
con.execute('''
CREATE ROLE replication WITH LOGIN REPLICATION
'''))
cls.master_cluster.trust_local_replication_by('replication')
conn_spec = cls.master_cluster.get_connection_spec()
cls.standby_cluster = cls.new_cluster(
pg_cluster.HotStandbyCluster,
cluster_kwargs={
'master': conn_spec,
'replication_user': 'replication'
}
)
cls.start_cluster(
cls.standby_cluster,
server_settings={
'hot_standby': True
}
)
finally:
if con is not None:
cls.loop.run_until_complete(con.close())
@classmethod
def get_cluster_connection_spec(cls, cluster, kwargs={}):
conn_spec = cluster.get_connection_spec()
if kwargs.get('dsn'):
conn_spec.pop('host')
conn_spec.update(kwargs)
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
conn_spec['user'] = 'postgres'
return conn_spec
@classmethod
def get_connection_spec(cls, kwargs={}):
primary_spec = cls.get_cluster_connection_spec(
cls.master_cluster, kwargs
)
standby_spec = cls.get_cluster_connection_spec(
cls.standby_cluster, kwargs
)
return {
'host': [primary_spec['host'], standby_spec['host']],
'port': [primary_spec['port'], standby_spec['port']],
'database': primary_spec['database'],
'user': primary_spec['user'],
**kwargs
}
@classmethod
def connect_primary(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)
@classmethod
def connect_standby(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(
cls.standby_cluster,
kwargs
)
return pg_connection.connect(**conn_spec, loop=cls.loop)

View File

@@ -0,0 +1,306 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import socket
import threading
import typing
from asyncpg import cluster
class StopServer(Exception):
pass
class TCPFuzzingProxy:
def __init__(self, *, listening_addr: str='127.0.0.1',
listening_port: typing.Optional[int]=None,
backend_host: str, backend_port: int,
settings: typing.Optional[dict]=None) -> None:
self.listening_addr = listening_addr
self.listening_port = listening_port
self.backend_host = backend_host
self.backend_port = backend_port
self.settings = settings or {}
self.loop = None
self.connectivity = None
self.connectivity_loss = None
self.stop_event = None
self.connections = {}
self.sock = None
self.listen_task = None
async def _wait(self, work):
work_task = asyncio.ensure_future(work)
stop_event_task = asyncio.ensure_future(self.stop_event.wait())
try:
await asyncio.wait(
[work_task, stop_event_task],
return_when=asyncio.FIRST_COMPLETED)
if self.stop_event.is_set():
raise StopServer()
else:
return work_task.result()
finally:
if not work_task.done():
work_task.cancel()
if not stop_event_task.done():
stop_event_task.cancel()
def start(self):
started = threading.Event()
self.thread = threading.Thread(
target=self._start_thread, args=(started,))
self.thread.start()
if not started.wait(timeout=2):
raise RuntimeError('fuzzer proxy failed to start')
def stop(self):
self.loop.call_soon_threadsafe(self._stop)
self.thread.join()
def _stop(self):
self.stop_event.set()
def _start_thread(self, started_event):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.connectivity = asyncio.Event()
self.connectivity.set()
self.connectivity_loss = asyncio.Event()
self.stop_event = asyncio.Event()
if self.listening_port is None:
self.listening_port = cluster.find_available_port()
self.sock = socket.socket()
self.sock.bind((self.listening_addr, self.listening_port))
self.sock.listen(50)
self.sock.setblocking(False)
try:
self.loop.run_until_complete(self._main(started_event))
finally:
self.loop.close()
async def _main(self, started_event):
self.listen_task = asyncio.ensure_future(self.listen())
# Notify the main thread that we are ready to go.
started_event.set()
try:
await self.listen_task
finally:
for c in list(self.connections):
c.close()
await asyncio.sleep(0.01)
if hasattr(self.loop, 'remove_reader'):
self.loop.remove_reader(self.sock.fileno())
self.sock.close()
async def listen(self):
while True:
try:
client_sock, _ = await self._wait(
self.loop.sock_accept(self.sock))
backend_sock = socket.socket()
backend_sock.setblocking(False)
await self._wait(self.loop.sock_connect(
backend_sock, (self.backend_host, self.backend_port)))
except StopServer:
break
conn = Connection(client_sock, backend_sock, self)
conn_task = self.loop.create_task(conn.handle())
self.connections[conn] = conn_task
def trigger_connectivity_loss(self):
self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)
def _trigger_connectivity_loss(self):
self.connectivity.clear()
self.connectivity_loss.set()
def restore_connectivity(self):
self.loop.call_soon_threadsafe(self._restore_connectivity)
def _restore_connectivity(self):
self.connectivity.set()
self.connectivity_loss.clear()
def reset(self):
self.restore_connectivity()
def _close_connection(self, connection):
conn_task = self.connections.pop(connection, None)
if conn_task is not None:
conn_task.cancel()
def close_all_connections(self):
for conn in list(self.connections):
self.loop.call_soon_threadsafe(self._close_connection, conn)
class Connection:
def __init__(self, client_sock, backend_sock, proxy):
self.client_sock = client_sock
self.backend_sock = backend_sock
self.proxy = proxy
self.loop = proxy.loop
self.connectivity = proxy.connectivity
self.connectivity_loss = proxy.connectivity_loss
self.proxy_to_backend_task = None
self.proxy_from_backend_task = None
self.is_closed = False
def close(self):
if self.is_closed:
return
self.is_closed = True
if self.proxy_to_backend_task is not None:
self.proxy_to_backend_task.cancel()
self.proxy_to_backend_task = None
if self.proxy_from_backend_task is not None:
self.proxy_from_backend_task.cancel()
self.proxy_from_backend_task = None
self.proxy._close_connection(self)
async def handle(self):
self.proxy_to_backend_task = asyncio.ensure_future(
self.proxy_to_backend())
self.proxy_from_backend_task = asyncio.ensure_future(
self.proxy_from_backend())
try:
await asyncio.wait(
[self.proxy_to_backend_task, self.proxy_from_backend_task],
return_when=asyncio.FIRST_COMPLETED)
finally:
if self.proxy_to_backend_task is not None:
self.proxy_to_backend_task.cancel()
if self.proxy_from_backend_task is not None:
self.proxy_from_backend_task.cancel()
# Asyncio fails to properly remove the readers and writers
# when the task doing recv() or send() is cancelled, so
# we must remove the readers and writers manually before
# closing the sockets.
self.loop.remove_reader(self.client_sock.fileno())
self.loop.remove_writer(self.client_sock.fileno())
self.loop.remove_reader(self.backend_sock.fileno())
self.loop.remove_writer(self.backend_sock.fileno())
self.client_sock.close()
self.backend_sock.close()
async def _read(self, sock, n):
read_task = asyncio.ensure_future(
self.loop.sock_recv(sock, n))
conn_event_task = asyncio.ensure_future(
self.connectivity_loss.wait())
try:
await asyncio.wait(
[read_task, conn_event_task],
return_when=asyncio.FIRST_COMPLETED)
if self.connectivity_loss.is_set():
return None
else:
return read_task.result()
finally:
if not self.loop.is_closed():
if not read_task.done():
read_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
async def _write(self, sock, data):
write_task = asyncio.ensure_future(
self.loop.sock_sendall(sock, data))
conn_event_task = asyncio.ensure_future(
self.connectivity_loss.wait())
try:
await asyncio.wait(
[write_task, conn_event_task],
return_when=asyncio.FIRST_COMPLETED)
if self.connectivity_loss.is_set():
return None
else:
return write_task.result()
finally:
if not self.loop.is_closed():
if not write_task.done():
write_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
async def proxy_to_backend(self):
buf = None
try:
while True:
await self.connectivity.wait()
if buf is not None:
data = buf
buf = None
else:
data = await self._read(self.client_sock, 4096)
if data == b'':
break
if self.connectivity_loss.is_set():
if data:
buf = data
continue
await self._write(self.backend_sock, data)
except ConnectionError:
pass
finally:
if not self.loop.is_closed():
self.loop.call_soon(self.close)
async def proxy_from_backend(self):
buf = None
try:
while True:
await self.connectivity.wait()
if buf is not None:
data = buf
buf = None
else:
data = await self._read(self.backend_sock, 4096)
if data == b'':
break
if self.connectivity_loss.is_set():
if data:
buf = data
continue
await self._write(self.client_sock, data)
except ConnectionError:
pass
finally:
if not self.loop.is_closed():
self.loop.call_soon(self.close)

View File

@@ -0,0 +1,13 @@
# This file MUST NOT contain anything but the __version__ assignment.
#
# When making a release, change the value of __version__
# to an appropriate value, and open a pull request against
# the correct branch (master if making a new feature release).
# The commit message MUST contain a properly formatted release
# log, and the commit must be signed.
#
# The release automation will: build and test the packages for the
# supported platforms, publish the packages on PyPI, merge the PR
# to the target branch, create a Git tag pointing to the commit.
__version__ = '0.29.0'

View File

@@ -0,0 +1,688 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import os
import os.path
import platform
import re
import shutil
import socket
import subprocess
import sys
import tempfile
import textwrap
import time
import asyncpg
from asyncpg import serverversion
_system = platform.uname().system
if _system == 'Windows':
def platform_exe(name):
if name.endswith('.exe'):
return name
return name + '.exe'
else:
def platform_exe(name):
return name
def find_available_port():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(('127.0.0.1', 0))
return sock.getsockname()[1]
except Exception:
return None
finally:
sock.close()
class ClusterError(Exception):
pass
class Cluster:
def __init__(self, data_dir, *, pg_config_path=None):
self._data_dir = data_dir
self._pg_config_path = pg_config_path
self._pg_bin_dir = (
os.environ.get('PGINSTALLATION')
or os.environ.get('PGBIN')
)
self._pg_ctl = None
self._daemon_pid = None
self._daemon_process = None
self._connection_addr = None
self._connection_spec_override = None
def get_pg_version(self):
return self._pg_version
def is_managed(self):
return True
def get_data_dir(self):
return self._data_dir
def get_status(self):
if self._pg_ctl is None:
self._init_env()
process = subprocess.run(
[self._pg_ctl, 'status', '-D', self._data_dir],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if (process.returncode == 4 or not os.path.exists(self._data_dir) or
not os.listdir(self._data_dir)):
return 'not-initialized'
elif process.returncode == 3:
return 'stopped'
elif process.returncode == 0:
r = re.match(r'.*PID\s?:\s+(\d+).*', stdout.decode())
if not r:
raise ClusterError(
'could not parse pg_ctl status output: {}'.format(
stdout.decode()))
self._daemon_pid = int(r.group(1))
return self._test_connection(timeout=0)
else:
raise ClusterError(
'pg_ctl status exited with status {:d}: {}'.format(
process.returncode, stderr))
async def connect(self, loop=None, **kwargs):
conn_info = self.get_connection_spec()
conn_info.update(kwargs)
return await asyncpg.connect(loop=loop, **conn_info)
def init(self, **settings):
"""Initialize cluster."""
if self.get_status() != 'not-initialized':
raise ClusterError(
'cluster in {!r} has already been initialized'.format(
self._data_dir))
settings = dict(settings)
if 'encoding' not in settings:
settings['encoding'] = 'UTF-8'
if settings:
settings_args = ['--{}={}'.format(k, v)
for k, v in settings.items()]
extra_args = ['-o'] + [' '.join(settings_args)]
else:
extra_args = []
process = subprocess.run(
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
output = process.stdout
if process.returncode != 0:
raise ClusterError(
'pg_ctl init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
return output.decode()
def start(self, wait=60, *, server_settings={}, **opts):
"""Start the cluster."""
status = self.get_status()
if status == 'running':
return
elif status == 'not-initialized':
raise ClusterError(
'cluster in {!r} has not been initialized'.format(
self._data_dir))
port = opts.pop('port', None)
if port == 'dynamic':
port = find_available_port()
extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()]
extra_args.append('--port={}'.format(port))
sockdir = server_settings.get('unix_socket_directories')
if sockdir is None:
sockdir = server_settings.get('unix_socket_directory')
if sockdir is None and _system != 'Windows':
sockdir = tempfile.gettempdir()
ssl_key = server_settings.get('ssl_key_file')
if ssl_key:
# Make sure server certificate key file has correct permissions.
keyfile = os.path.join(self._data_dir, 'srvkey.pem')
shutil.copy(ssl_key, keyfile)
os.chmod(keyfile, 0o600)
server_settings = server_settings.copy()
server_settings['ssl_key_file'] = keyfile
if sockdir is not None:
if self._pg_version < (9, 3):
sockdir_opt = 'unix_socket_directory'
else:
sockdir_opt = 'unix_socket_directories'
server_settings[sockdir_opt] = sockdir
for k, v in server_settings.items():
extra_args.extend(['-c', '{}={}'.format(k, v)])
if _system == 'Windows':
# On Windows we have to use pg_ctl as direct execution
# of postgres daemon under an Administrative account
# is not permitted and there is no easy way to drop
# privileges.
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
print(
'asyncpg.cluster: Running',
' '.join([
self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)
]),
file=sys.stderr,
)
else:
stdout = subprocess.DEVNULL
process = subprocess.run(
[self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)],
stdout=stdout, stderr=subprocess.STDOUT)
if process.returncode != 0:
if process.stderr:
stderr = ':\n{}'.format(process.stderr.decode())
else:
stderr = ''
raise ClusterError(
'pg_ctl start exited with status {:d}{}'.format(
process.returncode, stderr))
else:
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
else:
stdout = subprocess.DEVNULL
self._daemon_process = \
subprocess.Popen(
[self._postgres, '-D', self._data_dir, *extra_args],
stdout=stdout, stderr=subprocess.STDOUT)
self._daemon_pid = self._daemon_process.pid
self._test_connection(timeout=wait)
def reload(self):
"""Reload server configuration."""
status = self.get_status()
if status != 'running':
raise ClusterError('cannot reload: cluster is not running')
process = subprocess.run(
[self._pg_ctl, 'reload', '-D', self._data_dir],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_ctl stop exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
def stop(self, wait=60):
process = subprocess.run(
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
'-m', 'fast'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_ctl stop exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
if (self._daemon_process is not None and
self._daemon_process.returncode is None):
self._daemon_process.kill()
def destroy(self):
status = self.get_status()
if status == 'stopped' or status == 'not-initialized':
shutil.rmtree(self._data_dir)
else:
raise ClusterError('cannot destroy {} cluster'.format(status))
def _get_connection_spec(self):
if self._connection_addr is None:
self._connection_addr = self._connection_addr_from_pidfile()
if self._connection_addr is not None:
if self._connection_spec_override:
args = self._connection_addr.copy()
args.update(self._connection_spec_override)
return args
else:
return self._connection_addr
def get_connection_spec(self):
status = self.get_status()
if status != 'running':
raise ClusterError('cluster is not running')
return self._get_connection_spec()
def override_connection_spec(self, **kwargs):
self._connection_spec_override = kwargs
def reset_wal(self, *, oid=None, xid=None):
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify WAL status: cluster is not initialized')
if status == 'running':
raise ClusterError(
'cannot modify WAL status: cluster is running')
opts = []
if oid is not None:
opts.extend(['-o', str(oid)])
if xid is not None:
opts.extend(['-x', str(xid)])
if not opts:
return
opts.append(self._data_dir)
try:
reset_wal = self._find_pg_binary('pg_resetwal')
except ClusterError:
reset_wal = self._find_pg_binary('pg_resetxlog')
process = subprocess.run(
[reset_wal] + opts,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_resetwal exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
def reset_hba(self):
"""Remove all records from pg_hba.conf."""
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify HBA records: cluster is not initialized')
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
try:
with open(pg_hba, 'w'):
pass
except IOError as e:
raise ClusterError(
'cannot modify HBA records: {}'.format(e)) from e
def add_hba_entry(self, *, type='host', database, user, address=None,
auth_method, auth_options=None):
"""Add a record to pg_hba.conf."""
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify HBA records: cluster is not initialized')
if type not in {'local', 'host', 'hostssl', 'hostnossl'}:
raise ValueError('invalid HBA record type: {!r}'.format(type))
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
record = '{} {} {}'.format(type, database, user)
if type != 'local':
if address is None:
raise ValueError(
'{!r} entry requires a valid address'.format(type))
else:
record += ' {}'.format(address)
record += ' {}'.format(auth_method)
if auth_options is not None:
record += ' ' + ' '.join(
'{}={}'.format(k, v) for k, v in auth_options)
try:
with open(pg_hba, 'a') as f:
print(record, file=f)
except IOError as e:
raise ClusterError(
'cannot modify HBA records: {}'.format(e)) from e
def trust_local_connections(self):
self.reset_hba()
if _system != 'Windows':
self.add_hba_entry(type='local', database='all',
user='all', auth_method='trust')
self.add_hba_entry(type='host', address='127.0.0.1/32',
database='all', user='all',
auth_method='trust')
self.add_hba_entry(type='host', address='::1/128',
database='all', user='all',
auth_method='trust')
status = self.get_status()
if status == 'running':
self.reload()
def trust_local_replication_by(self, user):
if _system != 'Windows':
self.add_hba_entry(type='local', database='replication',
user=user, auth_method='trust')
self.add_hba_entry(type='host', address='127.0.0.1/32',
database='replication', user=user,
auth_method='trust')
self.add_hba_entry(type='host', address='::1/128',
database='replication', user=user,
auth_method='trust')
status = self.get_status()
if status == 'running':
self.reload()
def _init_env(self):
if not self._pg_bin_dir:
pg_config = self._find_pg_config(self._pg_config_path)
pg_config_data = self._run_pg_config(pg_config)
self._pg_bin_dir = pg_config_data.get('bindir')
if not self._pg_bin_dir:
raise ClusterError(
'pg_config output did not provide the BINDIR value')
self._pg_ctl = self._find_pg_binary('pg_ctl')
self._postgres = self._find_pg_binary('postgres')
self._pg_version = self._get_pg_version()
def _connection_addr_from_pidfile(self):
pidfile = os.path.join(self._data_dir, 'postmaster.pid')
try:
with open(pidfile, 'rt') as f:
piddata = f.read()
except FileNotFoundError:
return None
lines = piddata.splitlines()
if len(lines) < 6:
# A complete postgres pidfile is at least 6 lines
return None
pmpid = int(lines[0])
if self._daemon_pid and pmpid != self._daemon_pid:
# This might be an old pidfile left from previous postgres
# daemon run.
return None
portnum = lines[3]
sockdir = lines[4]
hostaddr = lines[5]
if sockdir:
if sockdir[0] != '/':
# Relative sockdir
sockdir = os.path.normpath(
os.path.join(self._data_dir, sockdir))
host_str = sockdir
else:
host_str = hostaddr
if host_str == '*':
host_str = 'localhost'
elif host_str == '0.0.0.0':
host_str = '127.0.0.1'
elif host_str == '::':
host_str = '::1'
return {
'host': host_str,
'port': portnum
}
def _test_connection(self, timeout=60):
self._connection_addr = None
loop = asyncio.new_event_loop()
try:
for i in range(timeout):
if self._connection_addr is None:
conn_spec = self._get_connection_spec()
if conn_spec is None:
time.sleep(1)
continue
try:
con = loop.run_until_complete(
asyncpg.connect(database='postgres',
user='postgres',
timeout=5, loop=loop,
**self._connection_addr))
except (OSError, asyncio.TimeoutError,
asyncpg.CannotConnectNowError,
asyncpg.PostgresConnectionError):
time.sleep(1)
continue
except asyncpg.PostgresError:
# Any other error other than ServerNotReadyError or
# ConnectionError is interpreted to indicate the server is
# up.
break
else:
loop.run_until_complete(con.close())
break
finally:
loop.close()
return 'running'
def _run_pg_config(self, pg_config_path):
process = subprocess.run(
pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if process.returncode != 0:
raise ClusterError('pg_config exited with status {:d}: {}'.format(
process.returncode, stderr))
else:
config = {}
for line in stdout.splitlines():
k, eq, v = line.decode('utf-8').partition('=')
if eq:
config[k.strip().lower()] = v.strip()
return config
def _find_pg_config(self, pg_config_path):
if pg_config_path is None:
pg_install = (
os.environ.get('PGINSTALLATION')
or os.environ.get('PGBIN')
)
if pg_install:
pg_config_path = platform_exe(
os.path.join(pg_install, 'pg_config'))
else:
pathenv = os.environ.get('PATH').split(os.pathsep)
for path in pathenv:
pg_config_path = platform_exe(
os.path.join(path, 'pg_config'))
if os.path.exists(pg_config_path):
break
else:
pg_config_path = None
if not pg_config_path:
raise ClusterError('could not find pg_config executable')
if not os.path.isfile(pg_config_path):
raise ClusterError('{!r} is not an executable'.format(
pg_config_path))
return pg_config_path
def _find_pg_binary(self, binary):
bpath = platform_exe(os.path.join(self._pg_bin_dir, binary))
if not os.path.isfile(bpath):
raise ClusterError(
'could not find {} executable: '.format(binary) +
'{!r} does not exist or is not a file'.format(bpath))
return bpath
def _get_pg_version(self):
process = subprocess.run(
[self._postgres, '--version'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if process.returncode != 0:
raise ClusterError(
'postgres --version exited with status {:d}: {}'.format(
process.returncode, stderr))
version_string = stdout.decode('utf-8').strip(' \n')
prefix = 'postgres (PostgreSQL) '
if not version_string.startswith(prefix):
raise ClusterError(
'could not determine server version from {!r}'.format(
version_string))
version_string = version_string[len(prefix):]
return serverversion.split_server_version_string(version_string)
class TempCluster(Cluster):
def __init__(self, *,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
prefix=data_dir_prefix,
dir=data_dir_parent)
super().__init__(self._data_dir, pg_config_path=pg_config_path)
class HotStandbyCluster(TempCluster):
def __init__(self, *,
master, replication_user,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._master = master
self._repl_user = replication_user
super().__init__(
data_dir_suffix=data_dir_suffix,
data_dir_prefix=data_dir_prefix,
data_dir_parent=data_dir_parent,
pg_config_path=pg_config_path)
def _init_env(self):
super()._init_env()
self._pg_basebackup = self._find_pg_binary('pg_basebackup')
def init(self, **settings):
"""Initialize cluster."""
if self.get_status() != 'not-initialized':
raise ClusterError(
'cluster in {!r} has already been initialized'.format(
self._data_dir))
process = subprocess.run(
[self._pg_basebackup, '-h', self._master['host'],
'-p', self._master['port'], '-D', self._data_dir,
'-U', self._repl_user],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
output = process.stdout
if process.returncode != 0:
raise ClusterError(
'pg_basebackup init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
if self._pg_version < (12, 0):
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
f.write(textwrap.dedent("""\
standby_mode = 'on'
primary_conninfo = 'host={host} port={port} user={user}'
""".format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user)))
else:
f = open(os.path.join(self._data_dir, 'standby.signal'), 'w')
f.close()
return output.decode()
def start(self, wait=60, *, server_settings={}, **opts):
if self._pg_version >= (12, 0):
server_settings = server_settings.copy()
server_settings['primary_conninfo'] = (
'"host={host} port={port} user={user}"'.format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user,
)
)
super().start(wait=wait, server_settings=server_settings, **opts)
class RunningCluster(Cluster):
def __init__(self, **kwargs):
self.conn_spec = kwargs
def is_managed(self):
return False
def get_connection_spec(self):
return dict(self.conn_spec)
def get_status(self):
return 'running'
def init(self, **settings):
pass
def start(self, wait=60, **settings):
pass
def stop(self, wait=60):
pass
def destroy(self):
pass
def reset_hba(self):
raise ClusterError('cannot modify HBA records of unmanaged cluster')
def add_hba_entry(self, *, type='host', database, user, address=None,
auth_method, auth_options=None):
raise ClusterError('cannot modify HBA records of unmanaged cluster')

View File

@@ -0,0 +1,61 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import pathlib
import platform
import typing
import sys
SYSTEM = platform.uname().system
if SYSTEM == 'Windows':
import ctypes.wintypes
CSIDL_APPDATA = 0x001a
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH)
r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf)
if r:
return None
else:
return pathlib.Path(buf.value) / 'postgresql'
else:
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
try:
return pathlib.Path.home()
except (RuntimeError, KeyError):
return None
async def wait_closed(stream):
# Not all asyncio versions have StreamWriter.wait_closed().
if hasattr(stream, 'wait_closed'):
try:
await stream.wait_closed()
except ConnectionResetError:
# On Windows wait_closed() sometimes propagates
# ConnectionResetError which is totally unnecessary.
pass
if sys.version_info < (3, 12):
from ._asyncio_compat import wait_for as wait_for # noqa: F401
else:
from asyncio import wait_for as wait_for # noqa: F401
if sys.version_info < (3, 11):
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
else:
from asyncio import timeout as timeout # noqa: F401

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,44 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import functools
from . import exceptions
def guarded(meth):
"""A decorator to add a sanity check to ConnectionResource methods."""
@functools.wraps(meth)
def _check(self, *args, **kwargs):
self._check_conn_validity(meth.__name__)
return meth(self, *args, **kwargs)
return _check
class ConnectionResource:
__slots__ = ('_connection', '_con_release_ctr')
def __init__(self, connection):
self._connection = connection
self._con_release_ctr = connection._pool_release_ctr
def _check_conn_validity(self, meth_name):
con_release_ctr = self._connection._pool_release_ctr
if con_release_ctr != self._con_release_ctr:
raise exceptions.InterfaceError(
'cannot call {}.{}(): '
'the underlying connection has been released back '
'to the pool'.format(self.__class__.__name__, meth_name))
if self._connection.is_closed():
raise exceptions.InterfaceError(
'cannot call {}.{}(): '
'the underlying connection is closed'.format(
self.__class__.__name__, meth_name))

View File

@@ -0,0 +1,323 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import collections
from . import connresource
from . import exceptions
class CursorFactory(connresource.ConnectionResource):
"""A cursor interface for the results of a query.
A cursor interface can be used to initiate efficient traversal of the
results of a large query.
"""
__slots__ = (
'_state',
'_args',
'_prefetch',
'_query',
'_timeout',
'_record_class',
)
def __init__(
self,
connection,
query,
state,
args,
prefetch,
timeout,
record_class
):
super().__init__(connection)
self._args = args
self._prefetch = prefetch
self._query = query
self._timeout = timeout
self._state = state
self._record_class = record_class
if state is not None:
state.attach()
@connresource.guarded
def __aiter__(self):
prefetch = 50 if self._prefetch is None else self._prefetch
return CursorIterator(
self._connection,
self._query,
self._state,
self._args,
self._record_class,
prefetch,
self._timeout,
)
@connresource.guarded
def __await__(self):
if self._prefetch is not None:
raise exceptions.InterfaceError(
'prefetch argument can only be specified for iterable cursor')
cursor = Cursor(
self._connection,
self._query,
self._state,
self._args,
self._record_class,
)
return cursor._init(self._timeout).__await__()
def __del__(self):
if self._state is not None:
self._state.detach()
self._connection._maybe_gc_stmt(self._state)
class BaseCursor(connresource.ConnectionResource):
__slots__ = (
'_state',
'_args',
'_portal_name',
'_exhausted',
'_query',
'_record_class',
)
def __init__(self, connection, query, state, args, record_class):
super().__init__(connection)
self._args = args
self._state = state
if state is not None:
state.attach()
self._portal_name = None
self._exhausted = False
self._query = query
self._record_class = record_class
def _check_ready(self):
if self._state is None:
raise exceptions.InterfaceError(
'cursor: no associated prepared statement')
if self._state.closed:
raise exceptions.InterfaceError(
'cursor: the prepared statement is closed')
if not self._connection._top_xact:
raise exceptions.NoActiveSQLTransactionError(
'cursor cannot be created outside of a transaction')
async def _bind_exec(self, n, timeout):
self._check_ready()
if self._portal_name:
raise exceptions.InterfaceError(
'cursor already has an open portal')
con = self._connection
protocol = con._protocol
self._portal_name = con._get_unique_id('portal')
buffer, _, self._exhausted = await protocol.bind_execute(
self._state, self._args, self._portal_name, n, True, timeout)
return buffer
async def _bind(self, timeout):
self._check_ready()
if self._portal_name:
raise exceptions.InterfaceError(
'cursor already has an open portal')
con = self._connection
protocol = con._protocol
self._portal_name = con._get_unique_id('portal')
buffer = await protocol.bind(self._state, self._args,
self._portal_name,
timeout)
return buffer
async def _exec(self, n, timeout):
self._check_ready()
if not self._portal_name:
raise exceptions.InterfaceError(
'cursor does not have an open portal')
protocol = self._connection._protocol
buffer, _, self._exhausted = await protocol.execute(
self._state, self._portal_name, n, True, timeout)
return buffer
async def _close_portal(self, timeout):
self._check_ready()
if not self._portal_name:
raise exceptions.InterfaceError(
'cursor does not have an open portal')
protocol = self._connection._protocol
await protocol.close_portal(self._portal_name, timeout)
self._portal_name = None
def __repr__(self):
attrs = []
if self._exhausted:
attrs.append('exhausted')
attrs.append('') # to separate from id
if self.__class__.__module__.startswith('asyncpg.'):
mod = 'asyncpg'
else:
mod = self.__class__.__module__
return '<{}.{} "{!s:.30}" {}{:#x}>'.format(
mod, self.__class__.__name__,
self._state.query,
' '.join(attrs), id(self))
def __del__(self):
if self._state is not None:
self._state.detach()
self._connection._maybe_gc_stmt(self._state)
class CursorIterator(BaseCursor):
__slots__ = ('_buffer', '_prefetch', '_timeout')
def __init__(
self,
connection,
query,
state,
args,
record_class,
prefetch,
timeout
):
super().__init__(connection, query, state, args, record_class)
if prefetch <= 0:
raise exceptions.InterfaceError(
'prefetch argument must be greater than zero')
self._buffer = collections.deque()
self._prefetch = prefetch
self._timeout = timeout
@connresource.guarded
def __aiter__(self):
return self
@connresource.guarded
async def __anext__(self):
if self._state is None:
self._state = await self._connection._get_statement(
self._query,
self._timeout,
named=True,
record_class=self._record_class,
)
self._state.attach()
if not self._portal_name and not self._exhausted:
buffer = await self._bind_exec(self._prefetch, self._timeout)
self._buffer.extend(buffer)
if not self._buffer and not self._exhausted:
buffer = await self._exec(self._prefetch, self._timeout)
self._buffer.extend(buffer)
if self._portal_name and self._exhausted:
await self._close_portal(self._timeout)
if self._buffer:
return self._buffer.popleft()
raise StopAsyncIteration
class Cursor(BaseCursor):
"""An open *portal* into the results of a query."""
__slots__ = ()
async def _init(self, timeout):
if self._state is None:
self._state = await self._connection._get_statement(
self._query,
timeout,
named=True,
record_class=self._record_class,
)
self._state.attach()
self._check_ready()
await self._bind(timeout)
return self
@connresource.guarded
async def fetch(self, n, *, timeout=None):
r"""Return the next *n* rows as a list of :class:`Record` objects.
:param float timeout: Optional timeout value in seconds.
:return: A list of :class:`Record` instances.
"""
self._check_ready()
if n <= 0:
raise exceptions.InterfaceError('n must be greater than zero')
if self._exhausted:
return []
recs = await self._exec(n, timeout)
if len(recs) < n:
self._exhausted = True
return recs
@connresource.guarded
async def fetchrow(self, *, timeout=None):
r"""Return the next row.
:param float timeout: Optional timeout value in seconds.
:return: A :class:`Record` instance.
"""
self._check_ready()
if self._exhausted:
return None
recs = await self._exec(1, timeout)
if len(recs) < 1:
self._exhausted = True
return None
return recs[0]
@connresource.guarded
async def forward(self, n, *, timeout=None) -> int:
r"""Skip over the next *n* rows.
:param float timeout: Optional timeout value in seconds.
:return: A number of rows actually skipped over (<= *n*).
"""
self._check_ready()
if n <= 0:
raise exceptions.InterfaceError('n must be greater than zero')
protocol = self._connection._protocol
status = await protocol.query('MOVE FORWARD {:d} {}'.format(
n, self._portal_name), timeout)
advanced = int(status.split()[1])
if advanced < n:
self._exhausted = True
return advanced

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,299 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncpg
import sys
import textwrap
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'ClientConfigurationError',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
'UnsupportedServerFeatureError')
def _is_asyncpg_class(cls):
modname = cls.__module__
return modname == 'asyncpg' or modname.startswith('asyncpg.')
class PostgresMessageMeta(type):
_message_map = {}
_field_map = {
'S': 'severity',
'V': 'severity_en',
'C': 'sqlstate',
'M': 'message',
'D': 'detail',
'H': 'hint',
'P': 'position',
'p': 'internal_position',
'q': 'internal_query',
'W': 'context',
's': 'schema_name',
't': 'table_name',
'c': 'column_name',
'd': 'data_type_name',
'n': 'constraint_name',
'F': 'server_source_filename',
'L': 'server_source_line',
'R': 'server_source_function'
}
def __new__(mcls, name, bases, dct):
cls = super().__new__(mcls, name, bases, dct)
if cls.__module__ == mcls.__module__ and name == 'PostgresMessage':
for f in mcls._field_map.values():
setattr(cls, f, None)
if _is_asyncpg_class(cls):
mod = sys.modules[cls.__module__]
if hasattr(mod, name):
raise RuntimeError('exception class redefinition: {}'.format(
name))
code = dct.get('sqlstate')
if code is not None:
existing = mcls._message_map.get(code)
if existing is not None:
raise TypeError('{} has duplicate SQLSTATE code, which is'
'already defined by {}'.format(
name, existing.__name__))
mcls._message_map[code] = cls
return cls
@classmethod
def get_message_class_for_sqlstate(mcls, code):
return mcls._message_map.get(code, UnknownPostgresError)
class PostgresMessage(metaclass=PostgresMessageMeta):
@classmethod
def _get_error_class(cls, fields):
sqlstate = fields.get('C')
return type(cls).get_message_class_for_sqlstate(sqlstate)
@classmethod
def _get_error_dict(cls, fields, query):
dct = {
'query': query
}
field_map = type(cls)._field_map
for k, v in fields.items():
field = field_map.get(k)
if field:
dct[field] = v
return dct
@classmethod
def _make_constructor(cls, fields, query=None):
dct = cls._get_error_dict(fields, query)
exccls = cls._get_error_class(fields)
message = dct.get('message', '')
# PostgreSQL will raise an exception when it detects
# that the result type of the query has changed from
# when the statement was prepared.
#
# The original error is somewhat cryptic and unspecific,
# so we raise a custom subclass that is easier to handle
# and identify.
#
# Note that we specifically do not rely on the error
# message, as it is localizable.
is_icse = (
exccls.__name__ == 'FeatureNotSupportedError' and
_is_asyncpg_class(exccls) and
dct.get('server_source_function') == 'RevalidateCachedQuery'
)
if is_icse:
exceptions = sys.modules[exccls.__module__]
exccls = exceptions.InvalidCachedStatementError
message = ('cached statement plan is invalid due to a database '
'schema or configuration change')
is_prepared_stmt_error = (
exccls.__name__ in ('DuplicatePreparedStatementError',
'InvalidSQLStatementNameError') and
_is_asyncpg_class(exccls)
)
if is_prepared_stmt_error:
hint = dct.get('hint', '')
hint += textwrap.dedent("""\
NOTE: pgbouncer with pool_mode set to "transaction" or
"statement" does not support prepared statements properly.
You have two options:
* if you are using pgbouncer for connection pooling to a
single server, switch to the connection pool functionality
provided by asyncpg, it is a much better option for this
purpose;
* if you have no option of avoiding the use of pgbouncer,
then you can set statement_cache_size to 0 when creating
the asyncpg connection object.
""")
dct['hint'] = hint
return exccls, message, dct
def as_dict(self):
dct = {}
for f in type(self)._field_map.values():
val = getattr(self, f)
if val is not None:
dct[f] = val
return dct
class PostgresError(PostgresMessage, Exception):
"""Base class for all Postgres errors."""
def __str__(self):
msg = self.args[0]
if self.detail:
msg += '\nDETAIL: {}'.format(self.detail)
if self.hint:
msg += '\nHINT: {}'.format(self.hint)
return msg
@classmethod
def new(cls, fields, query=None):
exccls, message, dct = cls._make_constructor(fields, query)
ex = exccls(message)
ex.__dict__.update(dct)
return ex
class FatalPostgresError(PostgresError):
"""A fatal error that should result in server disconnection."""
class UnknownPostgresError(FatalPostgresError):
"""An error with an unknown SQLSTATE code."""
class InterfaceMessage:
def __init__(self, *, detail=None, hint=None):
self.detail = detail
self.hint = hint
def __str__(self):
msg = self.args[0]
if self.detail:
msg += '\nDETAIL: {}'.format(self.detail)
if self.hint:
msg += '\nHINT: {}'.format(self.hint)
return msg
class InterfaceError(InterfaceMessage, Exception):
"""An error caused by improper use of asyncpg API."""
def __init__(self, msg, *, detail=None, hint=None):
InterfaceMessage.__init__(self, detail=detail, hint=hint)
Exception.__init__(self, msg)
def with_msg(self, msg):
return type(self)(
msg,
detail=self.detail,
hint=self.hint,
).with_traceback(
self.__traceback__
)
class ClientConfigurationError(InterfaceError, ValueError):
"""An error caused by improper client configuration."""
class DataError(InterfaceError, ValueError):
"""An error caused by invalid query input."""
class UnsupportedClientFeatureError(InterfaceError):
"""Requested feature is unsupported by asyncpg."""
class UnsupportedServerFeatureError(InterfaceError):
"""Requested feature is unsupported by PostgreSQL server."""
class InterfaceWarning(InterfaceMessage, UserWarning):
"""A warning caused by an improper use of asyncpg API."""
def __init__(self, msg, *, detail=None, hint=None):
InterfaceMessage.__init__(self, detail=detail, hint=hint)
UserWarning.__init__(self, msg)
class InternalClientError(Exception):
"""All unexpected errors not classified otherwise."""
class ProtocolError(InternalClientError):
"""Unexpected condition in the handling of PostgreSQL protocol input."""
class TargetServerAttributeNotMatched(InternalClientError):
"""Could not find a host that satisfies the target attribute requirement"""
class OutdatedSchemaCacheError(InternalClientError):
"""A value decoding error caused by a schema change before row fetching."""
def __init__(self, msg, *, schema=None, data_type=None, position=None):
super().__init__(msg)
self.schema_name = schema
self.data_type_name = data_type
self.position = position
class PostgresLogMessage(PostgresMessage):
"""A base class for non-error server messages."""
def __str__(self):
return '{}: {}'.format(type(self).__name__, self.message)
def __setattr__(self, name, val):
raise TypeError('instances of {} are immutable'.format(
type(self).__name__))
@classmethod
def new(cls, fields, query=None):
exccls, message_text, dct = cls._make_constructor(fields, query)
if exccls is UnknownPostgresError:
exccls = PostgresLogMessage
if exccls is PostgresLogMessage:
severity = dct.get('severity_en') or dct.get('severity')
if severity and severity.upper() == 'WARNING':
exccls = asyncpg.PostgresWarning
if issubclass(exccls, (BaseException, Warning)):
msg = exccls(message_text)
else:
msg = exccls()
msg.__dict__.update(dct)
return msg

View File

@@ -0,0 +1,292 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
_TYPEINFO_13 = '''\
(
SELECT
t.oid AS oid,
ns.nspname AS ns,
t.typname AS name,
t.typtype AS kind,
(CASE WHEN t.typtype = 'd' THEN
(WITH RECURSIVE typebases(oid, depth) AS (
SELECT
t2.typbasetype AS oid,
0 AS depth
FROM
pg_type t2
WHERE
t2.oid = t.oid
UNION ALL
SELECT
t2.typbasetype AS oid,
tb.depth + 1 AS depth
FROM
pg_type t2,
typebases tb
WHERE
tb.oid = t2.oid
AND t2.typbasetype != 0
) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1)
ELSE NULL
END) AS basetype,
t.typelem AS elemtype,
elem_t.typdelim AS elemdelim,
range_t.rngsubtype AS range_subtype,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.atttypid ORDER BY ia.attnum)
FROM
pg_attribute ia
INNER JOIN pg_class c
ON (ia.attrelid = c.oid)
WHERE
ia.attnum > 0 AND NOT ia.attisdropped
AND c.reltype = t.oid)
ELSE NULL
END) AS attrtypoids,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.attname::text ORDER BY ia.attnum)
FROM
pg_attribute ia
INNER JOIN pg_class c
ON (ia.attrelid = c.oid)
WHERE
ia.attnum > 0 AND NOT ia.attisdropped
AND c.reltype = t.oid)
ELSE NULL
END) AS attrnames
FROM
pg_catalog.pg_type AS t
INNER JOIN pg_catalog.pg_namespace ns ON (
ns.oid = t.typnamespace)
LEFT JOIN pg_type elem_t ON (
t.typlen = -1 AND
t.typelem != 0 AND
t.typelem = elem_t.oid
)
LEFT JOIN pg_range range_t ON (
t.oid = range_t.rngtypid
)
)
'''
INTRO_LOOKUP_TYPES_13 = '''\
WITH RECURSIVE typeinfo_tree(
oid, ns, name, kind, basetype, elemtype, elemdelim,
range_subtype, attrtypoids, attrnames, depth)
AS (
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, 0
FROM
{typeinfo} AS ti
WHERE
ti.oid = any($1::oid[])
UNION ALL
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, tt.depth + 1
FROM
{typeinfo} ti,
typeinfo_tree tt
WHERE
(tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
)
SELECT DISTINCT
*,
basetype::regtype::text AS basetype_name,
elemtype::regtype::text AS elemtype_name,
range_subtype::regtype::text AS range_subtype_name
FROM
typeinfo_tree
ORDER BY
depth DESC
'''.format(typeinfo=_TYPEINFO_13)
_TYPEINFO = '''\
(
SELECT
t.oid AS oid,
ns.nspname AS ns,
t.typname AS name,
t.typtype AS kind,
(CASE WHEN t.typtype = 'd' THEN
(WITH RECURSIVE typebases(oid, depth) AS (
SELECT
t2.typbasetype AS oid,
0 AS depth
FROM
pg_type t2
WHERE
t2.oid = t.oid
UNION ALL
SELECT
t2.typbasetype AS oid,
tb.depth + 1 AS depth
FROM
pg_type t2,
typebases tb
WHERE
tb.oid = t2.oid
AND t2.typbasetype != 0
) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1)
ELSE NULL
END) AS basetype,
t.typelem AS elemtype,
elem_t.typdelim AS elemdelim,
COALESCE(
range_t.rngsubtype,
multirange_t.rngsubtype) AS range_subtype,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.atttypid ORDER BY ia.attnum)
FROM
pg_attribute ia
INNER JOIN pg_class c
ON (ia.attrelid = c.oid)
WHERE
ia.attnum > 0 AND NOT ia.attisdropped
AND c.reltype = t.oid)
ELSE NULL
END) AS attrtypoids,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.attname::text ORDER BY ia.attnum)
FROM
pg_attribute ia
INNER JOIN pg_class c
ON (ia.attrelid = c.oid)
WHERE
ia.attnum > 0 AND NOT ia.attisdropped
AND c.reltype = t.oid)
ELSE NULL
END) AS attrnames
FROM
pg_catalog.pg_type AS t
INNER JOIN pg_catalog.pg_namespace ns ON (
ns.oid = t.typnamespace)
LEFT JOIN pg_type elem_t ON (
t.typlen = -1 AND
t.typelem != 0 AND
t.typelem = elem_t.oid
)
LEFT JOIN pg_range range_t ON (
t.oid = range_t.rngtypid
)
LEFT JOIN pg_range multirange_t ON (
t.oid = multirange_t.rngmultitypid
)
)
'''
INTRO_LOOKUP_TYPES = '''\
WITH RECURSIVE typeinfo_tree(
oid, ns, name, kind, basetype, elemtype, elemdelim,
range_subtype, attrtypoids, attrnames, depth)
AS (
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, 0
FROM
{typeinfo} AS ti
WHERE
ti.oid = any($1::oid[])
UNION ALL
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, tt.depth + 1
FROM
{typeinfo} ti,
typeinfo_tree tt
WHERE
(tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
)
SELECT DISTINCT
*,
basetype::regtype::text AS basetype_name,
elemtype::regtype::text AS elemtype_name,
range_subtype::regtype::text AS range_subtype_name
FROM
typeinfo_tree
ORDER BY
depth DESC
'''.format(typeinfo=_TYPEINFO)
TYPE_BY_NAME = '''\
SELECT
t.oid,
t.typelem AS elemtype,
t.typtype AS kind
FROM
pg_catalog.pg_type AS t
INNER JOIN pg_catalog.pg_namespace ns ON (ns.oid = t.typnamespace)
WHERE
t.typname = $1 AND ns.nspname = $2
'''
TYPE_BY_OID = '''\
SELECT
t.oid,
t.typelem AS elemtype,
t.typtype AS kind
FROM
pg_catalog.pg_type AS t
WHERE
t.oid = $1
'''
# 'b' for a base type, 'd' for a domain, 'e' for enum.
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
def is_scalar_type(typeinfo) -> bool:
return (
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)
def is_domain_type(typeinfo) -> bool:
return typeinfo['kind'] == b'd'
def is_composite_type(typeinfo) -> bool:
return typeinfo['kind'] == b'c'

View File

@@ -0,0 +1,5 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

View File

@@ -0,0 +1,5 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

View File

@@ -0,0 +1,136 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class WriteBuffer:
cdef:
# Preallocated small buffer
bint _smallbuf_inuse
char _smallbuf[_BUFFER_INITIAL_SIZE]
char *_buf
# Allocated size
ssize_t _size
# Length of data in the buffer
ssize_t _length
# Number of memoryviews attached to the buffer
int _view_count
# True is start_message was used
bint _message_mode
cdef inline len(self):
return self._length
cdef inline write_len_prefixed_utf8(self, str s):
return self.write_len_prefixed_bytes(s.encode('utf-8'))
cdef inline _check_readonly(self)
cdef inline _ensure_alloced(self, ssize_t extra_length)
cdef _reallocate(self, ssize_t new_size)
cdef inline reset(self)
cdef inline start_message(self, char type)
cdef inline end_message(self)
cdef write_buffer(self, WriteBuffer buf)
cdef write_byte(self, char b)
cdef write_bytes(self, bytes data)
cdef write_len_prefixed_buffer(self, WriteBuffer buf)
cdef write_len_prefixed_bytes(self, bytes data)
cdef write_bytestring(self, bytes string)
cdef write_str(self, str string, str encoding)
cdef write_frbuf(self, FRBuffer *buf)
cdef write_cstr(self, const char *data, ssize_t len)
cdef write_int16(self, int16_t i)
cdef write_int32(self, int32_t i)
cdef write_int64(self, int64_t i)
cdef write_float(self, float f)
cdef write_double(self, double d)
@staticmethod
cdef WriteBuffer new_message(char type)
@staticmethod
cdef WriteBuffer new()
ctypedef const char * (*try_consume_message_method)(object, ssize_t*)
ctypedef int32_t (*take_message_type_method)(object, char) except -1
ctypedef int32_t (*take_message_method)(object) except -1
ctypedef char (*get_message_type_method)(object)
cdef class ReadBuffer:
cdef:
# A deque of buffers (bytes objects)
object _bufs
object _bufs_append
object _bufs_popleft
# A pointer to the first buffer in `_bufs`
bytes _buf0
# A pointer to the previous first buffer
# (used to prolong the life of _buf0 when using
# methods like _try_read_bytes)
bytes _buf0_prev
# Number of buffers in `_bufs`
int32_t _bufs_len
# A read position in the first buffer in `_bufs`
ssize_t _pos0
# Length of the first buffer in `_bufs`
ssize_t _len0
# A total number of buffered bytes in ReadBuffer
ssize_t _length
char _current_message_type
int32_t _current_message_len
ssize_t _current_message_len_unread
bint _current_message_ready
cdef inline len(self):
return self._length
cdef inline char get_message_type(self):
return self._current_message_type
cdef inline int32_t get_message_length(self):
return self._current_message_len
cdef feed_data(self, data)
cdef inline _ensure_first_buf(self)
cdef _switch_to_next_buf(self)
cdef inline char read_byte(self) except? -1
cdef inline const char* _try_read_bytes(self, ssize_t nbytes)
cdef inline _read_into(self, char *buf, ssize_t nbytes)
cdef inline _read_and_discard(self, ssize_t nbytes)
cdef bytes read_bytes(self, ssize_t nbytes)
cdef bytes read_len_prefixed_bytes(self)
cdef str read_len_prefixed_utf8(self)
cdef read_uuid(self)
cdef inline int64_t read_int64(self) except? -1
cdef inline int32_t read_int32(self) except? -1
cdef inline int16_t read_int16(self) except? -1
cdef inline read_null_str(self)
cdef int32_t take_message(self) except -1
cdef inline int32_t take_message_type(self, char mtype) except -1
cdef int32_t put_message(self) except -1
cdef inline const char* try_consume_message(self, ssize_t* len)
cdef bytes consume_message(self)
cdef discard_message(self)
cdef redirect_messages(self, WriteBuffer buf, char mtype, int stop_at=?)
cdef bytearray consume_messages(self, char mtype)
cdef finish_message(self)
cdef inline _finish_message(self)
@staticmethod
cdef ReadBuffer new_message_parser(object data)

View File

@@ -0,0 +1,817 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc.string cimport memcpy
import collections
class BufferError(Exception):
pass
@cython.no_gc_clear
@cython.final
@cython.freelist(_BUFFER_FREELIST_SIZE)
cdef class WriteBuffer:
def __cinit__(self):
self._smallbuf_inuse = True
self._buf = self._smallbuf
self._size = _BUFFER_INITIAL_SIZE
self._length = 0
self._message_mode = 0
def __dealloc__(self):
if self._buf is not NULL and not self._smallbuf_inuse:
cpython.PyMem_Free(self._buf)
self._buf = NULL
self._size = 0
if self._view_count:
raise BufferError(
'Deallocating buffer with attached memoryviews')
def __getbuffer__(self, Py_buffer *buffer, int flags):
self._view_count += 1
cpython.PyBuffer_FillInfo(
buffer, self, self._buf, self._length,
1, # read-only
flags)
def __releasebuffer__(self, Py_buffer *buffer):
self._view_count -= 1
cdef inline _check_readonly(self):
if self._view_count:
raise BufferError('the buffer is in read-only mode')
cdef inline _ensure_alloced(self, ssize_t extra_length):
cdef ssize_t new_size = extra_length + self._length
if new_size > self._size:
self._reallocate(new_size)
cdef _reallocate(self, ssize_t new_size):
cdef char *new_buf
if new_size < _BUFFER_MAX_GROW:
new_size = _BUFFER_MAX_GROW
else:
# Add a little extra
new_size += _BUFFER_INITIAL_SIZE
if self._smallbuf_inuse:
new_buf = <char*>cpython.PyMem_Malloc(
sizeof(char) * <size_t>new_size)
if new_buf is NULL:
self._buf = NULL
self._size = 0
self._length = 0
raise MemoryError
memcpy(new_buf, self._buf, <size_t>self._size)
self._size = new_size
self._buf = new_buf
self._smallbuf_inuse = False
else:
new_buf = <char*>cpython.PyMem_Realloc(
<void*>self._buf, <size_t>new_size)
if new_buf is NULL:
cpython.PyMem_Free(self._buf)
self._buf = NULL
self._size = 0
self._length = 0
raise MemoryError
self._buf = new_buf
self._size = new_size
cdef inline start_message(self, char type):
if self._length != 0:
raise BufferError(
'cannot start_message for a non-empty buffer')
self._ensure_alloced(5)
self._message_mode = 1
self._buf[0] = type
self._length = 5
cdef inline end_message(self):
# "length-1" to exclude the message type byte
cdef ssize_t mlen = self._length - 1
self._check_readonly()
if not self._message_mode:
raise BufferError(
'end_message can only be called with start_message')
if self._length < 5:
raise BufferError('end_message: buffer is too small')
if mlen > _MAXINT32:
raise BufferError('end_message: message is too large')
hton.pack_int32(&self._buf[1], <int32_t>mlen)
return self
cdef inline reset(self):
self._length = 0
self._message_mode = 0
cdef write_buffer(self, WriteBuffer buf):
self._check_readonly()
if not buf._length:
return
self._ensure_alloced(buf._length)
memcpy(self._buf + self._length,
<void*>buf._buf,
<size_t>buf._length)
self._length += buf._length
cdef write_byte(self, char b):
self._check_readonly()
self._ensure_alloced(1)
self._buf[self._length] = b
self._length += 1
cdef write_bytes(self, bytes data):
cdef char* buf
cdef ssize_t len
cpython.PyBytes_AsStringAndSize(data, &buf, &len)
self.write_cstr(buf, len)
cdef write_bytestring(self, bytes string):
cdef char* buf
cdef ssize_t len
cpython.PyBytes_AsStringAndSize(string, &buf, &len)
# PyBytes_AsStringAndSize returns a null-terminated buffer,
# but the null byte is not counted in len. hence the + 1
self.write_cstr(buf, len + 1)
cdef write_str(self, str string, str encoding):
self.write_bytestring(string.encode(encoding))
cdef write_len_prefixed_buffer(self, WriteBuffer buf):
# Write a length-prefixed (not NULL-terminated) bytes sequence.
self.write_int32(<int32_t>buf.len())
self.write_buffer(buf)
cdef write_len_prefixed_bytes(self, bytes data):
# Write a length-prefixed (not NULL-terminated) bytes sequence.
cdef:
char *buf
ssize_t size
cpython.PyBytes_AsStringAndSize(data, &buf, &size)
if size > _MAXINT32:
raise BufferError('string is too large')
# `size` does not account for the NULL at the end.
self.write_int32(<int32_t>size)
self.write_cstr(buf, size)
cdef write_frbuf(self, FRBuffer *buf):
cdef:
ssize_t buf_len = buf.len
if buf_len > 0:
self.write_cstr(frb_read_all(buf), buf_len)
cdef write_cstr(self, const char *data, ssize_t len):
self._check_readonly()
self._ensure_alloced(len)
memcpy(self._buf + self._length, <void*>data, <size_t>len)
self._length += len
cdef write_int16(self, int16_t i):
self._check_readonly()
self._ensure_alloced(2)
hton.pack_int16(&self._buf[self._length], i)
self._length += 2
cdef write_int32(self, int32_t i):
self._check_readonly()
self._ensure_alloced(4)
hton.pack_int32(&self._buf[self._length], i)
self._length += 4
cdef write_int64(self, int64_t i):
self._check_readonly()
self._ensure_alloced(8)
hton.pack_int64(&self._buf[self._length], i)
self._length += 8
cdef write_float(self, float f):
self._check_readonly()
self._ensure_alloced(4)
hton.pack_float(&self._buf[self._length], f)
self._length += 4
cdef write_double(self, double d):
self._check_readonly()
self._ensure_alloced(8)
hton.pack_double(&self._buf[self._length], d)
self._length += 8
@staticmethod
cdef WriteBuffer new_message(char type):
cdef WriteBuffer buf
buf = WriteBuffer.__new__(WriteBuffer)
buf.start_message(type)
return buf
@staticmethod
cdef WriteBuffer new():
cdef WriteBuffer buf
buf = WriteBuffer.__new__(WriteBuffer)
return buf
@cython.no_gc_clear
@cython.final
@cython.freelist(_BUFFER_FREELIST_SIZE)
cdef class ReadBuffer:
def __cinit__(self):
self._bufs = collections.deque()
self._bufs_append = self._bufs.append
self._bufs_popleft = self._bufs.popleft
self._bufs_len = 0
self._buf0 = None
self._buf0_prev = None
self._pos0 = 0
self._len0 = 0
self._length = 0
self._current_message_type = 0
self._current_message_len = 0
self._current_message_len_unread = 0
self._current_message_ready = 0
cdef feed_data(self, data):
cdef:
ssize_t dlen
bytes data_bytes
if not cpython.PyBytes_CheckExact(data):
if cpythonx.PyByteArray_CheckExact(data):
# ProactorEventLoop in Python 3.10+ seems to be sending
# bytearray objects instead of bytes. Handle this here
# to avoid duplicating this check in every data_received().
data = bytes(data)
else:
raise BufferError(
'feed_data: a bytes or bytearray object expected')
# Uncomment the below code to test code paths that
# read single int/str/bytes sequences are split over
# multiple received buffers.
#
# ll = 107
# if len(data) > ll:
# self.feed_data(data[:ll])
# self.feed_data(data[ll:])
# return
data_bytes = <bytes>data
dlen = cpython.Py_SIZE(data_bytes)
if dlen == 0:
# EOF?
return
self._bufs_append(data_bytes)
self._length += dlen
if self._bufs_len == 0:
# First buffer
self._len0 = dlen
self._buf0 = data_bytes
self._bufs_len += 1
cdef inline _ensure_first_buf(self):
if PG_DEBUG:
if self._len0 == 0:
raise BufferError('empty first buffer')
if self._length == 0:
raise BufferError('empty buffer')
if self._pos0 == self._len0:
self._switch_to_next_buf()
cdef _switch_to_next_buf(self):
# The first buffer is fully read, discard it
self._bufs_popleft()
self._bufs_len -= 1
# Shouldn't fail, since we've checked that `_length >= 1`
# in _ensure_first_buf()
self._buf0_prev = self._buf0
self._buf0 = <bytes>self._bufs[0]
self._pos0 = 0
self._len0 = len(self._buf0)
if PG_DEBUG:
if self._len0 < 1:
raise BufferError(
'debug: second buffer of ReadBuffer is empty')
cdef inline const char* _try_read_bytes(self, ssize_t nbytes):
# Try to read *nbytes* from the first buffer.
#
# Returns pointer to data if there is at least *nbytes*
# in the buffer, NULL otherwise.
#
# Important: caller must call _ensure_first_buf() prior
# to calling try_read_bytes, and must not overread
cdef:
const char *result
if PG_DEBUG:
if nbytes > self._length:
return NULL
if self._current_message_ready:
if self._current_message_len_unread < nbytes:
return NULL
if self._pos0 + nbytes <= self._len0:
result = cpython.PyBytes_AS_STRING(self._buf0)
result += self._pos0
self._pos0 += nbytes
self._length -= nbytes
if self._current_message_ready:
self._current_message_len_unread -= nbytes
return result
else:
return NULL
cdef inline _read_into(self, char *buf, ssize_t nbytes):
cdef:
ssize_t nread
char *buf0
while True:
buf0 = cpython.PyBytes_AS_STRING(self._buf0)
if self._pos0 + nbytes > self._len0:
nread = self._len0 - self._pos0
memcpy(buf, buf0 + self._pos0, <size_t>nread)
self._pos0 = self._len0
self._length -= nread
nbytes -= nread
buf += nread
self._ensure_first_buf()
else:
memcpy(buf, buf0 + self._pos0, <size_t>nbytes)
self._pos0 += nbytes
self._length -= nbytes
break
cdef inline _read_and_discard(self, ssize_t nbytes):
cdef:
ssize_t nread
self._ensure_first_buf()
while True:
if self._pos0 + nbytes > self._len0:
nread = self._len0 - self._pos0
self._pos0 = self._len0
self._length -= nread
nbytes -= nread
self._ensure_first_buf()
else:
self._pos0 += nbytes
self._length -= nbytes
break
cdef bytes read_bytes(self, ssize_t nbytes):
cdef:
bytes result
ssize_t nread
const char *cbuf
char *buf
self._ensure_first_buf()
cbuf = self._try_read_bytes(nbytes)
if cbuf != NULL:
return cpython.PyBytes_FromStringAndSize(cbuf, nbytes)
if nbytes > self._length:
raise BufferError(
'not enough data to read {} bytes'.format(nbytes))
if self._current_message_ready:
self._current_message_len_unread -= nbytes
if self._current_message_len_unread < 0:
raise BufferError('buffer overread')
result = cpython.PyBytes_FromStringAndSize(NULL, nbytes)
buf = cpython.PyBytes_AS_STRING(result)
self._read_into(buf, nbytes)
return result
cdef bytes read_len_prefixed_bytes(self):
cdef int32_t size = self.read_int32()
if size < 0:
raise BufferError(
'negative length for a len-prefixed bytes value')
if size == 0:
return b''
return self.read_bytes(size)
cdef str read_len_prefixed_utf8(self):
cdef:
int32_t size
const char *cbuf
size = self.read_int32()
if size < 0:
raise BufferError(
'negative length for a len-prefixed bytes value')
if size == 0:
return ''
self._ensure_first_buf()
cbuf = self._try_read_bytes(size)
if cbuf != NULL:
return cpython.PyUnicode_DecodeUTF8(cbuf, size, NULL)
else:
return self.read_bytes(size).decode('utf-8')
cdef read_uuid(self):
cdef:
bytes mem
const char *cbuf
self._ensure_first_buf()
cbuf = self._try_read_bytes(16)
if cbuf != NULL:
return pg_uuid_from_buf(cbuf)
else:
return pg_UUID(self.read_bytes(16))
cdef inline char read_byte(self) except? -1:
cdef const char *first_byte
if PG_DEBUG:
if not self._buf0:
raise BufferError(
'debug: first buffer of ReadBuffer is empty')
self._ensure_first_buf()
first_byte = self._try_read_bytes(1)
if first_byte is NULL:
raise BufferError('not enough data to read one byte')
return first_byte[0]
cdef inline int64_t read_int64(self) except? -1:
cdef:
bytes mem
const char *cbuf
self._ensure_first_buf()
cbuf = self._try_read_bytes(8)
if cbuf != NULL:
return hton.unpack_int64(cbuf)
else:
mem = self.read_bytes(8)
return hton.unpack_int64(cpython.PyBytes_AS_STRING(mem))
cdef inline int32_t read_int32(self) except? -1:
cdef:
bytes mem
const char *cbuf
self._ensure_first_buf()
cbuf = self._try_read_bytes(4)
if cbuf != NULL:
return hton.unpack_int32(cbuf)
else:
mem = self.read_bytes(4)
return hton.unpack_int32(cpython.PyBytes_AS_STRING(mem))
cdef inline int16_t read_int16(self) except? -1:
cdef:
bytes mem
const char *cbuf
self._ensure_first_buf()
cbuf = self._try_read_bytes(2)
if cbuf != NULL:
return hton.unpack_int16(cbuf)
else:
mem = self.read_bytes(2)
return hton.unpack_int16(cpython.PyBytes_AS_STRING(mem))
cdef inline read_null_str(self):
if not self._current_message_ready:
raise BufferError(
'read_null_str only works when the message guaranteed '
'to be in the buffer')
cdef:
ssize_t pos
ssize_t nread
bytes result
const char *buf
const char *buf_start
self._ensure_first_buf()
buf_start = cpython.PyBytes_AS_STRING(self._buf0)
buf = buf_start + self._pos0
while buf - buf_start < self._len0:
if buf[0] == 0:
pos = buf - buf_start
nread = pos - self._pos0
buf = self._try_read_bytes(nread + 1)
if buf != NULL:
return cpython.PyBytes_FromStringAndSize(buf, nread)
else:
break
else:
buf += 1
result = b''
while True:
pos = self._buf0.find(b'\x00', self._pos0)
if pos >= 0:
result += self._buf0[self._pos0 : pos]
nread = pos - self._pos0 + 1
self._pos0 = pos + 1
self._length -= nread
self._current_message_len_unread -= nread
if self._current_message_len_unread < 0:
raise BufferError(
'read_null_str: buffer overread')
return result
else:
result += self._buf0[self._pos0:]
nread = self._len0 - self._pos0
self._pos0 = self._len0
self._length -= nread
self._current_message_len_unread -= nread
if self._current_message_len_unread < 0:
raise BufferError(
'read_null_str: buffer overread')
self._ensure_first_buf()
cdef int32_t take_message(self) except -1:
cdef:
const char *cbuf
if self._current_message_ready:
return 1
if self._current_message_type == 0:
if self._length < 1:
return 0
self._ensure_first_buf()
cbuf = self._try_read_bytes(1)
if cbuf == NULL:
raise BufferError(
'failed to read one byte on a non-empty buffer')
self._current_message_type = cbuf[0]
if self._current_message_len == 0:
if self._length < 4:
return 0
self._ensure_first_buf()
cbuf = self._try_read_bytes(4)
if cbuf != NULL:
self._current_message_len = hton.unpack_int32(cbuf)
else:
self._current_message_len = self.read_int32()
self._current_message_len_unread = self._current_message_len - 4
if self._length < self._current_message_len_unread:
return 0
self._current_message_ready = 1
return 1
cdef inline int32_t take_message_type(self, char mtype) except -1:
cdef const char *buf0
if self._current_message_ready:
return self._current_message_type == mtype
elif self._length >= 1:
self._ensure_first_buf()
buf0 = cpython.PyBytes_AS_STRING(self._buf0)
return buf0[self._pos0] == mtype and self.take_message()
else:
return 0
cdef int32_t put_message(self) except -1:
if not self._current_message_ready:
raise BufferError(
'cannot put message: no message taken')
self._current_message_ready = False
return 0
cdef inline const char* try_consume_message(self, ssize_t* len):
cdef:
ssize_t buf_len
const char *buf
if not self._current_message_ready:
return NULL
self._ensure_first_buf()
buf_len = self._current_message_len_unread
buf = self._try_read_bytes(buf_len)
if buf != NULL:
len[0] = buf_len
self._finish_message()
return buf
cdef discard_message(self):
if not self._current_message_ready:
raise BufferError('no message to discard')
if self._current_message_len_unread > 0:
self._read_and_discard(self._current_message_len_unread)
self._current_message_len_unread = 0
self._finish_message()
cdef bytes consume_message(self):
if not self._current_message_ready:
raise BufferError('no message to consume')
if self._current_message_len_unread > 0:
mem = self.read_bytes(self._current_message_len_unread)
else:
mem = b''
self._finish_message()
return mem
cdef redirect_messages(self, WriteBuffer buf, char mtype,
int stop_at=0):
if not self._current_message_ready:
raise BufferError(
'consume_full_messages called on a buffer without a '
'complete first message')
if mtype != self._current_message_type:
raise BufferError(
'consume_full_messages called with a wrong mtype')
if self._current_message_len_unread != self._current_message_len - 4:
raise BufferError(
'consume_full_messages called on a partially read message')
cdef:
const char* cbuf
ssize_t cbuf_len
int32_t msg_len
ssize_t new_pos0
ssize_t pos_delta
int32_t done
while True:
buf.write_byte(mtype)
buf.write_int32(self._current_message_len)
cbuf = self.try_consume_message(&cbuf_len)
if cbuf != NULL:
buf.write_cstr(cbuf, cbuf_len)
else:
buf.write_bytes(self.consume_message())
if self._length > 0:
self._ensure_first_buf()
else:
return
if stop_at and buf._length >= stop_at:
return
# Fast path: exhaust buf0 as efficiently as possible.
if self._pos0 + 5 <= self._len0:
cbuf = cpython.PyBytes_AS_STRING(self._buf0)
new_pos0 = self._pos0
cbuf_len = self._len0
done = 0
# Scan the first buffer and find the position of the
# end of the last "mtype" message.
while new_pos0 + 5 <= cbuf_len:
if (cbuf + new_pos0)[0] != mtype:
done = 1
break
if (stop_at and
(buf._length + new_pos0 - self._pos0) > stop_at):
done = 1
break
msg_len = hton.unpack_int32(cbuf + new_pos0 + 1) + 1
if new_pos0 + msg_len > cbuf_len:
break
new_pos0 += msg_len
if new_pos0 != self._pos0:
assert self._pos0 < new_pos0 <= self._len0
pos_delta = new_pos0 - self._pos0
buf.write_cstr(
cbuf + self._pos0,
pos_delta)
self._pos0 = new_pos0
self._length -= pos_delta
assert self._length >= 0
if done:
# The next message is of a different type.
return
# Back to slow path.
if not self.take_message_type(mtype):
return
cdef bytearray consume_messages(self, char mtype):
"""Consume consecutive messages of the same type."""
cdef:
char *buf
ssize_t nbytes
ssize_t total_bytes = 0
bytearray result
if not self.take_message_type(mtype):
return None
# consume_messages is a volume-oriented method, so
# we assume that the remainder of the buffer will contain
# messages of the requested type.
result = cpythonx.PyByteArray_FromStringAndSize(NULL, self._length)
buf = cpythonx.PyByteArray_AsString(result)
while self.take_message_type(mtype):
self._ensure_first_buf()
nbytes = self._current_message_len_unread
self._read_into(buf, nbytes)
buf += nbytes
total_bytes += nbytes
self._finish_message()
# Clamp the result to an actual size read.
cpythonx.PyByteArray_Resize(result, total_bytes)
return result
cdef finish_message(self):
if self._current_message_type == 0 or not self._current_message_ready:
# The message has already been finished (e.g by consume_message()),
# or has been put back by put_message().
return
if self._current_message_len_unread:
if PG_DEBUG:
mtype = chr(self._current_message_type)
discarded = self.consume_message()
if PG_DEBUG:
print('!!! discarding message {!r} unread data: {!r}'.format(
mtype,
discarded))
self._finish_message()
cdef inline _finish_message(self):
self._current_message_type = 0
self._current_message_len = 0
self._current_message_ready = 0
self._current_message_len_unread = 0
@staticmethod
cdef ReadBuffer new_message_parser(object data):
cdef ReadBuffer buf
buf = ReadBuffer.__new__(ReadBuffer)
buf.feed_data(data)
buf._current_message_ready = 1
buf._current_message_len_unread = buf._len0
return buf

View File

@@ -0,0 +1,157 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class CodecContext:
cpdef get_text_codec(self)
cdef is_encoding_utf8(self)
cpdef get_json_decoder(self)
cdef is_decoding_json(self)
cpdef get_json_encoder(self)
cdef is_encoding_json(self)
ctypedef object (*encode_func)(CodecContext settings,
WriteBuffer buf,
object obj)
ctypedef object (*decode_func)(CodecContext settings,
FRBuffer *buf)
# Datetime
cdef date_encode(CodecContext settings, WriteBuffer buf, obj)
cdef date_decode(CodecContext settings, FRBuffer * buf)
cdef date_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
cdef date_decode_tuple(CodecContext settings, FRBuffer * buf)
cdef timestamp_encode(CodecContext settings, WriteBuffer buf, obj)
cdef timestamp_decode(CodecContext settings, FRBuffer * buf)
cdef timestamp_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
cdef timestamp_decode_tuple(CodecContext settings, FRBuffer * buf)
cdef timestamptz_encode(CodecContext settings, WriteBuffer buf, obj)
cdef timestamptz_decode(CodecContext settings, FRBuffer * buf)
cdef time_encode(CodecContext settings, WriteBuffer buf, obj)
cdef time_decode(CodecContext settings, FRBuffer * buf)
cdef time_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
cdef time_decode_tuple(CodecContext settings, FRBuffer * buf)
cdef timetz_encode(CodecContext settings, WriteBuffer buf, obj)
cdef timetz_decode(CodecContext settings, FRBuffer * buf)
cdef timetz_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
cdef timetz_decode_tuple(CodecContext settings, FRBuffer * buf)
cdef interval_encode(CodecContext settings, WriteBuffer buf, obj)
cdef interval_decode(CodecContext settings, FRBuffer * buf)
cdef interval_encode_tuple(CodecContext settings, WriteBuffer buf, tuple obj)
cdef interval_decode_tuple(CodecContext settings, FRBuffer * buf)
# Bits
cdef bits_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef bits_decode(CodecContext settings, FRBuffer * buf)
# Bools
cdef bool_encode(CodecContext settings, WriteBuffer buf, obj)
cdef bool_decode(CodecContext settings, FRBuffer * buf)
# Geometry
cdef box_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef box_decode(CodecContext settings, FRBuffer * buf)
cdef line_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef line_decode(CodecContext settings, FRBuffer * buf)
cdef lseg_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef lseg_decode(CodecContext settings, FRBuffer * buf)
cdef point_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef point_decode(CodecContext settings, FRBuffer * buf)
cdef path_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef path_decode(CodecContext settings, FRBuffer * buf)
cdef poly_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef poly_decode(CodecContext settings, FRBuffer * buf)
cdef circle_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef circle_decode(CodecContext settings, FRBuffer * buf)
# Hstore
cdef hstore_encode(CodecContext settings, WriteBuffer buf, obj)
cdef hstore_decode(CodecContext settings, FRBuffer * buf)
# Ints
cdef int2_encode(CodecContext settings, WriteBuffer buf, obj)
cdef int2_decode(CodecContext settings, FRBuffer * buf)
cdef int4_encode(CodecContext settings, WriteBuffer buf, obj)
cdef int4_decode(CodecContext settings, FRBuffer * buf)
cdef uint4_encode(CodecContext settings, WriteBuffer buf, obj)
cdef uint4_decode(CodecContext settings, FRBuffer * buf)
cdef int8_encode(CodecContext settings, WriteBuffer buf, obj)
cdef int8_decode(CodecContext settings, FRBuffer * buf)
cdef uint8_encode(CodecContext settings, WriteBuffer buf, obj)
cdef uint8_decode(CodecContext settings, FRBuffer * buf)
# Floats
cdef float4_encode(CodecContext settings, WriteBuffer buf, obj)
cdef float4_decode(CodecContext settings, FRBuffer * buf)
cdef float8_encode(CodecContext settings, WriteBuffer buf, obj)
cdef float8_decode(CodecContext settings, FRBuffer * buf)
# JSON
cdef jsonb_encode(CodecContext settings, WriteBuffer buf, obj)
cdef jsonb_decode(CodecContext settings, FRBuffer * buf)
# JSON path
cdef jsonpath_encode(CodecContext settings, WriteBuffer buf, obj)
cdef jsonpath_decode(CodecContext settings, FRBuffer * buf)
# Text
cdef as_pg_string_and_size(
CodecContext settings, obj, char **cstr, ssize_t *size)
cdef text_encode(CodecContext settings, WriteBuffer buf, obj)
cdef text_decode(CodecContext settings, FRBuffer * buf)
# Bytea
cdef bytea_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef bytea_decode(CodecContext settings, FRBuffer * buf)
# UUID
cdef uuid_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef uuid_decode(CodecContext settings, FRBuffer * buf)
# Numeric
cdef numeric_encode_text(CodecContext settings, WriteBuffer buf, obj)
cdef numeric_decode_text(CodecContext settings, FRBuffer * buf)
cdef numeric_encode_binary(CodecContext settings, WriteBuffer buf, obj)
cdef numeric_decode_binary(CodecContext settings, FRBuffer * buf)
cdef numeric_decode_binary_ex(CodecContext settings, FRBuffer * buf,
bint trail_fract_zero)
# Void
cdef void_encode(CodecContext settings, WriteBuffer buf, obj)
cdef void_decode(CodecContext settings, FRBuffer * buf)
# tid
cdef tid_encode(CodecContext settings, WriteBuffer buf, obj)
cdef tid_decode(CodecContext settings, FRBuffer * buf)
# Network
cdef cidr_encode(CodecContext settings, WriteBuffer buf, obj)
cdef cidr_decode(CodecContext settings, FRBuffer * buf)
cdef inet_encode(CodecContext settings, WriteBuffer buf, obj)
cdef inet_decode(CodecContext settings, FRBuffer * buf)
# pg_snapshot
cdef pg_snapshot_encode(CodecContext settings, WriteBuffer buf, obj)
cdef pg_snapshot_decode(CodecContext settings, FRBuffer * buf)

View File

@@ -0,0 +1,47 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef bits_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
Py_buffer pybuf
bint pybuf_used = False
char *buf
ssize_t len
ssize_t bitlen
if cpython.PyBytes_CheckExact(obj):
buf = cpython.PyBytes_AS_STRING(obj)
len = cpython.Py_SIZE(obj)
bitlen = len * 8
elif isinstance(obj, pgproto_types.BitString):
cpython.PyBytes_AsStringAndSize(obj.bytes, &buf, &len)
bitlen = obj.__len__()
else:
cpython.PyObject_GetBuffer(obj, &pybuf, cpython.PyBUF_SIMPLE)
pybuf_used = True
buf = <char*>pybuf.buf
len = pybuf.len
bitlen = len * 8
try:
if bitlen > _MAXINT32:
raise ValueError('bit value too long')
wbuf.write_int32(4 + <int32_t>len)
wbuf.write_int32(<int32_t>bitlen)
wbuf.write_cstr(buf, len)
finally:
if pybuf_used:
cpython.PyBuffer_Release(&pybuf)
cdef bits_decode(CodecContext settings, FRBuffer *buf):
cdef:
int32_t bitlen = hton.unpack_int32(frb_read(buf, 4))
ssize_t buf_len = buf.len
bytes_ = cpython.PyBytes_FromStringAndSize(frb_read_all(buf), buf_len)
return pgproto_types.BitString.frombytes(bytes_, bitlen)

View File

@@ -0,0 +1,34 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef bytea_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
Py_buffer pybuf
bint pybuf_used = False
char *buf
ssize_t len
if cpython.PyBytes_CheckExact(obj):
buf = cpython.PyBytes_AS_STRING(obj)
len = cpython.Py_SIZE(obj)
else:
cpython.PyObject_GetBuffer(obj, &pybuf, cpython.PyBUF_SIMPLE)
pybuf_used = True
buf = <char*>pybuf.buf
len = pybuf.len
try:
wbuf.write_int32(<int32_t>len)
wbuf.write_cstr(buf, len)
finally:
if pybuf_used:
cpython.PyBuffer_Release(&pybuf)
cdef bytea_decode(CodecContext settings, FRBuffer *buf):
cdef ssize_t buf_len = buf.len
return cpython.PyBytes_FromStringAndSize(frb_read_all(buf), buf_len)

View File

@@ -0,0 +1,26 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class CodecContext:
cpdef get_text_codec(self):
raise NotImplementedError
cdef is_encoding_utf8(self):
raise NotImplementedError
cpdef get_json_decoder(self):
raise NotImplementedError
cdef is_decoding_json(self):
return False
cpdef get_json_encoder(self):
raise NotImplementedError
cdef is_encoding_json(self):
return False

View File

@@ -0,0 +1,423 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cimport cpython.datetime
import datetime
cpython.datetime.import_datetime()
utc = datetime.timezone.utc
date_from_ordinal = datetime.date.fromordinal
timedelta = datetime.timedelta
pg_epoch_datetime = datetime.datetime(2000, 1, 1)
cdef int32_t pg_epoch_datetime_ts = \
<int32_t>cpython.PyLong_AsLong(int(pg_epoch_datetime.timestamp()))
pg_epoch_datetime_utc = datetime.datetime(2000, 1, 1, tzinfo=utc)
cdef int32_t pg_epoch_datetime_utc_ts = \
<int32_t>cpython.PyLong_AsLong(int(pg_epoch_datetime_utc.timestamp()))
pg_epoch_date = datetime.date(2000, 1, 1)
cdef int32_t pg_date_offset_ord = \
<int32_t>cpython.PyLong_AsLong(pg_epoch_date.toordinal())
# Binary representations of infinity for datetimes.
cdef int64_t pg_time64_infinity = 0x7fffffffffffffff
cdef int64_t pg_time64_negative_infinity = <int64_t>0x8000000000000000
cdef int32_t pg_date_infinity = 0x7fffffff
cdef int32_t pg_date_negative_infinity = <int32_t>0x80000000
infinity_datetime = datetime.datetime(
datetime.MAXYEAR, 12, 31, 23, 59, 59, 999999)
cdef int32_t infinity_datetime_ord = <int32_t>cpython.PyLong_AsLong(
infinity_datetime.toordinal())
cdef int64_t infinity_datetime_ts = 252455615999999999
negative_infinity_datetime = datetime.datetime(
datetime.MINYEAR, 1, 1, 0, 0, 0, 0)
cdef int32_t negative_infinity_datetime_ord = <int32_t>cpython.PyLong_AsLong(
negative_infinity_datetime.toordinal())
cdef int64_t negative_infinity_datetime_ts = -63082281600000000
infinity_date = datetime.date(datetime.MAXYEAR, 12, 31)
cdef int32_t infinity_date_ord = <int32_t>cpython.PyLong_AsLong(
infinity_date.toordinal())
negative_infinity_date = datetime.date(datetime.MINYEAR, 1, 1)
cdef int32_t negative_infinity_date_ord = <int32_t>cpython.PyLong_AsLong(
negative_infinity_date.toordinal())
cdef inline _local_timezone():
d = datetime.datetime.now(datetime.timezone.utc).astimezone()
return datetime.timezone(d.utcoffset())
cdef inline _encode_time(WriteBuffer buf, int64_t seconds,
int32_t microseconds):
# XXX: add support for double timestamps
# int64 timestamps,
cdef int64_t ts = seconds * 1000000 + microseconds
if ts == infinity_datetime_ts:
buf.write_int64(pg_time64_infinity)
elif ts == negative_infinity_datetime_ts:
buf.write_int64(pg_time64_negative_infinity)
else:
buf.write_int64(ts)
cdef inline int32_t _decode_time(FRBuffer *buf, int64_t *seconds,
int32_t *microseconds):
cdef int64_t ts = hton.unpack_int64(frb_read(buf, 8))
if ts == pg_time64_infinity:
return 1
elif ts == pg_time64_negative_infinity:
return -1
else:
seconds[0] = ts // 1000000
microseconds[0] = <int32_t>(ts % 1000000)
return 0
cdef date_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
int32_t ordinal = <int32_t>cpython.PyLong_AsLong(obj.toordinal())
int32_t pg_ordinal
if ordinal == infinity_date_ord:
pg_ordinal = pg_date_infinity
elif ordinal == negative_infinity_date_ord:
pg_ordinal = pg_date_negative_infinity
else:
pg_ordinal = ordinal - pg_date_offset_ord
buf.write_int32(4)
buf.write_int32(pg_ordinal)
cdef date_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
cdef:
int32_t pg_ordinal
if len(obj) != 1:
raise ValueError(
'date tuple encoder: expecting 1 element '
'in tuple, got {}'.format(len(obj)))
pg_ordinal = obj[0]
buf.write_int32(4)
buf.write_int32(pg_ordinal)
cdef date_decode(CodecContext settings, FRBuffer *buf):
cdef int32_t pg_ordinal = hton.unpack_int32(frb_read(buf, 4))
if pg_ordinal == pg_date_infinity:
return infinity_date
elif pg_ordinal == pg_date_negative_infinity:
return negative_infinity_date
else:
return date_from_ordinal(pg_ordinal + pg_date_offset_ord)
cdef date_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef int32_t pg_ordinal = hton.unpack_int32(frb_read(buf, 4))
return (pg_ordinal,)
cdef timestamp_encode(CodecContext settings, WriteBuffer buf, obj):
if not cpython.datetime.PyDateTime_Check(obj):
if cpython.datetime.PyDate_Check(obj):
obj = datetime.datetime(obj.year, obj.month, obj.day)
else:
raise TypeError(
'expected a datetime.date or datetime.datetime instance, '
'got {!r}'.format(type(obj).__name__)
)
delta = obj - pg_epoch_datetime
cdef:
int64_t seconds = cpython.PyLong_AsLongLong(delta.days) * 86400 + \
cpython.PyLong_AsLong(delta.seconds)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(
delta.microseconds)
buf.write_int32(8)
_encode_time(buf, seconds, microseconds)
cdef timestamp_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
cdef:
int64_t microseconds
if len(obj) != 1:
raise ValueError(
'timestamp tuple encoder: expecting 1 element '
'in tuple, got {}'.format(len(obj)))
microseconds = obj[0]
buf.write_int32(8)
buf.write_int64(microseconds)
cdef timestamp_decode(CodecContext settings, FRBuffer *buf):
cdef:
int64_t seconds = 0
int32_t microseconds = 0
int32_t inf = _decode_time(buf, &seconds, &microseconds)
if inf > 0:
# positive infinity
return infinity_datetime
elif inf < 0:
# negative infinity
return negative_infinity_datetime
else:
return pg_epoch_datetime.__add__(
timedelta(0, seconds, microseconds))
cdef timestamp_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef:
int64_t ts = hton.unpack_int64(frb_read(buf, 8))
return (ts,)
cdef timestamptz_encode(CodecContext settings, WriteBuffer buf, obj):
if not cpython.datetime.PyDateTime_Check(obj):
if cpython.datetime.PyDate_Check(obj):
obj = datetime.datetime(obj.year, obj.month, obj.day,
tzinfo=_local_timezone())
else:
raise TypeError(
'expected a datetime.date or datetime.datetime instance, '
'got {!r}'.format(type(obj).__name__)
)
buf.write_int32(8)
if obj == infinity_datetime:
buf.write_int64(pg_time64_infinity)
return
elif obj == negative_infinity_datetime:
buf.write_int64(pg_time64_negative_infinity)
return
utc_dt = obj.astimezone(utc)
delta = utc_dt - pg_epoch_datetime_utc
cdef:
int64_t seconds = cpython.PyLong_AsLongLong(delta.days) * 86400 + \
cpython.PyLong_AsLong(delta.seconds)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(
delta.microseconds)
_encode_time(buf, seconds, microseconds)
cdef timestamptz_decode(CodecContext settings, FRBuffer *buf):
cdef:
int64_t seconds = 0
int32_t microseconds = 0
int32_t inf = _decode_time(buf, &seconds, &microseconds)
if inf > 0:
# positive infinity
return infinity_datetime
elif inf < 0:
# negative infinity
return negative_infinity_datetime
else:
return pg_epoch_datetime_utc.__add__(
timedelta(0, seconds, microseconds))
cdef time_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
int64_t seconds = cpython.PyLong_AsLong(obj.hour) * 3600 + \
cpython.PyLong_AsLong(obj.minute) * 60 + \
cpython.PyLong_AsLong(obj.second)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microsecond)
buf.write_int32(8)
_encode_time(buf, seconds, microseconds)
cdef time_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
cdef:
int64_t microseconds
if len(obj) != 1:
raise ValueError(
'time tuple encoder: expecting 1 element '
'in tuple, got {}'.format(len(obj)))
microseconds = obj[0]
buf.write_int32(8)
buf.write_int64(microseconds)
cdef time_decode(CodecContext settings, FRBuffer *buf):
cdef:
int64_t seconds = 0
int32_t microseconds = 0
_decode_time(buf, &seconds, &microseconds)
cdef:
int64_t minutes = <int64_t>(seconds / 60)
int64_t sec = seconds % 60
int64_t hours = <int64_t>(minutes / 60)
int64_t min = minutes % 60
return datetime.time(hours, min, sec, microseconds)
cdef time_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef:
int64_t ts = hton.unpack_int64(frb_read(buf, 8))
return (ts,)
cdef timetz_encode(CodecContext settings, WriteBuffer buf, obj):
offset = obj.tzinfo.utcoffset(None)
cdef:
int32_t offset_sec = \
<int32_t>cpython.PyLong_AsLong(offset.days) * 24 * 60 * 60 + \
<int32_t>cpython.PyLong_AsLong(offset.seconds)
int64_t seconds = cpython.PyLong_AsLong(obj.hour) * 3600 + \
cpython.PyLong_AsLong(obj.minute) * 60 + \
cpython.PyLong_AsLong(obj.second)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microsecond)
buf.write_int32(12)
_encode_time(buf, seconds, microseconds)
# In Python utcoffset() is the difference between the local time
# and the UTC, whereas in PostgreSQL it's the opposite,
# so we need to flip the sign.
buf.write_int32(-offset_sec)
cdef timetz_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
cdef:
int64_t microseconds
int32_t offset_sec
if len(obj) != 2:
raise ValueError(
'time tuple encoder: expecting 2 elements2 '
'in tuple, got {}'.format(len(obj)))
microseconds = obj[0]
offset_sec = obj[1]
buf.write_int32(12)
buf.write_int64(microseconds)
buf.write_int32(offset_sec)
cdef timetz_decode(CodecContext settings, FRBuffer *buf):
time = time_decode(settings, buf)
cdef int32_t offset = <int32_t>(hton.unpack_int32(frb_read(buf, 4)) / 60)
# See the comment in the `timetz_encode` method.
return time.replace(tzinfo=datetime.timezone(timedelta(minutes=-offset)))
cdef timetz_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef:
int64_t microseconds = hton.unpack_int64(frb_read(buf, 8))
int32_t offset_sec = hton.unpack_int32(frb_read(buf, 4))
return (microseconds, offset_sec)
cdef interval_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
int32_t days = <int32_t>cpython.PyLong_AsLong(obj.days)
int64_t seconds = cpython.PyLong_AsLongLong(obj.seconds)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microseconds)
buf.write_int32(16)
_encode_time(buf, seconds, microseconds)
buf.write_int32(days)
buf.write_int32(0) # Months
cdef interval_encode_tuple(CodecContext settings, WriteBuffer buf,
tuple obj):
cdef:
int32_t months
int32_t days
int64_t microseconds
if len(obj) != 3:
raise ValueError(
'interval tuple encoder: expecting 3 elements '
'in tuple, got {}'.format(len(obj)))
months = obj[0]
days = obj[1]
microseconds = obj[2]
buf.write_int32(16)
buf.write_int64(microseconds)
buf.write_int32(days)
buf.write_int32(months)
cdef interval_decode(CodecContext settings, FRBuffer *buf):
cdef:
int32_t days
int32_t months
int32_t years
int64_t seconds = 0
int32_t microseconds = 0
_decode_time(buf, &seconds, &microseconds)
days = hton.unpack_int32(frb_read(buf, 4))
months = hton.unpack_int32(frb_read(buf, 4))
if months < 0:
years = -<int32_t>(-months // 12)
months = -<int32_t>(-months % 12)
else:
years = <int32_t>(months // 12)
months = <int32_t>(months % 12)
return datetime.timedelta(days=days + months * 30 + years * 365,
seconds=seconds, microseconds=microseconds)
cdef interval_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef:
int32_t days
int32_t months
int64_t microseconds
microseconds = hton.unpack_int64(frb_read(buf, 8))
days = hton.unpack_int32(frb_read(buf, 4))
months = hton.unpack_int32(frb_read(buf, 4))
return (months, days, microseconds)

View File

@@ -0,0 +1,34 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc cimport math
cdef float4_encode(CodecContext settings, WriteBuffer buf, obj):
cdef double dval = cpython.PyFloat_AsDouble(obj)
cdef float fval = <float>dval
if math.isinf(fval) and not math.isinf(dval):
raise ValueError('value out of float32 range')
buf.write_int32(4)
buf.write_float(fval)
cdef float4_decode(CodecContext settings, FRBuffer *buf):
cdef float f = hton.unpack_float(frb_read(buf, 4))
return cpython.PyFloat_FromDouble(f)
cdef float8_encode(CodecContext settings, WriteBuffer buf, obj):
cdef double dval = cpython.PyFloat_AsDouble(obj)
buf.write_int32(8)
buf.write_double(dval)
cdef float8_decode(CodecContext settings, FRBuffer *buf):
cdef double f = hton.unpack_double(frb_read(buf, 8))
return cpython.PyFloat_FromDouble(f)

View File

@@ -0,0 +1,164 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef inline _encode_points(WriteBuffer wbuf, object points):
cdef object point
for point in points:
wbuf.write_double(point[0])
wbuf.write_double(point[1])
cdef inline _decode_points(FRBuffer *buf):
cdef:
int32_t npts = hton.unpack_int32(frb_read(buf, 4))
pts = cpython.PyTuple_New(npts)
int32_t i
object point
double x
double y
for i in range(npts):
x = hton.unpack_double(frb_read(buf, 8))
y = hton.unpack_double(frb_read(buf, 8))
point = pgproto_types.Point(x, y)
cpython.Py_INCREF(point)
cpython.PyTuple_SET_ITEM(pts, i, point)
return pts
cdef box_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(32)
_encode_points(wbuf, (obj[0], obj[1]))
cdef box_decode(CodecContext settings, FRBuffer *buf):
cdef:
double high_x = hton.unpack_double(frb_read(buf, 8))
double high_y = hton.unpack_double(frb_read(buf, 8))
double low_x = hton.unpack_double(frb_read(buf, 8))
double low_y = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.Box(
pgproto_types.Point(high_x, high_y),
pgproto_types.Point(low_x, low_y))
cdef line_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(24)
wbuf.write_double(obj[0])
wbuf.write_double(obj[1])
wbuf.write_double(obj[2])
cdef line_decode(CodecContext settings, FRBuffer *buf):
cdef:
double A = hton.unpack_double(frb_read(buf, 8))
double B = hton.unpack_double(frb_read(buf, 8))
double C = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.Line(A, B, C)
cdef lseg_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(32)
_encode_points(wbuf, (obj[0], obj[1]))
cdef lseg_decode(CodecContext settings, FRBuffer *buf):
cdef:
double p1_x = hton.unpack_double(frb_read(buf, 8))
double p1_y = hton.unpack_double(frb_read(buf, 8))
double p2_x = hton.unpack_double(frb_read(buf, 8))
double p2_y = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.LineSegment((p1_x, p1_y), (p2_x, p2_y))
cdef point_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(16)
wbuf.write_double(obj[0])
wbuf.write_double(obj[1])
cdef point_decode(CodecContext settings, FRBuffer *buf):
cdef:
double x = hton.unpack_double(frb_read(buf, 8))
double y = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.Point(x, y)
cdef path_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
int8_t is_closed = 0
ssize_t npts
ssize_t encoded_len
int32_t i
if cpython.PyTuple_Check(obj):
is_closed = 1
elif cpython.PyList_Check(obj):
is_closed = 0
elif isinstance(obj, pgproto_types.Path):
is_closed = obj.is_closed
npts = len(obj)
encoded_len = 1 + 4 + 16 * npts
if encoded_len > _MAXINT32:
raise ValueError('path value too long')
wbuf.write_int32(<int32_t>encoded_len)
wbuf.write_byte(is_closed)
wbuf.write_int32(<int32_t>npts)
_encode_points(wbuf, obj)
cdef path_decode(CodecContext settings, FRBuffer *buf):
cdef:
int8_t is_closed = <int8_t>(frb_read(buf, 1)[0])
return pgproto_types.Path(*_decode_points(buf), is_closed=is_closed == 1)
cdef poly_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
bint is_closed
ssize_t npts
ssize_t encoded_len
int32_t i
npts = len(obj)
encoded_len = 4 + 16 * npts
if encoded_len > _MAXINT32:
raise ValueError('polygon value too long')
wbuf.write_int32(<int32_t>encoded_len)
wbuf.write_int32(<int32_t>npts)
_encode_points(wbuf, obj)
cdef poly_decode(CodecContext settings, FRBuffer *buf):
return pgproto_types.Polygon(*_decode_points(buf))
cdef circle_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(24)
wbuf.write_double(obj[0][0])
wbuf.write_double(obj[0][1])
wbuf.write_double(obj[1])
cdef circle_decode(CodecContext settings, FRBuffer *buf):
cdef:
double center_x = hton.unpack_double(frb_read(buf, 8))
double center_y = hton.unpack_double(frb_read(buf, 8))
double radius = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.Circle((center_x, center_y), radius)

View File

@@ -0,0 +1,73 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef hstore_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
ssize_t count
object items
WriteBuffer item_buf = WriteBuffer.new()
count = len(obj)
if count > _MAXINT32:
raise ValueError('hstore value is too large')
item_buf.write_int32(<int32_t>count)
if hasattr(obj, 'items'):
items = obj.items()
else:
items = obj
for k, v in items:
if k is None:
raise ValueError('null value not allowed in hstore key')
as_pg_string_and_size(settings, k, &str, &size)
item_buf.write_int32(<int32_t>size)
item_buf.write_cstr(str, size)
if v is None:
item_buf.write_int32(<int32_t>-1)
else:
as_pg_string_and_size(settings, v, &str, &size)
item_buf.write_int32(<int32_t>size)
item_buf.write_cstr(str, size)
buf.write_int32(item_buf.len())
buf.write_buffer(item_buf)
cdef hstore_decode(CodecContext settings, FRBuffer *buf):
cdef:
dict result
uint32_t elem_count
int32_t elem_len
uint32_t i
str k
str v
result = {}
elem_count = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
if elem_count == 0:
return result
for i in range(elem_count):
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len < 0:
raise ValueError('null value not allowed in hstore key')
k = decode_pg_string(settings, frb_read(buf, elem_len), elem_len)
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len < 0:
v = None
else:
v = decode_pg_string(settings, frb_read(buf, elem_len), elem_len)
result[k] = v
return result

View File

@@ -0,0 +1,144 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef bool_encode(CodecContext settings, WriteBuffer buf, obj):
if not cpython.PyBool_Check(obj):
raise TypeError('a boolean is required (got type {})'.format(
type(obj).__name__))
buf.write_int32(1)
buf.write_byte(b'\x01' if obj is True else b'\x00')
cdef bool_decode(CodecContext settings, FRBuffer *buf):
return frb_read(buf, 1)[0] is b'\x01'
cdef int2_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef long val
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsLong(obj)
except OverflowError:
overflow = 1
if overflow or val < INT16_MIN or val > INT16_MAX:
raise OverflowError('value out of int16 range')
buf.write_int32(2)
buf.write_int16(<int16_t>val)
cdef int2_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromLong(hton.unpack_int16(frb_read(buf, 2)))
cdef int4_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef long val = 0
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsLong(obj)
except OverflowError:
overflow = 1
# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(val) > 4 and (val < INT32_MIN or val > INT32_MAX)):
raise OverflowError('value out of int32 range')
buf.write_int32(4)
buf.write_int32(<int32_t>val)
cdef int4_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromLong(hton.unpack_int32(frb_read(buf, 4)))
cdef uint4_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef unsigned long val = 0
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsUnsignedLong(obj)
except OverflowError:
overflow = 1
# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(val) > 4 and val > UINT32_MAX):
raise OverflowError('value out of uint32 range')
buf.write_int32(4)
buf.write_int32(<int32_t>val)
cdef uint4_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromUnsignedLong(
<uint32_t>hton.unpack_int32(frb_read(buf, 4)))
cdef int8_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef long long val
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsLongLong(obj)
except OverflowError:
overflow = 1
# Just in case for systems with "long long" bigger than 8 bytes
if overflow or (sizeof(val) > 8 and (val < INT64_MIN or val > INT64_MAX)):
raise OverflowError('value out of int64 range')
buf.write_int32(8)
buf.write_int64(<int64_t>val)
cdef int8_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromLongLong(hton.unpack_int64(frb_read(buf, 8)))
cdef uint8_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef unsigned long long val = 0
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsUnsignedLongLong(obj)
except OverflowError:
overflow = 1
# Just in case for systems with "long long" bigger than 8 bytes
if overflow or (sizeof(val) > 8 and val > UINT64_MAX):
raise OverflowError('value out of uint64 range')
buf.write_int32(8)
buf.write_int64(<int64_t>val)
cdef uint8_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromUnsignedLongLong(
<uint64_t>hton.unpack_int64(frb_read(buf, 8)))

View File

@@ -0,0 +1,57 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef jsonb_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
if settings.is_encoding_json():
obj = settings.get_json_encoder().encode(obj)
as_pg_string_and_size(settings, obj, &str, &size)
if size > 0x7fffffff - 1:
raise ValueError('string too long')
buf.write_int32(<int32_t>size + 1)
buf.write_byte(1) # JSONB format version
buf.write_cstr(str, size)
cdef jsonb_decode(CodecContext settings, FRBuffer *buf):
cdef uint8_t format = <uint8_t>(frb_read(buf, 1)[0])
if format != 1:
raise ValueError('unexpected JSONB format: {}'.format(format))
rv = text_decode(settings, buf)
if settings.is_decoding_json():
rv = settings.get_json_decoder().decode(rv)
return rv
cdef json_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
if settings.is_encoding_json():
obj = settings.get_json_encoder().encode(obj)
text_encode(settings, buf, obj)
cdef json_decode(CodecContext settings, FRBuffer *buf):
rv = text_decode(settings, buf)
if settings.is_decoding_json():
rv = settings.get_json_decoder().decode(rv)
return rv

View File

@@ -0,0 +1,29 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef jsonpath_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
as_pg_string_and_size(settings, obj, &str, &size)
if size > 0x7fffffff - 1:
raise ValueError('string too long')
buf.write_int32(<int32_t>size + 1)
buf.write_byte(1) # jsonpath format version
buf.write_cstr(str, size)
cdef jsonpath_decode(CodecContext settings, FRBuffer *buf):
cdef uint8_t format = <uint8_t>(frb_read(buf, 1)[0])
if format != 1:
raise ValueError('unexpected jsonpath format: {}'.format(format))
return text_decode(settings, buf)

View File

@@ -0,0 +1,16 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef void_encode(CodecContext settings, WriteBuffer buf, obj):
# Void is zero bytes
buf.write_int32(0)
cdef void_decode(CodecContext settings, FRBuffer *buf):
# Do nothing; void will be passed as NULL so this function
# will never be called.
pass

View File

@@ -0,0 +1,139 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import ipaddress
# defined in postgresql/src/include/inet.h
#
DEF PGSQL_AF_INET = 2 # AF_INET
DEF PGSQL_AF_INET6 = 3 # AF_INET + 1
_ipaddr = ipaddress.ip_address
_ipiface = ipaddress.ip_interface
_ipnet = ipaddress.ip_network
cdef inline uint8_t _ip_max_prefix_len(int32_t family):
# Maximum number of bits in the network prefix of the specified
# IP protocol version.
if family == PGSQL_AF_INET:
return 32
else:
return 128
cdef inline int32_t _ip_addr_len(int32_t family):
# Length of address in bytes for the specified IP protocol version.
if family == PGSQL_AF_INET:
return 4
else:
return 16
cdef inline int8_t _ver_to_family(int32_t version):
if version == 4:
return PGSQL_AF_INET
else:
return PGSQL_AF_INET6
cdef inline _net_encode(WriteBuffer buf, int8_t family, uint32_t bits,
int8_t is_cidr, bytes addr):
cdef:
char *addrbytes
ssize_t addrlen
cpython.PyBytes_AsStringAndSize(addr, &addrbytes, &addrlen)
buf.write_int32(4 + <int32_t>addrlen)
buf.write_byte(family)
buf.write_byte(<int8_t>bits)
buf.write_byte(is_cidr)
buf.write_byte(<int8_t>addrlen)
buf.write_cstr(addrbytes, addrlen)
cdef net_decode(CodecContext settings, FRBuffer *buf, bint as_cidr):
cdef:
int32_t family = <int32_t>frb_read(buf, 1)[0]
uint8_t bits = <uint8_t>frb_read(buf, 1)[0]
int prefix_len
int32_t is_cidr = <int32_t>frb_read(buf, 1)[0]
int32_t addrlen = <int32_t>frb_read(buf, 1)[0]
bytes addr
uint8_t max_prefix_len = _ip_max_prefix_len(family)
if is_cidr != as_cidr:
raise ValueError('unexpected CIDR flag set in non-cidr value')
if family != PGSQL_AF_INET and family != PGSQL_AF_INET6:
raise ValueError('invalid address family in "{}" value'.format(
'cidr' if is_cidr else 'inet'
))
max_prefix_len = _ip_max_prefix_len(family)
if bits > max_prefix_len:
raise ValueError('invalid network prefix length in "{}" value'.format(
'cidr' if is_cidr else 'inet'
))
if addrlen != _ip_addr_len(family):
raise ValueError('invalid address length in "{}" value'.format(
'cidr' if is_cidr else 'inet'
))
addr = cpython.PyBytes_FromStringAndSize(frb_read(buf, addrlen), addrlen)
if as_cidr or bits != max_prefix_len:
prefix_len = cpython.PyLong_FromLong(bits)
if as_cidr:
return _ipnet((addr, prefix_len))
else:
return _ipiface((addr, prefix_len))
else:
return _ipaddr(addr)
cdef cidr_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
object ipnet
int8_t family
ipnet = _ipnet(obj)
family = _ver_to_family(ipnet.version)
_net_encode(buf, family, ipnet.prefixlen, 1, ipnet.network_address.packed)
cdef cidr_decode(CodecContext settings, FRBuffer *buf):
return net_decode(settings, buf, True)
cdef inet_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
object ipaddr
int8_t family
try:
ipaddr = _ipaddr(obj)
except ValueError:
# PostgreSQL accepts *both* CIDR and host values
# for the host datatype.
ipaddr = _ipiface(obj)
family = _ver_to_family(ipaddr.version)
_net_encode(buf, family, ipaddr.network.prefixlen, 1, ipaddr.packed)
else:
family = _ver_to_family(ipaddr.version)
_net_encode(buf, family, _ip_max_prefix_len(family), 0, ipaddr.packed)
cdef inet_decode(CodecContext settings, FRBuffer *buf):
return net_decode(settings, buf, False)

View File

@@ -0,0 +1,356 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc.math cimport abs, log10
from libc.stdio cimport snprintf
import decimal
# defined in postgresql/src/backend/utils/adt/numeric.c
DEF DEC_DIGITS = 4
DEF MAX_DSCALE = 0x3FFF
DEF NUMERIC_POS = 0x0000
DEF NUMERIC_NEG = 0x4000
DEF NUMERIC_NAN = 0xC000
DEF NUMERIC_PINF = 0xD000
DEF NUMERIC_NINF = 0xF000
_Dec = decimal.Decimal
cdef numeric_encode_text(CodecContext settings, WriteBuffer buf, obj):
text_encode(settings, buf, str(obj))
cdef numeric_decode_text(CodecContext settings, FRBuffer *buf):
return _Dec(text_decode(settings, buf))
cdef numeric_encode_binary(CodecContext settings, WriteBuffer buf, obj):
cdef:
object dec
object dt
int64_t exponent
int64_t i
int64_t j
tuple pydigits
int64_t num_pydigits
int16_t pgdigit
int64_t num_pgdigits
int16_t dscale
int64_t dweight
int64_t weight
uint16_t sign
int64_t padding_size = 0
if isinstance(obj, _Dec):
dec = obj
else:
dec = _Dec(obj)
dt = dec.as_tuple()
if dt.exponent == 'n' or dt.exponent == 'N':
# NaN
sign = NUMERIC_NAN
num_pgdigits = 0
weight = 0
dscale = 0
elif dt.exponent == 'F':
# Infinity
if dt.sign:
sign = NUMERIC_NINF
else:
sign = NUMERIC_PINF
num_pgdigits = 0
weight = 0
dscale = 0
else:
exponent = dt.exponent
if exponent < 0 and -exponent > MAX_DSCALE:
raise ValueError(
'cannot encode Decimal value into numeric: '
'exponent is too small')
if dt.sign:
sign = NUMERIC_NEG
else:
sign = NUMERIC_POS
pydigits = dt.digits
num_pydigits = len(pydigits)
dweight = num_pydigits + exponent - 1
if dweight >= 0:
weight = (dweight + DEC_DIGITS) // DEC_DIGITS - 1
else:
weight = -((-dweight - 1) // DEC_DIGITS + 1)
if weight > 2 ** 16 - 1:
raise ValueError(
'cannot encode Decimal value into numeric: '
'exponent is too large')
padding_size = \
(weight + 1) * DEC_DIGITS - (dweight + 1)
num_pgdigits = \
(num_pydigits + padding_size + DEC_DIGITS - 1) // DEC_DIGITS
if num_pgdigits > 2 ** 16 - 1:
raise ValueError(
'cannot encode Decimal value into numeric: '
'number of digits is too large')
# Pad decimal digits to provide room for correct Postgres
# digit alignment in the digit computation loop.
pydigits = (0,) * DEC_DIGITS + pydigits + (0,) * DEC_DIGITS
if exponent < 0:
if -exponent > MAX_DSCALE:
raise ValueError(
'cannot encode Decimal value into numeric: '
'exponent is too small')
dscale = <int16_t>-exponent
else:
dscale = 0
buf.write_int32(2 + 2 + 2 + 2 + 2 * <uint16_t>num_pgdigits)
buf.write_int16(<int16_t>num_pgdigits)
buf.write_int16(<int16_t>weight)
buf.write_int16(<int16_t>sign)
buf.write_int16(dscale)
j = DEC_DIGITS - padding_size
for i in range(num_pgdigits):
pgdigit = (pydigits[j] * 1000 + pydigits[j + 1] * 100 +
pydigits[j + 2] * 10 + pydigits[j + 3])
j += DEC_DIGITS
buf.write_int16(pgdigit)
# The decoding strategy here is to form a string representation of
# the numeric var, as it is faster than passing an iterable of digits.
# For this reason the below code is pure overhead and is ~25% slower
# than the simple text decoder above. That said, we need the binary
# decoder to support binary COPY with numeric values.
cdef numeric_decode_binary_ex(
CodecContext settings,
FRBuffer *buf,
bint trail_fract_zero,
):
cdef:
uint16_t num_pgdigits = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
int16_t weight = hton.unpack_int16(frb_read(buf, 2))
uint16_t sign = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
uint16_t dscale = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
int16_t pgdigit0
ssize_t i
int16_t pgdigit
object pydigits
ssize_t num_pydigits
ssize_t actual_num_pydigits
ssize_t buf_size
int64_t exponent
int64_t abs_exponent
ssize_t exponent_chars
ssize_t front_padding = 0
ssize_t num_fract_digits
ssize_t trailing_fract_zeros_adj
char smallbuf[_NUMERIC_DECODER_SMALLBUF_SIZE]
char *charbuf
char *bufptr
bint buf_allocated = False
if sign == NUMERIC_NAN:
# Not-a-number
return _Dec('NaN')
elif sign == NUMERIC_PINF:
# +Infinity
return _Dec('Infinity')
elif sign == NUMERIC_NINF:
# -Infinity
return _Dec('-Infinity')
if num_pgdigits == 0:
# Zero
return _Dec('0e-' + str(dscale))
pgdigit0 = hton.unpack_int16(frb_read(buf, 2))
if weight >= 0:
if pgdigit0 < 10:
front_padding = 3
elif pgdigit0 < 100:
front_padding = 2
elif pgdigit0 < 1000:
front_padding = 1
# The number of fractional decimal digits actually encoded in
# base-DEC_DEIGITS digits sent by Postgres.
num_fract_digits = (num_pgdigits - weight - 1) * DEC_DIGITS
# The trailing zero adjustment necessary to obtain exactly
# dscale number of fractional digits in output. May be negative,
# which indicates that trailing zeros in the last input digit
# should be discarded.
trailing_fract_zeros_adj = dscale - num_fract_digits
# Maximum possible number of decimal digits in base 10.
# The actual number might be up to 3 digits smaller due to
# leading zeros in first input digit.
num_pydigits = num_pgdigits * DEC_DIGITS
if trailing_fract_zeros_adj > 0:
num_pydigits += trailing_fract_zeros_adj
# Exponent.
exponent = (weight + 1) * DEC_DIGITS - front_padding
abs_exponent = abs(exponent)
if abs_exponent != 0:
# Number of characters required to render absolute exponent value
# in decimal.
exponent_chars = <ssize_t>log10(<double>abs_exponent) + 1
else:
exponent_chars = 0
# Output buffer size.
buf_size = (
1 + # sign
1 + # leading zero
1 + # decimal dot
num_pydigits + # digits
1 + # possible trailing zero padding
2 + # exponent indicator (E-,E+)
exponent_chars + # exponent
1 # null terminator char
)
if buf_size > _NUMERIC_DECODER_SMALLBUF_SIZE:
charbuf = <char *>cpython.PyMem_Malloc(<size_t>buf_size)
buf_allocated = True
else:
charbuf = smallbuf
try:
bufptr = charbuf
if sign == NUMERIC_NEG:
bufptr[0] = b'-'
bufptr += 1
bufptr[0] = b'0'
bufptr[1] = b'.'
bufptr += 2
if weight >= 0:
bufptr = _unpack_digit_stripping_lzeros(bufptr, pgdigit0)
else:
bufptr = _unpack_digit(bufptr, pgdigit0)
for i in range(1, num_pgdigits):
pgdigit = hton.unpack_int16(frb_read(buf, 2))
bufptr = _unpack_digit(bufptr, pgdigit)
if dscale:
if trailing_fract_zeros_adj > 0:
for i in range(trailing_fract_zeros_adj):
bufptr[i] = <char>b'0'
# If display scale is _less_ than the number of rendered digits,
# trailing_fract_zeros_adj will be negative and this will strip
# the excess trailing zeros.
bufptr += trailing_fract_zeros_adj
if trail_fract_zero:
# Check if the number of rendered digits matches the exponent,
# and if so, add another trailing zero, so the result always
# appears with a decimal point.
actual_num_pydigits = bufptr - charbuf - 2
if sign == NUMERIC_NEG:
actual_num_pydigits -= 1
if actual_num_pydigits == abs_exponent:
bufptr[0] = <char>b'0'
bufptr += 1
if exponent != 0:
bufptr[0] = b'E'
if exponent < 0:
bufptr[1] = b'-'
else:
bufptr[1] = b'+'
bufptr += 2
snprintf(bufptr, <size_t>exponent_chars + 1, '%d',
<int>abs_exponent)
bufptr += exponent_chars
bufptr[0] = 0
pydigits = cpythonx.PyUnicode_FromString(charbuf)
return _Dec(pydigits)
finally:
if buf_allocated:
cpython.PyMem_Free(charbuf)
cdef numeric_decode_binary(CodecContext settings, FRBuffer *buf):
return numeric_decode_binary_ex(settings, buf, False)
cdef inline char *_unpack_digit_stripping_lzeros(char *buf, int64_t pgdigit):
cdef:
int64_t d
bint significant
d = pgdigit // 1000
significant = (d > 0)
if significant:
pgdigit -= d * 1000
buf[0] = <char>(d + <int32_t>b'0')
buf += 1
d = pgdigit // 100
significant |= (d > 0)
if significant:
pgdigit -= d * 100
buf[0] = <char>(d + <int32_t>b'0')
buf += 1
d = pgdigit // 10
significant |= (d > 0)
if significant:
pgdigit -= d * 10
buf[0] = <char>(d + <int32_t>b'0')
buf += 1
buf[0] = <char>(pgdigit + <int32_t>b'0')
buf += 1
return buf
cdef inline char *_unpack_digit(char *buf, int64_t pgdigit):
cdef:
int64_t d
d = pgdigit // 1000
pgdigit -= d * 1000
buf[0] = <char>(d + <int32_t>b'0')
d = pgdigit // 100
pgdigit -= d * 100
buf[1] = <char>(d + <int32_t>b'0')
d = pgdigit // 10
pgdigit -= d * 10
buf[2] = <char>(d + <int32_t>b'0')
buf[3] = <char>(pgdigit + <int32_t>b'0')
buf += 4
return buf

View File

@@ -0,0 +1,63 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef pg_snapshot_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
ssize_t nxip
uint64_t xmin
uint64_t xmax
int i
WriteBuffer xip_buf = WriteBuffer.new()
if not (cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj)):
raise TypeError(
'list or tuple expected (got type {})'.format(type(obj)))
if len(obj) != 3:
raise ValueError(
'invalid number of elements in txid_snapshot tuple, expecting 4')
nxip = len(obj[2])
if nxip > _MAXINT32:
raise ValueError('txid_snapshot value is too long')
xmin = obj[0]
xmax = obj[1]
for i in range(nxip):
xip_buf.write_int64(
<int64_t>cpython.PyLong_AsUnsignedLongLong(obj[2][i]))
buf.write_int32(20 + xip_buf.len())
buf.write_int32(<int32_t>nxip)
buf.write_int64(<int64_t>xmin)
buf.write_int64(<int64_t>xmax)
buf.write_buffer(xip_buf)
cdef pg_snapshot_decode(CodecContext settings, FRBuffer *buf):
cdef:
int32_t nxip
uint64_t xmin
uint64_t xmax
tuple xip_tup
int32_t i
object xip
nxip = hton.unpack_int32(frb_read(buf, 4))
xmin = <uint64_t>hton.unpack_int64(frb_read(buf, 8))
xmax = <uint64_t>hton.unpack_int64(frb_read(buf, 8))
xip_tup = cpython.PyTuple_New(nxip)
for i in range(nxip):
xip = cpython.PyLong_FromUnsignedLongLong(
<uint64_t>hton.unpack_int64(frb_read(buf, 8)))
cpython.Py_INCREF(xip)
cpython.PyTuple_SET_ITEM(xip_tup, i, xip)
return (xmin, xmax, xip_tup)

View File

@@ -0,0 +1,48 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef inline as_pg_string_and_size(
CodecContext settings, obj, char **cstr, ssize_t *size):
if not cpython.PyUnicode_Check(obj):
raise TypeError('expected str, got {}'.format(type(obj).__name__))
if settings.is_encoding_utf8():
cstr[0] = <char*>cpythonx.PyUnicode_AsUTF8AndSize(obj, size)
else:
encoded = settings.get_text_codec().encode(obj)[0]
cpython.PyBytes_AsStringAndSize(encoded, cstr, size)
if size[0] > 0x7fffffff:
raise ValueError('string too long')
cdef text_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
as_pg_string_and_size(settings, obj, &str, &size)
buf.write_int32(<int32_t>size)
buf.write_cstr(str, size)
cdef inline decode_pg_string(CodecContext settings, const char* data,
ssize_t len):
if settings.is_encoding_utf8():
# decode UTF-8 in strict mode
return cpython.PyUnicode_DecodeUTF8(data, len, NULL)
else:
bytes = cpython.PyBytes_FromStringAndSize(data, len)
return settings.get_text_codec().decode(bytes)[0]
cdef text_decode(CodecContext settings, FRBuffer *buf):
cdef ssize_t buf_len = buf.len
return decode_pg_string(settings, frb_read_all(buf), buf_len)

View File

@@ -0,0 +1,51 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef tid_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef unsigned long block, offset
if not (cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj)):
raise TypeError(
'list or tuple expected (got type {})'.format(type(obj)))
if len(obj) != 2:
raise ValueError(
'invalid number of elements in tid tuple, expecting 2')
try:
block = cpython.PyLong_AsUnsignedLong(obj[0])
except OverflowError:
overflow = 1
# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(block) > 4 and block > UINT32_MAX):
raise OverflowError('tuple id block value out of uint32 range')
try:
offset = cpython.PyLong_AsUnsignedLong(obj[1])
overflow = 0
except OverflowError:
overflow = 1
if overflow or offset > 65535:
raise OverflowError('tuple id offset value out of uint16 range')
buf.write_int32(6)
buf.write_int32(<int32_t>block)
buf.write_int16(<int16_t>offset)
cdef tid_decode(CodecContext settings, FRBuffer *buf):
cdef:
uint32_t block
uint16_t offset
block = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
offset = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
return (block, offset)

View File

@@ -0,0 +1,27 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef uuid_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
char buf[16]
if type(obj) is pg_UUID:
wbuf.write_int32(<int32_t>16)
wbuf.write_cstr((<UUID>obj)._data, 16)
elif cpython.PyUnicode_Check(obj):
pg_uuid_bytes_from_str(obj, buf)
wbuf.write_int32(<int32_t>16)
wbuf.write_cstr(buf, 16)
else:
bytea_encode(settings, wbuf, obj.bytes)
cdef uuid_decode(CodecContext settings, FRBuffer *buf):
if buf.len != 16:
raise TypeError(
f'cannot decode UUID, expected 16 bytes, got {buf.len}')
return pg_uuid_from_buf(frb_read_all(buf))

View File

@@ -0,0 +1,12 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
DEF _BUFFER_INITIAL_SIZE = 1024
DEF _BUFFER_MAX_GROW = 65536
DEF _BUFFER_FREELIST_SIZE = 256
DEF _MAXINT32 = 2**31 - 1
DEF _NUMERIC_DECODER_SMALLBUF_SIZE = 256

View File

@@ -0,0 +1,23 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from cpython cimport Py_buffer
cdef extern from "Python.h":
int PyUnicode_1BYTE_KIND
int PyByteArray_CheckExact(object)
int PyByteArray_Resize(object, ssize_t) except -1
object PyByteArray_FromStringAndSize(const char *, ssize_t)
char* PyByteArray_AsString(object)
object PyUnicode_FromString(const char *u)
const char* PyUnicode_AsUTF8AndSize(
object unicode, ssize_t *size) except NULL
object PyUnicode_FromKindAndData(
int kind, const void *buffer, Py_ssize_t size)

View File

@@ -0,0 +1,10 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef extern from "debug.h":
cdef int PG_DEBUG

View File

@@ -0,0 +1,48 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef:
struct FRBuffer:
const char* buf
ssize_t len
inline ssize_t frb_get_len(FRBuffer *frb):
return frb.len
inline void frb_set_len(FRBuffer *frb, ssize_t new_len):
frb.len = new_len
inline void frb_init(FRBuffer *frb, const char *buf, ssize_t len):
frb.buf = buf
frb.len = len
inline const char* frb_read(FRBuffer *frb, ssize_t n) except NULL:
cdef const char *result
frb_check(frb, n)
result = frb.buf
frb.buf += n
frb.len -= n
return result
inline const char* frb_read_all(FRBuffer *frb):
cdef const char *result
result = frb.buf
frb.buf += frb.len
frb.len = 0
return result
inline FRBuffer *frb_slice_from(FRBuffer *frb,
FRBuffer* source, ssize_t len):
frb.buf = frb_read(source, len)
frb.len = len
return frb
object frb_check(FRBuffer *frb, ssize_t n)

View File

@@ -0,0 +1,12 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef object frb_check(FRBuffer *frb, ssize_t n):
if n > frb.len:
raise AssertionError(
f'insufficient data in buffer: requested {n} '
f'remaining {frb.len}')

View File

@@ -0,0 +1,24 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc.stdint cimport int16_t, int32_t, uint16_t, uint32_t, int64_t, uint64_t
cdef extern from "./hton.h":
cdef void pack_int16(char *buf, int16_t x);
cdef void pack_int32(char *buf, int32_t x);
cdef void pack_int64(char *buf, int64_t x);
cdef void pack_float(char *buf, float f);
cdef void pack_double(char *buf, double f);
cdef int16_t unpack_int16(const char *buf);
cdef uint16_t unpack_uint16(const char *buf);
cdef int32_t unpack_int32(const char *buf);
cdef uint32_t unpack_uint32(const char *buf);
cdef int64_t unpack_int64(const char *buf);
cdef uint64_t unpack_uint64(const char *buf);
cdef float unpack_float(const char *buf);
cdef double unpack_double(const char *buf);

View File

@@ -0,0 +1,19 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cimport cython
cimport cpython
from libc.stdint cimport int16_t, int32_t, uint16_t, uint32_t, int64_t, uint64_t
include "./consts.pxi"
include "./frb.pxd"
include "./buffer.pxd"
include "./codecs/__init__.pxd"

View File

@@ -0,0 +1,49 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cimport cython
cimport cpython
from . cimport cpythonx
from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \
int32_t, uint32_t, int64_t, uint64_t, \
INT16_MIN, INT16_MAX, INT32_MIN, INT32_MAX, \
UINT32_MAX, INT64_MIN, INT64_MAX, UINT64_MAX
from . cimport hton
from . cimport tohex
from .debug cimport PG_DEBUG
from . import types as pgproto_types
include "./consts.pxi"
include "./frb.pyx"
include "./buffer.pyx"
include "./uuid.pyx"
include "./codecs/context.pyx"
include "./codecs/bytea.pyx"
include "./codecs/text.pyx"
include "./codecs/datetime.pyx"
include "./codecs/float.pyx"
include "./codecs/int.pyx"
include "./codecs/json.pyx"
include "./codecs/jsonpath.pyx"
include "./codecs/uuid.pyx"
include "./codecs/numeric.pyx"
include "./codecs/bits.pyx"
include "./codecs/geometry.pyx"
include "./codecs/hstore.pyx"
include "./codecs/misc.pyx"
include "./codecs/network.pyx"
include "./codecs/tid.pyx"
include "./codecs/pg_snapshot.pyx"

View File

@@ -0,0 +1,10 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef extern from "./tohex.h":
cdef void uuid_to_str(const char *source, char *dest)
cdef void uuid_to_hex(const char *source, char *dest)

View File

@@ -0,0 +1,423 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import builtins
import sys
import typing
if sys.version_info >= (3, 8):
from typing import Literal, SupportsIndex
else:
from typing_extensions import Literal, SupportsIndex
__all__ = (
'BitString', 'Point', 'Path', 'Polygon',
'Box', 'Line', 'LineSegment', 'Circle',
)
_BitString = typing.TypeVar('_BitString', bound='BitString')
_BitOrderType = Literal['big', 'little']
class BitString:
"""Immutable representation of PostgreSQL `bit` and `varbit` types."""
__slots__ = '_bytes', '_bitlength'
def __init__(self,
bitstring: typing.Optional[builtins.bytes] = None) -> None:
if not bitstring:
self._bytes = bytes()
self._bitlength = 0
else:
bytelen = len(bitstring) // 8 + 1
bytes_ = bytearray(bytelen)
byte = 0
byte_pos = 0
bit_pos = 0
for i, bit in enumerate(bitstring):
if bit == ' ': # type: ignore
continue
bit = int(bit)
if bit != 0 and bit != 1:
raise ValueError(
'invalid bit value at position {}'.format(i))
byte |= bit << (8 - bit_pos - 1)
bit_pos += 1
if bit_pos == 8:
bytes_[byte_pos] = byte
byte = 0
byte_pos += 1
bit_pos = 0
if bit_pos != 0:
bytes_[byte_pos] = byte
bitlen = byte_pos * 8 + bit_pos
bytelen = byte_pos + (1 if bit_pos else 0)
self._bytes = bytes(bytes_[:bytelen])
self._bitlength = bitlen
@classmethod
def frombytes(cls: typing.Type[_BitString],
bytes_: typing.Optional[builtins.bytes] = None,
bitlength: typing.Optional[int] = None) -> _BitString:
if bitlength is None:
if bytes_ is None:
bytes_ = bytes()
bitlength = 0
else:
bitlength = len(bytes_) * 8
else:
if bytes_ is None:
bytes_ = bytes(bitlength // 8 + 1)
bitlength = bitlength
else:
bytes_len = len(bytes_) * 8
if bytes_len == 0 and bitlength != 0:
raise ValueError('invalid bit length specified')
if bytes_len != 0 and bitlength == 0:
raise ValueError('invalid bit length specified')
if bitlength < bytes_len - 8:
raise ValueError('invalid bit length specified')
if bitlength > bytes_len:
raise ValueError('invalid bit length specified')
result = cls()
result._bytes = bytes_
result._bitlength = bitlength
return result
@property
def bytes(self) -> builtins.bytes:
return self._bytes
def as_string(self) -> str:
s = ''
for i in range(self._bitlength):
s += str(self._getitem(i))
if i % 4 == 3:
s += ' '
return s.strip()
def to_int(self, bitorder: _BitOrderType = 'big',
*, signed: bool = False) -> int:
"""Interpret the BitString as a Python int.
Acts similarly to int.from_bytes.
:param bitorder:
Determines the bit order used to interpret the BitString. By
default, this function uses Postgres conventions for casting bits
to ints. If bitorder is 'big', the most significant bit is at the
start of the string (this is the same as the default). If bitorder
is 'little', the most significant bit is at the end of the string.
:param bool signed:
Determines whether two's complement is used to interpret the
BitString. If signed is False, the returned value is always
non-negative.
:return int: An integer representing the BitString. Information about
the BitString's exact length is lost.
.. versionadded:: 0.18.0
"""
x = int.from_bytes(self._bytes, byteorder='big')
x >>= -self._bitlength % 8
if bitorder == 'big':
pass
elif bitorder == 'little':
x = int(bin(x)[:1:-1].ljust(self._bitlength, '0'), 2)
else:
raise ValueError("bitorder must be either 'big' or 'little'")
if signed and self._bitlength > 0 and x & (1 << (self._bitlength - 1)):
x -= 1 << self._bitlength
return x
@classmethod
def from_int(cls: typing.Type[_BitString], x: int, length: int,
bitorder: _BitOrderType = 'big', *, signed: bool = False) \
-> _BitString:
"""Represent the Python int x as a BitString.
Acts similarly to int.to_bytes.
:param int x:
An integer to represent. Negative integers are represented in two's
complement form, unless the argument signed is False, in which case
negative integers raise an OverflowError.
:param int length:
The length of the resulting BitString. An OverflowError is raised
if the integer is not representable in this many bits.
:param bitorder:
Determines the bit order used in the BitString representation. By
default, this function uses Postgres conventions for casting ints
to bits. If bitorder is 'big', the most significant bit is at the
start of the string (this is the same as the default). If bitorder
is 'little', the most significant bit is at the end of the string.
:param bool signed:
Determines whether two's complement is used in the BitString
representation. If signed is False and a negative integer is given,
an OverflowError is raised.
:return BitString: A BitString representing the input integer, in the
form specified by the other input args.
.. versionadded:: 0.18.0
"""
# Exception types are by analogy to int.to_bytes
if length < 0:
raise ValueError("length argument must be non-negative")
elif length < x.bit_length():
raise OverflowError("int too big to convert")
if x < 0:
if not signed:
raise OverflowError("can't convert negative int to unsigned")
x &= (1 << length) - 1
if bitorder == 'big':
pass
elif bitorder == 'little':
x = int(bin(x)[:1:-1].ljust(length, '0'), 2)
else:
raise ValueError("bitorder must be either 'big' or 'little'")
x <<= (-length % 8)
bytes_ = x.to_bytes((length + 7) // 8, byteorder='big')
return cls.frombytes(bytes_, length)
def __repr__(self) -> str:
return '<BitString {}>'.format(self.as_string())
__str__: typing.Callable[['BitString'], str] = __repr__
def __eq__(self, other: object) -> bool:
if not isinstance(other, BitString):
return NotImplemented
return (self._bytes == other._bytes and
self._bitlength == other._bitlength)
def __hash__(self) -> int:
return hash((self._bytes, self._bitlength))
def _getitem(self, i: int) -> int:
byte = self._bytes[i // 8]
shift = 8 - i % 8 - 1
return (byte >> shift) & 0x1
def __getitem__(self, i: int) -> int:
if isinstance(i, slice):
raise NotImplementedError('BitString does not support slices')
if i >= self._bitlength:
raise IndexError('index out of range')
return self._getitem(i)
def __len__(self) -> int:
return self._bitlength
class Point(typing.Tuple[float, float]):
"""Immutable representation of PostgreSQL `point` type."""
__slots__ = ()
def __new__(cls,
x: typing.Union[typing.SupportsFloat,
SupportsIndex,
typing.Text,
builtins.bytes,
builtins.bytearray],
y: typing.Union[typing.SupportsFloat,
SupportsIndex,
typing.Text,
builtins.bytes,
builtins.bytearray]) -> 'Point':
return super().__new__(cls,
typing.cast(typing.Any, (float(x), float(y))))
def __repr__(self) -> str:
return '{}.{}({})'.format(
type(self).__module__,
type(self).__name__,
tuple.__repr__(self)
)
@property
def x(self) -> float:
return self[0]
@property
def y(self) -> float:
return self[1]
class Box(typing.Tuple[Point, Point]):
"""Immutable representation of PostgreSQL `box` type."""
__slots__ = ()
def __new__(cls, high: typing.Sequence[float],
low: typing.Sequence[float]) -> 'Box':
return super().__new__(cls,
typing.cast(typing.Any, (Point(*high),
Point(*low))))
def __repr__(self) -> str:
return '{}.{}({})'.format(
type(self).__module__,
type(self).__name__,
tuple.__repr__(self)
)
@property
def high(self) -> Point:
return self[0]
@property
def low(self) -> Point:
return self[1]
class Line(typing.Tuple[float, float, float]):
"""Immutable representation of PostgreSQL `line` type."""
__slots__ = ()
def __new__(cls, A: float, B: float, C: float) -> 'Line':
return super().__new__(cls, typing.cast(typing.Any, (A, B, C)))
@property
def A(self) -> float:
return self[0]
@property
def B(self) -> float:
return self[1]
@property
def C(self) -> float:
return self[2]
class LineSegment(typing.Tuple[Point, Point]):
"""Immutable representation of PostgreSQL `lseg` type."""
__slots__ = ()
def __new__(cls, p1: typing.Sequence[float],
p2: typing.Sequence[float]) -> 'LineSegment':
return super().__new__(cls,
typing.cast(typing.Any, (Point(*p1),
Point(*p2))))
def __repr__(self) -> str:
return '{}.{}({})'.format(
type(self).__module__,
type(self).__name__,
tuple.__repr__(self)
)
@property
def p1(self) -> Point:
return self[0]
@property
def p2(self) -> Point:
return self[1]
class Path:
"""Immutable representation of PostgreSQL `path` type."""
__slots__ = '_is_closed', 'points'
points: typing.Tuple[Point, ...]
def __init__(self, *points: typing.Sequence[float],
is_closed: bool = False) -> None:
self.points = tuple(Point(*p) for p in points)
self._is_closed = is_closed
@property
def is_closed(self) -> bool:
return self._is_closed
def __eq__(self, other: object) -> bool:
if not isinstance(other, Path):
return NotImplemented
return (self.points == other.points and
self._is_closed == other._is_closed)
def __hash__(self) -> int:
return hash((self.points, self.is_closed))
def __iter__(self) -> typing.Iterator[Point]:
return iter(self.points)
def __len__(self) -> int:
return len(self.points)
@typing.overload
def __getitem__(self, i: int) -> Point:
...
@typing.overload
def __getitem__(self, i: slice) -> typing.Tuple[Point, ...]:
...
def __getitem__(self, i: typing.Union[int, slice]) \
-> typing.Union[Point, typing.Tuple[Point, ...]]:
return self.points[i]
def __contains__(self, point: object) -> bool:
return point in self.points
class Polygon(Path):
"""Immutable representation of PostgreSQL `polygon` type."""
__slots__ = ()
def __init__(self, *points: typing.Sequence[float]) -> None:
# polygon is always closed
super().__init__(*points, is_closed=True)
class Circle(typing.Tuple[Point, float]):
"""Immutable representation of PostgreSQL `circle` type."""
__slots__ = ()
def __new__(cls, center: Point, radius: float) -> 'Circle':
return super().__new__(cls, typing.cast(typing.Any, (center, radius)))
@property
def center(self) -> Point:
return self[0]
@property
def radius(self) -> float:
return self[1]

View File

@@ -0,0 +1,353 @@
import functools
import uuid
cimport cython
cimport cpython
from libc.stdint cimport uint8_t, int8_t
from libc.string cimport memcpy, memcmp
cdef extern from "Python.h":
int PyUnicode_1BYTE_KIND
const char* PyUnicode_AsUTF8AndSize(
object unicode, Py_ssize_t *size) except NULL
object PyUnicode_FromKindAndData(
int kind, const void *buffer, Py_ssize_t size)
cdef extern from "./tohex.h":
cdef void uuid_to_str(const char *source, char *dest)
cdef void uuid_to_hex(const char *source, char *dest)
# A more efficient UUID type implementation
# (6-7x faster than the starndard uuid.UUID):
#
# -= Benchmark results (less is better): =-
#
# std_UUID(bytes): 1.2368
# c_UUID(bytes): * 0.1645 (7.52x)
# object(): 0.1483
#
# std_UUID(str): 1.8038
# c_UUID(str): * 0.2313 (7.80x)
#
# str(std_UUID()): 1.4625
# str(c_UUID()): * 0.2681 (5.46x)
# str(object()): 0.5975
#
# std_UUID().bytes: 0.3508
# c_UUID().bytes: * 0.1068 (3.28x)
#
# std_UUID().int: 0.0871
# c_UUID().int: * 0.0856
#
# std_UUID().hex: 0.4871
# c_UUID().hex: * 0.1405
#
# hash(std_UUID()): 0.3635
# hash(c_UUID()): * 0.1564 (2.32x)
#
# dct[std_UUID()]: 0.3319
# dct[c_UUID()]: * 0.1570 (2.11x)
#
# std_UUID() ==: 0.3478
# c_UUID() ==: * 0.0915 (3.80x)
cdef char _hextable[256]
_hextable[:] = [
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1, 0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,-1,10,11,12,13,14,15,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,10,11,12,13,14,15,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
]
cdef std_UUID = uuid.UUID
cdef pg_uuid_bytes_from_str(str u, char *out):
cdef:
const char *orig_buf
Py_ssize_t size
unsigned char ch
uint8_t acc, part, acc_set
int i, j
orig_buf = PyUnicode_AsUTF8AndSize(u, &size)
if size > 36 or size < 32:
raise ValueError(
f'invalid UUID {u!r}: '
f'length must be between 32..36 characters, got {size}')
acc_set = 0
j = 0
for i in range(size):
ch = <unsigned char>orig_buf[i]
if ch == <unsigned char>b'-':
continue
part = <uint8_t><int8_t>_hextable[ch]
if part == <uint8_t>-1:
if ch >= 0x20 and ch <= 0x7e:
raise ValueError(
f'invalid UUID {u!r}: unexpected character {chr(ch)!r}')
else:
raise ValueError('invalid UUID {u!r}: unexpected character')
if acc_set:
acc |= part
out[j] = <char>acc
acc_set = 0
j += 1
else:
acc = <uint8_t>(part << 4)
acc_set = 1
if j > 16 or (j == 16 and acc_set):
raise ValueError(
f'invalid UUID {u!r}: decodes to more than 16 bytes')
if j != 16:
raise ValueError(
f'invalid UUID {u!r}: decodes to less than 16 bytes')
cdef class __UUIDReplaceMe:
pass
cdef pg_uuid_from_buf(const char *buf):
cdef:
UUID u = UUID.__new__(UUID)
memcpy(u._data, buf, 16)
return u
@cython.final
@cython.no_gc_clear
cdef class UUID(__UUIDReplaceMe):
cdef:
char _data[16]
object _int
object _hash
object __weakref__
def __cinit__(self):
self._int = None
self._hash = None
def __init__(self, inp):
cdef:
char *buf
Py_ssize_t size
if cpython.PyBytes_Check(inp):
cpython.PyBytes_AsStringAndSize(inp, &buf, &size)
if size != 16:
raise ValueError(f'16 bytes were expected, got {size}')
memcpy(self._data, buf, 16)
elif cpython.PyUnicode_Check(inp):
pg_uuid_bytes_from_str(inp, self._data)
else:
raise TypeError(f'a bytes or str object expected, got {inp!r}')
@property
def bytes(self):
return cpython.PyBytes_FromStringAndSize(self._data, 16)
@property
def int(self):
if self._int is None:
# The cache is important because `self.int` can be
# used multiple times by __hash__ etc.
self._int = int.from_bytes(self.bytes, 'big')
return self._int
@property
def is_safe(self):
return uuid.SafeUUID.unknown
def __str__(self):
cdef char out[36]
uuid_to_str(self._data, out)
return PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, <void*>out, 36)
@property
def hex(self):
cdef char out[32]
uuid_to_hex(self._data, out)
return PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, <void*>out, 32)
def __repr__(self):
return f"UUID('{self}')"
def __reduce__(self):
return (type(self), (self.bytes,))
def __eq__(self, other):
if type(other) is UUID:
return memcmp(self._data, (<UUID>other)._data, 16) == 0
if isinstance(other, std_UUID):
return self.int == other.int
return NotImplemented
def __ne__(self, other):
if type(other) is UUID:
return memcmp(self._data, (<UUID>other)._data, 16) != 0
if isinstance(other, std_UUID):
return self.int != other.int
return NotImplemented
def __lt__(self, other):
if type(other) is UUID:
return memcmp(self._data, (<UUID>other)._data, 16) < 0
if isinstance(other, std_UUID):
return self.int < other.int
return NotImplemented
def __gt__(self, other):
if type(other) is UUID:
return memcmp(self._data, (<UUID>other)._data, 16) > 0
if isinstance(other, std_UUID):
return self.int > other.int
return NotImplemented
def __le__(self, other):
if type(other) is UUID:
return memcmp(self._data, (<UUID>other)._data, 16) <= 0
if isinstance(other, std_UUID):
return self.int <= other.int
return NotImplemented
def __ge__(self, other):
if type(other) is UUID:
return memcmp(self._data, (<UUID>other)._data, 16) >= 0
if isinstance(other, std_UUID):
return self.int >= other.int
return NotImplemented
def __hash__(self):
# In EdgeDB every schema object has a uuid and there are
# huge hash-maps of them. We want UUID.__hash__ to be
# as fast as possible.
if self._hash is not None:
return self._hash
self._hash = hash(self.int)
return self._hash
def __int__(self):
return self.int
@property
def bytes_le(self):
bytes = self.bytes
return (bytes[4-1::-1] + bytes[6-1:4-1:-1] + bytes[8-1:6-1:-1] +
bytes[8:])
@property
def fields(self):
return (self.time_low, self.time_mid, self.time_hi_version,
self.clock_seq_hi_variant, self.clock_seq_low, self.node)
@property
def time_low(self):
return self.int >> 96
@property
def time_mid(self):
return (self.int >> 80) & 0xffff
@property
def time_hi_version(self):
return (self.int >> 64) & 0xffff
@property
def clock_seq_hi_variant(self):
return (self.int >> 56) & 0xff
@property
def clock_seq_low(self):
return (self.int >> 48) & 0xff
@property
def time(self):
return (((self.time_hi_version & 0x0fff) << 48) |
(self.time_mid << 32) | self.time_low)
@property
def clock_seq(self):
return (((self.clock_seq_hi_variant & 0x3f) << 8) |
self.clock_seq_low)
@property
def node(self):
return self.int & 0xffffffffffff
@property
def urn(self):
return 'urn:uuid:' + str(self)
@property
def variant(self):
if not self.int & (0x8000 << 48):
return uuid.RESERVED_NCS
elif not self.int & (0x4000 << 48):
return uuid.RFC_4122
elif not self.int & (0x2000 << 48):
return uuid.RESERVED_MICROSOFT
else:
return uuid.RESERVED_FUTURE
@property
def version(self):
# The version bits are only meaningful for RFC 4122 UUIDs.
if self.variant == uuid.RFC_4122:
return int((self.int >> 76) & 0xf)
# <hack>
# In order for `isinstance(pgproto.UUID, uuid.UUID)` to work,
# patch __bases__ and __mro__ by injecting `uuid.UUID`.
#
# We apply brute-force here because the following pattern stopped
# working with Python 3.8:
#
# cdef class OurUUID:
# ...
#
# class UUID(OurUUID, uuid.UUID):
# ...
#
# With Python 3.8 it now produces
#
# "TypeError: multiple bases have instance lay-out conflict"
#
# error. Maybe it's possible to fix this some other way, but
# the best solution possible would be to just contribute our
# faster UUID to the standard library and not have this problem
# at all. For now this hack is pretty safe and should be
# compatible with future Pythons for long enough.
#
assert UUID.__bases__[0] is __UUIDReplaceMe
assert UUID.__mro__[1] is __UUIDReplaceMe
cpython.Py_INCREF(std_UUID)
cpython.PyTuple_SET_ITEM(UUID.__bases__, 0, std_UUID)
cpython.Py_INCREF(std_UUID)
cpython.PyTuple_SET_ITEM(UUID.__mro__, 1, std_UUID)
# </hack>
cdef pg_UUID = UUID

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,259 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import json
from . import connresource
from . import cursor
from . import exceptions
class PreparedStatement(connresource.ConnectionResource):
"""A representation of a prepared statement."""
__slots__ = ('_state', '_query', '_last_status')
def __init__(self, connection, query, state):
super().__init__(connection)
self._state = state
self._query = query
state.attach()
self._last_status = None
@connresource.guarded
def get_name(self) -> str:
"""Return the name of this prepared statement.
.. versionadded:: 0.25.0
"""
return self._state.name
@connresource.guarded
def get_query(self) -> str:
"""Return the text of the query for this prepared statement.
Example::
stmt = await connection.prepare('SELECT $1::int')
assert stmt.get_query() == "SELECT $1::int"
"""
return self._query
@connresource.guarded
def get_statusmsg(self) -> str:
"""Return the status of the executed command.
Example::
stmt = await connection.prepare('CREATE TABLE mytab (a int)')
await stmt.fetch()
assert stmt.get_statusmsg() == "CREATE TABLE"
"""
if self._last_status is None:
return self._last_status
return self._last_status.decode()
@connresource.guarded
def get_parameters(self):
"""Return a description of statement parameters types.
:return: A tuple of :class:`asyncpg.types.Type`.
Example::
stmt = await connection.prepare('SELECT ($1::int, $2::text)')
print(stmt.get_parameters())
# Will print:
# (Type(oid=23, name='int4', kind='scalar', schema='pg_catalog'),
# Type(oid=25, name='text', kind='scalar', schema='pg_catalog'))
"""
return self._state._get_parameters()
@connresource.guarded
def get_attributes(self):
"""Return a description of relation attributes (columns).
:return: A tuple of :class:`asyncpg.types.Attribute`.
Example::
st = await self.con.prepare('''
SELECT typname, typnamespace FROM pg_type
''')
print(st.get_attributes())
# Will print:
# (Attribute(
# name='typname',
# type=Type(oid=19, name='name', kind='scalar',
# schema='pg_catalog')),
# Attribute(
# name='typnamespace',
# type=Type(oid=26, name='oid', kind='scalar',
# schema='pg_catalog')))
"""
return self._state._get_attributes()
@connresource.guarded
def cursor(self, *args, prefetch=None,
timeout=None) -> cursor.CursorFactory:
"""Return a *cursor factory* for the prepared statement.
:param args: Query arguments.
:param int prefetch: The number of rows the *cursor iterator*
will prefetch (defaults to ``50``.)
:param float timeout: Optional timeout in seconds.
:return: A :class:`~cursor.CursorFactory` object.
"""
return cursor.CursorFactory(
self._connection,
self._query,
self._state,
args,
prefetch,
timeout,
self._state.record_class,
)
@connresource.guarded
async def explain(self, *args, analyze=False):
"""Return the execution plan of the statement.
:param args: Query arguments.
:param analyze: If ``True``, the statement will be executed and
the run time statitics added to the return value.
:return: An object representing the execution plan. This value
is actually a deserialized JSON output of the SQL
``EXPLAIN`` command.
"""
query = 'EXPLAIN (FORMAT JSON, VERBOSE'
if analyze:
query += ', ANALYZE) '
else:
query += ') '
query += self._state.query
if analyze:
# From PostgreSQL docs:
# Important: Keep in mind that the statement is actually
# executed when the ANALYZE option is used. Although EXPLAIN
# will discard any output that a SELECT would return, other
# side effects of the statement will happen as usual. If you
# wish to use EXPLAIN ANALYZE on an INSERT, UPDATE, DELETE,
# CREATE TABLE AS, or EXECUTE statement without letting the
# command affect your data, use this approach:
# BEGIN;
# EXPLAIN ANALYZE ...;
# ROLLBACK;
tr = self._connection.transaction()
await tr.start()
try:
data = await self._connection.fetchval(query, *args)
finally:
await tr.rollback()
else:
data = await self._connection.fetchval(query, *args)
return json.loads(data)
@connresource.guarded
async def fetch(self, *args, timeout=None):
r"""Execute the statement and return a list of :class:`Record` objects.
:param str query: Query text
:param args: Query arguments
:param float timeout: Optional timeout value in seconds.
:return: A list of :class:`Record` instances.
"""
data = await self.__bind_execute(args, 0, timeout)
return data
@connresource.guarded
async def fetchval(self, *args, column=0, timeout=None):
"""Execute the statement and return a value in the first row.
:param args: Query arguments.
:param int column: Numeric index within the record of the value to
return (defaults to 0).
:param float timeout: Optional timeout value in seconds.
If not specified, defaults to the value of
``command_timeout`` argument to the ``Connection``
instance constructor.
:return: The value of the specified column of the first record.
"""
data = await self.__bind_execute(args, 1, timeout)
if not data:
return None
return data[0][column]
@connresource.guarded
async def fetchrow(self, *args, timeout=None):
"""Execute the statement and return the first row.
:param str query: Query text
:param args: Query arguments
:param float timeout: Optional timeout value in seconds.
:return: The first row as a :class:`Record` instance.
"""
data = await self.__bind_execute(args, 1, timeout)
if not data:
return None
return data[0]
@connresource.guarded
async def executemany(self, args, *, timeout: float=None):
"""Execute the statement for each sequence of arguments in *args*.
:param args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.
.. versionadded:: 0.22.0
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
self._state, args, '', timeout))
async def __do_execute(self, executor):
protocol = self._connection._protocol
try:
return await executor(protocol)
except exceptions.OutdatedSchemaCacheError:
await self._connection.reload_schema_state()
# We can not find all manually created prepared statements, so just
# drop known cached ones in the `self._connection`.
# Other manually created prepared statements will fail and
# invalidate themselves (unfortunately, clearing caches again).
self._state.mark_closed()
raise
async def __bind_execute(self, args, limit, timeout):
data, status, _ = await self.__do_execute(
lambda protocol: protocol.bind_execute(
self._state, args, '', limit, True, timeout))
self._last_status = status
return data
def _check_open(self, meth_name):
if self._state.closed:
raise exceptions.InterfaceError(
'cannot call PreparedStmt.{}(): '
'the prepared statement is closed'.format(meth_name))
def _check_conn_validity(self, meth_name):
self._check_open(meth_name)
super()._check_conn_validity(meth_name)
def __del__(self):
self._state.detach()
self._connection._maybe_gc_stmt(self._state)

View File

@@ -0,0 +1,9 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
# flake8: NOQA
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP

View File

@@ -0,0 +1,875 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from collections.abc import (Iterable as IterableABC,
Mapping as MappingABC,
Sized as SizedABC)
from asyncpg import exceptions
DEF ARRAY_MAXDIM = 6 # defined in postgresql/src/includes/c.h
# "NULL"
cdef Py_UCS4 *APG_NULL = [0x004E, 0x0055, 0x004C, 0x004C, 0x0000]
ctypedef object (*encode_func_ex)(ConnectionSettings settings,
WriteBuffer buf,
object obj,
const void *arg)
ctypedef object (*decode_func_ex)(ConnectionSettings settings,
FRBuffer *buf,
const void *arg)
cdef inline bint _is_trivial_container(object obj):
return cpython.PyUnicode_Check(obj) or cpython.PyBytes_Check(obj) or \
cpythonx.PyByteArray_Check(obj) or cpythonx.PyMemoryView_Check(obj)
cdef inline _is_array_iterable(object obj):
return (
isinstance(obj, IterableABC) and
isinstance(obj, SizedABC) and
not _is_trivial_container(obj) and
not isinstance(obj, MappingABC)
)
cdef inline _is_sub_array_iterable(object obj):
# Sub-arrays have a specialized check, because we treat
# nested tuples as records.
return _is_array_iterable(obj) and not cpython.PyTuple_Check(obj)
cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims):
cdef:
ssize_t mylen = len(obj)
ssize_t elemlen = -2
object it
if mylen > _MAXINT32:
raise ValueError('too many elements in array value')
if ndims[0] > ARRAY_MAXDIM:
raise ValueError(
'number of array dimensions ({}) exceed the maximum expected ({})'.
format(ndims[0], ARRAY_MAXDIM))
dims[ndims[0] - 1] = <int32_t>mylen
for elem in obj:
if _is_sub_array_iterable(elem):
if elemlen == -2:
elemlen = len(elem)
if elemlen > _MAXINT32:
raise ValueError('too many elements in array value')
ndims[0] += 1
_get_array_shape(elem, dims, ndims)
else:
if len(elem) != elemlen:
raise ValueError('non-homogeneous array')
else:
if elemlen >= 0:
raise ValueError('non-homogeneous array')
else:
elemlen = -1
cdef _write_array_data(ConnectionSettings settings, object obj, int32_t ndims,
int32_t dim, WriteBuffer elem_data,
encode_func_ex encoder, const void *encoder_arg):
if dim < ndims - 1:
for item in obj:
_write_array_data(settings, item, ndims, dim + 1, elem_data,
encoder, encoder_arg)
else:
for item in obj:
if item is None:
elem_data.write_int32(-1)
else:
try:
encoder(settings, elem_data, item, encoder_arg)
except TypeError as e:
raise ValueError(
'invalid array element: {}'.format(e.args[0])) from None
cdef inline array_encode(ConnectionSettings settings, WriteBuffer buf,
object obj, uint32_t elem_oid,
encode_func_ex encoder, const void *encoder_arg):
cdef:
WriteBuffer elem_data
int32_t dims[ARRAY_MAXDIM]
int32_t ndims = 1
int32_t i
if not _is_array_iterable(obj):
raise TypeError(
'a sized iterable container expected (got type {!r})'.format(
type(obj).__name__))
_get_array_shape(obj, dims, &ndims)
elem_data = WriteBuffer.new()
if ndims > 1:
_write_array_data(settings, obj, ndims, 0, elem_data,
encoder, encoder_arg)
else:
for i, item in enumerate(obj):
if item is None:
elem_data.write_int32(-1)
else:
try:
encoder(settings, elem_data, item, encoder_arg)
except TypeError as e:
raise ValueError(
'invalid array element at index {}: {}'.format(
i, e.args[0])) from None
buf.write_int32(12 + 8 * ndims + elem_data.len())
# Number of dimensions
buf.write_int32(ndims)
# flags
buf.write_int32(0)
# element type
buf.write_int32(<int32_t>elem_oid)
# upper / lower bounds
for i in range(ndims):
buf.write_int32(dims[i])
buf.write_int32(1)
# element data
buf.write_buffer(elem_data)
cdef _write_textarray_data(ConnectionSettings settings, object obj,
int32_t ndims, int32_t dim, WriteBuffer array_data,
encode_func_ex encoder, const void *encoder_arg,
Py_UCS4 typdelim):
cdef:
ssize_t i = 0
int8_t delim = <int8_t>typdelim
WriteBuffer elem_data
Py_buffer pybuf
const char *elem_str
char ch
ssize_t elem_len
ssize_t quoted_elem_len
bint need_quoting
array_data.write_byte(b'{')
if dim < ndims - 1:
for item in obj:
if i > 0:
array_data.write_byte(delim)
array_data.write_byte(b' ')
_write_textarray_data(settings, item, ndims, dim + 1, array_data,
encoder, encoder_arg, typdelim)
i += 1
else:
for item in obj:
elem_data = WriteBuffer.new()
if i > 0:
array_data.write_byte(delim)
array_data.write_byte(b' ')
if item is None:
array_data.write_bytes(b'NULL')
i += 1
continue
else:
try:
encoder(settings, elem_data, item, encoder_arg)
except TypeError as e:
raise ValueError(
'invalid array element: {}'.format(
e.args[0])) from None
# element string length (first four bytes are the encoded length.)
elem_len = elem_data.len() - 4
if elem_len == 0:
# Empty string
array_data.write_bytes(b'""')
else:
cpython.PyObject_GetBuffer(
elem_data, &pybuf, cpython.PyBUF_SIMPLE)
elem_str = <const char*>(pybuf.buf) + 4
try:
if not apg_strcasecmp_char(elem_str, b'NULL'):
array_data.write_byte(b'"')
array_data.write_cstr(elem_str, 4)
array_data.write_byte(b'"')
else:
quoted_elem_len = elem_len
need_quoting = False
for i in range(elem_len):
ch = elem_str[i]
if ch == b'"' or ch == b'\\':
# Quotes and backslashes need escaping.
quoted_elem_len += 1
need_quoting = True
elif (ch == b'{' or ch == b'}' or ch == delim or
apg_ascii_isspace(<uint32_t>ch)):
need_quoting = True
if need_quoting:
array_data.write_byte(b'"')
if quoted_elem_len == elem_len:
array_data.write_cstr(elem_str, elem_len)
else:
# Escaping required.
for i in range(elem_len):
ch = elem_str[i]
if ch == b'"' or ch == b'\\':
array_data.write_byte(b'\\')
array_data.write_byte(ch)
array_data.write_byte(b'"')
else:
array_data.write_cstr(elem_str, elem_len)
finally:
cpython.PyBuffer_Release(&pybuf)
i += 1
array_data.write_byte(b'}')
cdef inline textarray_encode(ConnectionSettings settings, WriteBuffer buf,
object obj, encode_func_ex encoder,
const void *encoder_arg, Py_UCS4 typdelim):
cdef:
WriteBuffer array_data
int32_t dims[ARRAY_MAXDIM]
int32_t ndims = 1
int32_t i
if not _is_array_iterable(obj):
raise TypeError(
'a sized iterable container expected (got type {!r})'.format(
type(obj).__name__))
_get_array_shape(obj, dims, &ndims)
array_data = WriteBuffer.new()
_write_textarray_data(settings, obj, ndims, 0, array_data,
encoder, encoder_arg, typdelim)
buf.write_int32(array_data.len())
buf.write_buffer(array_data)
cdef inline array_decode(ConnectionSettings settings, FRBuffer *buf,
decode_func_ex decoder, const void *decoder_arg):
cdef:
int32_t ndims = hton.unpack_int32(frb_read(buf, 4))
int32_t flags = hton.unpack_int32(frb_read(buf, 4))
uint32_t elem_oid = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
list result
int i
int32_t elem_len
int32_t elem_count = 1
FRBuffer elem_buf
int32_t dims[ARRAY_MAXDIM]
Codec elem_codec
if ndims == 0:
return []
if ndims > ARRAY_MAXDIM:
raise exceptions.ProtocolError(
'number of array dimensions ({}) exceed the maximum expected ({})'.
format(ndims, ARRAY_MAXDIM))
elif ndims < 0:
raise exceptions.ProtocolError(
'unexpected array dimensions value: {}'.format(ndims))
for i in range(ndims):
dims[i] = hton.unpack_int32(frb_read(buf, 4))
if dims[i] < 0:
raise exceptions.ProtocolError(
'unexpected array dimension size: {}'.format(dims[i]))
# Ignore the lower bound information
frb_read(buf, 4)
if ndims == 1:
# Fast path for flat arrays
elem_count = dims[0]
result = cpython.PyList_New(elem_count)
for i in range(elem_count):
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
elem = None
else:
frb_slice_from(&elem_buf, buf, elem_len)
elem = decoder(settings, &elem_buf, decoder_arg)
cpython.Py_INCREF(elem)
cpython.PyList_SET_ITEM(result, i, elem)
else:
result = _nested_array_decode(settings, buf,
decoder, decoder_arg, ndims, dims,
&elem_buf)
return result
cdef _nested_array_decode(ConnectionSettings settings,
FRBuffer *buf,
decode_func_ex decoder,
const void *decoder_arg,
int32_t ndims, int32_t *dims,
FRBuffer *elem_buf):
cdef:
int32_t elem_len
int64_t i, j
int64_t array_len = 1
object elem, stride
# An array of pointers to lists for each current array level.
void *strides[ARRAY_MAXDIM]
# An array of current positions at each array level.
int32_t indexes[ARRAY_MAXDIM]
for i in range(ndims):
array_len *= dims[i]
indexes[i] = 0
strides[i] = NULL
if array_len == 0:
# A multidimensional array with a zero-sized dimension?
return []
elif array_len < 0:
# Array length overflow
raise exceptions.ProtocolError('array length overflow')
for i in range(array_len):
# Decode the element.
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
elem = None
else:
elem = decoder(settings,
frb_slice_from(elem_buf, buf, elem_len),
decoder_arg)
# Take an explicit reference for PyList_SET_ITEM in the below
# loop expects this.
cpython.Py_INCREF(elem)
# Iterate over array dimentions and put the element in
# the correctly nested sublist.
for j in reversed(range(ndims)):
if indexes[j] == 0:
# Allocate the list for this array level.
stride = cpython.PyList_New(dims[j])
strides[j] = <void*><cpython.PyObject>stride
# Take an explicit reference for PyList_SET_ITEM below
# expects this.
cpython.Py_INCREF(stride)
stride = <object><cpython.PyObject*>strides[j]
cpython.PyList_SET_ITEM(stride, indexes[j], elem)
indexes[j] += 1
if indexes[j] == dims[j] and j != 0:
# This array level is full, continue the
# ascent in the dimensions so that this level
# sublist will be appened to the parent list.
elem = stride
# Reset the index, this will cause the
# new list to be allocated on the next
# iteration on this array axis.
indexes[j] = 0
else:
break
stride = <object><cpython.PyObject*>strides[0]
# Since each element in strides has a refcount of 1,
# returning strides[0] will increment it to 2, so
# balance that.
cpython.Py_DECREF(stride)
return stride
cdef textarray_decode(ConnectionSettings settings, FRBuffer *buf,
decode_func_ex decoder, const void *decoder_arg,
Py_UCS4 typdelim):
cdef:
Py_UCS4 *array_text
str s
# Make a copy of array data since we will be mutating it for
# the purposes of element decoding.
s = pgproto.text_decode(settings, buf)
array_text = cpythonx.PyUnicode_AsUCS4Copy(s)
try:
return _textarray_decode(
settings, array_text, decoder, decoder_arg, typdelim)
except ValueError as e:
raise exceptions.ProtocolError(
'malformed array literal {!r}: {}'.format(s, e.args[0]))
finally:
cpython.PyMem_Free(array_text)
cdef _textarray_decode(ConnectionSettings settings,
Py_UCS4 *array_text,
decode_func_ex decoder,
const void *decoder_arg,
Py_UCS4 typdelim):
cdef:
bytearray array_bytes
list result
list new_stride
Py_UCS4 *ptr
int32_t ndims = 0
int32_t ubound = 0
int32_t lbound = 0
int32_t dims[ARRAY_MAXDIM]
int32_t inferred_dims[ARRAY_MAXDIM]
int32_t inferred_ndims = 0
void *strides[ARRAY_MAXDIM]
int32_t indexes[ARRAY_MAXDIM]
int32_t nest_level = 0
int32_t item_level = 0
bint end_of_array = False
bint end_of_item = False
bint has_quoting = False
bint strip_spaces = False
bint in_quotes = False
Py_UCS4 *item_start
Py_UCS4 *item_ptr
Py_UCS4 *item_end
int i
object item
str item_text
FRBuffer item_buf
char *pg_item_str
ssize_t pg_item_len
ptr = array_text
while True:
while apg_ascii_isspace(ptr[0]):
ptr += 1
if ptr[0] != '[':
# Finished parsing dimensions spec.
break
ptr += 1 # '['
if ndims > ARRAY_MAXDIM:
raise ValueError(
'number of array dimensions ({}) exceed the '
'maximum expected ({})'.format(ndims, ARRAY_MAXDIM))
ptr = apg_parse_int32(ptr, &ubound)
if ptr == NULL:
raise ValueError('missing array dimension value')
if ptr[0] == ':':
ptr += 1
lbound = ubound
# [lower:upper] spec. We disregard the lbound for decoding.
ptr = apg_parse_int32(ptr, &ubound)
if ptr == NULL:
raise ValueError('missing array dimension value')
else:
lbound = 1
if ptr[0] != ']':
raise ValueError('missing \']\' after array dimensions')
ptr += 1 # ']'
dims[ndims] = ubound - lbound + 1
ndims += 1
if ndims != 0:
# If dimensions were given, the '=' token is expected.
if ptr[0] != '=':
raise ValueError('missing \'=\' after array dimensions')
ptr += 1 # '='
# Skip any whitespace after the '=', whitespace
# before was consumed in the above loop.
while apg_ascii_isspace(ptr[0]):
ptr += 1
# Infer the dimensions from the brace structure in the
# array literal body, and check that it matches the explicit
# spec. This also validates that the array literal is sane.
_infer_array_dims(ptr, typdelim, inferred_dims, &inferred_ndims)
if inferred_ndims != ndims:
raise ValueError(
'specified array dimensions do not match array content')
for i in range(ndims):
if inferred_dims[i] != dims[i]:
raise ValueError(
'specified array dimensions do not match array content')
else:
# Infer the dimensions from the brace structure in the array literal
# body. This also validates that the array literal is sane.
_infer_array_dims(ptr, typdelim, dims, &ndims)
while not end_of_array:
# We iterate over the literal character by character
# and modify the string in-place removing the array-specific
# quoting and determining the boundaries of each element.
end_of_item = has_quoting = in_quotes = False
strip_spaces = True
# Pointers to array element start, end, and the current pointer
# tracking the position where characters are written when
# escaping is folded.
item_start = item_end = item_ptr = ptr
item_level = 0
while not end_of_item:
if ptr[0] == '"':
in_quotes = not in_quotes
if in_quotes:
strip_spaces = False
else:
item_end = item_ptr
has_quoting = True
elif ptr[0] == '\\':
# Quoted character, collapse the backslash.
ptr += 1
has_quoting = True
item_ptr[0] = ptr[0]
item_ptr += 1
strip_spaces = False
item_end = item_ptr
elif in_quotes:
# Consume the string until we see the closing quote.
item_ptr[0] = ptr[0]
item_ptr += 1
elif ptr[0] == '{':
# Nesting level increase.
nest_level += 1
indexes[nest_level - 1] = 0
new_stride = cpython.PyList_New(dims[nest_level - 1])
strides[nest_level - 1] = \
<void*>(<cpython.PyObject>new_stride)
if nest_level > 1:
cpython.Py_INCREF(new_stride)
cpython.PyList_SET_ITEM(
<object><cpython.PyObject*>strides[nest_level - 2],
indexes[nest_level - 2],
new_stride)
else:
result = new_stride
elif ptr[0] == '}':
if item_level == 0:
# Make sure we keep track of which nesting
# level the item belongs to, as the loop
# will continue to consume closing braces
# until the delimiter or the end of input.
item_level = nest_level
nest_level -= 1
if nest_level == 0:
end_of_array = end_of_item = True
elif ptr[0] == typdelim:
# Array element delimiter,
end_of_item = True
if item_level == 0:
item_level = nest_level
elif apg_ascii_isspace(ptr[0]):
if not strip_spaces:
item_ptr[0] = ptr[0]
item_ptr += 1
# Ignore the leading literal whitespace.
else:
item_ptr[0] = ptr[0]
item_ptr += 1
strip_spaces = False
item_end = item_ptr
ptr += 1
# end while not end_of_item
if item_end == item_start:
# Empty array
continue
item_end[0] = '\0'
if not has_quoting and apg_strcasecmp(item_start, APG_NULL) == 0:
# NULL element.
item = None
else:
# XXX: find a way to avoid the redundant encode/decode
# cycle here.
item_text = cpythonx.PyUnicode_FromKindAndData(
cpythonx.PyUnicode_4BYTE_KIND,
<void *>item_start,
item_end - item_start)
# Prepare the element buffer and call the text decoder
# for the element type.
pgproto.as_pg_string_and_size(
settings, item_text, &pg_item_str, &pg_item_len)
frb_init(&item_buf, pg_item_str, pg_item_len)
item = decoder(settings, &item_buf, decoder_arg)
# Place the decoded element in the array.
cpython.Py_INCREF(item)
cpython.PyList_SET_ITEM(
<object><cpython.PyObject*>strides[item_level - 1],
indexes[item_level - 1],
item)
if nest_level > 0:
indexes[nest_level - 1] += 1
return result
cdef enum _ArrayParseState:
APS_START = 1
APS_STRIDE_STARTED = 2
APS_STRIDE_DONE = 3
APS_STRIDE_DELIMITED = 4
APS_ELEM_STARTED = 5
APS_ELEM_DELIMITED = 6
cdef _UnexpectedCharacter(const Py_UCS4 *array_text, const Py_UCS4 *ptr):
return ValueError('unexpected character {!r} at position {}'.format(
cpython.PyUnicode_FromOrdinal(<int>ptr[0]), ptr - array_text + 1))
cdef _infer_array_dims(const Py_UCS4 *array_text,
Py_UCS4 typdelim,
int32_t *dims,
int32_t *ndims):
cdef:
const Py_UCS4 *ptr = array_text
int i
int nest_level = 0
bint end_of_array = False
bint end_of_item = False
bint in_quotes = False
bint array_is_empty = True
int stride_len[ARRAY_MAXDIM]
int prev_stride_len[ARRAY_MAXDIM]
_ArrayParseState parse_state = APS_START
for i in range(ARRAY_MAXDIM):
dims[i] = prev_stride_len[i] = 0
stride_len[i] = 1
while not end_of_array:
end_of_item = False
while not end_of_item:
if ptr[0] == '\0':
raise ValueError('unexpected end of string')
elif ptr[0] == '"':
if (parse_state not in (APS_STRIDE_STARTED,
APS_ELEM_DELIMITED) and
not (parse_state == APS_ELEM_STARTED and in_quotes)):
raise _UnexpectedCharacter(array_text, ptr)
in_quotes = not in_quotes
if in_quotes:
parse_state = APS_ELEM_STARTED
array_is_empty = False
elif ptr[0] == '\\':
if parse_state not in (APS_STRIDE_STARTED,
APS_ELEM_STARTED,
APS_ELEM_DELIMITED):
raise _UnexpectedCharacter(array_text, ptr)
parse_state = APS_ELEM_STARTED
array_is_empty = False
if ptr[1] != '\0':
ptr += 1
else:
raise ValueError('unexpected end of string')
elif in_quotes:
# Ignore everything inside the quotes.
pass
elif ptr[0] == '{':
if parse_state not in (APS_START,
APS_STRIDE_STARTED,
APS_STRIDE_DELIMITED):
raise _UnexpectedCharacter(array_text, ptr)
parse_state = APS_STRIDE_STARTED
if nest_level >= ARRAY_MAXDIM:
raise ValueError(
'number of array dimensions ({}) exceed the '
'maximum expected ({})'.format(
nest_level, ARRAY_MAXDIM))
dims[nest_level] = 0
nest_level += 1
if ndims[0] < nest_level:
ndims[0] = nest_level
elif ptr[0] == '}':
if (parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE) and
not (nest_level == 1 and
parse_state == APS_STRIDE_STARTED)):
raise _UnexpectedCharacter(array_text, ptr)
parse_state = APS_STRIDE_DONE
if nest_level == 0:
raise _UnexpectedCharacter(array_text, ptr)
nest_level -= 1
if (prev_stride_len[nest_level] != 0 and
stride_len[nest_level] != prev_stride_len[nest_level]):
raise ValueError(
'inconsistent sub-array dimensions'
' at position {}'.format(
ptr - array_text + 1))
prev_stride_len[nest_level] = stride_len[nest_level]
stride_len[nest_level] = 1
if nest_level == 0:
end_of_array = end_of_item = True
else:
dims[nest_level - 1] += 1
elif ptr[0] == typdelim:
if parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE):
raise _UnexpectedCharacter(array_text, ptr)
if parse_state == APS_STRIDE_DONE:
parse_state = APS_STRIDE_DELIMITED
else:
parse_state = APS_ELEM_DELIMITED
end_of_item = True
stride_len[nest_level - 1] += 1
elif not apg_ascii_isspace(ptr[0]):
if parse_state not in (APS_STRIDE_STARTED,
APS_ELEM_STARTED,
APS_ELEM_DELIMITED):
raise _UnexpectedCharacter(array_text, ptr)
parse_state = APS_ELEM_STARTED
array_is_empty = False
if not end_of_item:
ptr += 1
if not array_is_empty:
dims[ndims[0] - 1] += 1
ptr += 1
# only whitespace is allowed after the closing brace
while ptr[0] != '\0':
if not apg_ascii_isspace(ptr[0]):
raise _UnexpectedCharacter(array_text, ptr)
ptr += 1
if array_is_empty:
ndims[0] = 0
cdef uint4_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj,
const void *arg):
return pgproto.uint4_encode(settings, buf, obj)
cdef uint4_decode_ex(ConnectionSettings settings, FRBuffer *buf,
const void *arg):
return pgproto.uint4_decode(settings, buf)
cdef arrayoid_encode(ConnectionSettings settings, WriteBuffer buf, items):
array_encode(settings, buf, items, OIDOID,
<encode_func_ex>&uint4_encode_ex, NULL)
cdef arrayoid_decode(ConnectionSettings settings, FRBuffer *buf):
return array_decode(settings, buf, <decode_func_ex>&uint4_decode_ex, NULL)
cdef text_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj,
const void *arg):
return pgproto.text_encode(settings, buf, obj)
cdef text_decode_ex(ConnectionSettings settings, FRBuffer *buf,
const void *arg):
return pgproto.text_decode(settings, buf)
cdef arraytext_encode(ConnectionSettings settings, WriteBuffer buf, items):
array_encode(settings, buf, items, TEXTOID,
<encode_func_ex>&text_encode_ex, NULL)
cdef arraytext_decode(ConnectionSettings settings, FRBuffer *buf):
return array_decode(settings, buf, <decode_func_ex>&text_decode_ex, NULL)
cdef init_array_codecs():
# oid[] and text[] are registered as core codecs
# to make type introspection query work
#
register_core_codec(_OIDOID,
<encode_func>&arrayoid_encode,
<decode_func>&arrayoid_decode,
PG_FORMAT_BINARY)
register_core_codec(_TEXTOID,
<encode_func>&arraytext_encode,
<decode_func>&arraytext_decode,
PG_FORMAT_BINARY)
init_array_codecs()

View File

@@ -0,0 +1,187 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
ctypedef object (*encode_func)(ConnectionSettings settings,
WriteBuffer buf,
object obj)
ctypedef object (*decode_func)(ConnectionSettings settings,
FRBuffer *buf)
ctypedef object (*codec_encode_func)(Codec codec,
ConnectionSettings settings,
WriteBuffer buf,
object obj)
ctypedef object (*codec_decode_func)(Codec codec,
ConnectionSettings settings,
FRBuffer *buf)
cdef enum CodecType:
CODEC_UNDEFINED = 0
CODEC_C = 1
CODEC_PY = 2
CODEC_ARRAY = 3
CODEC_COMPOSITE = 4
CODEC_RANGE = 5
CODEC_MULTIRANGE = 6
cdef enum ServerDataFormat:
PG_FORMAT_ANY = -1
PG_FORMAT_TEXT = 0
PG_FORMAT_BINARY = 1
cdef enum ClientExchangeFormat:
PG_XFORMAT_OBJECT = 1
PG_XFORMAT_TUPLE = 2
cdef class Codec:
cdef:
uint32_t oid
str name
str schema
str kind
CodecType type
ServerDataFormat format
ClientExchangeFormat xformat
encode_func c_encoder
decode_func c_decoder
Codec base_codec
object py_encoder
object py_decoder
# arrays
Codec element_codec
Py_UCS4 element_delimiter
# composite types
tuple element_type_oids
object element_names
object record_desc
list element_codecs
# Pointers to actual encoder/decoder functions for this codec
codec_encode_func encoder
codec_decode_func decoder
cdef init(self, str name, str schema, str kind,
CodecType type, ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder, decode_func c_decoder,
Codec base_codec,
object py_encoder, object py_decoder,
Codec element_codec, tuple element_type_oids,
object element_names, list element_codecs,
Py_UCS4 element_delimiter)
cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_array_text(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_composite(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_in_python(self, ConnectionSettings settings, FRBuffer *buf)
cdef inline encode(self,
ConnectionSettings settings,
WriteBuffer buf,
object obj)
cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf)
cdef has_encoder(self)
cdef has_decoder(self)
cdef is_binary(self)
cdef inline Codec copy(self)
@staticmethod
cdef Codec new_array_codec(uint32_t oid,
str name,
str schema,
Codec element_codec,
Py_UCS4 element_delimiter)
@staticmethod
cdef Codec new_range_codec(uint32_t oid,
str name,
str schema,
Codec element_codec)
@staticmethod
cdef Codec new_multirange_codec(uint32_t oid,
str name,
str schema,
Codec element_codec)
@staticmethod
cdef Codec new_composite_codec(uint32_t oid,
str name,
str schema,
ServerDataFormat format,
list element_codecs,
tuple element_type_oids,
object element_names)
@staticmethod
cdef Codec new_python_codec(uint32_t oid,
str name,
str schema,
str kind,
object encoder,
object decoder,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat)
cdef class DataCodecConfig:
cdef:
dict _derived_type_codecs
dict _custom_type_codecs
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint ignore_custom_codec=*)
cdef inline Codec get_custom_codec(self, uint32_t oid,
ServerDataFormat format)

View File

@@ -0,0 +1,895 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from collections.abc import Mapping as MappingABC
import asyncpg
from asyncpg import exceptions
cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2]
cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2]
cdef dict EXTRA_CODECS = {}
@cython.final
cdef class Codec:
def __cinit__(self, uint32_t oid):
self.oid = oid
self.type = CODEC_UNDEFINED
cdef init(
self,
str name,
str schema,
str kind,
CodecType type,
ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
object py_encoder,
object py_decoder,
Codec element_codec,
tuple element_type_oids,
object element_names,
list element_codecs,
Py_UCS4 element_delimiter,
):
self.name = name
self.schema = schema
self.kind = kind
self.type = type
self.format = format
self.xformat = xformat
self.c_encoder = c_encoder
self.c_decoder = c_decoder
self.base_codec = base_codec
self.py_encoder = py_encoder
self.py_decoder = py_decoder
self.element_codec = element_codec
self.element_type_oids = element_type_oids
self.element_codecs = element_codecs
self.element_delimiter = element_delimiter
self.element_names = element_names
if base_codec is not None:
if c_encoder != NULL or c_decoder != NULL:
raise exceptions.InternalClientError(
'base_codec is mutually exclusive with c_encoder/c_decoder'
)
if element_names is not None:
self.record_desc = record.ApgRecordDesc_New(
element_names, tuple(element_names))
else:
self.record_desc = None
if type == CODEC_C:
self.encoder = <codec_encode_func>&self.encode_scalar
self.decoder = <codec_decode_func>&self.decode_scalar
elif type == CODEC_ARRAY:
if format == PG_FORMAT_BINARY:
self.encoder = <codec_encode_func>&self.encode_array
self.decoder = <codec_decode_func>&self.decode_array
else:
self.encoder = <codec_encode_func>&self.encode_array_text
self.decoder = <codec_decode_func>&self.decode_array_text
elif type == CODEC_RANGE:
if format != PG_FORMAT_BINARY:
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'range types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_range
self.decoder = <codec_decode_func>&self.decode_range
elif type == CODEC_MULTIRANGE:
if format != PG_FORMAT_BINARY:
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'range types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_multirange
self.decoder = <codec_decode_func>&self.decode_multirange
elif type == CODEC_COMPOSITE:
if format != PG_FORMAT_BINARY:
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'composite types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_composite
self.decoder = <codec_decode_func>&self.decode_composite
elif type == CODEC_PY:
self.encoder = <codec_encode_func>&self.encode_in_python
self.decoder = <codec_decode_func>&self.decode_in_python
else:
raise exceptions.InternalClientError(
'unexpected codec type: {}'.format(type))
cdef Codec copy(self):
cdef Codec codec
codec = Codec(self.oid)
codec.init(self.name, self.schema, self.kind,
self.type, self.format, self.xformat,
self.c_encoder, self.c_decoder, self.base_codec,
self.py_encoder, self.py_decoder,
self.element_codec,
self.element_type_oids, self.element_names,
self.element_codecs, self.element_delimiter)
return codec
cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
self.c_encoder(settings, buf, obj)
cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
array_encode(settings, buf, obj, self.element_codec.oid,
codec_encode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
return textarray_encode(settings, buf, obj,
codec_encode_func_ex,
<void*>(<cpython.PyObject>self.element_codec),
self.element_delimiter)
cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
range_encode(settings, buf, obj, self.element_codec.oid,
codec_encode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
multirange_encode(settings, buf, obj, self.element_codec.oid,
codec_encode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
cdef:
WriteBuffer elem_data
int i
list elem_codecs = self.element_codecs
ssize_t count
ssize_t composite_size
tuple rec
if isinstance(obj, MappingABC):
# Input is dict-like, form a tuple
composite_size = len(self.element_type_oids)
rec = cpython.PyTuple_New(composite_size)
for i in range(composite_size):
cpython.Py_INCREF(None)
cpython.PyTuple_SET_ITEM(rec, i, None)
for field in obj:
try:
i = self.element_names[field]
except KeyError:
raise ValueError(
'{!r} is not a valid element of composite '
'type {}'.format(field, self.name)) from None
item = obj[field]
cpython.Py_INCREF(item)
cpython.PyTuple_SET_ITEM(rec, i, item)
obj = rec
count = len(obj)
if count > _MAXINT32:
raise ValueError('too many elements in composite type record')
elem_data = WriteBuffer.new()
i = 0
for item in obj:
elem_data.write_int32(<int32_t>self.element_type_oids[i])
if item is None:
elem_data.write_int32(-1)
else:
(<Codec>elem_codecs[i]).encode(settings, elem_data, item)
i += 1
record_encode_frame(settings, buf, elem_data, <int32_t>count)
cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
data = self.py_encoder(obj)
if self.xformat == PG_XFORMAT_OBJECT:
if self.format == PG_FORMAT_BINARY:
pgproto.bytea_encode(settings, buf, data)
elif self.format == PG_FORMAT_TEXT:
pgproto.text_encode(settings, buf, data)
else:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
if self.base_codec is not None:
self.base_codec.encode(settings, buf, data)
else:
self.c_encoder(settings, buf, data)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
cdef encode(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
return self.encoder(self, settings, buf, obj)
cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf):
return self.c_decoder(settings, buf)
cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf):
return array_decode(settings, buf, codec_decode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef decode_array_text(self, ConnectionSettings settings,
FRBuffer *buf):
return textarray_decode(settings, buf, codec_decode_func_ex,
<void*>(<cpython.PyObject>self.element_codec),
self.element_delimiter)
cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf):
return range_decode(settings, buf, codec_decode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf):
return multirange_decode(settings, buf, codec_decode_func_ex,
<void*>(<cpython.PyObject>self.element_codec))
cdef decode_composite(self, ConnectionSettings settings,
FRBuffer *buf):
cdef:
object result
ssize_t elem_count
ssize_t i
int32_t elem_len
uint32_t elem_typ
uint32_t received_elem_typ
Codec elem_codec
FRBuffer elem_buf
elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4))
if elem_count != len(self.element_type_oids):
raise exceptions.OutdatedSchemaCacheError(
'unexpected number of attributes of composite type: '
'{}, expected {}'
.format(
elem_count,
len(self.element_type_oids),
),
schema=self.schema,
data_type=self.name,
)
result = record.ApgRecord_New(asyncpg.Record, self.record_desc, elem_count)
for i in range(elem_count):
elem_typ = self.element_type_oids[i]
received_elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
if received_elem_typ != elem_typ:
raise exceptions.OutdatedSchemaCacheError(
'unexpected data type of composite type attribute {}: '
'{!r}, expected {!r}'
.format(
i,
BUILTIN_TYPE_OID_MAP.get(
received_elem_typ, received_elem_typ),
BUILTIN_TYPE_OID_MAP.get(
elem_typ, elem_typ)
),
schema=self.schema,
data_type=self.name,
position=i,
)
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
elem = None
else:
elem_codec = self.element_codecs[i]
elem = elem_codec.decode(
settings, frb_slice_from(&elem_buf, buf, elem_len))
cpython.Py_INCREF(elem)
record.ApgRecord_SET_ITEM(result, i, elem)
return result
cdef decode_in_python(self, ConnectionSettings settings,
FRBuffer *buf):
if self.xformat == PG_XFORMAT_OBJECT:
if self.format == PG_FORMAT_BINARY:
data = pgproto.bytea_decode(settings, buf)
elif self.format == PG_FORMAT_TEXT:
data = pgproto.text_decode(settings, buf)
else:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
if self.base_codec is not None:
data = self.base_codec.decode(settings, buf)
else:
data = self.c_decoder(settings, buf)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
return self.py_decoder(data)
cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf):
return self.decoder(self, settings, buf)
cdef inline has_encoder(self):
cdef Codec elem_codec
if self.c_encoder is not NULL or self.py_encoder is not None:
return True
elif (
self.type == CODEC_ARRAY
or self.type == CODEC_RANGE
or self.type == CODEC_MULTIRANGE
):
return self.element_codec.has_encoder()
elif self.type == CODEC_COMPOSITE:
for elem_codec in self.element_codecs:
if not elem_codec.has_encoder():
return False
return True
else:
return False
cdef has_decoder(self):
cdef Codec elem_codec
if self.c_decoder is not NULL or self.py_decoder is not None:
return True
elif (
self.type == CODEC_ARRAY
or self.type == CODEC_RANGE
or self.type == CODEC_MULTIRANGE
):
return self.element_codec.has_decoder()
elif self.type == CODEC_COMPOSITE:
for elem_codec in self.element_codecs:
if not elem_codec.has_decoder():
return False
return True
else:
return False
cdef is_binary(self):
return self.format == PG_FORMAT_BINARY
def __repr__(self):
return '<Codec oid={} elem_oid={} core={}>'.format(
self.oid,
'NA' if self.element_codec is None else self.element_codec.oid,
has_core_codec(self.oid))
@staticmethod
cdef Codec new_array_codec(uint32_t oid,
str name,
str schema,
Codec element_codec,
Py_UCS4 element_delimiter):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format,
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_codec, None, None, None, element_delimiter)
return codec
@staticmethod
cdef Codec new_range_codec(uint32_t oid,
str name,
str schema,
Codec element_codec):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format,
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
element_codec, None, None, None, 0)
return codec
@staticmethod
cdef Codec new_multirange_codec(uint32_t oid,
str name,
str schema,
Codec element_codec):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'multirange', CODEC_MULTIRANGE,
element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None,
None, None, element_codec, None, None, None, 0)
return codec
@staticmethod
cdef Codec new_composite_codec(uint32_t oid,
str name,
str schema,
ServerDataFormat format,
list element_codecs,
tuple element_type_oids,
object element_names):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'composite', CODEC_COMPOSITE,
format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
None, element_type_oids, element_names, element_codecs, 0)
return codec
@staticmethod
cdef Codec new_python_codec(uint32_t oid,
str name,
str schema,
str kind,
object encoder,
object decoder,
encode_func c_encoder,
decode_func c_decoder,
Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, kind, CODEC_PY, format, xformat,
c_encoder, c_decoder, base_codec, encoder, decoder,
None, None, None, None, 0)
return codec
# Encode callback for arrays
cdef codec_encode_func_ex(ConnectionSettings settings, WriteBuffer buf,
object obj, const void *arg):
return (<Codec>arg).encode(settings, buf, obj)
# Decode callback for arrays
cdef codec_decode_func_ex(ConnectionSettings settings, FRBuffer *buf,
const void *arg):
return (<Codec>arg).decode(settings, buf)
cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
cdef:
int64_t oid = 0
bint overflow = False
try:
oid = cpython.PyLong_AsLongLong(val)
except OverflowError:
overflow = True
if overflow or (oid < 0 or oid > UINT32_MAX):
raise OverflowError('OID value too large: {!r}'.format(val))
return <uint32_t>val
cdef class DataCodecConfig:
def __init__(self, cache_key):
# Codec instance cache for derived types:
# composites, arrays, ranges, domains and their combinations.
self._derived_type_codecs = {}
# Codec instances set up by the user for the connection.
self._custom_type_codecs = {}
def add_types(self, types):
cdef:
Codec elem_codec
list comp_elem_codecs
ServerDataFormat format
ServerDataFormat elem_format
bint has_text_elements
Py_UCS4 elem_delim
for ti in types:
oid = ti['oid']
if self.get_codec(oid, PG_FORMAT_ANY) is not None:
continue
name = ti['name']
schema = ti['ns']
array_element_oid = ti['elemtype']
range_subtype_oid = ti['range_subtype']
if ti['attrtypoids']:
comp_type_attrs = tuple(ti['attrtypoids'])
else:
comp_type_attrs = None
base_type = ti['basetype']
if array_element_oid:
# Array type (note, there is no separate 'kind' for arrays)
# Canonicalize type name to "elemtype[]"
if name.startswith('_'):
name = name[1:]
name = '{}[]'.format(name)
elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_codec = self.declare_fallback_codec(
array_element_oid, ti['elemtype_name'], schema)
elem_delim = <Py_UCS4>ti['elemdelim'][0]
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_array_codec(
oid, name, schema, elem_codec, elem_delim)
elif ti['kind'] == b'c':
# Composite type
if not comp_type_attrs:
raise exceptions.InternalClientError(
f'type record missing field types for composite {oid}')
comp_elem_codecs = []
has_text_elements = False
for typoid in comp_type_attrs:
elem_codec = self.get_codec(typoid, PG_FORMAT_ANY)
if elem_codec is None:
raise exceptions.InternalClientError(
f'no codec for composite attribute type {typoid}')
if elem_codec.format is PG_FORMAT_TEXT:
has_text_elements = True
comp_elem_codecs.append(elem_codec)
element_names = collections.OrderedDict()
for i, attrname in enumerate(ti['attrnames']):
element_names[attrname] = i
# If at least one element is text-encoded, we must
# encode the whole composite as text.
if has_text_elements:
elem_format = PG_FORMAT_TEXT
else:
elem_format = PG_FORMAT_BINARY
self._derived_type_codecs[oid, elem_format] = \
Codec.new_composite_codec(
oid, name, schema, elem_format, comp_elem_codecs,
comp_type_attrs, element_names)
elif ti['kind'] == b'd':
# Domain type
if not base_type:
raise exceptions.InternalClientError(
f'type record missing base type for domain {oid}')
elem_codec = self.get_codec(base_type, PG_FORMAT_ANY)
if elem_codec is None:
elem_codec = self.declare_fallback_codec(
base_type, ti['basetype_name'], schema)
self._derived_type_codecs[oid, elem_codec.format] = elem_codec
elif ti['kind'] == b'r':
# Range type
if not range_subtype_oid:
raise exceptions.InternalClientError(
f'type record missing base type for range {oid}')
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_codec = self.declare_fallback_codec(
range_subtype_oid, ti['range_subtype_name'], schema)
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_range_codec(oid, name, schema, elem_codec)
elif ti['kind'] == b'm':
# Multirange type
if not range_subtype_oid:
raise exceptions.InternalClientError(
f'type record missing base type for multirange {oid}')
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
if elem_codec is None:
elem_codec = self.declare_fallback_codec(
range_subtype_oid, ti['range_subtype_name'], schema)
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_multirange_codec(oid, name, schema, elem_codec)
elif ti['kind'] == b'e':
# Enum types are essentially text
self._set_builtin_type_codec(oid, name, schema, 'scalar',
TEXTOID, PG_FORMAT_ANY)
else:
self.declare_fallback_codec(oid, name, schema)
def add_python_codec(self, typeoid, typename, typeschema, typekind,
typeinfos, encoder, decoder, format, xformat):
cdef:
Codec core_codec = None
encode_func c_encoder = NULL
decode_func c_decoder = NULL
Codec base_codec = None
uint32_t oid = pylong_as_oid(typeoid)
bint codec_set = False
# Clear all previous overrides (this also clears type cache).
self.remove_python_codec(typeoid, typename, typeschema)
if typeinfos:
self.add_types(typeinfos)
if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY)
else:
formats = (format,)
for fmt in formats:
if xformat == PG_XFORMAT_TUPLE:
if typekind == "scalar":
core_codec = get_core_codec(oid, fmt, xformat)
if core_codec is None:
continue
c_encoder = core_codec.c_encoder
c_decoder = core_codec.c_decoder
elif typekind == "composite":
base_codec = self.get_codec(oid, fmt)
if base_codec is None:
continue
self._custom_type_codecs[typeoid, fmt] = \
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
base_codec, fmt, xformat)
codec_set = True
if not codec_set:
raise exceptions.InterfaceError(
"{} type does not support the 'tuple' exchange format".format(
typename))
def remove_python_codec(self, typeoid, typename, typeschema):
for fmt in (PG_FORMAT_BINARY, PG_FORMAT_TEXT):
self._custom_type_codecs.pop((typeoid, fmt), None)
self.clear_type_cache()
def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
alias_to, format=PG_FORMAT_ANY):
cdef:
Codec codec
Codec target_codec
uint32_t oid = pylong_as_oid(typeoid)
uint32_t alias_oid = 0
bint codec_set = False
if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT)
else:
formats = (format,)
if isinstance(alias_to, int):
alias_oid = pylong_as_oid(alias_to)
else:
alias_oid = BUILTIN_TYPE_NAME_MAP.get(alias_to, 0)
for format in formats:
if alias_oid != 0:
target_codec = self.get_codec(alias_oid, format)
else:
target_codec = get_extra_codec(alias_to, format)
if target_codec is None:
continue
codec = target_codec.copy()
codec.oid = typeoid
codec.name = typename
codec.schema = typeschema
codec.kind = typekind
self._custom_type_codecs[typeoid, format] = codec
codec_set = True
if not codec_set:
if format == PG_FORMAT_BINARY:
codec_str = 'binary'
elif format == PG_FORMAT_TEXT:
codec_str = 'text'
else:
codec_str = 'text or binary'
raise exceptions.InterfaceError(
f'cannot alias {typename} to {alias_to}: '
f'there is no {codec_str} codec for {alias_to}')
def set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
alias_to, format=PG_FORMAT_ANY):
self._set_builtin_type_codec(typeoid, typename, typeschema, typekind,
alias_to, format)
self.clear_type_cache()
def clear_type_cache(self):
self._derived_type_codecs.clear()
def declare_fallback_codec(self, uint32_t oid, str name, str schema):
cdef Codec codec
if oid <= MAXBUILTINOID:
# This is a BKI type, for which asyncpg has no
# defined codec. This should only happen for newly
# added builtin types, for which this version of
# asyncpg is lacking support.
#
raise exceptions.UnsupportedClientFeatureError(
f'unhandled standard data type {name!r} (OID {oid})')
else:
# This is a non-BKI type, and as such, has no
# stable OID, so no possibility of a builtin codec.
# In this case, fallback to text format. Applications
# can avoid this by specifying a codec for this type
# using Connection.set_type_codec().
#
self._set_builtin_type_codec(oid, name, schema, 'scalar',
TEXTOID, PG_FORMAT_TEXT)
codec = self.get_codec(oid, PG_FORMAT_TEXT)
return codec
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
bint ignore_custom_codec=False):
cdef Codec codec
if format == PG_FORMAT_ANY:
codec = self.get_codec(
oid, PG_FORMAT_BINARY, ignore_custom_codec)
if codec is None:
codec = self.get_codec(
oid, PG_FORMAT_TEXT, ignore_custom_codec)
return codec
else:
if not ignore_custom_codec:
codec = self.get_custom_codec(oid, PG_FORMAT_ANY)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
# set_{builtin}_type_codec with a different format.
# We must respect that and not return a core codec.
return None
else:
return codec
codec = get_core_codec(oid, format)
if codec is not None:
return codec
else:
try:
return self._derived_type_codecs[oid, format]
except KeyError:
return None
cdef inline Codec get_custom_codec(
self,
uint32_t oid,
ServerDataFormat format
):
cdef Codec codec
if format == PG_FORMAT_ANY:
codec = self.get_custom_codec(oid, PG_FORMAT_BINARY)
if codec is None:
codec = self.get_custom_codec(oid, PG_FORMAT_TEXT)
else:
codec = self._custom_type_codecs.get((oid, format))
return codec
cdef inline Codec get_core_codec(
uint32_t oid, ServerDataFormat format,
ClientExchangeFormat xformat=PG_XFORMAT_OBJECT):
cdef:
void *ptr = NULL
if oid > MAXSUPPORTEDOID:
return None
if format == PG_FORMAT_BINARY:
ptr = binary_codec_map[oid * xformat]
elif format == PG_FORMAT_TEXT:
ptr = text_codec_map[oid * xformat]
if ptr is NULL:
return None
else:
return <Codec>ptr
cdef inline Codec get_any_core_codec(
uint32_t oid, ServerDataFormat format,
ClientExchangeFormat xformat=PG_XFORMAT_OBJECT):
"""A version of get_core_codec that accepts PG_FORMAT_ANY."""
cdef:
Codec codec
if format == PG_FORMAT_ANY:
codec = get_core_codec(oid, PG_FORMAT_BINARY, xformat)
if codec is None:
codec = get_core_codec(oid, PG_FORMAT_TEXT, xformat)
else:
codec = get_core_codec(oid, format, xformat)
return codec
cdef inline int has_core_codec(uint32_t oid):
return binary_codec_map[oid] != NULL or text_codec_map[oid] != NULL
cdef register_core_codec(uint32_t oid,
encode_func encode,
decode_func decode,
ServerDataFormat format,
ClientExchangeFormat xformat=PG_XFORMAT_OBJECT):
if oid > MAXSUPPORTEDOID:
raise exceptions.InternalClientError(
'cannot register core codec for OID {}: it is greater '
'than MAXSUPPORTEDOID ({})'.format(oid, MAXSUPPORTEDOID))
cdef:
Codec codec
str name
str kind
name = BUILTIN_TYPE_OID_MAP[oid]
kind = 'array' if oid in ARRAY_TYPES else 'scalar'
codec = Codec(oid)
codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat,
encode, decode, None, None, None, None, None, None, None, 0)
cpython.Py_INCREF(codec) # immortalize
if format == PG_FORMAT_BINARY:
binary_codec_map[oid * xformat] = <void*>codec
elif format == PG_FORMAT_TEXT:
text_codec_map[oid * xformat] = <void*>codec
else:
raise exceptions.InternalClientError(
'invalid data format: {}'.format(format))
cdef register_extra_codec(str name,
encode_func encode,
decode_func decode,
ServerDataFormat format):
cdef:
Codec codec
str kind
kind = 'scalar'
codec = Codec(INVALIDOID)
codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT,
encode, decode, None, None, None, None, None, None, None, 0)
EXTRA_CODECS[name, format] = codec
cdef inline Codec get_extra_codec(str name, ServerDataFormat format):
return EXTRA_CODECS.get((name, format))

View File

@@ -0,0 +1,484 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef init_bits_codecs():
register_core_codec(BITOID,
<encode_func>pgproto.bits_encode,
<decode_func>pgproto.bits_decode,
PG_FORMAT_BINARY)
register_core_codec(VARBITOID,
<encode_func>pgproto.bits_encode,
<decode_func>pgproto.bits_decode,
PG_FORMAT_BINARY)
cdef init_bytea_codecs():
register_core_codec(BYTEAOID,
<encode_func>pgproto.bytea_encode,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
register_core_codec(CHAROID,
<encode_func>pgproto.bytea_encode,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
cdef init_datetime_codecs():
register_core_codec(DATEOID,
<encode_func>pgproto.date_encode,
<decode_func>pgproto.date_decode,
PG_FORMAT_BINARY)
register_core_codec(DATEOID,
<encode_func>pgproto.date_encode_tuple,
<decode_func>pgproto.date_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(TIMEOID,
<encode_func>pgproto.time_encode,
<decode_func>pgproto.time_decode,
PG_FORMAT_BINARY)
register_core_codec(TIMEOID,
<encode_func>pgproto.time_encode_tuple,
<decode_func>pgproto.time_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(TIMETZOID,
<encode_func>pgproto.timetz_encode,
<decode_func>pgproto.timetz_decode,
PG_FORMAT_BINARY)
register_core_codec(TIMETZOID,
<encode_func>pgproto.timetz_encode_tuple,
<decode_func>pgproto.timetz_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(TIMESTAMPOID,
<encode_func>pgproto.timestamp_encode,
<decode_func>pgproto.timestamp_decode,
PG_FORMAT_BINARY)
register_core_codec(TIMESTAMPOID,
<encode_func>pgproto.timestamp_encode_tuple,
<decode_func>pgproto.timestamp_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(TIMESTAMPTZOID,
<encode_func>pgproto.timestamptz_encode,
<decode_func>pgproto.timestamptz_decode,
PG_FORMAT_BINARY)
register_core_codec(TIMESTAMPTZOID,
<encode_func>pgproto.timestamp_encode_tuple,
<decode_func>pgproto.timestamp_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
register_core_codec(INTERVALOID,
<encode_func>pgproto.interval_encode,
<decode_func>pgproto.interval_decode,
PG_FORMAT_BINARY)
register_core_codec(INTERVALOID,
<encode_func>pgproto.interval_encode_tuple,
<decode_func>pgproto.interval_decode_tuple,
PG_FORMAT_BINARY,
PG_XFORMAT_TUPLE)
# For obsolete abstime/reltime/tinterval, we do not bother to
# interpret the value, and simply return and pass it as text.
#
register_core_codec(ABSTIMEOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(RELTIMEOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(TINTERVALOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_float_codecs():
register_core_codec(FLOAT4OID,
<encode_func>pgproto.float4_encode,
<decode_func>pgproto.float4_decode,
PG_FORMAT_BINARY)
register_core_codec(FLOAT8OID,
<encode_func>pgproto.float8_encode,
<decode_func>pgproto.float8_decode,
PG_FORMAT_BINARY)
cdef init_geometry_codecs():
register_core_codec(BOXOID,
<encode_func>pgproto.box_encode,
<decode_func>pgproto.box_decode,
PG_FORMAT_BINARY)
register_core_codec(LINEOID,
<encode_func>pgproto.line_encode,
<decode_func>pgproto.line_decode,
PG_FORMAT_BINARY)
register_core_codec(LSEGOID,
<encode_func>pgproto.lseg_encode,
<decode_func>pgproto.lseg_decode,
PG_FORMAT_BINARY)
register_core_codec(POINTOID,
<encode_func>pgproto.point_encode,
<decode_func>pgproto.point_decode,
PG_FORMAT_BINARY)
register_core_codec(PATHOID,
<encode_func>pgproto.path_encode,
<decode_func>pgproto.path_decode,
PG_FORMAT_BINARY)
register_core_codec(POLYGONOID,
<encode_func>pgproto.poly_encode,
<decode_func>pgproto.poly_decode,
PG_FORMAT_BINARY)
register_core_codec(CIRCLEOID,
<encode_func>pgproto.circle_encode,
<decode_func>pgproto.circle_decode,
PG_FORMAT_BINARY)
cdef init_hstore_codecs():
register_extra_codec('pg_contrib.hstore',
<encode_func>pgproto.hstore_encode,
<decode_func>pgproto.hstore_decode,
PG_FORMAT_BINARY)
cdef init_json_codecs():
register_core_codec(JSONOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_BINARY)
register_core_codec(JSONBOID,
<encode_func>pgproto.jsonb_encode,
<decode_func>pgproto.jsonb_decode,
PG_FORMAT_BINARY)
register_core_codec(JSONPATHOID,
<encode_func>pgproto.jsonpath_encode,
<decode_func>pgproto.jsonpath_decode,
PG_FORMAT_BINARY)
cdef init_int_codecs():
register_core_codec(BOOLOID,
<encode_func>pgproto.bool_encode,
<decode_func>pgproto.bool_decode,
PG_FORMAT_BINARY)
register_core_codec(INT2OID,
<encode_func>pgproto.int2_encode,
<decode_func>pgproto.int2_decode,
PG_FORMAT_BINARY)
register_core_codec(INT4OID,
<encode_func>pgproto.int4_encode,
<decode_func>pgproto.int4_decode,
PG_FORMAT_BINARY)
register_core_codec(INT8OID,
<encode_func>pgproto.int8_encode,
<decode_func>pgproto.int8_decode,
PG_FORMAT_BINARY)
cdef init_pseudo_codecs():
# Void type is returned by SELECT void_returning_function()
register_core_codec(VOIDOID,
<encode_func>pgproto.void_encode,
<decode_func>pgproto.void_decode,
PG_FORMAT_BINARY)
# Unknown type, always decoded as text
register_core_codec(UNKNOWNOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# OID and friends
oid_types = [
OIDOID, XIDOID, CIDOID
]
for oid_type in oid_types:
register_core_codec(oid_type,
<encode_func>pgproto.uint4_encode,
<decode_func>pgproto.uint4_decode,
PG_FORMAT_BINARY)
# 64-bit OID types
oid8_types = [
XID8OID,
]
for oid_type in oid8_types:
register_core_codec(oid_type,
<encode_func>pgproto.uint8_encode,
<decode_func>pgproto.uint8_decode,
PG_FORMAT_BINARY)
# reg* types -- these are really system catalog OIDs, but
# allow the catalog object name as an input. We could just
# decode these as OIDs, but handling them as text seems more
# useful.
#
reg_types = [
REGPROCOID, REGPROCEDUREOID, REGOPEROID, REGOPERATOROID,
REGCLASSOID, REGTYPEOID, REGCONFIGOID, REGDICTIONARYOID,
REGNAMESPACEOID, REGROLEOID, REFCURSOROID, REGCOLLATIONOID,
]
for reg_type in reg_types:
register_core_codec(reg_type,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# cstring type is used by Postgres' I/O functions
register_core_codec(CSTRINGOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_BINARY)
# various system pseudotypes with no I/O
no_io_types = [
ANYOID, TRIGGEROID, EVENT_TRIGGEROID, LANGUAGE_HANDLEROID,
FDW_HANDLEROID, TSM_HANDLEROID, INTERNALOID, OPAQUEOID,
ANYELEMENTOID, ANYNONARRAYOID, ANYCOMPATIBLEOID,
ANYCOMPATIBLEARRAYOID, ANYCOMPATIBLENONARRAYOID,
ANYCOMPATIBLERANGEOID, ANYCOMPATIBLEMULTIRANGEOID,
ANYRANGEOID, ANYMULTIRANGEOID, ANYARRAYOID,
PG_DDL_COMMANDOID, INDEX_AM_HANDLEROID, TABLE_AM_HANDLEROID,
]
register_core_codec(ANYENUMOID,
NULL,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
for no_io_type in no_io_types:
register_core_codec(no_io_type,
NULL,
NULL,
PG_FORMAT_BINARY)
# ACL specification string
register_core_codec(ACLITEMOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# Postgres' serialized expression tree type
register_core_codec(PG_NODE_TREEOID,
NULL,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# pg_lsn type -- a pointer to a location in the XLOG.
register_core_codec(PG_LSNOID,
<encode_func>pgproto.int8_encode,
<decode_func>pgproto.int8_decode,
PG_FORMAT_BINARY)
register_core_codec(SMGROID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# pg_dependencies and pg_ndistinct are special types
# used in pg_statistic_ext columns.
register_core_codec(PG_DEPENDENCIESOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(PG_NDISTINCTOID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
# pg_mcv_list is a special type used in pg_statistic_ext_data
# system catalog
register_core_codec(PG_MCV_LISTOID,
<encode_func>pgproto.bytea_encode,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
# These two are internal to BRIN index support and are unlikely
# to be sent, but since I/O functions for these exist, add decoders
# nonetheless.
register_core_codec(PG_BRIN_BLOOM_SUMMARYOID,
NULL,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
register_core_codec(PG_BRIN_MINMAX_MULTI_SUMMARYOID,
NULL,
<decode_func>pgproto.bytea_decode,
PG_FORMAT_BINARY)
cdef init_text_codecs():
textoids = [
NAMEOID,
BPCHAROID,
VARCHAROID,
TEXTOID,
XMLOID
]
for oid in textoids:
register_core_codec(oid,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_BINARY)
register_core_codec(oid,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_tid_codecs():
register_core_codec(TIDOID,
<encode_func>pgproto.tid_encode,
<decode_func>pgproto.tid_decode,
PG_FORMAT_BINARY)
cdef init_txid_codecs():
register_core_codec(TXID_SNAPSHOTOID,
<encode_func>pgproto.pg_snapshot_encode,
<decode_func>pgproto.pg_snapshot_decode,
PG_FORMAT_BINARY)
register_core_codec(PG_SNAPSHOTOID,
<encode_func>pgproto.pg_snapshot_encode,
<decode_func>pgproto.pg_snapshot_decode,
PG_FORMAT_BINARY)
cdef init_tsearch_codecs():
ts_oids = [
TSQUERYOID,
TSVECTOROID,
]
for oid in ts_oids:
register_core_codec(oid,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(GTSVECTOROID,
NULL,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_uuid_codecs():
register_core_codec(UUIDOID,
<encode_func>pgproto.uuid_encode,
<decode_func>pgproto.uuid_decode,
PG_FORMAT_BINARY)
cdef init_numeric_codecs():
register_core_codec(NUMERICOID,
<encode_func>pgproto.numeric_encode_text,
<decode_func>pgproto.numeric_decode_text,
PG_FORMAT_TEXT)
register_core_codec(NUMERICOID,
<encode_func>pgproto.numeric_encode_binary,
<decode_func>pgproto.numeric_decode_binary,
PG_FORMAT_BINARY)
cdef init_network_codecs():
register_core_codec(CIDROID,
<encode_func>pgproto.cidr_encode,
<decode_func>pgproto.cidr_decode,
PG_FORMAT_BINARY)
register_core_codec(INETOID,
<encode_func>pgproto.inet_encode,
<decode_func>pgproto.inet_decode,
PG_FORMAT_BINARY)
register_core_codec(MACADDROID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
register_core_codec(MACADDR8OID,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_monetary_codecs():
moneyoids = [
MONEYOID,
]
for oid in moneyoids:
register_core_codec(oid,
<encode_func>pgproto.text_encode,
<decode_func>pgproto.text_decode,
PG_FORMAT_TEXT)
cdef init_all_pgproto_codecs():
# Builtin types, in lexicographical order.
init_bits_codecs()
init_bytea_codecs()
init_datetime_codecs()
init_float_codecs()
init_geometry_codecs()
init_int_codecs()
init_json_codecs()
init_monetary_codecs()
init_network_codecs()
init_numeric_codecs()
init_text_codecs()
init_tid_codecs()
init_tsearch_codecs()
init_txid_codecs()
init_uuid_codecs()
# Various pseudotypes and system types
init_pseudo_codecs()
# contrib
init_hstore_codecs()
init_all_pgproto_codecs()

View File

@@ -0,0 +1,207 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from asyncpg import types as apg_types
from collections.abc import Sequence as SequenceABC
# defined in postgresql/src/include/utils/rangetypes.h
DEF RANGE_EMPTY = 0x01 # range is empty
DEF RANGE_LB_INC = 0x02 # lower bound is inclusive
DEF RANGE_UB_INC = 0x04 # upper bound is inclusive
DEF RANGE_LB_INF = 0x08 # lower bound is -infinity
DEF RANGE_UB_INF = 0x10 # upper bound is +infinity
cdef enum _RangeArgumentType:
_RANGE_ARGUMENT_INVALID = 0
_RANGE_ARGUMENT_TUPLE = 1
_RANGE_ARGUMENT_RANGE = 2
cdef inline bint _range_has_lbound(uint8_t flags):
return not (flags & (RANGE_EMPTY | RANGE_LB_INF))
cdef inline bint _range_has_ubound(uint8_t flags):
return not (flags & (RANGE_EMPTY | RANGE_UB_INF))
cdef inline _RangeArgumentType _range_type(object obj):
if cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj):
return _RANGE_ARGUMENT_TUPLE
elif isinstance(obj, apg_types.Range):
return _RANGE_ARGUMENT_RANGE
else:
return _RANGE_ARGUMENT_INVALID
cdef range_encode(ConnectionSettings settings, WriteBuffer buf,
object obj, uint32_t elem_oid,
encode_func_ex encoder, const void *encoder_arg):
cdef:
ssize_t obj_len
uint8_t flags = 0
object lower = None
object upper = None
WriteBuffer bounds_data = WriteBuffer.new()
_RangeArgumentType arg_type = _range_type(obj)
if arg_type == _RANGE_ARGUMENT_INVALID:
raise TypeError(
'list, tuple or Range object expected (got type {})'.format(
type(obj)))
elif arg_type == _RANGE_ARGUMENT_TUPLE:
obj_len = len(obj)
if obj_len == 2:
lower = obj[0]
upper = obj[1]
if lower is None:
flags |= RANGE_LB_INF
if upper is None:
flags |= RANGE_UB_INF
flags |= RANGE_LB_INC | RANGE_UB_INC
elif obj_len == 1:
lower = obj[0]
flags |= RANGE_LB_INC | RANGE_UB_INF
elif obj_len == 0:
flags |= RANGE_EMPTY
else:
raise ValueError(
'expected 0, 1 or 2 elements in range (got {})'.format(
obj_len))
else:
if obj.isempty:
flags |= RANGE_EMPTY
else:
lower = obj.lower
upper = obj.upper
if obj.lower_inc:
flags |= RANGE_LB_INC
elif lower is None:
flags |= RANGE_LB_INF
if obj.upper_inc:
flags |= RANGE_UB_INC
elif upper is None:
flags |= RANGE_UB_INF
if _range_has_lbound(flags):
encoder(settings, bounds_data, lower, encoder_arg)
if _range_has_ubound(flags):
encoder(settings, bounds_data, upper, encoder_arg)
buf.write_int32(1 + bounds_data.len())
buf.write_byte(<int8_t>flags)
buf.write_buffer(bounds_data)
cdef range_decode(ConnectionSettings settings, FRBuffer *buf,
decode_func_ex decoder, const void *decoder_arg):
cdef:
uint8_t flags = <uint8_t>frb_read(buf, 1)[0]
int32_t bound_len
object lower = None
object upper = None
FRBuffer bound_buf
if _range_has_lbound(flags):
bound_len = hton.unpack_int32(frb_read(buf, 4))
if bound_len == -1:
lower = None
else:
frb_slice_from(&bound_buf, buf, bound_len)
lower = decoder(settings, &bound_buf, decoder_arg)
if _range_has_ubound(flags):
bound_len = hton.unpack_int32(frb_read(buf, 4))
if bound_len == -1:
upper = None
else:
frb_slice_from(&bound_buf, buf, bound_len)
upper = decoder(settings, &bound_buf, decoder_arg)
return apg_types.Range(lower=lower, upper=upper,
lower_inc=(flags & RANGE_LB_INC) != 0,
upper_inc=(flags & RANGE_UB_INC) != 0,
empty=(flags & RANGE_EMPTY) != 0)
cdef multirange_encode(ConnectionSettings settings, WriteBuffer buf,
object obj, uint32_t elem_oid,
encode_func_ex encoder, const void *encoder_arg):
cdef:
WriteBuffer elem_data
ssize_t elem_data_len
ssize_t elem_count
if not isinstance(obj, SequenceABC):
raise TypeError(
'expected a sequence (got type {!r})'.format(type(obj).__name__)
)
elem_data = WriteBuffer.new()
for elem in obj:
range_encode(settings, elem_data, elem, elem_oid, encoder, encoder_arg)
elem_count = len(obj)
if elem_count > INT32_MAX:
raise OverflowError(f'too many elements in multirange value')
elem_data_len = elem_data.len()
if elem_data_len > INT32_MAX - 4:
raise OverflowError(
f'size of encoded multirange datum exceeds the maximum allowed'
f' {INT32_MAX - 4} bytes')
# Datum length
buf.write_int32(4 + <int32_t>elem_data_len)
# Number of elements in multirange
buf.write_int32(<int32_t>elem_count)
buf.write_buffer(elem_data)
cdef multirange_decode(ConnectionSettings settings, FRBuffer *buf,
decode_func_ex decoder, const void *decoder_arg):
cdef:
int32_t nelems = hton.unpack_int32(frb_read(buf, 4))
FRBuffer elem_buf
int32_t elem_len
int i
list result
if nelems == 0:
return []
if nelems < 0:
raise exceptions.ProtocolError(
'unexpected multirange size value: {}'.format(nelems))
result = cpython.PyList_New(nelems)
for i in range(nelems):
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
raise exceptions.ProtocolError(
'unexpected NULL element in multirange value')
else:
frb_slice_from(&elem_buf, buf, elem_len)
elem = range_decode(settings, &elem_buf, decoder, decoder_arg)
cpython.Py_INCREF(elem)
cpython.PyList_SET_ITEM(result, i, elem)
return result

View File

@@ -0,0 +1,71 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from asyncpg import exceptions
cdef inline record_encode_frame(ConnectionSettings settings, WriteBuffer buf,
WriteBuffer elem_data, int32_t elem_count):
buf.write_int32(4 + elem_data.len())
# attribute count
buf.write_int32(elem_count)
# encoded attribute data
buf.write_buffer(elem_data)
cdef anonymous_record_decode(ConnectionSettings settings, FRBuffer *buf):
cdef:
tuple result
ssize_t elem_count
ssize_t i
int32_t elem_len
uint32_t elem_typ
Codec elem_codec
FRBuffer elem_buf
elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4))
result = cpython.PyTuple_New(elem_count)
for i in range(elem_count):
elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len == -1:
elem = None
else:
elem_codec = settings.get_data_codec(elem_typ)
if elem_codec is None or not elem_codec.has_decoder():
raise exceptions.InternalClientError(
'no decoder for composite type element in '
'position {} of type OID {}'.format(i, elem_typ))
elem = elem_codec.decode(settings,
frb_slice_from(&elem_buf, buf, elem_len))
cpython.Py_INCREF(elem)
cpython.PyTuple_SET_ITEM(result, i, elem)
return result
cdef anonymous_record_encode(ConnectionSettings settings, WriteBuffer buf, obj):
raise exceptions.UnsupportedClientFeatureError(
'input of anonymous composite types is not supported',
hint=(
'Consider declaring an explicit composite type and '
'using it to cast the argument.'
),
detail='PostgreSQL does not implement anonymous composite type input.'
)
cdef init_record_codecs():
register_core_codec(RECORDOID,
<encode_func>anonymous_record_encode,
<decode_func>anonymous_record_decode,
PG_FORMAT_BINARY)
init_record_codecs()

View File

@@ -0,0 +1,99 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef inline uint32_t _apg_tolower(uint32_t c):
if c >= <uint32_t><Py_UCS4>'A' and c <= <uint32_t><Py_UCS4>'Z':
return c + <uint32_t><Py_UCS4>'a' - <uint32_t><Py_UCS4>'A'
else:
return c
cdef int apg_strcasecmp(const Py_UCS4 *s1, const Py_UCS4 *s2):
cdef:
uint32_t c1
uint32_t c2
int i = 0
while True:
c1 = s1[i]
c2 = s2[i]
if c1 != c2:
c1 = _apg_tolower(c1)
c2 = _apg_tolower(c2)
if c1 != c2:
return <int32_t>c1 - <int32_t>c2
if c1 == 0 or c2 == 0:
break
i += 1
return 0
cdef int apg_strcasecmp_char(const char *s1, const char *s2):
cdef:
uint8_t c1
uint8_t c2
int i = 0
while True:
c1 = <uint8_t>s1[i]
c2 = <uint8_t>s2[i]
if c1 != c2:
c1 = <uint8_t>_apg_tolower(c1)
c2 = <uint8_t>_apg_tolower(c2)
if c1 != c2:
return <int8_t>c1 - <int8_t>c2
if c1 == 0 or c2 == 0:
break
i += 1
return 0
cdef inline bint apg_ascii_isspace(Py_UCS4 ch):
return (
ch == ' ' or
ch == '\n' or
ch == '\r' or
ch == '\t' or
ch == '\v' or
ch == '\f'
)
cdef Py_UCS4 *apg_parse_int32(Py_UCS4 *buf, int32_t *num):
cdef:
Py_UCS4 *p
int32_t n = 0
int32_t neg = 0
if buf[0] == '-':
neg = 1
buf += 1
elif buf[0] == '+':
buf += 1
p = buf
while <int>p[0] >= <int><Py_UCS4>'0' and <int>p[0] <= <int><Py_UCS4>'9':
n = 10 * n - (<int>p[0] - <int32_t><Py_UCS4>'0')
p += 1
if p == buf:
return NULL
if not neg:
n = -n
num[0] = n
return p

View File

@@ -0,0 +1,12 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
DEF _MAXINT32 = 2**31 - 1
DEF _COPY_BUFFER_SIZE = 524288
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
DEF _EXECUTE_MANY_BUF_NUM = 4
DEF _EXECUTE_MANY_BUF_SIZE = 32768

View File

@@ -0,0 +1,195 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
include "scram.pxd"
cdef enum ConnectionStatus:
CONNECTION_OK = 1
CONNECTION_BAD = 2
CONNECTION_STARTED = 3 # Waiting for connection to be made.
cdef enum ProtocolState:
PROTOCOL_IDLE = 0
PROTOCOL_FAILED = 1
PROTOCOL_ERROR_CONSUME = 2
PROTOCOL_CANCELLED = 3
PROTOCOL_TERMINATING = 4
PROTOCOL_AUTH = 10
PROTOCOL_PREPARE = 11
PROTOCOL_BIND_EXECUTE = 12
PROTOCOL_BIND_EXECUTE_MANY = 13
PROTOCOL_CLOSE_STMT_PORTAL = 14
PROTOCOL_SIMPLE_QUERY = 15
PROTOCOL_EXECUTE = 16
PROTOCOL_BIND = 17
PROTOCOL_COPY_OUT = 18
PROTOCOL_COPY_OUT_DATA = 19
PROTOCOL_COPY_OUT_DONE = 20
PROTOCOL_COPY_IN = 21
PROTOCOL_COPY_IN_DATA = 22
cdef enum AuthenticationMessage:
AUTH_SUCCESSFUL = 0
AUTH_REQUIRED_KERBEROS = 2
AUTH_REQUIRED_PASSWORD = 3
AUTH_REQUIRED_PASSWORDMD5 = 5
AUTH_REQUIRED_SCMCRED = 6
AUTH_REQUIRED_GSS = 7
AUTH_REQUIRED_GSS_CONTINUE = 8
AUTH_REQUIRED_SSPI = 9
AUTH_REQUIRED_SASL = 10
AUTH_SASL_CONTINUE = 11
AUTH_SASL_FINAL = 12
AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}
cdef enum ResultType:
RESULT_OK = 1
RESULT_FAILED = 2
cdef enum TransactionStatus:
PQTRANS_IDLE = 0 # connection idle
PQTRANS_ACTIVE = 1 # command in progress
PQTRANS_INTRANS = 2 # idle, within transaction block
PQTRANS_INERROR = 3 # idle, within failed transaction
PQTRANS_UNKNOWN = 4 # cannot determine status
ctypedef object (*decode_row_method)(object, const char*, ssize_t)
cdef class CoreProtocol:
cdef:
ReadBuffer buffer
bint _skip_discard
bint _discard_data
# executemany support data
object _execute_iter
str _execute_portal_name
str _execute_stmt_name
ConnectionStatus con_status
ProtocolState state
TransactionStatus xact_status
str encoding
object transport
# Instance of _ConnectionParameters
object con_params
# Instance of SCRAMAuthentication
SCRAMAuthentication scram
readonly int32_t backend_pid
readonly int32_t backend_secret
## Result
ResultType result_type
object result
bytes result_param_desc
bytes result_row_desc
bytes result_status_msg
# True - completed, False - suspended
bint result_execute_completed
cpdef is_in_transaction(self)
cdef _process__auth(self, char mtype)
cdef _process__prepare(self, char mtype)
cdef _process__bind_execute(self, char mtype)
cdef _process__bind_execute_many(self, char mtype)
cdef _process__close_stmt_portal(self, char mtype)
cdef _process__simple_query(self, char mtype)
cdef _process__bind(self, char mtype)
cdef _process__copy_out(self, char mtype)
cdef _process__copy_out_data(self, char mtype)
cdef _process__copy_in(self, char mtype)
cdef _process__copy_in_data(self, char mtype)
cdef _parse_msg_authentication(self)
cdef _parse_msg_parameter_status(self)
cdef _parse_msg_notification(self)
cdef _parse_msg_backend_key_data(self)
cdef _parse_msg_ready_for_query(self)
cdef _parse_data_msgs(self)
cdef _parse_copy_data_msgs(self)
cdef _parse_msg_error_response(self, is_error)
cdef _parse_msg_command_complete(self)
cdef _write_copy_data_msg(self, object data)
cdef _write_copy_done_msg(self)
cdef _write_copy_fail_msg(self, str cause)
cdef _auth_password_message_cleartext(self)
cdef _auth_password_message_md5(self, bytes salt)
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
cdef _auth_password_message_sasl_continue(self, bytes server_response)
cdef _write(self, buf)
cdef _writelines(self, list buffers)
cdef _read_server_messages(self)
cdef _push_result(self)
cdef _reset_result(self)
cdef _set_state(self, ProtocolState new_state)
cdef _ensure_connected(self)
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
cdef WriteBuffer _build_bind_message(self, str portal_name,
str stmt_name,
WriteBuffer bind_data)
cdef WriteBuffer _build_empty_bind_data(self)
cdef WriteBuffer _build_execute_message(self, str portal_name,
int32_t limit)
cdef _connect(self)
cdef _prepare_and_describe(self, str stmt_name, str query)
cdef _send_parse_message(self, str stmt_name, str query)
cdef _send_bind_message(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
object bind_data)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,
WriteBuffer bind_data)
cdef _execute(self, str portal_name, int32_t limit)
cdef _close(self, str name, bint is_portal)
cdef _simple_query(self, str query)
cdef _copy_out(self, str copy_stmt)
cdef _copy_in(self, str copy_stmt)
cdef _terminate(self)
cdef _decode_row(self, const char* buf, ssize_t buf_len)
cdef _on_result(self)
cdef _on_notification(self, pid, channel, payload)
cdef _on_notice(self, parsed)
cdef _set_server_parameter(self, name, val)
cdef _on_connection_lost(self, exc)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef extern from "Python.h":
int PyByteArray_Check(object)
int PyMemoryView_Check(object)
Py_buffer *PyMemoryView_GET_BUFFER(object)
object PyMemoryView_GetContiguous(object, int buffertype, char order)
Py_UCS4* PyUnicode_AsUCS4Copy(object) except NULL
object PyUnicode_FromKindAndData(
int kind, const void *buffer, Py_ssize_t size)
int PyUnicode_4BYTE_KIND

View File

@@ -0,0 +1,63 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
'''Map PostgreSQL encoding names to Python encoding names
https://www.postgresql.org/docs/current/static/multibyte.html#CHARSET-TABLE
'''
cdef dict ENCODINGS_MAP = {
'abc': 'cp1258',
'alt': 'cp866',
'euc_cn': 'euccn',
'euc_jp': 'eucjp',
'euc_kr': 'euckr',
'koi8r': 'koi8_r',
'koi8u': 'koi8_u',
'shift_jis_2004': 'euc_jis_2004',
'sjis': 'shift_jis',
'sql_ascii': 'ascii',
'vscii': 'cp1258',
'tcvn': 'cp1258',
'tcvn5712': 'cp1258',
'unicode': 'utf_8',
'win': 'cp1521',
'win1250': 'cp1250',
'win1251': 'cp1251',
'win1252': 'cp1252',
'win1253': 'cp1253',
'win1254': 'cp1254',
'win1255': 'cp1255',
'win1256': 'cp1256',
'win1257': 'cp1257',
'win1258': 'cp1258',
'win866': 'cp866',
'win874': 'cp874',
'win932': 'cp932',
'win936': 'cp936',
'win949': 'cp949',
'win950': 'cp950',
'windows1250': 'cp1250',
'windows1251': 'cp1251',
'windows1252': 'cp1252',
'windows1253': 'cp1253',
'windows1254': 'cp1254',
'windows1255': 'cp1255',
'windows1256': 'cp1256',
'windows1257': 'cp1257',
'windows1258': 'cp1258',
'windows866': 'cp866',
'windows874': 'cp874',
'windows932': 'cp932',
'windows936': 'cp936',
'windows949': 'cp949',
'windows950': 'cp950',
}
cdef get_python_encoding(pg_encoding):
return ENCODINGS_MAP.get(pg_encoding.lower(), pg_encoding.lower())

View File

@@ -0,0 +1,266 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
# GENERATED FROM pg_catalog.pg_type
# DO NOT MODIFY, use tools/generate_type_map.py to update
DEF INVALIDOID = 0
DEF MAXBUILTINOID = 9999
DEF MAXSUPPORTEDOID = 5080
DEF BOOLOID = 16
DEF BYTEAOID = 17
DEF CHAROID = 18
DEF NAMEOID = 19
DEF INT8OID = 20
DEF INT2OID = 21
DEF INT4OID = 23
DEF REGPROCOID = 24
DEF TEXTOID = 25
DEF OIDOID = 26
DEF TIDOID = 27
DEF XIDOID = 28
DEF CIDOID = 29
DEF PG_DDL_COMMANDOID = 32
DEF JSONOID = 114
DEF XMLOID = 142
DEF PG_NODE_TREEOID = 194
DEF SMGROID = 210
DEF TABLE_AM_HANDLEROID = 269
DEF INDEX_AM_HANDLEROID = 325
DEF POINTOID = 600
DEF LSEGOID = 601
DEF PATHOID = 602
DEF BOXOID = 603
DEF POLYGONOID = 604
DEF LINEOID = 628
DEF CIDROID = 650
DEF FLOAT4OID = 700
DEF FLOAT8OID = 701
DEF ABSTIMEOID = 702
DEF RELTIMEOID = 703
DEF TINTERVALOID = 704
DEF UNKNOWNOID = 705
DEF CIRCLEOID = 718
DEF MACADDR8OID = 774
DEF MONEYOID = 790
DEF MACADDROID = 829
DEF INETOID = 869
DEF _TEXTOID = 1009
DEF _OIDOID = 1028
DEF ACLITEMOID = 1033
DEF BPCHAROID = 1042
DEF VARCHAROID = 1043
DEF DATEOID = 1082
DEF TIMEOID = 1083
DEF TIMESTAMPOID = 1114
DEF TIMESTAMPTZOID = 1184
DEF INTERVALOID = 1186
DEF TIMETZOID = 1266
DEF BITOID = 1560
DEF VARBITOID = 1562
DEF NUMERICOID = 1700
DEF REFCURSOROID = 1790
DEF REGPROCEDUREOID = 2202
DEF REGOPEROID = 2203
DEF REGOPERATOROID = 2204
DEF REGCLASSOID = 2205
DEF REGTYPEOID = 2206
DEF RECORDOID = 2249
DEF CSTRINGOID = 2275
DEF ANYOID = 2276
DEF ANYARRAYOID = 2277
DEF VOIDOID = 2278
DEF TRIGGEROID = 2279
DEF LANGUAGE_HANDLEROID = 2280
DEF INTERNALOID = 2281
DEF OPAQUEOID = 2282
DEF ANYELEMENTOID = 2283
DEF ANYNONARRAYOID = 2776
DEF UUIDOID = 2950
DEF TXID_SNAPSHOTOID = 2970
DEF FDW_HANDLEROID = 3115
DEF PG_LSNOID = 3220
DEF TSM_HANDLEROID = 3310
DEF PG_NDISTINCTOID = 3361
DEF PG_DEPENDENCIESOID = 3402
DEF ANYENUMOID = 3500
DEF TSVECTOROID = 3614
DEF TSQUERYOID = 3615
DEF GTSVECTOROID = 3642
DEF REGCONFIGOID = 3734
DEF REGDICTIONARYOID = 3769
DEF JSONBOID = 3802
DEF ANYRANGEOID = 3831
DEF EVENT_TRIGGEROID = 3838
DEF JSONPATHOID = 4072
DEF REGNAMESPACEOID = 4089
DEF REGROLEOID = 4096
DEF REGCOLLATIONOID = 4191
DEF ANYMULTIRANGEOID = 4537
DEF ANYCOMPATIBLEMULTIRANGEOID = 4538
DEF PG_BRIN_BLOOM_SUMMARYOID = 4600
DEF PG_BRIN_MINMAX_MULTI_SUMMARYOID = 4601
DEF PG_MCV_LISTOID = 5017
DEF PG_SNAPSHOTOID = 5038
DEF XID8OID = 5069
DEF ANYCOMPATIBLEOID = 5077
DEF ANYCOMPATIBLEARRAYOID = 5078
DEF ANYCOMPATIBLENONARRAYOID = 5079
DEF ANYCOMPATIBLERANGEOID = 5080
cdef ARRAY_TYPES = (_TEXTOID, _OIDOID,)
BUILTIN_TYPE_OID_MAP = {
ABSTIMEOID: 'abstime',
ACLITEMOID: 'aclitem',
ANYARRAYOID: 'anyarray',
ANYCOMPATIBLEARRAYOID: 'anycompatiblearray',
ANYCOMPATIBLEMULTIRANGEOID: 'anycompatiblemultirange',
ANYCOMPATIBLENONARRAYOID: 'anycompatiblenonarray',
ANYCOMPATIBLEOID: 'anycompatible',
ANYCOMPATIBLERANGEOID: 'anycompatiblerange',
ANYELEMENTOID: 'anyelement',
ANYENUMOID: 'anyenum',
ANYMULTIRANGEOID: 'anymultirange',
ANYNONARRAYOID: 'anynonarray',
ANYOID: 'any',
ANYRANGEOID: 'anyrange',
BITOID: 'bit',
BOOLOID: 'bool',
BOXOID: 'box',
BPCHAROID: 'bpchar',
BYTEAOID: 'bytea',
CHAROID: 'char',
CIDOID: 'cid',
CIDROID: 'cidr',
CIRCLEOID: 'circle',
CSTRINGOID: 'cstring',
DATEOID: 'date',
EVENT_TRIGGEROID: 'event_trigger',
FDW_HANDLEROID: 'fdw_handler',
FLOAT4OID: 'float4',
FLOAT8OID: 'float8',
GTSVECTOROID: 'gtsvector',
INDEX_AM_HANDLEROID: 'index_am_handler',
INETOID: 'inet',
INT2OID: 'int2',
INT4OID: 'int4',
INT8OID: 'int8',
INTERNALOID: 'internal',
INTERVALOID: 'interval',
JSONBOID: 'jsonb',
JSONOID: 'json',
JSONPATHOID: 'jsonpath',
LANGUAGE_HANDLEROID: 'language_handler',
LINEOID: 'line',
LSEGOID: 'lseg',
MACADDR8OID: 'macaddr8',
MACADDROID: 'macaddr',
MONEYOID: 'money',
NAMEOID: 'name',
NUMERICOID: 'numeric',
OIDOID: 'oid',
OPAQUEOID: 'opaque',
PATHOID: 'path',
PG_BRIN_BLOOM_SUMMARYOID: 'pg_brin_bloom_summary',
PG_BRIN_MINMAX_MULTI_SUMMARYOID: 'pg_brin_minmax_multi_summary',
PG_DDL_COMMANDOID: 'pg_ddl_command',
PG_DEPENDENCIESOID: 'pg_dependencies',
PG_LSNOID: 'pg_lsn',
PG_MCV_LISTOID: 'pg_mcv_list',
PG_NDISTINCTOID: 'pg_ndistinct',
PG_NODE_TREEOID: 'pg_node_tree',
PG_SNAPSHOTOID: 'pg_snapshot',
POINTOID: 'point',
POLYGONOID: 'polygon',
RECORDOID: 'record',
REFCURSOROID: 'refcursor',
REGCLASSOID: 'regclass',
REGCOLLATIONOID: 'regcollation',
REGCONFIGOID: 'regconfig',
REGDICTIONARYOID: 'regdictionary',
REGNAMESPACEOID: 'regnamespace',
REGOPERATOROID: 'regoperator',
REGOPEROID: 'regoper',
REGPROCEDUREOID: 'regprocedure',
REGPROCOID: 'regproc',
REGROLEOID: 'regrole',
REGTYPEOID: 'regtype',
RELTIMEOID: 'reltime',
SMGROID: 'smgr',
TABLE_AM_HANDLEROID: 'table_am_handler',
TEXTOID: 'text',
TIDOID: 'tid',
TIMEOID: 'time',
TIMESTAMPOID: 'timestamp',
TIMESTAMPTZOID: 'timestamptz',
TIMETZOID: 'timetz',
TINTERVALOID: 'tinterval',
TRIGGEROID: 'trigger',
TSM_HANDLEROID: 'tsm_handler',
TSQUERYOID: 'tsquery',
TSVECTOROID: 'tsvector',
TXID_SNAPSHOTOID: 'txid_snapshot',
UNKNOWNOID: 'unknown',
UUIDOID: 'uuid',
VARBITOID: 'varbit',
VARCHAROID: 'varchar',
VOIDOID: 'void',
XID8OID: 'xid8',
XIDOID: 'xid',
XMLOID: 'xml',
_OIDOID: 'oid[]',
_TEXTOID: 'text[]'
}
BUILTIN_TYPE_NAME_MAP = {v: k for k, v in BUILTIN_TYPE_OID_MAP.items()}
BUILTIN_TYPE_NAME_MAP['smallint'] = \
BUILTIN_TYPE_NAME_MAP['int2']
BUILTIN_TYPE_NAME_MAP['int'] = \
BUILTIN_TYPE_NAME_MAP['int4']
BUILTIN_TYPE_NAME_MAP['integer'] = \
BUILTIN_TYPE_NAME_MAP['int4']
BUILTIN_TYPE_NAME_MAP['bigint'] = \
BUILTIN_TYPE_NAME_MAP['int8']
BUILTIN_TYPE_NAME_MAP['decimal'] = \
BUILTIN_TYPE_NAME_MAP['numeric']
BUILTIN_TYPE_NAME_MAP['real'] = \
BUILTIN_TYPE_NAME_MAP['float4']
BUILTIN_TYPE_NAME_MAP['double precision'] = \
BUILTIN_TYPE_NAME_MAP['float8']
BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamptz']
BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamp']
BUILTIN_TYPE_NAME_MAP['time with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timetz']
BUILTIN_TYPE_NAME_MAP['time without timezone'] = \
BUILTIN_TYPE_NAME_MAP['time']
BUILTIN_TYPE_NAME_MAP['char'] = \
BUILTIN_TYPE_NAME_MAP['bpchar']
BUILTIN_TYPE_NAME_MAP['character'] = \
BUILTIN_TYPE_NAME_MAP['bpchar']
BUILTIN_TYPE_NAME_MAP['character varying'] = \
BUILTIN_TYPE_NAME_MAP['varchar']
BUILTIN_TYPE_NAME_MAP['bit varying'] = \
BUILTIN_TYPE_NAME_MAP['varbit']

View File

@@ -0,0 +1,39 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class PreparedStatementState:
cdef:
readonly str name
readonly str query
readonly bint closed
readonly bint prepared
readonly int refs
readonly type record_class
readonly bint ignore_custom_codec
list row_desc
list parameters_desc
ConnectionSettings settings
int16_t args_num
bint have_text_args
tuple args_codecs
int16_t cols_num
object cols_desc
bint have_text_cols
tuple rows_codecs
cdef _encode_bind_msg(self, args, int seqno = ?)
cpdef _init_codecs(self)
cdef _ensure_rows_decoder(self)
cdef _ensure_args_encoder(self)
cdef _set_row_desc(self, object desc)
cdef _set_args_desc(self, object desc)
cdef _decode_row(self, const char* cbuf, ssize_t buf_len)

View File

@@ -0,0 +1,395 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from asyncpg import exceptions
@cython.final
cdef class PreparedStatementState:
def __cinit__(
self,
str name,
str query,
BaseProtocol protocol,
type record_class,
bint ignore_custom_codec
):
self.name = name
self.query = query
self.settings = protocol.settings
self.row_desc = self.parameters_desc = None
self.args_codecs = self.rows_codecs = None
self.args_num = self.cols_num = 0
self.cols_desc = None
self.closed = False
self.prepared = True
self.refs = 0
self.record_class = record_class
self.ignore_custom_codec = ignore_custom_codec
def _get_parameters(self):
cdef Codec codec
result = []
for oid in self.parameters_desc:
codec = self.settings.get_data_codec(oid)
if codec is None:
raise exceptions.InternalClientError(
'missing codec information for OID {}'.format(oid))
result.append(apg_types.Type(
oid, codec.name, codec.kind, codec.schema))
return tuple(result)
def _get_attributes(self):
cdef Codec codec
if not self.row_desc:
return ()
result = []
for d in self.row_desc:
name = d[0]
oid = d[3]
codec = self.settings.get_data_codec(oid)
if codec is None:
raise exceptions.InternalClientError(
'missing codec information for OID {}'.format(oid))
name = name.decode(self.settings._encoding)
result.append(
apg_types.Attribute(name,
apg_types.Type(oid, codec.name, codec.kind, codec.schema)))
return tuple(result)
def _init_types(self):
cdef:
Codec codec
set missing = set()
if self.parameters_desc:
for p_oid in self.parameters_desc:
codec = self.settings.get_data_codec(<uint32_t>p_oid)
if codec is None or not codec.has_encoder():
missing.add(p_oid)
if self.row_desc:
for rdesc in self.row_desc:
codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
if codec is None or not codec.has_decoder():
missing.add(rdesc[3])
return missing
cpdef _init_codecs(self):
self._ensure_args_encoder()
self._ensure_rows_decoder()
def attach(self):
self.refs += 1
def detach(self):
self.refs -= 1
def mark_closed(self):
self.closed = True
def mark_unprepared(self):
if self.name:
raise exceptions.InternalClientError(
"named prepared statements cannot be marked unprepared")
self.prepared = False
cdef _encode_bind_msg(self, args, int seqno = -1):
cdef:
int idx
WriteBuffer writer
Codec codec
if not cpython.PySequence_Check(args):
if seqno >= 0:
raise exceptions.DataError(
f'invalid input in executemany() argument sequence '
f'element #{seqno}: expected a sequence, got '
f'{type(args).__name__}'
)
else:
# Non executemany() callers do not pass user input directly,
# so bad input is a bug.
raise exceptions.InternalClientError(
f'Bind: expected a sequence, got {type(args).__name__}')
if len(args) > 32767:
raise exceptions.InterfaceError(
'the number of query arguments cannot exceed 32767')
writer = WriteBuffer.new()
num_args_passed = len(args)
if self.args_num != num_args_passed:
hint = 'Check the query against the passed list of arguments.'
if self.args_num == 0:
# If the server was expecting zero arguments, it is likely
# that the user tried to parametrize a statement that does
# not support parameters.
hint += (r' Note that parameters are supported only in'
r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
r' statements, and will *not* work in statements '
r' like CREATE VIEW or DECLARE CURSOR.')
raise exceptions.InterfaceError(
'the server expects {x} argument{s} for this query, '
'{y} {w} passed'.format(
x=self.args_num, s='s' if self.args_num != 1 else '',
y=num_args_passed,
w='was' if num_args_passed == 1 else 'were'),
hint=hint)
if self.have_text_args:
writer.write_int16(self.args_num)
for idx in range(self.args_num):
codec = <Codec>(self.args_codecs[idx])
writer.write_int16(<int16_t>codec.format)
else:
# All arguments are in binary format
writer.write_int32(0x00010001)
writer.write_int16(self.args_num)
for idx in range(self.args_num):
arg = args[idx]
if arg is None:
writer.write_int32(-1)
else:
codec = <Codec>(self.args_codecs[idx])
try:
codec.encode(self.settings, writer, arg)
except (AssertionError, exceptions.InternalClientError):
# These are internal errors and should raise as-is.
raise
except exceptions.InterfaceError as e:
# This is already a descriptive error, but annotate
# with argument name for clarity.
pos = f'${idx + 1}'
if seqno >= 0:
pos = (
f'{pos} in element #{seqno} of'
f' executemany() sequence'
)
raise e.with_msg(
f'query argument {pos}: {e.args[0]}'
) from None
except Exception as e:
# Everything else is assumed to be an encoding error
# due to invalid input.
pos = f'${idx + 1}'
if seqno >= 0:
pos = (
f'{pos} in element #{seqno} of'
f' executemany() sequence'
)
value_repr = repr(arg)
if len(value_repr) > 40:
value_repr = value_repr[:40] + '...'
raise exceptions.DataError(
f'invalid input for query argument'
f' {pos}: {value_repr} ({e})'
) from e
if self.have_text_cols:
writer.write_int16(self.cols_num)
for idx in range(self.cols_num):
codec = <Codec>(self.rows_codecs[idx])
writer.write_int16(<int16_t>codec.format)
else:
# All columns are in binary format
writer.write_int32(0x00010001)
return writer
cdef _ensure_rows_decoder(self):
cdef:
list cols_names
object cols_mapping
tuple row
uint32_t oid
Codec codec
list codecs
if self.cols_desc is not None:
return
if self.cols_num == 0:
self.cols_desc = record.ApgRecordDesc_New({}, ())
return
cols_mapping = collections.OrderedDict()
cols_names = []
codecs = []
for i from 0 <= i < self.cols_num:
row = self.row_desc[i]
col_name = row[0].decode(self.settings._encoding)
cols_mapping[col_name] = i
cols_names.append(col_name)
oid = row[3]
codec = self.settings.get_data_codec(
oid, ignore_custom_codec=self.ignore_custom_codec)
if codec is None or not codec.has_decoder():
raise exceptions.InternalClientError(
'no decoder for OID {}'.format(oid))
if not codec.is_binary():
self.have_text_cols = True
codecs.append(codec)
self.cols_desc = record.ApgRecordDesc_New(
cols_mapping, tuple(cols_names))
self.rows_codecs = tuple(codecs)
cdef _ensure_args_encoder(self):
cdef:
uint32_t p_oid
Codec codec
list codecs = []
if self.args_num == 0 or self.args_codecs is not None:
return
for i from 0 <= i < self.args_num:
p_oid = self.parameters_desc[i]
codec = self.settings.get_data_codec(
p_oid, ignore_custom_codec=self.ignore_custom_codec)
if codec is None or not codec.has_encoder():
raise exceptions.InternalClientError(
'no encoder for OID {}'.format(p_oid))
if codec.type not in {}:
self.have_text_args = True
codecs.append(codec)
self.args_codecs = tuple(codecs)
cdef _set_row_desc(self, object desc):
self.row_desc = _decode_row_desc(desc)
self.cols_num = <int16_t>(len(self.row_desc))
cdef _set_args_desc(self, object desc):
self.parameters_desc = _decode_parameters_desc(desc)
self.args_num = <int16_t>(len(self.parameters_desc))
cdef _decode_row(self, const char* cbuf, ssize_t buf_len):
cdef:
Codec codec
int16_t fnum
int32_t flen
object dec_row
tuple rows_codecs = self.rows_codecs
ConnectionSettings settings = self.settings
int32_t i
FRBuffer rbuf
ssize_t bl
frb_init(&rbuf, cbuf, buf_len)
fnum = hton.unpack_int16(frb_read(&rbuf, 2))
if fnum != self.cols_num:
raise exceptions.ProtocolError(
'the number of columns in the result row ({}) is '
'different from what was described ({})'.format(
fnum, self.cols_num))
dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum)
for i in range(fnum):
flen = hton.unpack_int32(frb_read(&rbuf, 4))
if flen == -1:
val = None
else:
# Clamp buffer size to that of the reported field length
# to make sure that codecs can rely on read_all() working
# properly.
bl = frb_get_len(&rbuf)
if flen > bl:
frb_check(&rbuf, flen)
frb_set_len(&rbuf, flen)
codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i)
val = codec.decode(settings, &rbuf)
if frb_get_len(&rbuf) != 0:
raise BufferError(
'unexpected trailing {} bytes in buffer'.format(
frb_get_len(&rbuf)))
frb_set_len(&rbuf, bl - flen)
cpython.Py_INCREF(val)
record.ApgRecord_SET_ITEM(dec_row, i, val)
if frb_get_len(&rbuf) != 0:
raise BufferError('unexpected trailing {} bytes in buffer'.format(
frb_get_len(&rbuf)))
return dec_row
cdef _decode_parameters_desc(object desc):
cdef:
ReadBuffer reader
int16_t nparams
uint32_t p_oid
list result = []
reader = ReadBuffer.new_message_parser(desc)
nparams = reader.read_int16()
for i from 0 <= i < nparams:
p_oid = <uint32_t>reader.read_int32()
result.append(p_oid)
return result
cdef _decode_row_desc(object desc):
cdef:
ReadBuffer reader
int16_t nfields
bytes f_name
uint32_t f_table_oid
int16_t f_column_num
uint32_t f_dt_oid
int16_t f_dt_size
int32_t f_dt_mod
int16_t f_format
list result
reader = ReadBuffer.new_message_parser(desc)
nfields = reader.read_int16()
result = []
for i from 0 <= i < nfields:
f_name = reader.read_null_str()
f_table_oid = <uint32_t>reader.read_int32()
f_column_num = reader.read_int16()
f_dt_oid = <uint32_t>reader.read_int32()
f_dt_size = reader.read_int16()
f_dt_mod = reader.read_int32()
f_format = reader.read_int16()
result.append(
(f_name, f_table_oid, f_column_num, f_dt_oid,
f_dt_size, f_dt_mod, f_format))
return result

View File

@@ -0,0 +1,78 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc.stdint cimport int16_t, int32_t, uint16_t, \
uint32_t, int64_t, uint64_t
from asyncpg.pgproto.debug cimport PG_DEBUG
from asyncpg.pgproto.pgproto cimport (
WriteBuffer,
ReadBuffer,
FRBuffer,
)
from asyncpg.pgproto cimport pgproto
include "consts.pxi"
include "pgtypes.pxi"
include "codecs/base.pxd"
include "settings.pxd"
include "coreproto.pxd"
include "prepared_stmt.pxd"
cdef class BaseProtocol(CoreProtocol):
cdef:
object loop
object address
ConnectionSettings settings
object cancel_sent_waiter
object cancel_waiter
object waiter
bint return_extra
object create_future
object timeout_handle
object conref
type record_class
bint is_reading
str last_query
bint writing_paused
bint closing
readonly uint64_t queries_count
bint _is_ssl
PreparedStatementState statement
cdef get_connection(self)
cdef _get_timeout_impl(self, timeout)
cdef _check_state(self)
cdef _new_waiter(self, timeout)
cdef _coreproto_error(self)
cdef _on_result__connect(self, object waiter)
cdef _on_result__prepare(self, object waiter)
cdef _on_result__bind_and_exec(self, object waiter)
cdef _on_result__close_stmt_or_portal(self, object waiter)
cdef _on_result__simple_query(self, object waiter)
cdef _on_result__bind(self, object waiter)
cdef _on_result__copy_out(self, object waiter)
cdef _on_result__copy_in(self, object waiter)
cdef _handle_waiter_on_connection_lost(self, cause)
cdef _dispatch_result(self)
cdef inline resume_reading(self)
cdef inline pause_reading(self)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cimport cpython
cdef extern from "record/recordobj.h":
cpython.PyTypeObject *ApgRecord_InitTypes() except NULL
int ApgRecord_CheckExact(object)
object ApgRecord_New(type, object, int)
void ApgRecord_SET_ITEM(object, int, object)
object ApgRecordDesc_New(object, object)

View File

@@ -0,0 +1,31 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class SCRAMAuthentication:
cdef:
readonly bytes authentication_method
readonly bytes authorization_message
readonly bytes client_channel_binding
readonly bytes client_first_message_bare
readonly bytes client_nonce
readonly bytes client_proof
readonly bytes password_salt
readonly int password_iterations
readonly bytes server_first_message
# server_key is an instance of hmac.HAMC
readonly object server_key
readonly bytes server_nonce
cdef create_client_first_message(self, str username)
cdef create_client_final_message(self, str password)
cdef parse_server_first_message(self, bytes server_response)
cdef verify_server_final_message(self, bytes server_final_message)
cdef _bytes_xor(self, bytes a, bytes b)
cdef _generate_client_nonce(self, int num_bytes)
cdef _generate_client_proof(self, str password)
cdef _generate_salted_password(self, str password, bytes salt, int iterations)
cdef _normalize_password(self, str original_password)

View File

@@ -0,0 +1,341 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import base64
import hashlib
import hmac
import re
import secrets
import stringprep
import unicodedata
@cython.final
cdef class SCRAMAuthentication:
"""Contains the protocol for generating and a SCRAM hashed password.
Since PostgreSQL 10, the option to hash passwords using the SCRAM-SHA-256
method was added. This module follows the defined protocol, which can be
referenced from here:
https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256
libpq references the following RFCs that it uses for implementation:
* RFC 5802
* RFC 5803
* RFC 7677
The protocol works as such:
- A client connets to the server. The server requests the client to begin
SASL authentication using SCRAM and presents a client with the methods it
supports. At present, those are SCRAM-SHA-256, and, on servers that are
built with OpenSSL and
are PG11+, SCRAM-SHA-256-PLUS (which supports channel binding, more on that
below)
- The client sends a "first message" to the server, where it chooses which
method to authenticate with, and sends, along with the method, an indication
of channel binding (we disable for now), a nonce, and the username.
(Technically, PostgreSQL ignores the username as it already has it from the
initical connection, but we add it for completeness)
- The server responds with a "first message" in which it extends the nonce,
as well as a password salt and the number of iterations to hash the password
with. The client validates that the new nonce contains the first part of the
client's original nonce
- The client generates a salted password, but does not sent this up to the
server. Instead, the client follows the SCRAM algorithm (RFC5802) to
generate a proof. This proof is sent aspart of a client "final message" to
the server for it to validate.
- The server validates the proof. If it is valid, the server sends a
verification code for the client to verify that the server came to the same
proof the client did. PostgreSQL immediately sends an AuthenticationOK
response right after a valid negotiation. If the password the client
provided was invalid, then authentication fails.
(The beauty of this is that the salted password is never transmitted over
the wire!)
PostgreSQL 11 added support for the channel binding (i.e.
SCRAM-SHA-256-PLUS) but to do some ongoing discussion, there is a conscious
decision by several driver authors to not support it as of yet. As such, the
channel binding parameter is hard-coded to "n" for now, but can be updated
to support other channel binding methos in the future
"""
AUTHENTICATION_METHODS = [b"SCRAM-SHA-256"]
DEFAULT_CLIENT_NONCE_BYTES = 24
DIGEST = hashlib.sha256
REQUIREMENTS_CLIENT_FINAL_MESSAGE = ['client_channel_binding',
'server_nonce']
REQUIREMENTS_CLIENT_PROOF = ['password_iterations', 'password_salt',
'server_first_message', 'server_nonce']
SASLPREP_PROHIBITED = (
stringprep.in_table_a1, # PostgreSQL treats this as prohibited
stringprep.in_table_c12,
stringprep.in_table_c21_c22,
stringprep.in_table_c3,
stringprep.in_table_c4,
stringprep.in_table_c5,
stringprep.in_table_c6,
stringprep.in_table_c7,
stringprep.in_table_c8,
stringprep.in_table_c9,
)
def __cinit__(self, bytes authentication_method):
self.authentication_method = authentication_method
self.authorization_message = None
# channel binding is turned off for the time being
self.client_channel_binding = b"n,,"
self.client_first_message_bare = None
self.client_nonce = None
self.client_proof = None
self.password_salt = None
# self.password_iterations = None
self.server_first_message = None
self.server_key = None
self.server_nonce = None
cdef create_client_first_message(self, str username):
"""Create the initial client message for SCRAM authentication"""
cdef:
bytes msg
bytes client_first_message
self.client_nonce = \
self._generate_client_nonce(self.DEFAULT_CLIENT_NONCE_BYTES)
# set the client first message bare here, as it's used in a later step
self.client_first_message_bare = b"n=" + username.encode("utf-8") + \
b",r=" + self.client_nonce
# put together the full message here
msg = bytes()
msg += self.authentication_method + b"\0"
client_first_message = self.client_channel_binding + \
self.client_first_message_bare
msg += (len(client_first_message)).to_bytes(4, byteorder='big') + \
client_first_message
return msg
cdef create_client_final_message(self, str password):
"""Create the final client message as part of SCRAM authentication"""
cdef:
bytes msg
if any([getattr(self, val) is None for val in
self.REQUIREMENTS_CLIENT_FINAL_MESSAGE]):
raise Exception(
"you need values from server to generate a client proof")
# normalize the password using the SASLprep algorithm in RFC 4013
password = self._normalize_password(password)
# generate the client proof
self.client_proof = self._generate_client_proof(password=password)
msg = bytes()
msg += b"c=" + base64.b64encode(self.client_channel_binding) + \
b",r=" + self.server_nonce + \
b",p=" + base64.b64encode(self.client_proof)
return msg
cdef parse_server_first_message(self, bytes server_response):
"""Parse the response from the first message from the server"""
self.server_first_message = server_response
try:
self.server_nonce = re.search(b'r=([^,]+),',
self.server_first_message).group(1)
except IndexError:
raise Exception("could not get nonce")
if not self.server_nonce.startswith(self.client_nonce):
raise Exception("invalid nonce")
try:
self.password_salt = re.search(b',s=([^,]+),',
self.server_first_message).group(1)
except IndexError:
raise Exception("could not get salt")
try:
self.password_iterations = int(re.search(b',i=(\d+),?',
self.server_first_message).group(1))
except (IndexError, TypeError, ValueError):
raise Exception("could not get iterations")
cdef verify_server_final_message(self, bytes server_final_message):
"""Verify the final message from the server"""
cdef:
bytes server_signature
try:
server_signature = re.search(b'v=([^,]+)',
server_final_message).group(1)
except IndexError:
raise Exception("could not get server signature")
verify_server_signature = hmac.new(self.server_key.digest(),
self.authorization_message, self.DIGEST)
# validate the server signature against the verifier
return server_signature == base64.b64encode(
verify_server_signature.digest())
cdef _bytes_xor(self, bytes a, bytes b):
"""XOR two bytestrings together"""
return bytes(a_i ^ b_i for a_i, b_i in zip(a, b))
cdef _generate_client_nonce(self, int num_bytes):
cdef:
bytes token
token = secrets.token_bytes(num_bytes)
return base64.b64encode(token)
cdef _generate_client_proof(self, str password):
"""need to ensure a server response exists, i.e. """
cdef:
bytes salted_password
if any([getattr(self, val) is None for val in
self.REQUIREMENTS_CLIENT_PROOF]):
raise Exception(
"you need values from server to generate a client proof")
# generate a salt password
salted_password = self._generate_salted_password(password,
self.password_salt, self.password_iterations)
# client key is derived from the salted password
client_key = hmac.new(salted_password, b"Client Key", self.DIGEST)
# this allows us to compute the stored key that is residing on the server
stored_key = self.DIGEST(client_key.digest())
# as well as compute the server key
self.server_key = hmac.new(salted_password, b"Server Key", self.DIGEST)
# build the authorization message that will be used in the
# client signature
# the "c=" portion is for the channel binding, but this is not
# presently implemented
self.authorization_message = self.client_first_message_bare + b"," + \
self.server_first_message + b",c=" + \
base64.b64encode(self.client_channel_binding) + \
b",r=" + self.server_nonce
# sign!
client_signature = hmac.new(stored_key.digest(),
self.authorization_message, self.DIGEST)
# and the proof
return self._bytes_xor(client_key.digest(), client_signature.digest())
cdef _generate_salted_password(self, str password, bytes salt, int iterations):
"""This follows the "Hi" algorithm specified in RFC5802"""
cdef:
bytes p
bytes s
bytes u
# convert the password to a binary string - UTF8 is safe for SASL
# (though there are SASLPrep rules)
p = password.encode("utf8")
# the salt needs to be base64 decoded -- full binary must be used
s = base64.b64decode(salt)
# the initial signature is the salt with a terminator of a 32-bit string
# ending in 1
ui = hmac.new(p, s + b'\x00\x00\x00\x01', self.DIGEST)
# grab the initial digest
u = ui.digest()
# for X number of iterations, recompute the HMAC signature against the
# password and the latest iteration of the hash, and XOR it with the
# previous version
for x in range(iterations - 1):
ui = hmac.new(p, ui.digest(), hashlib.sha256)
# this is a fancy way of XORing two byte strings together
u = self._bytes_xor(u, ui.digest())
return u
cdef _normalize_password(self, str original_password):
"""Normalize the password using the SASLprep from RFC4013"""
cdef:
str normalized_password
# Note: Per the PostgreSQL documentation, PostgreSWL does not require
# UTF-8 to be used for the password, but will perform SASLprep on the
# password regardless.
# If the password is not valid UTF-8, PostgreSQL will then **not** use
# SASLprep processing.
# If the password fails SASLprep, the password should still be sent
# See: https://www.postgresql.org/docs/current/sasl-authentication.html
# and
# https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/common/saslprep.c
# using the `pg_saslprep` function
normalized_password = original_password
# if the original password is an ASCII string or fails to encode as a
# UTF-8 string, then no further action is needed
try:
original_password.encode("ascii")
except UnicodeEncodeError:
pass
else:
return original_password
# Step 1 of SASLPrep: Map. Per the algorithm, we map non-ascii space
# characters to ASCII spaces (\x20 or \u0020, but we will use ' ') and
# commonly mapped to nothing characters are removed
# Table C.1.2 -- non-ASCII spaces
# Table B.1 -- "Commonly mapped to nothing"
normalized_password = u"".join(
' ' if stringprep.in_table_c12(c) else c
for c in tuple(normalized_password) if not stringprep.in_table_b1(c)
)
# If at this point the password is empty, PostgreSQL uses the original
# password
if not normalized_password:
return original_password
# Step 2 of SASLPrep: Normalize. Normalize the password using the
# Unicode normalization algorithm to NFKC form
normalized_password = unicodedata.normalize('NFKC', normalized_password)
# If the password is not empty, PostgreSQL uses the original password
if not normalized_password:
return original_password
normalized_password_tuple = tuple(normalized_password)
# Step 3 of SASLPrep: Prohobited characters. If PostgreSQL detects any
# of the prohibited characters in SASLPrep, it will use the original
# password
# We also include "unassigned code points" in the prohibited character
# category as PostgreSQL does the same
for c in normalized_password_tuple:
if any(
in_prohibited_table(c)
for in_prohibited_table in self.SASLPREP_PROHIBITED
):
return original_password
# Step 4 of SASLPrep: Bi-directional characters. PostgreSQL follows the
# rules for bi-directional characters laid on in RFC3454 Sec. 6 which
# are:
# 1. Characters in RFC 3454 Sec 5.8 are prohibited (C.8)
# 2. If a string contains a RandALCat character, it cannot containy any
# LCat character
# 3. If the string contains any RandALCat character, an RandALCat
# character must be the first and last character of the string
# RandALCat characters are found in table D.1, whereas LCat are in D.2
if any(stringprep.in_table_d1(c) for c in normalized_password_tuple):
# if the first character or the last character are not in D.1,
# return the original password
if not (stringprep.in_table_d1(normalized_password_tuple[0]) and
stringprep.in_table_d1(normalized_password_tuple[-1])):
return original_password
# if any characters are in D.2, use the original password
if any(
stringprep.in_table_d2(c) for c in normalized_password_tuple
):
return original_password
# return the normalized password
return normalized_password

View File

@@ -0,0 +1,30 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class ConnectionSettings(pgproto.CodecContext):
cdef:
str _encoding
object _codec
dict _settings
bint _is_utf8
DataCodecConfig _data_codecs
cdef add_setting(self, str name, str val)
cdef is_encoding_utf8(self)
cpdef get_text_codec(self)
cpdef inline register_data_types(self, types)
cpdef inline add_python_codec(
self, typeoid, typename, typeschema, typeinfos, typekind, encoder,
decoder, format)
cpdef inline remove_python_codec(
self, typeoid, typename, typeschema)
cpdef inline clear_type_cache(self)
cpdef inline set_builtin_type_codec(
self, typeoid, typename, typeschema, typekind, alias_to, format)
cpdef inline Codec get_data_codec(
self, uint32_t oid, ServerDataFormat format=*,
bint ignore_custom_codec=*)

View File

@@ -0,0 +1,106 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from asyncpg import exceptions
@cython.final
cdef class ConnectionSettings(pgproto.CodecContext):
def __cinit__(self, conn_key):
self._encoding = 'utf-8'
self._is_utf8 = True
self._settings = {}
self._codec = codecs.lookup('utf-8')
self._data_codecs = DataCodecConfig(conn_key)
cdef add_setting(self, str name, str val):
self._settings[name] = val
if name == 'client_encoding':
py_enc = get_python_encoding(val)
self._codec = codecs.lookup(py_enc)
self._encoding = self._codec.name
self._is_utf8 = self._encoding == 'utf-8'
cdef is_encoding_utf8(self):
return self._is_utf8
cpdef get_text_codec(self):
return self._codec
cpdef inline register_data_types(self, types):
self._data_codecs.add_types(types)
cpdef inline add_python_codec(self, typeoid, typename, typeschema,
typeinfos, typekind, encoder, decoder,
format):
cdef:
ServerDataFormat _format
ClientExchangeFormat xformat
if format == 'binary':
_format = PG_FORMAT_BINARY
xformat = PG_XFORMAT_OBJECT
elif format == 'text':
_format = PG_FORMAT_TEXT
xformat = PG_XFORMAT_OBJECT
elif format == 'tuple':
_format = PG_FORMAT_ANY
xformat = PG_XFORMAT_TUPLE
else:
raise exceptions.InterfaceError(
'invalid `format` argument, expected {}, got {!r}'.format(
"'text', 'binary' or 'tuple'", format
))
self._data_codecs.add_python_codec(typeoid, typename, typeschema,
typekind, typeinfos,
encoder, decoder,
_format, xformat)
cpdef inline remove_python_codec(self, typeoid, typename, typeschema):
self._data_codecs.remove_python_codec(typeoid, typename, typeschema)
cpdef inline clear_type_cache(self):
self._data_codecs.clear_type_cache()
cpdef inline set_builtin_type_codec(self, typeoid, typename, typeschema,
typekind, alias_to, format):
cdef:
ServerDataFormat _format
if format is None:
_format = PG_FORMAT_ANY
elif format == 'binary':
_format = PG_FORMAT_BINARY
elif format == 'text':
_format = PG_FORMAT_TEXT
else:
raise exceptions.InterfaceError(
'invalid `format` argument, expected {}, got {!r}'.format(
"'text' or 'binary'", format
))
self._data_codecs.set_builtin_type_codec(typeoid, typename, typeschema,
typekind, alias_to, _format)
cpdef inline Codec get_data_codec(self, uint32_t oid,
ServerDataFormat format=PG_FORMAT_ANY,
bint ignore_custom_codec=False):
return self._data_codecs.get_codec(oid, format, ignore_custom_codec)
def __getattr__(self, name):
if not name.startswith('_'):
try:
return self._settings[name]
except KeyError:
raise AttributeError(name) from None
return object.__getattribute__(self, name)
def __repr__(self):
return '<ConnectionSettings {!r}>'.format(self._settings)

View File

@@ -0,0 +1,60 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import re
from .types import ServerVersion
version_regex = re.compile(
r"(Postgre[^\s]*)?\s*"
r"(?P<major>[0-9]+)\.?"
r"((?P<minor>[0-9]+)\.?)?"
r"(?P<micro>[0-9]+)?"
r"(?P<releaselevel>[a-z]+)?"
r"(?P<serial>[0-9]+)?"
)
def split_server_version_string(version_string):
version_match = version_regex.search(version_string)
if version_match is None:
raise ValueError(
"Unable to parse Postgres "
f'version from "{version_string}"'
)
version = version_match.groupdict()
for ver_key, ver_value in version.items():
# Cast all possible versions parts to int
try:
version[ver_key] = int(ver_value)
except (TypeError, ValueError):
pass
if version.get("major") < 10:
return ServerVersion(
version.get("major"),
version.get("minor") or 0,
version.get("micro") or 0,
version.get("releaselevel") or "final",
version.get("serial") or 0,
)
# Since PostgreSQL 10 the versioning scheme has changed.
# 10.x really means 10.0.x. While parsing 10.1
# as (10, 1) may seem less confusing, in practice most
# version checks are written as version[:2], and we
# want to keep that behaviour consistent, i.e not fail
# a major version check due to a bugfix release.
return ServerVersion(
version.get("major"),
0,
version.get("minor") or 0,
version.get("releaselevel") or "final",
version.get("serial") or 0,
)

View File

@@ -0,0 +1,246 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import enum
from . import connresource
from . import exceptions as apg_errors
class TransactionState(enum.Enum):
NEW = 0
STARTED = 1
COMMITTED = 2
ROLLEDBACK = 3
FAILED = 4
ISOLATION_LEVELS = {
'read_committed',
'read_uncommitted',
'serializable',
'repeatable_read',
}
ISOLATION_LEVELS_BY_VALUE = {
'read committed': 'read_committed',
'read uncommitted': 'read_uncommitted',
'serializable': 'serializable',
'repeatable read': 'repeatable_read',
}
class Transaction(connresource.ConnectionResource):
"""Represents a transaction or savepoint block.
Transactions are created by calling the
:meth:`Connection.transaction() <connection.Connection.transaction>`
function.
"""
__slots__ = ('_connection', '_isolation', '_readonly', '_deferrable',
'_state', '_nested', '_id', '_managed')
def __init__(self, connection, isolation, readonly, deferrable):
super().__init__(connection)
if isolation and isolation not in ISOLATION_LEVELS:
raise ValueError(
'isolation is expected to be either of {}, '
'got {!r}'.format(ISOLATION_LEVELS, isolation))
self._isolation = isolation
self._readonly = readonly
self._deferrable = deferrable
self._state = TransactionState.NEW
self._nested = False
self._id = None
self._managed = False
async def __aenter__(self):
if self._managed:
raise apg_errors.InterfaceError(
'cannot enter context: already in an `async with` block')
self._managed = True
await self.start()
async def __aexit__(self, extype, ex, tb):
try:
self._check_conn_validity('__aexit__')
except apg_errors.InterfaceError:
if extype is GeneratorExit:
# When a PoolAcquireContext is being exited, and there
# is an open transaction in an async generator that has
# not been iterated fully, there is a possibility that
# Pool.release() would race with this __aexit__(), since
# both would be in concurrent tasks. In such case we
# yield to Pool.release() to do the ROLLBACK for us.
# See https://github.com/MagicStack/asyncpg/issues/232
# for an example.
return
else:
raise
try:
if extype is not None:
await self.__rollback()
else:
await self.__commit()
finally:
self._managed = False
@connresource.guarded
async def start(self):
"""Enter the transaction or savepoint block."""
self.__check_state_base('start')
if self._state is TransactionState.STARTED:
raise apg_errors.InterfaceError(
'cannot start; the transaction is already started')
con = self._connection
if con._top_xact is None:
if con._protocol.is_in_transaction():
raise apg_errors.InterfaceError(
'cannot use Connection.transaction() in '
'a manually started transaction')
con._top_xact = self
else:
# Nested transaction block
if self._isolation:
top_xact_isolation = con._top_xact._isolation
if top_xact_isolation is None:
top_xact_isolation = ISOLATION_LEVELS_BY_VALUE[
await self._connection.fetchval(
'SHOW transaction_isolation;')]
if self._isolation != top_xact_isolation:
raise apg_errors.InterfaceError(
'nested transaction has a different isolation level: '
'current {!r} != outer {!r}'.format(
self._isolation, top_xact_isolation))
self._nested = True
if self._nested:
self._id = con._get_unique_id('savepoint')
query = 'SAVEPOINT {};'.format(self._id)
else:
query = 'BEGIN'
if self._isolation == 'read_committed':
query += ' ISOLATION LEVEL READ COMMITTED'
elif self._isolation == 'read_uncommitted':
query += ' ISOLATION LEVEL READ UNCOMMITTED'
elif self._isolation == 'repeatable_read':
query += ' ISOLATION LEVEL REPEATABLE READ'
elif self._isolation == 'serializable':
query += ' ISOLATION LEVEL SERIALIZABLE'
if self._readonly:
query += ' READ ONLY'
if self._deferrable:
query += ' DEFERRABLE'
query += ';'
try:
await self._connection.execute(query)
except BaseException:
self._state = TransactionState.FAILED
raise
else:
self._state = TransactionState.STARTED
def __check_state_base(self, opname):
if self._state is TransactionState.COMMITTED:
raise apg_errors.InterfaceError(
'cannot {}; the transaction is already committed'.format(
opname))
if self._state is TransactionState.ROLLEDBACK:
raise apg_errors.InterfaceError(
'cannot {}; the transaction is already rolled back'.format(
opname))
if self._state is TransactionState.FAILED:
raise apg_errors.InterfaceError(
'cannot {}; the transaction is in error state'.format(
opname))
def __check_state(self, opname):
if self._state is not TransactionState.STARTED:
if self._state is TransactionState.NEW:
raise apg_errors.InterfaceError(
'cannot {}; the transaction is not yet started'.format(
opname))
self.__check_state_base(opname)
async def __commit(self):
self.__check_state('commit')
if self._connection._top_xact is self:
self._connection._top_xact = None
if self._nested:
query = 'RELEASE SAVEPOINT {};'.format(self._id)
else:
query = 'COMMIT;'
try:
await self._connection.execute(query)
except BaseException:
self._state = TransactionState.FAILED
raise
else:
self._state = TransactionState.COMMITTED
async def __rollback(self):
self.__check_state('rollback')
if self._connection._top_xact is self:
self._connection._top_xact = None
if self._nested:
query = 'ROLLBACK TO {};'.format(self._id)
else:
query = 'ROLLBACK;'
try:
await self._connection.execute(query)
except BaseException:
self._state = TransactionState.FAILED
raise
else:
self._state = TransactionState.ROLLEDBACK
@connresource.guarded
async def commit(self):
"""Exit the transaction or savepoint block and commit changes."""
if self._managed:
raise apg_errors.InterfaceError(
'cannot manually commit from within an `async with` block')
await self.__commit()
@connresource.guarded
async def rollback(self):
"""Exit the transaction or savepoint block and rollback changes."""
if self._managed:
raise apg_errors.InterfaceError(
'cannot manually rollback from within an `async with` block')
await self.__rollback()
def __repr__(self):
attrs = []
attrs.append('state:{}'.format(self._state.name.lower()))
if self._isolation is not None:
attrs.append(self._isolation)
if self._readonly:
attrs.append('readonly')
if self._deferrable:
attrs.append('deferrable')
if self.__class__.__module__.startswith('asyncpg.'):
mod = 'asyncpg'
else:
mod = self.__class__.__module__
return '<{}.{} {} {:#x}>'.format(
mod, self.__class__.__name__, ' '.join(attrs), id(self))

View File

@@ -0,0 +1,177 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import collections
from asyncpg.pgproto.types import (
BitString, Point, Path, Polygon,
Box, Line, LineSegment, Circle,
)
__all__ = (
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion',
)
Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema'])
Type.__doc__ = 'Database data type.'
Type.oid.__doc__ = 'OID of the type.'
Type.name.__doc__ = 'Type name. For example "int2".'
Type.kind.__doc__ = \
'Type kind. Can be "scalar", "array", "composite" or "range".'
Type.schema.__doc__ = 'Name of the database schema that defines the type.'
Attribute = collections.namedtuple('Attribute', ['name', 'type'])
Attribute.__doc__ = 'Database relation attribute.'
Attribute.name.__doc__ = 'Attribute name.'
Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.'
ServerVersion = collections.namedtuple(
'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial'])
ServerVersion.__doc__ = 'PostgreSQL server version tuple.'
class Range:
"""Immutable representation of PostgreSQL `range` type."""
__slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty'
def __init__(self, lower=None, upper=None, *,
lower_inc=True, upper_inc=False,
empty=False):
self._empty = empty
if empty:
self._lower = self._upper = None
self._lower_inc = self._upper_inc = False
else:
self._lower = lower
self._upper = upper
self._lower_inc = lower is not None and lower_inc
self._upper_inc = upper is not None and upper_inc
@property
def lower(self):
return self._lower
@property
def lower_inc(self):
return self._lower_inc
@property
def lower_inf(self):
return self._lower is None and not self._empty
@property
def upper(self):
return self._upper
@property
def upper_inc(self):
return self._upper_inc
@property
def upper_inf(self):
return self._upper is None and not self._empty
@property
def isempty(self):
return self._empty
def _issubset_lower(self, other):
if other._lower is None:
return True
if self._lower is None:
return False
return self._lower > other._lower or (
self._lower == other._lower
and (other._lower_inc or not self._lower_inc)
)
def _issubset_upper(self, other):
if other._upper is None:
return True
if self._upper is None:
return False
return self._upper < other._upper or (
self._upper == other._upper
and (other._upper_inc or not self._upper_inc)
)
def issubset(self, other):
if self._empty:
return True
if other._empty:
return False
return self._issubset_lower(other) and self._issubset_upper(other)
def issuperset(self, other):
return other.issubset(self)
def __bool__(self):
return not self._empty
def __eq__(self, other):
if not isinstance(other, Range):
return NotImplemented
return (
self._lower,
self._upper,
self._lower_inc,
self._upper_inc,
self._empty
) == (
other._lower,
other._upper,
other._lower_inc,
other._upper_inc,
other._empty
)
def __hash__(self):
return hash((
self._lower,
self._upper,
self._lower_inc,
self._upper_inc,
self._empty
))
def __repr__(self):
if self._empty:
desc = 'empty'
else:
if self._lower is None or not self._lower_inc:
lb = '('
else:
lb = '['
if self._lower is not None:
lb += repr(self._lower)
if self._upper is not None:
ub = repr(self._upper)
else:
ub = ''
if self._upper is None or not self._upper_inc:
ub += ')'
else:
ub += ']'
desc = '{}, {}'.format(lb, ub)
return '<Range {}>'.format(desc)
__str__ = __repr__

View File

@@ -0,0 +1,45 @@
# Copyright (C) 2016-present the ayncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import re
def _quote_ident(ident):
return '"{}"'.format(ident.replace('"', '""'))
def _quote_literal(string):
return "'{}'".format(string.replace("'", "''"))
async def _mogrify(conn, query, args):
"""Safely inline arguments to query text."""
# Introspect the target query for argument types and
# build a list of safely-quoted fully-qualified type names.
ps = await conn.prepare(query)
paramtypes = []
for t in ps.get_parameters():
if t.name.endswith('[]'):
pname = '_' + t.name[:-2]
else:
pname = t.name
paramtypes.append('{}.{}'.format(
_quote_ident(t.schema), _quote_ident(pname)))
del ps
# Use Postgres to convert arguments to text representation
# by casting each value to text.
cols = ['quote_literal(${}::{}::text)'.format(i, t)
for i, t in enumerate(paramtypes, start=1)]
textified = await conn.fetchrow(
'SELECT {cols}'.format(cols=', '.join(cols)), *args)
# Finally, replace $n references with text values.
return re.sub(
r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query)