Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
19
venv/lib/python3.12/site-packages/asyncpg/__init__.py
Normal file
19
venv/lib/python3.12/site-packages/asyncpg/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from .connection import connect, Connection # NOQA
|
||||
from .exceptions import * # NOQA
|
||||
from .pool import create_pool, Pool # NOQA
|
||||
from .protocol import Record # NOQA
|
||||
from .types import * # NOQA
|
||||
|
||||
|
||||
from ._version import __version__ # NOQA
|
||||
|
||||
|
||||
__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection')
|
||||
__all__ += exceptions.__all__ # NOQA
|
||||
87
venv/lib/python3.12/site-packages/asyncpg/_asyncio_compat.py
Normal file
87
venv/lib/python3.12/site-packages/asyncpg/_asyncio_compat.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Backports from Python/Lib/asyncio for older Pythons
|
||||
#
|
||||
# Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved
|
||||
#
|
||||
# SPDX-License-Identifier: PSF-2.0
|
||||
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from async_timeout import timeout as timeout_ctx
|
||||
else:
|
||||
from asyncio import timeout as timeout_ctx
|
||||
|
||||
|
||||
async def wait_for(fut, timeout):
|
||||
"""Wait for the single Future or coroutine to complete, with timeout.
|
||||
|
||||
Coroutine will be wrapped in Task.
|
||||
|
||||
Returns result of the Future or coroutine. When a timeout occurs,
|
||||
it cancels the task and raises TimeoutError. To avoid the task
|
||||
cancellation, wrap it in shield().
|
||||
|
||||
If the wait is cancelled, the task is also cancelled.
|
||||
|
||||
If the task supresses the cancellation and returns a value instead,
|
||||
that value is returned.
|
||||
|
||||
This function is a coroutine.
|
||||
"""
|
||||
# The special case for timeout <= 0 is for the following case:
|
||||
#
|
||||
# async def test_waitfor():
|
||||
# func_started = False
|
||||
#
|
||||
# async def func():
|
||||
# nonlocal func_started
|
||||
# func_started = True
|
||||
#
|
||||
# try:
|
||||
# await asyncio.wait_for(func(), 0)
|
||||
# except asyncio.TimeoutError:
|
||||
# assert not func_started
|
||||
# else:
|
||||
# assert False
|
||||
#
|
||||
# asyncio.run(test_waitfor())
|
||||
|
||||
if timeout is not None and timeout <= 0:
|
||||
fut = asyncio.ensure_future(fut)
|
||||
|
||||
if fut.done():
|
||||
return fut.result()
|
||||
|
||||
await _cancel_and_wait(fut)
|
||||
try:
|
||||
return fut.result()
|
||||
except asyncio.CancelledError as exc:
|
||||
raise TimeoutError from exc
|
||||
|
||||
async with timeout_ctx(timeout):
|
||||
return await fut
|
||||
|
||||
|
||||
async def _cancel_and_wait(fut):
|
||||
"""Cancel the *fut* future or task and wait until it completes."""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
waiter = loop.create_future()
|
||||
cb = functools.partial(_release_waiter, waiter)
|
||||
fut.add_done_callback(cb)
|
||||
|
||||
try:
|
||||
fut.cancel()
|
||||
# We cannot wait on *fut* directly to make
|
||||
# sure _cancel_and_wait itself is reliably cancellable.
|
||||
await waiter
|
||||
finally:
|
||||
fut.remove_done_callback(cb)
|
||||
|
||||
|
||||
def _release_waiter(waiter, *args):
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
527
venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py
Normal file
527
venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
|
||||
import asyncpg
|
||||
from asyncpg import cluster as pg_cluster
|
||||
from asyncpg import connection as pg_connection
|
||||
from asyncpg import pool as pg_pool
|
||||
|
||||
from . import fuzzer
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def silence_asyncio_long_exec_warning():
|
||||
def flt(log_record):
|
||||
msg = log_record.getMessage()
|
||||
return not msg.startswith('Executing ')
|
||||
|
||||
logger = logging.getLogger('asyncio')
|
||||
logger.addFilter(flt)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.removeFilter(flt)
|
||||
|
||||
|
||||
def with_timeout(timeout):
|
||||
def wrap(func):
|
||||
func.__timeout__ = timeout
|
||||
return func
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class TestCaseMeta(type(unittest.TestCase)):
|
||||
TEST_TIMEOUT = None
|
||||
|
||||
@staticmethod
|
||||
def _iter_methods(bases, ns):
|
||||
for base in bases:
|
||||
for methname in dir(base):
|
||||
if not methname.startswith('test_'):
|
||||
continue
|
||||
|
||||
meth = getattr(base, methname)
|
||||
if not inspect.iscoroutinefunction(meth):
|
||||
continue
|
||||
|
||||
yield methname, meth
|
||||
|
||||
for methname, meth in ns.items():
|
||||
if not methname.startswith('test_'):
|
||||
continue
|
||||
|
||||
if not inspect.iscoroutinefunction(meth):
|
||||
continue
|
||||
|
||||
yield methname, meth
|
||||
|
||||
def __new__(mcls, name, bases, ns):
|
||||
for methname, meth in mcls._iter_methods(bases, ns):
|
||||
@functools.wraps(meth)
|
||||
def wrapper(self, *args, __meth__=meth, **kwargs):
|
||||
coro = __meth__(self, *args, **kwargs)
|
||||
timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT)
|
||||
if timeout:
|
||||
coro = asyncio.wait_for(coro, timeout)
|
||||
try:
|
||||
self.loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError:
|
||||
raise self.failureException(
|
||||
'test timed out after {} seconds'.format(
|
||||
timeout)) from None
|
||||
else:
|
||||
self.loop.run_until_complete(coro)
|
||||
ns[methname] = wrapper
|
||||
|
||||
return super().__new__(mcls, name, bases, ns)
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if os.environ.get('USE_UVLOOP'):
|
||||
import uvloop
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(None)
|
||||
cls.loop = loop
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def setUp(self):
|
||||
self.loop.set_exception_handler(self.loop_exception_handler)
|
||||
self.__unhandled_exceptions = []
|
||||
|
||||
def tearDown(self):
|
||||
if self.__unhandled_exceptions:
|
||||
formatted = []
|
||||
|
||||
for i, context in enumerate(self.__unhandled_exceptions):
|
||||
formatted.append(self._format_loop_exception(context, i + 1))
|
||||
|
||||
self.fail(
|
||||
'unexpected exceptions in asynchronous code:\n' +
|
||||
'\n'.join(formatted))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertRunUnder(self, delta):
|
||||
st = time.monotonic()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed = time.monotonic() - st
|
||||
if elapsed > delta:
|
||||
raise AssertionError(
|
||||
'running block took {:0.3f}s which is longer '
|
||||
'than the expected maximum of {:0.3f}s'.format(
|
||||
elapsed, delta))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertLoopErrorHandlerCalled(self, msg_re: str):
|
||||
contexts = []
|
||||
|
||||
def handler(loop, ctx):
|
||||
contexts.append(ctx)
|
||||
|
||||
old_handler = self.loop.get_exception_handler()
|
||||
self.loop.set_exception_handler(handler)
|
||||
try:
|
||||
yield
|
||||
|
||||
for ctx in contexts:
|
||||
msg = ctx.get('message')
|
||||
if msg and re.search(msg_re, msg):
|
||||
return
|
||||
|
||||
raise AssertionError(
|
||||
'no message matching {!r} was logged with '
|
||||
'loop.call_exception_handler()'.format(msg_re))
|
||||
|
||||
finally:
|
||||
self.loop.set_exception_handler(old_handler)
|
||||
|
||||
def loop_exception_handler(self, loop, context):
|
||||
self.__unhandled_exceptions.append(context)
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
def _format_loop_exception(self, context, n):
|
||||
message = context.get('message', 'Unhandled exception in event loop')
|
||||
exception = context.get('exception')
|
||||
if exception is not None:
|
||||
exc_info = (type(exception), exception, exception.__traceback__)
|
||||
else:
|
||||
exc_info = None
|
||||
|
||||
lines = []
|
||||
for key in sorted(context):
|
||||
if key in {'message', 'exception'}:
|
||||
continue
|
||||
value = context[key]
|
||||
if key == 'source_traceback':
|
||||
tb = ''.join(traceback.format_list(value))
|
||||
value = 'Object created at (most recent call last):\n'
|
||||
value += tb.rstrip()
|
||||
else:
|
||||
try:
|
||||
value = repr(value)
|
||||
except Exception as ex:
|
||||
value = ('Exception in __repr__ {!r}; '
|
||||
'value type: {!r}'.format(ex, type(value)))
|
||||
lines.append('[{}]: {}\n\n'.format(key, value))
|
||||
|
||||
if exc_info is not None:
|
||||
lines.append('[exception]:\n')
|
||||
formatted_exc = textwrap.indent(
|
||||
''.join(traceback.format_exception(*exc_info)), ' ')
|
||||
lines.append(formatted_exc)
|
||||
|
||||
details = textwrap.indent(''.join(lines), ' ')
|
||||
return '{:02d}. {}:\n{}\n'.format(n, message, details)
|
||||
|
||||
|
||||
_default_cluster = None
|
||||
|
||||
|
||||
def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
|
||||
cluster = ClusterCls(**cluster_kwargs)
|
||||
cluster.init(**(initdb_options or {}))
|
||||
cluster.trust_local_connections()
|
||||
atexit.register(_shutdown_cluster, cluster)
|
||||
return cluster
|
||||
|
||||
|
||||
def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
|
||||
initdb_options=None):
|
||||
cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
|
||||
cluster.start(port='dynamic', server_settings=server_settings)
|
||||
return cluster
|
||||
|
||||
|
||||
def _get_initdb_options(initdb_options=None):
|
||||
if not initdb_options:
|
||||
initdb_options = {}
|
||||
else:
|
||||
initdb_options = dict(initdb_options)
|
||||
|
||||
# Make the default superuser name stable.
|
||||
if 'username' not in initdb_options:
|
||||
initdb_options['username'] = 'postgres'
|
||||
|
||||
return initdb_options
|
||||
|
||||
|
||||
def _init_default_cluster(initdb_options=None):
|
||||
global _default_cluster
|
||||
|
||||
if _default_cluster is None:
|
||||
pg_host = os.environ.get('PGHOST')
|
||||
if pg_host:
|
||||
# Using existing cluster, assuming it is initialized and running
|
||||
_default_cluster = pg_cluster.RunningCluster()
|
||||
else:
|
||||
_default_cluster = _init_cluster(
|
||||
pg_cluster.TempCluster, cluster_kwargs={},
|
||||
initdb_options=_get_initdb_options(initdb_options))
|
||||
|
||||
return _default_cluster
|
||||
|
||||
|
||||
def _shutdown_cluster(cluster):
|
||||
if cluster.get_status() == 'running':
|
||||
cluster.stop()
|
||||
if cluster.get_status() != 'not-initialized':
|
||||
cluster.destroy()
|
||||
|
||||
|
||||
def create_pool(dsn=None, *,
|
||||
min_size=10,
|
||||
max_size=10,
|
||||
max_queries=50000,
|
||||
max_inactive_connection_lifetime=60.0,
|
||||
setup=None,
|
||||
init=None,
|
||||
loop=None,
|
||||
pool_class=pg_pool.Pool,
|
||||
connection_class=pg_connection.Connection,
|
||||
record_class=asyncpg.Record,
|
||||
**connect_kwargs):
|
||||
return pool_class(
|
||||
dsn,
|
||||
min_size=min_size, max_size=max_size,
|
||||
max_queries=max_queries, loop=loop, setup=setup, init=init,
|
||||
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
|
||||
connection_class=connection_class,
|
||||
record_class=record_class,
|
||||
**connect_kwargs)
|
||||
|
||||
|
||||
class ClusterTestCase(TestCase):
|
||||
@classmethod
|
||||
def get_server_settings(cls):
|
||||
settings = {
|
||||
'log_connections': 'on'
|
||||
}
|
||||
|
||||
if cls.cluster.get_pg_version() >= (11, 0):
|
||||
# JITting messes up timing tests, and
|
||||
# is not essential for testing.
|
||||
settings['jit'] = 'off'
|
||||
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
|
||||
cluster = _init_cluster(ClusterCls, cluster_kwargs,
|
||||
_get_initdb_options(initdb_options))
|
||||
cls._clusters.append(cluster)
|
||||
return cluster
|
||||
|
||||
@classmethod
|
||||
def start_cluster(cls, cluster, *, server_settings={}):
|
||||
cluster.start(port='dynamic', server_settings=server_settings)
|
||||
|
||||
@classmethod
|
||||
def setup_cluster(cls):
|
||||
cls.cluster = _init_default_cluster()
|
||||
|
||||
if cls.cluster.get_status() != 'running':
|
||||
cls.cluster.start(
|
||||
port='dynamic', server_settings=cls.get_server_settings())
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._clusters = []
|
||||
cls.setup_cluster()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
for cluster in cls._clusters:
|
||||
if cluster is not _default_cluster:
|
||||
cluster.stop()
|
||||
cluster.destroy()
|
||||
cls._clusters = []
|
||||
|
||||
@classmethod
|
||||
def get_connection_spec(cls, kwargs={}):
|
||||
conn_spec = cls.cluster.get_connection_spec()
|
||||
if kwargs.get('dsn'):
|
||||
conn_spec.pop('host')
|
||||
conn_spec.update(kwargs)
|
||||
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
|
||||
if 'database' not in conn_spec:
|
||||
conn_spec['database'] = 'postgres'
|
||||
if 'user' not in conn_spec:
|
||||
conn_spec['user'] = 'postgres'
|
||||
return conn_spec
|
||||
|
||||
@classmethod
|
||||
def connect(cls, **kwargs):
|
||||
conn_spec = cls.get_connection_spec(kwargs)
|
||||
return pg_connection.connect(**conn_spec, loop=cls.loop)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._pools = []
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
for pool in self._pools:
|
||||
pool.terminate()
|
||||
self._pools = []
|
||||
|
||||
def create_pool(self, pool_class=pg_pool.Pool,
|
||||
connection_class=pg_connection.Connection, **kwargs):
|
||||
conn_spec = self.get_connection_spec(kwargs)
|
||||
pool = create_pool(loop=self.loop, pool_class=pool_class,
|
||||
connection_class=connection_class, **conn_spec)
|
||||
self._pools.append(pool)
|
||||
return pool
|
||||
|
||||
|
||||
class ProxiedClusterTestCase(ClusterTestCase):
|
||||
@classmethod
|
||||
def get_server_settings(cls):
|
||||
settings = dict(super().get_server_settings())
|
||||
settings['listen_addresses'] = '127.0.0.1'
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
def get_proxy_settings(cls):
|
||||
return {'fuzzing-mode': None}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
conn_spec = cls.cluster.get_connection_spec()
|
||||
host = conn_spec.get('host')
|
||||
if not host:
|
||||
host = '127.0.0.1'
|
||||
elif host.startswith('/'):
|
||||
host = '127.0.0.1'
|
||||
cls.proxy = fuzzer.TCPFuzzingProxy(
|
||||
backend_host=host,
|
||||
backend_port=conn_spec['port'],
|
||||
)
|
||||
cls.proxy.start()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.proxy.stop()
|
||||
super().tearDownClass()
|
||||
|
||||
@classmethod
|
||||
def get_connection_spec(cls, kwargs):
|
||||
conn_spec = super().get_connection_spec(kwargs)
|
||||
conn_spec['host'] = cls.proxy.listening_addr
|
||||
conn_spec['port'] = cls.proxy.listening_port
|
||||
return conn_spec
|
||||
|
||||
def tearDown(self):
|
||||
self.proxy.reset()
|
||||
super().tearDown()
|
||||
|
||||
|
||||
def with_connection_options(**options):
|
||||
if not options:
|
||||
raise ValueError('no connection options were specified')
|
||||
|
||||
def wrap(func):
|
||||
func.__connect_options__ = options
|
||||
return func
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class ConnectedTestCase(ClusterTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# Extract options set up with `with_connection_options`.
|
||||
test_func = getattr(self, self._testMethodName).__func__
|
||||
opts = getattr(test_func, '__connect_options__', {})
|
||||
self.con = self.loop.run_until_complete(self.connect(**opts))
|
||||
self.server_version = self.con.get_server_version()
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
self.loop.run_until_complete(self.con.close())
|
||||
self.con = None
|
||||
finally:
|
||||
super().tearDown()
|
||||
|
||||
|
||||
class HotStandbyTestCase(ClusterTestCase):
|
||||
|
||||
@classmethod
|
||||
def setup_cluster(cls):
|
||||
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
|
||||
cls.start_cluster(
|
||||
cls.master_cluster,
|
||||
server_settings={
|
||||
'max_wal_senders': 10,
|
||||
'wal_level': 'hot_standby'
|
||||
}
|
||||
)
|
||||
|
||||
con = None
|
||||
|
||||
try:
|
||||
con = cls.loop.run_until_complete(
|
||||
cls.master_cluster.connect(
|
||||
database='postgres', user='postgres', loop=cls.loop))
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
con.execute('''
|
||||
CREATE ROLE replication WITH LOGIN REPLICATION
|
||||
'''))
|
||||
|
||||
cls.master_cluster.trust_local_replication_by('replication')
|
||||
|
||||
conn_spec = cls.master_cluster.get_connection_spec()
|
||||
|
||||
cls.standby_cluster = cls.new_cluster(
|
||||
pg_cluster.HotStandbyCluster,
|
||||
cluster_kwargs={
|
||||
'master': conn_spec,
|
||||
'replication_user': 'replication'
|
||||
}
|
||||
)
|
||||
cls.start_cluster(
|
||||
cls.standby_cluster,
|
||||
server_settings={
|
||||
'hot_standby': True
|
||||
}
|
||||
)
|
||||
|
||||
finally:
|
||||
if con is not None:
|
||||
cls.loop.run_until_complete(con.close())
|
||||
|
||||
@classmethod
|
||||
def get_cluster_connection_spec(cls, cluster, kwargs={}):
|
||||
conn_spec = cluster.get_connection_spec()
|
||||
if kwargs.get('dsn'):
|
||||
conn_spec.pop('host')
|
||||
conn_spec.update(kwargs)
|
||||
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
|
||||
if 'database' not in conn_spec:
|
||||
conn_spec['database'] = 'postgres'
|
||||
if 'user' not in conn_spec:
|
||||
conn_spec['user'] = 'postgres'
|
||||
return conn_spec
|
||||
|
||||
@classmethod
|
||||
def get_connection_spec(cls, kwargs={}):
|
||||
primary_spec = cls.get_cluster_connection_spec(
|
||||
cls.master_cluster, kwargs
|
||||
)
|
||||
standby_spec = cls.get_cluster_connection_spec(
|
||||
cls.standby_cluster, kwargs
|
||||
)
|
||||
return {
|
||||
'host': [primary_spec['host'], standby_spec['host']],
|
||||
'port': [primary_spec['port'], standby_spec['port']],
|
||||
'database': primary_spec['database'],
|
||||
'user': primary_spec['user'],
|
||||
**kwargs
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def connect_primary(cls, **kwargs):
|
||||
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
|
||||
return pg_connection.connect(**conn_spec, loop=cls.loop)
|
||||
|
||||
@classmethod
|
||||
def connect_standby(cls, **kwargs):
|
||||
conn_spec = cls.get_cluster_connection_spec(
|
||||
cls.standby_cluster,
|
||||
kwargs
|
||||
)
|
||||
return pg_connection.connect(**conn_spec, loop=cls.loop)
|
||||
306
venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py
Normal file
306
venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import threading
|
||||
import typing
|
||||
|
||||
from asyncpg import cluster
|
||||
|
||||
|
||||
class StopServer(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TCPFuzzingProxy:
|
||||
def __init__(self, *, listening_addr: str='127.0.0.1',
|
||||
listening_port: typing.Optional[int]=None,
|
||||
backend_host: str, backend_port: int,
|
||||
settings: typing.Optional[dict]=None) -> None:
|
||||
self.listening_addr = listening_addr
|
||||
self.listening_port = listening_port
|
||||
self.backend_host = backend_host
|
||||
self.backend_port = backend_port
|
||||
self.settings = settings or {}
|
||||
self.loop = None
|
||||
self.connectivity = None
|
||||
self.connectivity_loss = None
|
||||
self.stop_event = None
|
||||
self.connections = {}
|
||||
self.sock = None
|
||||
self.listen_task = None
|
||||
|
||||
async def _wait(self, work):
|
||||
work_task = asyncio.ensure_future(work)
|
||||
stop_event_task = asyncio.ensure_future(self.stop_event.wait())
|
||||
|
||||
try:
|
||||
await asyncio.wait(
|
||||
[work_task, stop_event_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if self.stop_event.is_set():
|
||||
raise StopServer()
|
||||
else:
|
||||
return work_task.result()
|
||||
finally:
|
||||
if not work_task.done():
|
||||
work_task.cancel()
|
||||
if not stop_event_task.done():
|
||||
stop_event_task.cancel()
|
||||
|
||||
def start(self):
|
||||
started = threading.Event()
|
||||
self.thread = threading.Thread(
|
||||
target=self._start_thread, args=(started,))
|
||||
self.thread.start()
|
||||
if not started.wait(timeout=2):
|
||||
raise RuntimeError('fuzzer proxy failed to start')
|
||||
|
||||
def stop(self):
|
||||
self.loop.call_soon_threadsafe(self._stop)
|
||||
self.thread.join()
|
||||
|
||||
def _stop(self):
|
||||
self.stop_event.set()
|
||||
|
||||
def _start_thread(self, started_event):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.connectivity = asyncio.Event()
|
||||
self.connectivity.set()
|
||||
self.connectivity_loss = asyncio.Event()
|
||||
self.stop_event = asyncio.Event()
|
||||
|
||||
if self.listening_port is None:
|
||||
self.listening_port = cluster.find_available_port()
|
||||
|
||||
self.sock = socket.socket()
|
||||
self.sock.bind((self.listening_addr, self.listening_port))
|
||||
self.sock.listen(50)
|
||||
self.sock.setblocking(False)
|
||||
|
||||
try:
|
||||
self.loop.run_until_complete(self._main(started_event))
|
||||
finally:
|
||||
self.loop.close()
|
||||
|
||||
async def _main(self, started_event):
|
||||
self.listen_task = asyncio.ensure_future(self.listen())
|
||||
# Notify the main thread that we are ready to go.
|
||||
started_event.set()
|
||||
try:
|
||||
await self.listen_task
|
||||
finally:
|
||||
for c in list(self.connections):
|
||||
c.close()
|
||||
await asyncio.sleep(0.01)
|
||||
if hasattr(self.loop, 'remove_reader'):
|
||||
self.loop.remove_reader(self.sock.fileno())
|
||||
self.sock.close()
|
||||
|
||||
async def listen(self):
|
||||
while True:
|
||||
try:
|
||||
client_sock, _ = await self._wait(
|
||||
self.loop.sock_accept(self.sock))
|
||||
|
||||
backend_sock = socket.socket()
|
||||
backend_sock.setblocking(False)
|
||||
|
||||
await self._wait(self.loop.sock_connect(
|
||||
backend_sock, (self.backend_host, self.backend_port)))
|
||||
except StopServer:
|
||||
break
|
||||
|
||||
conn = Connection(client_sock, backend_sock, self)
|
||||
conn_task = self.loop.create_task(conn.handle())
|
||||
self.connections[conn] = conn_task
|
||||
|
||||
def trigger_connectivity_loss(self):
|
||||
self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)
|
||||
|
||||
def _trigger_connectivity_loss(self):
|
||||
self.connectivity.clear()
|
||||
self.connectivity_loss.set()
|
||||
|
||||
def restore_connectivity(self):
|
||||
self.loop.call_soon_threadsafe(self._restore_connectivity)
|
||||
|
||||
def _restore_connectivity(self):
|
||||
self.connectivity.set()
|
||||
self.connectivity_loss.clear()
|
||||
|
||||
def reset(self):
|
||||
self.restore_connectivity()
|
||||
|
||||
def _close_connection(self, connection):
|
||||
conn_task = self.connections.pop(connection, None)
|
||||
if conn_task is not None:
|
||||
conn_task.cancel()
|
||||
|
||||
def close_all_connections(self):
|
||||
for conn in list(self.connections):
|
||||
self.loop.call_soon_threadsafe(self._close_connection, conn)
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, client_sock, backend_sock, proxy):
|
||||
self.client_sock = client_sock
|
||||
self.backend_sock = backend_sock
|
||||
self.proxy = proxy
|
||||
self.loop = proxy.loop
|
||||
self.connectivity = proxy.connectivity
|
||||
self.connectivity_loss = proxy.connectivity_loss
|
||||
self.proxy_to_backend_task = None
|
||||
self.proxy_from_backend_task = None
|
||||
self.is_closed = False
|
||||
|
||||
def close(self):
|
||||
if self.is_closed:
|
||||
return
|
||||
|
||||
self.is_closed = True
|
||||
|
||||
if self.proxy_to_backend_task is not None:
|
||||
self.proxy_to_backend_task.cancel()
|
||||
self.proxy_to_backend_task = None
|
||||
|
||||
if self.proxy_from_backend_task is not None:
|
||||
self.proxy_from_backend_task.cancel()
|
||||
self.proxy_from_backend_task = None
|
||||
|
||||
self.proxy._close_connection(self)
|
||||
|
||||
async def handle(self):
|
||||
self.proxy_to_backend_task = asyncio.ensure_future(
|
||||
self.proxy_to_backend())
|
||||
|
||||
self.proxy_from_backend_task = asyncio.ensure_future(
|
||||
self.proxy_from_backend())
|
||||
|
||||
try:
|
||||
await asyncio.wait(
|
||||
[self.proxy_to_backend_task, self.proxy_from_backend_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
finally:
|
||||
if self.proxy_to_backend_task is not None:
|
||||
self.proxy_to_backend_task.cancel()
|
||||
|
||||
if self.proxy_from_backend_task is not None:
|
||||
self.proxy_from_backend_task.cancel()
|
||||
|
||||
# Asyncio fails to properly remove the readers and writers
|
||||
# when the task doing recv() or send() is cancelled, so
|
||||
# we must remove the readers and writers manually before
|
||||
# closing the sockets.
|
||||
self.loop.remove_reader(self.client_sock.fileno())
|
||||
self.loop.remove_writer(self.client_sock.fileno())
|
||||
self.loop.remove_reader(self.backend_sock.fileno())
|
||||
self.loop.remove_writer(self.backend_sock.fileno())
|
||||
|
||||
self.client_sock.close()
|
||||
self.backend_sock.close()
|
||||
|
||||
async def _read(self, sock, n):
|
||||
read_task = asyncio.ensure_future(
|
||||
self.loop.sock_recv(sock, n))
|
||||
conn_event_task = asyncio.ensure_future(
|
||||
self.connectivity_loss.wait())
|
||||
|
||||
try:
|
||||
await asyncio.wait(
|
||||
[read_task, conn_event_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if self.connectivity_loss.is_set():
|
||||
return None
|
||||
else:
|
||||
return read_task.result()
|
||||
finally:
|
||||
if not self.loop.is_closed():
|
||||
if not read_task.done():
|
||||
read_task.cancel()
|
||||
if not conn_event_task.done():
|
||||
conn_event_task.cancel()
|
||||
|
||||
async def _write(self, sock, data):
|
||||
write_task = asyncio.ensure_future(
|
||||
self.loop.sock_sendall(sock, data))
|
||||
conn_event_task = asyncio.ensure_future(
|
||||
self.connectivity_loss.wait())
|
||||
|
||||
try:
|
||||
await asyncio.wait(
|
||||
[write_task, conn_event_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if self.connectivity_loss.is_set():
|
||||
return None
|
||||
else:
|
||||
return write_task.result()
|
||||
finally:
|
||||
if not self.loop.is_closed():
|
||||
if not write_task.done():
|
||||
write_task.cancel()
|
||||
if not conn_event_task.done():
|
||||
conn_event_task.cancel()
|
||||
|
||||
async def proxy_to_backend(self):
|
||||
buf = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
await self.connectivity.wait()
|
||||
if buf is not None:
|
||||
data = buf
|
||||
buf = None
|
||||
else:
|
||||
data = await self._read(self.client_sock, 4096)
|
||||
if data == b'':
|
||||
break
|
||||
if self.connectivity_loss.is_set():
|
||||
if data:
|
||||
buf = data
|
||||
continue
|
||||
await self._write(self.backend_sock, data)
|
||||
|
||||
except ConnectionError:
|
||||
pass
|
||||
|
||||
finally:
|
||||
if not self.loop.is_closed():
|
||||
self.loop.call_soon(self.close)
|
||||
|
||||
async def proxy_from_backend(self):
|
||||
buf = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
await self.connectivity.wait()
|
||||
if buf is not None:
|
||||
data = buf
|
||||
buf = None
|
||||
else:
|
||||
data = await self._read(self.backend_sock, 4096)
|
||||
if data == b'':
|
||||
break
|
||||
if self.connectivity_loss.is_set():
|
||||
if data:
|
||||
buf = data
|
||||
continue
|
||||
await self._write(self.client_sock, data)
|
||||
|
||||
except ConnectionError:
|
||||
pass
|
||||
|
||||
finally:
|
||||
if not self.loop.is_closed():
|
||||
self.loop.call_soon(self.close)
|
||||
13
venv/lib/python3.12/site-packages/asyncpg/_version.py
Normal file
13
venv/lib/python3.12/site-packages/asyncpg/_version.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# This file MUST NOT contain anything but the __version__ assignment.
|
||||
#
|
||||
# When making a release, change the value of __version__
|
||||
# to an appropriate value, and open a pull request against
|
||||
# the correct branch (master if making a new feature release).
|
||||
# The commit message MUST contain a properly formatted release
|
||||
# log, and the commit must be signed.
|
||||
#
|
||||
# The release automation will: build and test the packages for the
|
||||
# supported platforms, publish the packages on PyPI, merge the PR
|
||||
# to the target branch, create a Git tag pointing to the commit.
|
||||
|
||||
__version__ = '0.29.0'
|
||||
688
venv/lib/python3.12/site-packages/asyncpg/cluster.py
Normal file
688
venv/lib/python3.12/site-packages/asyncpg/cluster.py
Normal file
@@ -0,0 +1,688 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import os.path
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
|
||||
import asyncpg
|
||||
from asyncpg import serverversion
|
||||
|
||||
|
||||
_system = platform.uname().system
|
||||
|
||||
if _system == 'Windows':
|
||||
def platform_exe(name):
|
||||
if name.endswith('.exe'):
|
||||
return name
|
||||
return name + '.exe'
|
||||
else:
|
||||
def platform_exe(name):
|
||||
return name
|
||||
|
||||
|
||||
def find_available_port():
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.bind(('127.0.0.1', 0))
|
||||
return sock.getsockname()[1]
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
class ClusterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Cluster:
|
||||
def __init__(self, data_dir, *, pg_config_path=None):
|
||||
self._data_dir = data_dir
|
||||
self._pg_config_path = pg_config_path
|
||||
self._pg_bin_dir = (
|
||||
os.environ.get('PGINSTALLATION')
|
||||
or os.environ.get('PGBIN')
|
||||
)
|
||||
self._pg_ctl = None
|
||||
self._daemon_pid = None
|
||||
self._daemon_process = None
|
||||
self._connection_addr = None
|
||||
self._connection_spec_override = None
|
||||
|
||||
def get_pg_version(self):
|
||||
return self._pg_version
|
||||
|
||||
def is_managed(self):
|
||||
return True
|
||||
|
||||
def get_data_dir(self):
|
||||
return self._data_dir
|
||||
|
||||
def get_status(self):
|
||||
if self._pg_ctl is None:
|
||||
self._init_env()
|
||||
|
||||
process = subprocess.run(
|
||||
[self._pg_ctl, 'status', '-D', self._data_dir],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout, stderr = process.stdout, process.stderr
|
||||
|
||||
if (process.returncode == 4 or not os.path.exists(self._data_dir) or
|
||||
not os.listdir(self._data_dir)):
|
||||
return 'not-initialized'
|
||||
elif process.returncode == 3:
|
||||
return 'stopped'
|
||||
elif process.returncode == 0:
|
||||
r = re.match(r'.*PID\s?:\s+(\d+).*', stdout.decode())
|
||||
if not r:
|
||||
raise ClusterError(
|
||||
'could not parse pg_ctl status output: {}'.format(
|
||||
stdout.decode()))
|
||||
self._daemon_pid = int(r.group(1))
|
||||
return self._test_connection(timeout=0)
|
||||
else:
|
||||
raise ClusterError(
|
||||
'pg_ctl status exited with status {:d}: {}'.format(
|
||||
process.returncode, stderr))
|
||||
|
||||
async def connect(self, loop=None, **kwargs):
|
||||
conn_info = self.get_connection_spec()
|
||||
conn_info.update(kwargs)
|
||||
return await asyncpg.connect(loop=loop, **conn_info)
|
||||
|
||||
def init(self, **settings):
|
||||
"""Initialize cluster."""
|
||||
if self.get_status() != 'not-initialized':
|
||||
raise ClusterError(
|
||||
'cluster in {!r} has already been initialized'.format(
|
||||
self._data_dir))
|
||||
|
||||
settings = dict(settings)
|
||||
if 'encoding' not in settings:
|
||||
settings['encoding'] = 'UTF-8'
|
||||
|
||||
if settings:
|
||||
settings_args = ['--{}={}'.format(k, v)
|
||||
for k, v in settings.items()]
|
||||
extra_args = ['-o'] + [' '.join(settings_args)]
|
||||
else:
|
||||
extra_args = []
|
||||
|
||||
process = subprocess.run(
|
||||
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
output = process.stdout
|
||||
|
||||
if process.returncode != 0:
|
||||
raise ClusterError(
|
||||
'pg_ctl init exited with status {:d}:\n{}'.format(
|
||||
process.returncode, output.decode()))
|
||||
|
||||
return output.decode()
|
||||
|
||||
def start(self, wait=60, *, server_settings={}, **opts):
|
||||
"""Start the cluster."""
|
||||
status = self.get_status()
|
||||
if status == 'running':
|
||||
return
|
||||
elif status == 'not-initialized':
|
||||
raise ClusterError(
|
||||
'cluster in {!r} has not been initialized'.format(
|
||||
self._data_dir))
|
||||
|
||||
port = opts.pop('port', None)
|
||||
if port == 'dynamic':
|
||||
port = find_available_port()
|
||||
|
||||
extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()]
|
||||
extra_args.append('--port={}'.format(port))
|
||||
|
||||
sockdir = server_settings.get('unix_socket_directories')
|
||||
if sockdir is None:
|
||||
sockdir = server_settings.get('unix_socket_directory')
|
||||
if sockdir is None and _system != 'Windows':
|
||||
sockdir = tempfile.gettempdir()
|
||||
|
||||
ssl_key = server_settings.get('ssl_key_file')
|
||||
if ssl_key:
|
||||
# Make sure server certificate key file has correct permissions.
|
||||
keyfile = os.path.join(self._data_dir, 'srvkey.pem')
|
||||
shutil.copy(ssl_key, keyfile)
|
||||
os.chmod(keyfile, 0o600)
|
||||
server_settings = server_settings.copy()
|
||||
server_settings['ssl_key_file'] = keyfile
|
||||
|
||||
if sockdir is not None:
|
||||
if self._pg_version < (9, 3):
|
||||
sockdir_opt = 'unix_socket_directory'
|
||||
else:
|
||||
sockdir_opt = 'unix_socket_directories'
|
||||
|
||||
server_settings[sockdir_opt] = sockdir
|
||||
|
||||
for k, v in server_settings.items():
|
||||
extra_args.extend(['-c', '{}={}'.format(k, v)])
|
||||
|
||||
if _system == 'Windows':
|
||||
# On Windows we have to use pg_ctl as direct execution
|
||||
# of postgres daemon under an Administrative account
|
||||
# is not permitted and there is no easy way to drop
|
||||
# privileges.
|
||||
if os.getenv('ASYNCPG_DEBUG_SERVER'):
|
||||
stdout = sys.stdout
|
||||
print(
|
||||
'asyncpg.cluster: Running',
|
||||
' '.join([
|
||||
self._pg_ctl, 'start', '-D', self._data_dir,
|
||||
'-o', ' '.join(extra_args)
|
||||
]),
|
||||
file=sys.stderr,
|
||||
)
|
||||
else:
|
||||
stdout = subprocess.DEVNULL
|
||||
|
||||
process = subprocess.run(
|
||||
[self._pg_ctl, 'start', '-D', self._data_dir,
|
||||
'-o', ' '.join(extra_args)],
|
||||
stdout=stdout, stderr=subprocess.STDOUT)
|
||||
|
||||
if process.returncode != 0:
|
||||
if process.stderr:
|
||||
stderr = ':\n{}'.format(process.stderr.decode())
|
||||
else:
|
||||
stderr = ''
|
||||
raise ClusterError(
|
||||
'pg_ctl start exited with status {:d}{}'.format(
|
||||
process.returncode, stderr))
|
||||
else:
|
||||
if os.getenv('ASYNCPG_DEBUG_SERVER'):
|
||||
stdout = sys.stdout
|
||||
else:
|
||||
stdout = subprocess.DEVNULL
|
||||
|
||||
self._daemon_process = \
|
||||
subprocess.Popen(
|
||||
[self._postgres, '-D', self._data_dir, *extra_args],
|
||||
stdout=stdout, stderr=subprocess.STDOUT)
|
||||
|
||||
self._daemon_pid = self._daemon_process.pid
|
||||
|
||||
self._test_connection(timeout=wait)
|
||||
|
||||
def reload(self):
|
||||
"""Reload server configuration."""
|
||||
status = self.get_status()
|
||||
if status != 'running':
|
||||
raise ClusterError('cannot reload: cluster is not running')
|
||||
|
||||
process = subprocess.run(
|
||||
[self._pg_ctl, 'reload', '-D', self._data_dir],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
stderr = process.stderr
|
||||
|
||||
if process.returncode != 0:
|
||||
raise ClusterError(
|
||||
'pg_ctl stop exited with status {:d}: {}'.format(
|
||||
process.returncode, stderr.decode()))
|
||||
|
||||
def stop(self, wait=60):
|
||||
process = subprocess.run(
|
||||
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
|
||||
'-m', 'fast'],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
stderr = process.stderr
|
||||
|
||||
if process.returncode != 0:
|
||||
raise ClusterError(
|
||||
'pg_ctl stop exited with status {:d}: {}'.format(
|
||||
process.returncode, stderr.decode()))
|
||||
|
||||
if (self._daemon_process is not None and
|
||||
self._daemon_process.returncode is None):
|
||||
self._daemon_process.kill()
|
||||
|
||||
def destroy(self):
|
||||
status = self.get_status()
|
||||
if status == 'stopped' or status == 'not-initialized':
|
||||
shutil.rmtree(self._data_dir)
|
||||
else:
|
||||
raise ClusterError('cannot destroy {} cluster'.format(status))
|
||||
|
||||
def _get_connection_spec(self):
|
||||
if self._connection_addr is None:
|
||||
self._connection_addr = self._connection_addr_from_pidfile()
|
||||
|
||||
if self._connection_addr is not None:
|
||||
if self._connection_spec_override:
|
||||
args = self._connection_addr.copy()
|
||||
args.update(self._connection_spec_override)
|
||||
return args
|
||||
else:
|
||||
return self._connection_addr
|
||||
|
||||
def get_connection_spec(self):
|
||||
status = self.get_status()
|
||||
if status != 'running':
|
||||
raise ClusterError('cluster is not running')
|
||||
|
||||
return self._get_connection_spec()
|
||||
|
||||
def override_connection_spec(self, **kwargs):
|
||||
self._connection_spec_override = kwargs
|
||||
|
||||
def reset_wal(self, *, oid=None, xid=None):
|
||||
status = self.get_status()
|
||||
if status == 'not-initialized':
|
||||
raise ClusterError(
|
||||
'cannot modify WAL status: cluster is not initialized')
|
||||
|
||||
if status == 'running':
|
||||
raise ClusterError(
|
||||
'cannot modify WAL status: cluster is running')
|
||||
|
||||
opts = []
|
||||
if oid is not None:
|
||||
opts.extend(['-o', str(oid)])
|
||||
if xid is not None:
|
||||
opts.extend(['-x', str(xid)])
|
||||
if not opts:
|
||||
return
|
||||
|
||||
opts.append(self._data_dir)
|
||||
|
||||
try:
|
||||
reset_wal = self._find_pg_binary('pg_resetwal')
|
||||
except ClusterError:
|
||||
reset_wal = self._find_pg_binary('pg_resetxlog')
|
||||
|
||||
process = subprocess.run(
|
||||
[reset_wal] + opts,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
stderr = process.stderr
|
||||
|
||||
if process.returncode != 0:
|
||||
raise ClusterError(
|
||||
'pg_resetwal exited with status {:d}: {}'.format(
|
||||
process.returncode, stderr.decode()))
|
||||
|
||||
def reset_hba(self):
|
||||
"""Remove all records from pg_hba.conf."""
|
||||
status = self.get_status()
|
||||
if status == 'not-initialized':
|
||||
raise ClusterError(
|
||||
'cannot modify HBA records: cluster is not initialized')
|
||||
|
||||
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
|
||||
|
||||
try:
|
||||
with open(pg_hba, 'w'):
|
||||
pass
|
||||
except IOError as e:
|
||||
raise ClusterError(
|
||||
'cannot modify HBA records: {}'.format(e)) from e
|
||||
|
||||
def add_hba_entry(self, *, type='host', database, user, address=None,
|
||||
auth_method, auth_options=None):
|
||||
"""Add a record to pg_hba.conf."""
|
||||
status = self.get_status()
|
||||
if status == 'not-initialized':
|
||||
raise ClusterError(
|
||||
'cannot modify HBA records: cluster is not initialized')
|
||||
|
||||
if type not in {'local', 'host', 'hostssl', 'hostnossl'}:
|
||||
raise ValueError('invalid HBA record type: {!r}'.format(type))
|
||||
|
||||
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
|
||||
|
||||
record = '{} {} {}'.format(type, database, user)
|
||||
|
||||
if type != 'local':
|
||||
if address is None:
|
||||
raise ValueError(
|
||||
'{!r} entry requires a valid address'.format(type))
|
||||
else:
|
||||
record += ' {}'.format(address)
|
||||
|
||||
record += ' {}'.format(auth_method)
|
||||
|
||||
if auth_options is not None:
|
||||
record += ' ' + ' '.join(
|
||||
'{}={}'.format(k, v) for k, v in auth_options)
|
||||
|
||||
try:
|
||||
with open(pg_hba, 'a') as f:
|
||||
print(record, file=f)
|
||||
except IOError as e:
|
||||
raise ClusterError(
|
||||
'cannot modify HBA records: {}'.format(e)) from e
|
||||
|
||||
def trust_local_connections(self):
|
||||
self.reset_hba()
|
||||
|
||||
if _system != 'Windows':
|
||||
self.add_hba_entry(type='local', database='all',
|
||||
user='all', auth_method='trust')
|
||||
self.add_hba_entry(type='host', address='127.0.0.1/32',
|
||||
database='all', user='all',
|
||||
auth_method='trust')
|
||||
self.add_hba_entry(type='host', address='::1/128',
|
||||
database='all', user='all',
|
||||
auth_method='trust')
|
||||
status = self.get_status()
|
||||
if status == 'running':
|
||||
self.reload()
|
||||
|
||||
def trust_local_replication_by(self, user):
|
||||
if _system != 'Windows':
|
||||
self.add_hba_entry(type='local', database='replication',
|
||||
user=user, auth_method='trust')
|
||||
self.add_hba_entry(type='host', address='127.0.0.1/32',
|
||||
database='replication', user=user,
|
||||
auth_method='trust')
|
||||
self.add_hba_entry(type='host', address='::1/128',
|
||||
database='replication', user=user,
|
||||
auth_method='trust')
|
||||
status = self.get_status()
|
||||
if status == 'running':
|
||||
self.reload()
|
||||
|
||||
def _init_env(self):
|
||||
if not self._pg_bin_dir:
|
||||
pg_config = self._find_pg_config(self._pg_config_path)
|
||||
pg_config_data = self._run_pg_config(pg_config)
|
||||
|
||||
self._pg_bin_dir = pg_config_data.get('bindir')
|
||||
if not self._pg_bin_dir:
|
||||
raise ClusterError(
|
||||
'pg_config output did not provide the BINDIR value')
|
||||
|
||||
self._pg_ctl = self._find_pg_binary('pg_ctl')
|
||||
self._postgres = self._find_pg_binary('postgres')
|
||||
self._pg_version = self._get_pg_version()
|
||||
|
||||
def _connection_addr_from_pidfile(self):
|
||||
pidfile = os.path.join(self._data_dir, 'postmaster.pid')
|
||||
|
||||
try:
|
||||
with open(pidfile, 'rt') as f:
|
||||
piddata = f.read()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
lines = piddata.splitlines()
|
||||
|
||||
if len(lines) < 6:
|
||||
# A complete postgres pidfile is at least 6 lines
|
||||
return None
|
||||
|
||||
pmpid = int(lines[0])
|
||||
if self._daemon_pid and pmpid != self._daemon_pid:
|
||||
# This might be an old pidfile left from previous postgres
|
||||
# daemon run.
|
||||
return None
|
||||
|
||||
portnum = lines[3]
|
||||
sockdir = lines[4]
|
||||
hostaddr = lines[5]
|
||||
|
||||
if sockdir:
|
||||
if sockdir[0] != '/':
|
||||
# Relative sockdir
|
||||
sockdir = os.path.normpath(
|
||||
os.path.join(self._data_dir, sockdir))
|
||||
host_str = sockdir
|
||||
else:
|
||||
host_str = hostaddr
|
||||
|
||||
if host_str == '*':
|
||||
host_str = 'localhost'
|
||||
elif host_str == '0.0.0.0':
|
||||
host_str = '127.0.0.1'
|
||||
elif host_str == '::':
|
||||
host_str = '::1'
|
||||
|
||||
return {
|
||||
'host': host_str,
|
||||
'port': portnum
|
||||
}
|
||||
|
||||
def _test_connection(self, timeout=60):
|
||||
self._connection_addr = None
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
try:
|
||||
for i in range(timeout):
|
||||
if self._connection_addr is None:
|
||||
conn_spec = self._get_connection_spec()
|
||||
if conn_spec is None:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
try:
|
||||
con = loop.run_until_complete(
|
||||
asyncpg.connect(database='postgres',
|
||||
user='postgres',
|
||||
timeout=5, loop=loop,
|
||||
**self._connection_addr))
|
||||
except (OSError, asyncio.TimeoutError,
|
||||
asyncpg.CannotConnectNowError,
|
||||
asyncpg.PostgresConnectionError):
|
||||
time.sleep(1)
|
||||
continue
|
||||
except asyncpg.PostgresError:
|
||||
# Any other error other than ServerNotReadyError or
|
||||
# ConnectionError is interpreted to indicate the server is
|
||||
# up.
|
||||
break
|
||||
else:
|
||||
loop.run_until_complete(con.close())
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return 'running'
|
||||
|
||||
def _run_pg_config(self, pg_config_path):
|
||||
process = subprocess.run(
|
||||
pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout, stderr = process.stdout, process.stderr
|
||||
|
||||
if process.returncode != 0:
|
||||
raise ClusterError('pg_config exited with status {:d}: {}'.format(
|
||||
process.returncode, stderr))
|
||||
else:
|
||||
config = {}
|
||||
|
||||
for line in stdout.splitlines():
|
||||
k, eq, v = line.decode('utf-8').partition('=')
|
||||
if eq:
|
||||
config[k.strip().lower()] = v.strip()
|
||||
|
||||
return config
|
||||
|
||||
def _find_pg_config(self, pg_config_path):
|
||||
if pg_config_path is None:
|
||||
pg_install = (
|
||||
os.environ.get('PGINSTALLATION')
|
||||
or os.environ.get('PGBIN')
|
||||
)
|
||||
if pg_install:
|
||||
pg_config_path = platform_exe(
|
||||
os.path.join(pg_install, 'pg_config'))
|
||||
else:
|
||||
pathenv = os.environ.get('PATH').split(os.pathsep)
|
||||
for path in pathenv:
|
||||
pg_config_path = platform_exe(
|
||||
os.path.join(path, 'pg_config'))
|
||||
if os.path.exists(pg_config_path):
|
||||
break
|
||||
else:
|
||||
pg_config_path = None
|
||||
|
||||
if not pg_config_path:
|
||||
raise ClusterError('could not find pg_config executable')
|
||||
|
||||
if not os.path.isfile(pg_config_path):
|
||||
raise ClusterError('{!r} is not an executable'.format(
|
||||
pg_config_path))
|
||||
|
||||
return pg_config_path
|
||||
|
||||
def _find_pg_binary(self, binary):
|
||||
bpath = platform_exe(os.path.join(self._pg_bin_dir, binary))
|
||||
|
||||
if not os.path.isfile(bpath):
|
||||
raise ClusterError(
|
||||
'could not find {} executable: '.format(binary) +
|
||||
'{!r} does not exist or is not a file'.format(bpath))
|
||||
|
||||
return bpath
|
||||
|
||||
def _get_pg_version(self):
|
||||
process = subprocess.run(
|
||||
[self._postgres, '--version'],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout, stderr = process.stdout, process.stderr
|
||||
|
||||
if process.returncode != 0:
|
||||
raise ClusterError(
|
||||
'postgres --version exited with status {:d}: {}'.format(
|
||||
process.returncode, stderr))
|
||||
|
||||
version_string = stdout.decode('utf-8').strip(' \n')
|
||||
prefix = 'postgres (PostgreSQL) '
|
||||
if not version_string.startswith(prefix):
|
||||
raise ClusterError(
|
||||
'could not determine server version from {!r}'.format(
|
||||
version_string))
|
||||
version_string = version_string[len(prefix):]
|
||||
|
||||
return serverversion.split_server_version_string(version_string)
|
||||
|
||||
|
||||
class TempCluster(Cluster):
|
||||
def __init__(self, *,
|
||||
data_dir_suffix=None, data_dir_prefix=None,
|
||||
data_dir_parent=None, pg_config_path=None):
|
||||
self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
|
||||
prefix=data_dir_prefix,
|
||||
dir=data_dir_parent)
|
||||
super().__init__(self._data_dir, pg_config_path=pg_config_path)
|
||||
|
||||
|
||||
class HotStandbyCluster(TempCluster):
|
||||
def __init__(self, *,
|
||||
master, replication_user,
|
||||
data_dir_suffix=None, data_dir_prefix=None,
|
||||
data_dir_parent=None, pg_config_path=None):
|
||||
self._master = master
|
||||
self._repl_user = replication_user
|
||||
super().__init__(
|
||||
data_dir_suffix=data_dir_suffix,
|
||||
data_dir_prefix=data_dir_prefix,
|
||||
data_dir_parent=data_dir_parent,
|
||||
pg_config_path=pg_config_path)
|
||||
|
||||
def _init_env(self):
|
||||
super()._init_env()
|
||||
self._pg_basebackup = self._find_pg_binary('pg_basebackup')
|
||||
|
||||
def init(self, **settings):
|
||||
"""Initialize cluster."""
|
||||
if self.get_status() != 'not-initialized':
|
||||
raise ClusterError(
|
||||
'cluster in {!r} has already been initialized'.format(
|
||||
self._data_dir))
|
||||
|
||||
process = subprocess.run(
|
||||
[self._pg_basebackup, '-h', self._master['host'],
|
||||
'-p', self._master['port'], '-D', self._data_dir,
|
||||
'-U', self._repl_user],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
|
||||
output = process.stdout
|
||||
|
||||
if process.returncode != 0:
|
||||
raise ClusterError(
|
||||
'pg_basebackup init exited with status {:d}:\n{}'.format(
|
||||
process.returncode, output.decode()))
|
||||
|
||||
if self._pg_version < (12, 0):
|
||||
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
|
||||
f.write(textwrap.dedent("""\
|
||||
standby_mode = 'on'
|
||||
primary_conninfo = 'host={host} port={port} user={user}'
|
||||
""".format(
|
||||
host=self._master['host'],
|
||||
port=self._master['port'],
|
||||
user=self._repl_user)))
|
||||
else:
|
||||
f = open(os.path.join(self._data_dir, 'standby.signal'), 'w')
|
||||
f.close()
|
||||
|
||||
return output.decode()
|
||||
|
||||
def start(self, wait=60, *, server_settings={}, **opts):
|
||||
if self._pg_version >= (12, 0):
|
||||
server_settings = server_settings.copy()
|
||||
server_settings['primary_conninfo'] = (
|
||||
'"host={host} port={port} user={user}"'.format(
|
||||
host=self._master['host'],
|
||||
port=self._master['port'],
|
||||
user=self._repl_user,
|
||||
)
|
||||
)
|
||||
|
||||
super().start(wait=wait, server_settings=server_settings, **opts)
|
||||
|
||||
|
||||
class RunningCluster(Cluster):
|
||||
def __init__(self, **kwargs):
|
||||
self.conn_spec = kwargs
|
||||
|
||||
def is_managed(self):
|
||||
return False
|
||||
|
||||
def get_connection_spec(self):
|
||||
return dict(self.conn_spec)
|
||||
|
||||
def get_status(self):
|
||||
return 'running'
|
||||
|
||||
def init(self, **settings):
|
||||
pass
|
||||
|
||||
def start(self, wait=60, **settings):
|
||||
pass
|
||||
|
||||
def stop(self, wait=60):
|
||||
pass
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
def reset_hba(self):
|
||||
raise ClusterError('cannot modify HBA records of unmanaged cluster')
|
||||
|
||||
def add_hba_entry(self, *, type='host', database, user, address=None,
|
||||
auth_method, auth_options=None):
|
||||
raise ClusterError('cannot modify HBA records of unmanaged cluster')
|
||||
61
venv/lib/python3.12/site-packages/asyncpg/compat.py
Normal file
61
venv/lib/python3.12/site-packages/asyncpg/compat.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import pathlib
|
||||
import platform
|
||||
import typing
|
||||
import sys
|
||||
|
||||
|
||||
SYSTEM = platform.uname().system
|
||||
|
||||
|
||||
if SYSTEM == 'Windows':
|
||||
import ctypes.wintypes
|
||||
|
||||
CSIDL_APPDATA = 0x001a
|
||||
|
||||
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
|
||||
# We cannot simply use expanduser() as that returns the user's
|
||||
# home directory, whereas Postgres stores its config in
|
||||
# %AppData% on Windows.
|
||||
buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH)
|
||||
r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf)
|
||||
if r:
|
||||
return None
|
||||
else:
|
||||
return pathlib.Path(buf.value) / 'postgresql'
|
||||
|
||||
else:
|
||||
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
|
||||
try:
|
||||
return pathlib.Path.home()
|
||||
except (RuntimeError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
async def wait_closed(stream):
|
||||
# Not all asyncio versions have StreamWriter.wait_closed().
|
||||
if hasattr(stream, 'wait_closed'):
|
||||
try:
|
||||
await stream.wait_closed()
|
||||
except ConnectionResetError:
|
||||
# On Windows wait_closed() sometimes propagates
|
||||
# ConnectionResetError which is totally unnecessary.
|
||||
pass
|
||||
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
from ._asyncio_compat import wait_for as wait_for # noqa: F401
|
||||
else:
|
||||
from asyncio import wait_for as wait_for # noqa: F401
|
||||
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
|
||||
else:
|
||||
from asyncio import timeout as timeout # noqa: F401
|
||||
1081
venv/lib/python3.12/site-packages/asyncpg/connect_utils.py
Normal file
1081
venv/lib/python3.12/site-packages/asyncpg/connect_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
2655
venv/lib/python3.12/site-packages/asyncpg/connection.py
Normal file
2655
venv/lib/python3.12/site-packages/asyncpg/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
44
venv/lib/python3.12/site-packages/asyncpg/connresource.py
Normal file
44
venv/lib/python3.12/site-packages/asyncpg/connresource.py
Normal file
@@ -0,0 +1,44 @@
|
||||
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import functools
|
||||
|
||||
from . import exceptions
|
||||
|
||||
|
||||
def guarded(meth):
|
||||
"""A decorator to add a sanity check to ConnectionResource methods."""
|
||||
|
||||
@functools.wraps(meth)
|
||||
def _check(self, *args, **kwargs):
|
||||
self._check_conn_validity(meth.__name__)
|
||||
return meth(self, *args, **kwargs)
|
||||
|
||||
return _check
|
||||
|
||||
|
||||
class ConnectionResource:
|
||||
__slots__ = ('_connection', '_con_release_ctr')
|
||||
|
||||
def __init__(self, connection):
|
||||
self._connection = connection
|
||||
self._con_release_ctr = connection._pool_release_ctr
|
||||
|
||||
def _check_conn_validity(self, meth_name):
|
||||
con_release_ctr = self._connection._pool_release_ctr
|
||||
if con_release_ctr != self._con_release_ctr:
|
||||
raise exceptions.InterfaceError(
|
||||
'cannot call {}.{}(): '
|
||||
'the underlying connection has been released back '
|
||||
'to the pool'.format(self.__class__.__name__, meth_name))
|
||||
|
||||
if self._connection.is_closed():
|
||||
raise exceptions.InterfaceError(
|
||||
'cannot call {}.{}(): '
|
||||
'the underlying connection is closed'.format(
|
||||
self.__class__.__name__, meth_name))
|
||||
323
venv/lib/python3.12/site-packages/asyncpg/cursor.py
Normal file
323
venv/lib/python3.12/site-packages/asyncpg/cursor.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import collections
|
||||
|
||||
from . import connresource
|
||||
from . import exceptions
|
||||
|
||||
|
||||
class CursorFactory(connresource.ConnectionResource):
|
||||
"""A cursor interface for the results of a query.
|
||||
|
||||
A cursor interface can be used to initiate efficient traversal of the
|
||||
results of a large query.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'_state',
|
||||
'_args',
|
||||
'_prefetch',
|
||||
'_query',
|
||||
'_timeout',
|
||||
'_record_class',
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection,
|
||||
query,
|
||||
state,
|
||||
args,
|
||||
prefetch,
|
||||
timeout,
|
||||
record_class
|
||||
):
|
||||
super().__init__(connection)
|
||||
self._args = args
|
||||
self._prefetch = prefetch
|
||||
self._query = query
|
||||
self._timeout = timeout
|
||||
self._state = state
|
||||
self._record_class = record_class
|
||||
if state is not None:
|
||||
state.attach()
|
||||
|
||||
@connresource.guarded
|
||||
def __aiter__(self):
|
||||
prefetch = 50 if self._prefetch is None else self._prefetch
|
||||
return CursorIterator(
|
||||
self._connection,
|
||||
self._query,
|
||||
self._state,
|
||||
self._args,
|
||||
self._record_class,
|
||||
prefetch,
|
||||
self._timeout,
|
||||
)
|
||||
|
||||
@connresource.guarded
|
||||
def __await__(self):
|
||||
if self._prefetch is not None:
|
||||
raise exceptions.InterfaceError(
|
||||
'prefetch argument can only be specified for iterable cursor')
|
||||
cursor = Cursor(
|
||||
self._connection,
|
||||
self._query,
|
||||
self._state,
|
||||
self._args,
|
||||
self._record_class,
|
||||
)
|
||||
return cursor._init(self._timeout).__await__()
|
||||
|
||||
def __del__(self):
|
||||
if self._state is not None:
|
||||
self._state.detach()
|
||||
self._connection._maybe_gc_stmt(self._state)
|
||||
|
||||
|
||||
class BaseCursor(connresource.ConnectionResource):
|
||||
|
||||
__slots__ = (
|
||||
'_state',
|
||||
'_args',
|
||||
'_portal_name',
|
||||
'_exhausted',
|
||||
'_query',
|
||||
'_record_class',
|
||||
)
|
||||
|
||||
def __init__(self, connection, query, state, args, record_class):
|
||||
super().__init__(connection)
|
||||
self._args = args
|
||||
self._state = state
|
||||
if state is not None:
|
||||
state.attach()
|
||||
self._portal_name = None
|
||||
self._exhausted = False
|
||||
self._query = query
|
||||
self._record_class = record_class
|
||||
|
||||
def _check_ready(self):
|
||||
if self._state is None:
|
||||
raise exceptions.InterfaceError(
|
||||
'cursor: no associated prepared statement')
|
||||
|
||||
if self._state.closed:
|
||||
raise exceptions.InterfaceError(
|
||||
'cursor: the prepared statement is closed')
|
||||
|
||||
if not self._connection._top_xact:
|
||||
raise exceptions.NoActiveSQLTransactionError(
|
||||
'cursor cannot be created outside of a transaction')
|
||||
|
||||
async def _bind_exec(self, n, timeout):
|
||||
self._check_ready()
|
||||
|
||||
if self._portal_name:
|
||||
raise exceptions.InterfaceError(
|
||||
'cursor already has an open portal')
|
||||
|
||||
con = self._connection
|
||||
protocol = con._protocol
|
||||
|
||||
self._portal_name = con._get_unique_id('portal')
|
||||
buffer, _, self._exhausted = await protocol.bind_execute(
|
||||
self._state, self._args, self._portal_name, n, True, timeout)
|
||||
return buffer
|
||||
|
||||
async def _bind(self, timeout):
|
||||
self._check_ready()
|
||||
|
||||
if self._portal_name:
|
||||
raise exceptions.InterfaceError(
|
||||
'cursor already has an open portal')
|
||||
|
||||
con = self._connection
|
||||
protocol = con._protocol
|
||||
|
||||
self._portal_name = con._get_unique_id('portal')
|
||||
buffer = await protocol.bind(self._state, self._args,
|
||||
self._portal_name,
|
||||
timeout)
|
||||
return buffer
|
||||
|
||||
async def _exec(self, n, timeout):
|
||||
self._check_ready()
|
||||
|
||||
if not self._portal_name:
|
||||
raise exceptions.InterfaceError(
|
||||
'cursor does not have an open portal')
|
||||
|
||||
protocol = self._connection._protocol
|
||||
buffer, _, self._exhausted = await protocol.execute(
|
||||
self._state, self._portal_name, n, True, timeout)
|
||||
return buffer
|
||||
|
||||
async def _close_portal(self, timeout):
|
||||
self._check_ready()
|
||||
|
||||
if not self._portal_name:
|
||||
raise exceptions.InterfaceError(
|
||||
'cursor does not have an open portal')
|
||||
|
||||
protocol = self._connection._protocol
|
||||
await protocol.close_portal(self._portal_name, timeout)
|
||||
self._portal_name = None
|
||||
|
||||
def __repr__(self):
|
||||
attrs = []
|
||||
if self._exhausted:
|
||||
attrs.append('exhausted')
|
||||
attrs.append('') # to separate from id
|
||||
|
||||
if self.__class__.__module__.startswith('asyncpg.'):
|
||||
mod = 'asyncpg'
|
||||
else:
|
||||
mod = self.__class__.__module__
|
||||
|
||||
return '<{}.{} "{!s:.30}" {}{:#x}>'.format(
|
||||
mod, self.__class__.__name__,
|
||||
self._state.query,
|
||||
' '.join(attrs), id(self))
|
||||
|
||||
def __del__(self):
|
||||
if self._state is not None:
|
||||
self._state.detach()
|
||||
self._connection._maybe_gc_stmt(self._state)
|
||||
|
||||
|
||||
class CursorIterator(BaseCursor):
|
||||
|
||||
__slots__ = ('_buffer', '_prefetch', '_timeout')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection,
|
||||
query,
|
||||
state,
|
||||
args,
|
||||
record_class,
|
||||
prefetch,
|
||||
timeout
|
||||
):
|
||||
super().__init__(connection, query, state, args, record_class)
|
||||
|
||||
if prefetch <= 0:
|
||||
raise exceptions.InterfaceError(
|
||||
'prefetch argument must be greater than zero')
|
||||
|
||||
self._buffer = collections.deque()
|
||||
self._prefetch = prefetch
|
||||
self._timeout = timeout
|
||||
|
||||
@connresource.guarded
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
@connresource.guarded
|
||||
async def __anext__(self):
|
||||
if self._state is None:
|
||||
self._state = await self._connection._get_statement(
|
||||
self._query,
|
||||
self._timeout,
|
||||
named=True,
|
||||
record_class=self._record_class,
|
||||
)
|
||||
self._state.attach()
|
||||
|
||||
if not self._portal_name and not self._exhausted:
|
||||
buffer = await self._bind_exec(self._prefetch, self._timeout)
|
||||
self._buffer.extend(buffer)
|
||||
|
||||
if not self._buffer and not self._exhausted:
|
||||
buffer = await self._exec(self._prefetch, self._timeout)
|
||||
self._buffer.extend(buffer)
|
||||
|
||||
if self._portal_name and self._exhausted:
|
||||
await self._close_portal(self._timeout)
|
||||
|
||||
if self._buffer:
|
||||
return self._buffer.popleft()
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class Cursor(BaseCursor):
|
||||
"""An open *portal* into the results of a query."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
async def _init(self, timeout):
|
||||
if self._state is None:
|
||||
self._state = await self._connection._get_statement(
|
||||
self._query,
|
||||
timeout,
|
||||
named=True,
|
||||
record_class=self._record_class,
|
||||
)
|
||||
self._state.attach()
|
||||
self._check_ready()
|
||||
await self._bind(timeout)
|
||||
return self
|
||||
|
||||
@connresource.guarded
|
||||
async def fetch(self, n, *, timeout=None):
|
||||
r"""Return the next *n* rows as a list of :class:`Record` objects.
|
||||
|
||||
:param float timeout: Optional timeout value in seconds.
|
||||
|
||||
:return: A list of :class:`Record` instances.
|
||||
"""
|
||||
self._check_ready()
|
||||
if n <= 0:
|
||||
raise exceptions.InterfaceError('n must be greater than zero')
|
||||
if self._exhausted:
|
||||
return []
|
||||
recs = await self._exec(n, timeout)
|
||||
if len(recs) < n:
|
||||
self._exhausted = True
|
||||
return recs
|
||||
|
||||
@connresource.guarded
|
||||
async def fetchrow(self, *, timeout=None):
|
||||
r"""Return the next row.
|
||||
|
||||
:param float timeout: Optional timeout value in seconds.
|
||||
|
||||
:return: A :class:`Record` instance.
|
||||
"""
|
||||
self._check_ready()
|
||||
if self._exhausted:
|
||||
return None
|
||||
recs = await self._exec(1, timeout)
|
||||
if len(recs) < 1:
|
||||
self._exhausted = True
|
||||
return None
|
||||
return recs[0]
|
||||
|
||||
@connresource.guarded
|
||||
async def forward(self, n, *, timeout=None) -> int:
|
||||
r"""Skip over the next *n* rows.
|
||||
|
||||
:param float timeout: Optional timeout value in seconds.
|
||||
|
||||
:return: A number of rows actually skipped over (<= *n*).
|
||||
"""
|
||||
self._check_ready()
|
||||
if n <= 0:
|
||||
raise exceptions.InterfaceError('n must be greater than zero')
|
||||
|
||||
protocol = self._connection._protocol
|
||||
status = await protocol.query('MOVE FORWARD {:d} {}'.format(
|
||||
n, self._portal_name), timeout)
|
||||
|
||||
advanced = int(status.split()[1])
|
||||
if advanced < n:
|
||||
self._exhausted = True
|
||||
|
||||
return advanced
|
||||
1198
venv/lib/python3.12/site-packages/asyncpg/exceptions/__init__.py
Normal file
1198
venv/lib/python3.12/site-packages/asyncpg/exceptions/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
299
venv/lib/python3.12/site-packages/asyncpg/exceptions/_base.py
Normal file
299
venv/lib/python3.12/site-packages/asyncpg/exceptions/_base.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import asyncpg
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
|
||||
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
|
||||
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
|
||||
'ClientConfigurationError',
|
||||
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
|
||||
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
|
||||
'UnsupportedServerFeatureError')
|
||||
|
||||
|
||||
def _is_asyncpg_class(cls):
|
||||
modname = cls.__module__
|
||||
return modname == 'asyncpg' or modname.startswith('asyncpg.')
|
||||
|
||||
|
||||
class PostgresMessageMeta(type):
|
||||
|
||||
_message_map = {}
|
||||
_field_map = {
|
||||
'S': 'severity',
|
||||
'V': 'severity_en',
|
||||
'C': 'sqlstate',
|
||||
'M': 'message',
|
||||
'D': 'detail',
|
||||
'H': 'hint',
|
||||
'P': 'position',
|
||||
'p': 'internal_position',
|
||||
'q': 'internal_query',
|
||||
'W': 'context',
|
||||
's': 'schema_name',
|
||||
't': 'table_name',
|
||||
'c': 'column_name',
|
||||
'd': 'data_type_name',
|
||||
'n': 'constraint_name',
|
||||
'F': 'server_source_filename',
|
||||
'L': 'server_source_line',
|
||||
'R': 'server_source_function'
|
||||
}
|
||||
|
||||
def __new__(mcls, name, bases, dct):
|
||||
cls = super().__new__(mcls, name, bases, dct)
|
||||
if cls.__module__ == mcls.__module__ and name == 'PostgresMessage':
|
||||
for f in mcls._field_map.values():
|
||||
setattr(cls, f, None)
|
||||
|
||||
if _is_asyncpg_class(cls):
|
||||
mod = sys.modules[cls.__module__]
|
||||
if hasattr(mod, name):
|
||||
raise RuntimeError('exception class redefinition: {}'.format(
|
||||
name))
|
||||
|
||||
code = dct.get('sqlstate')
|
||||
if code is not None:
|
||||
existing = mcls._message_map.get(code)
|
||||
if existing is not None:
|
||||
raise TypeError('{} has duplicate SQLSTATE code, which is'
|
||||
'already defined by {}'.format(
|
||||
name, existing.__name__))
|
||||
mcls._message_map[code] = cls
|
||||
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def get_message_class_for_sqlstate(mcls, code):
|
||||
return mcls._message_map.get(code, UnknownPostgresError)
|
||||
|
||||
|
||||
class PostgresMessage(metaclass=PostgresMessageMeta):
|
||||
|
||||
@classmethod
|
||||
def _get_error_class(cls, fields):
|
||||
sqlstate = fields.get('C')
|
||||
return type(cls).get_message_class_for_sqlstate(sqlstate)
|
||||
|
||||
@classmethod
|
||||
def _get_error_dict(cls, fields, query):
|
||||
dct = {
|
||||
'query': query
|
||||
}
|
||||
|
||||
field_map = type(cls)._field_map
|
||||
for k, v in fields.items():
|
||||
field = field_map.get(k)
|
||||
if field:
|
||||
dct[field] = v
|
||||
|
||||
return dct
|
||||
|
||||
@classmethod
|
||||
def _make_constructor(cls, fields, query=None):
|
||||
dct = cls._get_error_dict(fields, query)
|
||||
|
||||
exccls = cls._get_error_class(fields)
|
||||
message = dct.get('message', '')
|
||||
|
||||
# PostgreSQL will raise an exception when it detects
|
||||
# that the result type of the query has changed from
|
||||
# when the statement was prepared.
|
||||
#
|
||||
# The original error is somewhat cryptic and unspecific,
|
||||
# so we raise a custom subclass that is easier to handle
|
||||
# and identify.
|
||||
#
|
||||
# Note that we specifically do not rely on the error
|
||||
# message, as it is localizable.
|
||||
is_icse = (
|
||||
exccls.__name__ == 'FeatureNotSupportedError' and
|
||||
_is_asyncpg_class(exccls) and
|
||||
dct.get('server_source_function') == 'RevalidateCachedQuery'
|
||||
)
|
||||
|
||||
if is_icse:
|
||||
exceptions = sys.modules[exccls.__module__]
|
||||
exccls = exceptions.InvalidCachedStatementError
|
||||
message = ('cached statement plan is invalid due to a database '
|
||||
'schema or configuration change')
|
||||
|
||||
is_prepared_stmt_error = (
|
||||
exccls.__name__ in ('DuplicatePreparedStatementError',
|
||||
'InvalidSQLStatementNameError') and
|
||||
_is_asyncpg_class(exccls)
|
||||
)
|
||||
|
||||
if is_prepared_stmt_error:
|
||||
hint = dct.get('hint', '')
|
||||
hint += textwrap.dedent("""\
|
||||
|
||||
NOTE: pgbouncer with pool_mode set to "transaction" or
|
||||
"statement" does not support prepared statements properly.
|
||||
You have two options:
|
||||
|
||||
* if you are using pgbouncer for connection pooling to a
|
||||
single server, switch to the connection pool functionality
|
||||
provided by asyncpg, it is a much better option for this
|
||||
purpose;
|
||||
|
||||
* if you have no option of avoiding the use of pgbouncer,
|
||||
then you can set statement_cache_size to 0 when creating
|
||||
the asyncpg connection object.
|
||||
""")
|
||||
|
||||
dct['hint'] = hint
|
||||
|
||||
return exccls, message, dct
|
||||
|
||||
def as_dict(self):
|
||||
dct = {}
|
||||
for f in type(self)._field_map.values():
|
||||
val = getattr(self, f)
|
||||
if val is not None:
|
||||
dct[f] = val
|
||||
return dct
|
||||
|
||||
|
||||
class PostgresError(PostgresMessage, Exception):
|
||||
"""Base class for all Postgres errors."""
|
||||
|
||||
def __str__(self):
|
||||
msg = self.args[0]
|
||||
if self.detail:
|
||||
msg += '\nDETAIL: {}'.format(self.detail)
|
||||
if self.hint:
|
||||
msg += '\nHINT: {}'.format(self.hint)
|
||||
|
||||
return msg
|
||||
|
||||
@classmethod
|
||||
def new(cls, fields, query=None):
|
||||
exccls, message, dct = cls._make_constructor(fields, query)
|
||||
ex = exccls(message)
|
||||
ex.__dict__.update(dct)
|
||||
return ex
|
||||
|
||||
|
||||
class FatalPostgresError(PostgresError):
|
||||
"""A fatal error that should result in server disconnection."""
|
||||
|
||||
|
||||
class UnknownPostgresError(FatalPostgresError):
|
||||
"""An error with an unknown SQLSTATE code."""
|
||||
|
||||
|
||||
class InterfaceMessage:
|
||||
def __init__(self, *, detail=None, hint=None):
|
||||
self.detail = detail
|
||||
self.hint = hint
|
||||
|
||||
def __str__(self):
|
||||
msg = self.args[0]
|
||||
if self.detail:
|
||||
msg += '\nDETAIL: {}'.format(self.detail)
|
||||
if self.hint:
|
||||
msg += '\nHINT: {}'.format(self.hint)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
class InterfaceError(InterfaceMessage, Exception):
|
||||
"""An error caused by improper use of asyncpg API."""
|
||||
|
||||
def __init__(self, msg, *, detail=None, hint=None):
|
||||
InterfaceMessage.__init__(self, detail=detail, hint=hint)
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
def with_msg(self, msg):
|
||||
return type(self)(
|
||||
msg,
|
||||
detail=self.detail,
|
||||
hint=self.hint,
|
||||
).with_traceback(
|
||||
self.__traceback__
|
||||
)
|
||||
|
||||
|
||||
class ClientConfigurationError(InterfaceError, ValueError):
|
||||
"""An error caused by improper client configuration."""
|
||||
|
||||
|
||||
class DataError(InterfaceError, ValueError):
|
||||
"""An error caused by invalid query input."""
|
||||
|
||||
|
||||
class UnsupportedClientFeatureError(InterfaceError):
|
||||
"""Requested feature is unsupported by asyncpg."""
|
||||
|
||||
|
||||
class UnsupportedServerFeatureError(InterfaceError):
|
||||
"""Requested feature is unsupported by PostgreSQL server."""
|
||||
|
||||
|
||||
class InterfaceWarning(InterfaceMessage, UserWarning):
|
||||
"""A warning caused by an improper use of asyncpg API."""
|
||||
|
||||
def __init__(self, msg, *, detail=None, hint=None):
|
||||
InterfaceMessage.__init__(self, detail=detail, hint=hint)
|
||||
UserWarning.__init__(self, msg)
|
||||
|
||||
|
||||
class InternalClientError(Exception):
|
||||
"""All unexpected errors not classified otherwise."""
|
||||
|
||||
|
||||
class ProtocolError(InternalClientError):
|
||||
"""Unexpected condition in the handling of PostgreSQL protocol input."""
|
||||
|
||||
|
||||
class TargetServerAttributeNotMatched(InternalClientError):
|
||||
"""Could not find a host that satisfies the target attribute requirement"""
|
||||
|
||||
|
||||
class OutdatedSchemaCacheError(InternalClientError):
|
||||
"""A value decoding error caused by a schema change before row fetching."""
|
||||
|
||||
def __init__(self, msg, *, schema=None, data_type=None, position=None):
|
||||
super().__init__(msg)
|
||||
self.schema_name = schema
|
||||
self.data_type_name = data_type
|
||||
self.position = position
|
||||
|
||||
|
||||
class PostgresLogMessage(PostgresMessage):
|
||||
"""A base class for non-error server messages."""
|
||||
|
||||
def __str__(self):
|
||||
return '{}: {}'.format(type(self).__name__, self.message)
|
||||
|
||||
def __setattr__(self, name, val):
|
||||
raise TypeError('instances of {} are immutable'.format(
|
||||
type(self).__name__))
|
||||
|
||||
@classmethod
|
||||
def new(cls, fields, query=None):
|
||||
exccls, message_text, dct = cls._make_constructor(fields, query)
|
||||
|
||||
if exccls is UnknownPostgresError:
|
||||
exccls = PostgresLogMessage
|
||||
|
||||
if exccls is PostgresLogMessage:
|
||||
severity = dct.get('severity_en') or dct.get('severity')
|
||||
if severity and severity.upper() == 'WARNING':
|
||||
exccls = asyncpg.PostgresWarning
|
||||
|
||||
if issubclass(exccls, (BaseException, Warning)):
|
||||
msg = exccls(message_text)
|
||||
else:
|
||||
msg = exccls()
|
||||
|
||||
msg.__dict__.update(dct)
|
||||
return msg
|
||||
292
venv/lib/python3.12/site-packages/asyncpg/introspection.py
Normal file
292
venv/lib/python3.12/site-packages/asyncpg/introspection.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
_TYPEINFO_13 = '''\
|
||||
(
|
||||
SELECT
|
||||
t.oid AS oid,
|
||||
ns.nspname AS ns,
|
||||
t.typname AS name,
|
||||
t.typtype AS kind,
|
||||
(CASE WHEN t.typtype = 'd' THEN
|
||||
(WITH RECURSIVE typebases(oid, depth) AS (
|
||||
SELECT
|
||||
t2.typbasetype AS oid,
|
||||
0 AS depth
|
||||
FROM
|
||||
pg_type t2
|
||||
WHERE
|
||||
t2.oid = t.oid
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
t2.typbasetype AS oid,
|
||||
tb.depth + 1 AS depth
|
||||
FROM
|
||||
pg_type t2,
|
||||
typebases tb
|
||||
WHERE
|
||||
tb.oid = t2.oid
|
||||
AND t2.typbasetype != 0
|
||||
) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1)
|
||||
|
||||
ELSE NULL
|
||||
END) AS basetype,
|
||||
t.typelem AS elemtype,
|
||||
elem_t.typdelim AS elemdelim,
|
||||
range_t.rngsubtype AS range_subtype,
|
||||
(CASE WHEN t.typtype = 'c' THEN
|
||||
(SELECT
|
||||
array_agg(ia.atttypid ORDER BY ia.attnum)
|
||||
FROM
|
||||
pg_attribute ia
|
||||
INNER JOIN pg_class c
|
||||
ON (ia.attrelid = c.oid)
|
||||
WHERE
|
||||
ia.attnum > 0 AND NOT ia.attisdropped
|
||||
AND c.reltype = t.oid)
|
||||
|
||||
ELSE NULL
|
||||
END) AS attrtypoids,
|
||||
(CASE WHEN t.typtype = 'c' THEN
|
||||
(SELECT
|
||||
array_agg(ia.attname::text ORDER BY ia.attnum)
|
||||
FROM
|
||||
pg_attribute ia
|
||||
INNER JOIN pg_class c
|
||||
ON (ia.attrelid = c.oid)
|
||||
WHERE
|
||||
ia.attnum > 0 AND NOT ia.attisdropped
|
||||
AND c.reltype = t.oid)
|
||||
|
||||
ELSE NULL
|
||||
END) AS attrnames
|
||||
FROM
|
||||
pg_catalog.pg_type AS t
|
||||
INNER JOIN pg_catalog.pg_namespace ns ON (
|
||||
ns.oid = t.typnamespace)
|
||||
LEFT JOIN pg_type elem_t ON (
|
||||
t.typlen = -1 AND
|
||||
t.typelem != 0 AND
|
||||
t.typelem = elem_t.oid
|
||||
)
|
||||
LEFT JOIN pg_range range_t ON (
|
||||
t.oid = range_t.rngtypid
|
||||
)
|
||||
)
|
||||
'''
|
||||
|
||||
|
||||
INTRO_LOOKUP_TYPES_13 = '''\
|
||||
WITH RECURSIVE typeinfo_tree(
|
||||
oid, ns, name, kind, basetype, elemtype, elemdelim,
|
||||
range_subtype, attrtypoids, attrnames, depth)
|
||||
AS (
|
||||
SELECT
|
||||
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
|
||||
ti.elemtype, ti.elemdelim, ti.range_subtype,
|
||||
ti.attrtypoids, ti.attrnames, 0
|
||||
FROM
|
||||
{typeinfo} AS ti
|
||||
WHERE
|
||||
ti.oid = any($1::oid[])
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
|
||||
ti.elemtype, ti.elemdelim, ti.range_subtype,
|
||||
ti.attrtypoids, ti.attrnames, tt.depth + 1
|
||||
FROM
|
||||
{typeinfo} ti,
|
||||
typeinfo_tree tt
|
||||
WHERE
|
||||
(tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
|
||||
OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
|
||||
OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
|
||||
OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
|
||||
)
|
||||
|
||||
SELECT DISTINCT
|
||||
*,
|
||||
basetype::regtype::text AS basetype_name,
|
||||
elemtype::regtype::text AS elemtype_name,
|
||||
range_subtype::regtype::text AS range_subtype_name
|
||||
FROM
|
||||
typeinfo_tree
|
||||
ORDER BY
|
||||
depth DESC
|
||||
'''.format(typeinfo=_TYPEINFO_13)
|
||||
|
||||
|
||||
_TYPEINFO = '''\
|
||||
(
|
||||
SELECT
|
||||
t.oid AS oid,
|
||||
ns.nspname AS ns,
|
||||
t.typname AS name,
|
||||
t.typtype AS kind,
|
||||
(CASE WHEN t.typtype = 'd' THEN
|
||||
(WITH RECURSIVE typebases(oid, depth) AS (
|
||||
SELECT
|
||||
t2.typbasetype AS oid,
|
||||
0 AS depth
|
||||
FROM
|
||||
pg_type t2
|
||||
WHERE
|
||||
t2.oid = t.oid
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
t2.typbasetype AS oid,
|
||||
tb.depth + 1 AS depth
|
||||
FROM
|
||||
pg_type t2,
|
||||
typebases tb
|
||||
WHERE
|
||||
tb.oid = t2.oid
|
||||
AND t2.typbasetype != 0
|
||||
) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1)
|
||||
|
||||
ELSE NULL
|
||||
END) AS basetype,
|
||||
t.typelem AS elemtype,
|
||||
elem_t.typdelim AS elemdelim,
|
||||
COALESCE(
|
||||
range_t.rngsubtype,
|
||||
multirange_t.rngsubtype) AS range_subtype,
|
||||
(CASE WHEN t.typtype = 'c' THEN
|
||||
(SELECT
|
||||
array_agg(ia.atttypid ORDER BY ia.attnum)
|
||||
FROM
|
||||
pg_attribute ia
|
||||
INNER JOIN pg_class c
|
||||
ON (ia.attrelid = c.oid)
|
||||
WHERE
|
||||
ia.attnum > 0 AND NOT ia.attisdropped
|
||||
AND c.reltype = t.oid)
|
||||
|
||||
ELSE NULL
|
||||
END) AS attrtypoids,
|
||||
(CASE WHEN t.typtype = 'c' THEN
|
||||
(SELECT
|
||||
array_agg(ia.attname::text ORDER BY ia.attnum)
|
||||
FROM
|
||||
pg_attribute ia
|
||||
INNER JOIN pg_class c
|
||||
ON (ia.attrelid = c.oid)
|
||||
WHERE
|
||||
ia.attnum > 0 AND NOT ia.attisdropped
|
||||
AND c.reltype = t.oid)
|
||||
|
||||
ELSE NULL
|
||||
END) AS attrnames
|
||||
FROM
|
||||
pg_catalog.pg_type AS t
|
||||
INNER JOIN pg_catalog.pg_namespace ns ON (
|
||||
ns.oid = t.typnamespace)
|
||||
LEFT JOIN pg_type elem_t ON (
|
||||
t.typlen = -1 AND
|
||||
t.typelem != 0 AND
|
||||
t.typelem = elem_t.oid
|
||||
)
|
||||
LEFT JOIN pg_range range_t ON (
|
||||
t.oid = range_t.rngtypid
|
||||
)
|
||||
LEFT JOIN pg_range multirange_t ON (
|
||||
t.oid = multirange_t.rngmultitypid
|
||||
)
|
||||
)
|
||||
'''
|
||||
|
||||
|
||||
INTRO_LOOKUP_TYPES = '''\
|
||||
WITH RECURSIVE typeinfo_tree(
|
||||
oid, ns, name, kind, basetype, elemtype, elemdelim,
|
||||
range_subtype, attrtypoids, attrnames, depth)
|
||||
AS (
|
||||
SELECT
|
||||
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
|
||||
ti.elemtype, ti.elemdelim, ti.range_subtype,
|
||||
ti.attrtypoids, ti.attrnames, 0
|
||||
FROM
|
||||
{typeinfo} AS ti
|
||||
WHERE
|
||||
ti.oid = any($1::oid[])
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
|
||||
ti.elemtype, ti.elemdelim, ti.range_subtype,
|
||||
ti.attrtypoids, ti.attrnames, tt.depth + 1
|
||||
FROM
|
||||
{typeinfo} ti,
|
||||
typeinfo_tree tt
|
||||
WHERE
|
||||
(tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
|
||||
OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
|
||||
OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
|
||||
OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
|
||||
)
|
||||
|
||||
SELECT DISTINCT
|
||||
*,
|
||||
basetype::regtype::text AS basetype_name,
|
||||
elemtype::regtype::text AS elemtype_name,
|
||||
range_subtype::regtype::text AS range_subtype_name
|
||||
FROM
|
||||
typeinfo_tree
|
||||
ORDER BY
|
||||
depth DESC
|
||||
'''.format(typeinfo=_TYPEINFO)
|
||||
|
||||
|
||||
TYPE_BY_NAME = '''\
|
||||
SELECT
|
||||
t.oid,
|
||||
t.typelem AS elemtype,
|
||||
t.typtype AS kind
|
||||
FROM
|
||||
pg_catalog.pg_type AS t
|
||||
INNER JOIN pg_catalog.pg_namespace ns ON (ns.oid = t.typnamespace)
|
||||
WHERE
|
||||
t.typname = $1 AND ns.nspname = $2
|
||||
'''
|
||||
|
||||
|
||||
TYPE_BY_OID = '''\
|
||||
SELECT
|
||||
t.oid,
|
||||
t.typelem AS elemtype,
|
||||
t.typtype AS kind
|
||||
FROM
|
||||
pg_catalog.pg_type AS t
|
||||
WHERE
|
||||
t.oid = $1
|
||||
'''
|
||||
|
||||
|
||||
# 'b' for a base type, 'd' for a domain, 'e' for enum.
|
||||
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
|
||||
|
||||
|
||||
def is_scalar_type(typeinfo) -> bool:
|
||||
return (
|
||||
typeinfo['kind'] in SCALAR_TYPE_KINDS and
|
||||
not typeinfo['elemtype']
|
||||
)
|
||||
|
||||
|
||||
def is_domain_type(typeinfo) -> bool:
|
||||
return typeinfo['kind'] == b'd'
|
||||
|
||||
|
||||
def is_composite_type(typeinfo) -> bool:
|
||||
return typeinfo['kind'] == b'c'
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
136
venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pxd
Normal file
136
venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pxd
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef class WriteBuffer:
|
||||
cdef:
|
||||
# Preallocated small buffer
|
||||
bint _smallbuf_inuse
|
||||
char _smallbuf[_BUFFER_INITIAL_SIZE]
|
||||
|
||||
char *_buf
|
||||
|
||||
# Allocated size
|
||||
ssize_t _size
|
||||
|
||||
# Length of data in the buffer
|
||||
ssize_t _length
|
||||
|
||||
# Number of memoryviews attached to the buffer
|
||||
int _view_count
|
||||
|
||||
# True is start_message was used
|
||||
bint _message_mode
|
||||
|
||||
cdef inline len(self):
|
||||
return self._length
|
||||
|
||||
cdef inline write_len_prefixed_utf8(self, str s):
|
||||
return self.write_len_prefixed_bytes(s.encode('utf-8'))
|
||||
|
||||
cdef inline _check_readonly(self)
|
||||
cdef inline _ensure_alloced(self, ssize_t extra_length)
|
||||
cdef _reallocate(self, ssize_t new_size)
|
||||
cdef inline reset(self)
|
||||
cdef inline start_message(self, char type)
|
||||
cdef inline end_message(self)
|
||||
cdef write_buffer(self, WriteBuffer buf)
|
||||
cdef write_byte(self, char b)
|
||||
cdef write_bytes(self, bytes data)
|
||||
cdef write_len_prefixed_buffer(self, WriteBuffer buf)
|
||||
cdef write_len_prefixed_bytes(self, bytes data)
|
||||
cdef write_bytestring(self, bytes string)
|
||||
cdef write_str(self, str string, str encoding)
|
||||
cdef write_frbuf(self, FRBuffer *buf)
|
||||
cdef write_cstr(self, const char *data, ssize_t len)
|
||||
cdef write_int16(self, int16_t i)
|
||||
cdef write_int32(self, int32_t i)
|
||||
cdef write_int64(self, int64_t i)
|
||||
cdef write_float(self, float f)
|
||||
cdef write_double(self, double d)
|
||||
|
||||
@staticmethod
|
||||
cdef WriteBuffer new_message(char type)
|
||||
|
||||
@staticmethod
|
||||
cdef WriteBuffer new()
|
||||
|
||||
|
||||
ctypedef const char * (*try_consume_message_method)(object, ssize_t*)
|
||||
ctypedef int32_t (*take_message_type_method)(object, char) except -1
|
||||
ctypedef int32_t (*take_message_method)(object) except -1
|
||||
ctypedef char (*get_message_type_method)(object)
|
||||
|
||||
|
||||
cdef class ReadBuffer:
|
||||
cdef:
|
||||
# A deque of buffers (bytes objects)
|
||||
object _bufs
|
||||
object _bufs_append
|
||||
object _bufs_popleft
|
||||
|
||||
# A pointer to the first buffer in `_bufs`
|
||||
bytes _buf0
|
||||
|
||||
# A pointer to the previous first buffer
|
||||
# (used to prolong the life of _buf0 when using
|
||||
# methods like _try_read_bytes)
|
||||
bytes _buf0_prev
|
||||
|
||||
# Number of buffers in `_bufs`
|
||||
int32_t _bufs_len
|
||||
|
||||
# A read position in the first buffer in `_bufs`
|
||||
ssize_t _pos0
|
||||
|
||||
# Length of the first buffer in `_bufs`
|
||||
ssize_t _len0
|
||||
|
||||
# A total number of buffered bytes in ReadBuffer
|
||||
ssize_t _length
|
||||
|
||||
char _current_message_type
|
||||
int32_t _current_message_len
|
||||
ssize_t _current_message_len_unread
|
||||
bint _current_message_ready
|
||||
|
||||
cdef inline len(self):
|
||||
return self._length
|
||||
|
||||
cdef inline char get_message_type(self):
|
||||
return self._current_message_type
|
||||
|
||||
cdef inline int32_t get_message_length(self):
|
||||
return self._current_message_len
|
||||
|
||||
cdef feed_data(self, data)
|
||||
cdef inline _ensure_first_buf(self)
|
||||
cdef _switch_to_next_buf(self)
|
||||
cdef inline char read_byte(self) except? -1
|
||||
cdef inline const char* _try_read_bytes(self, ssize_t nbytes)
|
||||
cdef inline _read_into(self, char *buf, ssize_t nbytes)
|
||||
cdef inline _read_and_discard(self, ssize_t nbytes)
|
||||
cdef bytes read_bytes(self, ssize_t nbytes)
|
||||
cdef bytes read_len_prefixed_bytes(self)
|
||||
cdef str read_len_prefixed_utf8(self)
|
||||
cdef read_uuid(self)
|
||||
cdef inline int64_t read_int64(self) except? -1
|
||||
cdef inline int32_t read_int32(self) except? -1
|
||||
cdef inline int16_t read_int16(self) except? -1
|
||||
cdef inline read_null_str(self)
|
||||
cdef int32_t take_message(self) except -1
|
||||
cdef inline int32_t take_message_type(self, char mtype) except -1
|
||||
cdef int32_t put_message(self) except -1
|
||||
cdef inline const char* try_consume_message(self, ssize_t* len)
|
||||
cdef bytes consume_message(self)
|
||||
cdef discard_message(self)
|
||||
cdef redirect_messages(self, WriteBuffer buf, char mtype, int stop_at=?)
|
||||
cdef bytearray consume_messages(self, char mtype)
|
||||
cdef finish_message(self)
|
||||
cdef inline _finish_message(self)
|
||||
|
||||
@staticmethod
|
||||
cdef ReadBuffer new_message_parser(object data)
|
||||
817
venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pyx
Normal file
817
venv/lib/python3.12/site-packages/asyncpg/pgproto/buffer.pyx
Normal file
@@ -0,0 +1,817 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from libc.string cimport memcpy
|
||||
|
||||
import collections
|
||||
|
||||
class BufferError(Exception):
|
||||
pass
|
||||
|
||||
@cython.no_gc_clear
|
||||
@cython.final
|
||||
@cython.freelist(_BUFFER_FREELIST_SIZE)
|
||||
cdef class WriteBuffer:
|
||||
|
||||
def __cinit__(self):
|
||||
self._smallbuf_inuse = True
|
||||
self._buf = self._smallbuf
|
||||
self._size = _BUFFER_INITIAL_SIZE
|
||||
self._length = 0
|
||||
self._message_mode = 0
|
||||
|
||||
def __dealloc__(self):
|
||||
if self._buf is not NULL and not self._smallbuf_inuse:
|
||||
cpython.PyMem_Free(self._buf)
|
||||
self._buf = NULL
|
||||
self._size = 0
|
||||
|
||||
if self._view_count:
|
||||
raise BufferError(
|
||||
'Deallocating buffer with attached memoryviews')
|
||||
|
||||
def __getbuffer__(self, Py_buffer *buffer, int flags):
|
||||
self._view_count += 1
|
||||
|
||||
cpython.PyBuffer_FillInfo(
|
||||
buffer, self, self._buf, self._length,
|
||||
1, # read-only
|
||||
flags)
|
||||
|
||||
def __releasebuffer__(self, Py_buffer *buffer):
|
||||
self._view_count -= 1
|
||||
|
||||
cdef inline _check_readonly(self):
|
||||
if self._view_count:
|
||||
raise BufferError('the buffer is in read-only mode')
|
||||
|
||||
cdef inline _ensure_alloced(self, ssize_t extra_length):
|
||||
cdef ssize_t new_size = extra_length + self._length
|
||||
|
||||
if new_size > self._size:
|
||||
self._reallocate(new_size)
|
||||
|
||||
cdef _reallocate(self, ssize_t new_size):
|
||||
cdef char *new_buf
|
||||
|
||||
if new_size < _BUFFER_MAX_GROW:
|
||||
new_size = _BUFFER_MAX_GROW
|
||||
else:
|
||||
# Add a little extra
|
||||
new_size += _BUFFER_INITIAL_SIZE
|
||||
|
||||
if self._smallbuf_inuse:
|
||||
new_buf = <char*>cpython.PyMem_Malloc(
|
||||
sizeof(char) * <size_t>new_size)
|
||||
if new_buf is NULL:
|
||||
self._buf = NULL
|
||||
self._size = 0
|
||||
self._length = 0
|
||||
raise MemoryError
|
||||
memcpy(new_buf, self._buf, <size_t>self._size)
|
||||
self._size = new_size
|
||||
self._buf = new_buf
|
||||
self._smallbuf_inuse = False
|
||||
else:
|
||||
new_buf = <char*>cpython.PyMem_Realloc(
|
||||
<void*>self._buf, <size_t>new_size)
|
||||
if new_buf is NULL:
|
||||
cpython.PyMem_Free(self._buf)
|
||||
self._buf = NULL
|
||||
self._size = 0
|
||||
self._length = 0
|
||||
raise MemoryError
|
||||
self._buf = new_buf
|
||||
self._size = new_size
|
||||
|
||||
cdef inline start_message(self, char type):
|
||||
if self._length != 0:
|
||||
raise BufferError(
|
||||
'cannot start_message for a non-empty buffer')
|
||||
self._ensure_alloced(5)
|
||||
self._message_mode = 1
|
||||
self._buf[0] = type
|
||||
self._length = 5
|
||||
|
||||
cdef inline end_message(self):
|
||||
# "length-1" to exclude the message type byte
|
||||
cdef ssize_t mlen = self._length - 1
|
||||
|
||||
self._check_readonly()
|
||||
if not self._message_mode:
|
||||
raise BufferError(
|
||||
'end_message can only be called with start_message')
|
||||
if self._length < 5:
|
||||
raise BufferError('end_message: buffer is too small')
|
||||
if mlen > _MAXINT32:
|
||||
raise BufferError('end_message: message is too large')
|
||||
|
||||
hton.pack_int32(&self._buf[1], <int32_t>mlen)
|
||||
return self
|
||||
|
||||
cdef inline reset(self):
|
||||
self._length = 0
|
||||
self._message_mode = 0
|
||||
|
||||
cdef write_buffer(self, WriteBuffer buf):
|
||||
self._check_readonly()
|
||||
|
||||
if not buf._length:
|
||||
return
|
||||
|
||||
self._ensure_alloced(buf._length)
|
||||
memcpy(self._buf + self._length,
|
||||
<void*>buf._buf,
|
||||
<size_t>buf._length)
|
||||
self._length += buf._length
|
||||
|
||||
cdef write_byte(self, char b):
|
||||
self._check_readonly()
|
||||
|
||||
self._ensure_alloced(1)
|
||||
self._buf[self._length] = b
|
||||
self._length += 1
|
||||
|
||||
cdef write_bytes(self, bytes data):
|
||||
cdef char* buf
|
||||
cdef ssize_t len
|
||||
|
||||
cpython.PyBytes_AsStringAndSize(data, &buf, &len)
|
||||
self.write_cstr(buf, len)
|
||||
|
||||
cdef write_bytestring(self, bytes string):
|
||||
cdef char* buf
|
||||
cdef ssize_t len
|
||||
|
||||
cpython.PyBytes_AsStringAndSize(string, &buf, &len)
|
||||
# PyBytes_AsStringAndSize returns a null-terminated buffer,
|
||||
# but the null byte is not counted in len. hence the + 1
|
||||
self.write_cstr(buf, len + 1)
|
||||
|
||||
cdef write_str(self, str string, str encoding):
|
||||
self.write_bytestring(string.encode(encoding))
|
||||
|
||||
cdef write_len_prefixed_buffer(self, WriteBuffer buf):
|
||||
# Write a length-prefixed (not NULL-terminated) bytes sequence.
|
||||
self.write_int32(<int32_t>buf.len())
|
||||
self.write_buffer(buf)
|
||||
|
||||
cdef write_len_prefixed_bytes(self, bytes data):
|
||||
# Write a length-prefixed (not NULL-terminated) bytes sequence.
|
||||
cdef:
|
||||
char *buf
|
||||
ssize_t size
|
||||
|
||||
cpython.PyBytes_AsStringAndSize(data, &buf, &size)
|
||||
if size > _MAXINT32:
|
||||
raise BufferError('string is too large')
|
||||
# `size` does not account for the NULL at the end.
|
||||
self.write_int32(<int32_t>size)
|
||||
self.write_cstr(buf, size)
|
||||
|
||||
cdef write_frbuf(self, FRBuffer *buf):
|
||||
cdef:
|
||||
ssize_t buf_len = buf.len
|
||||
if buf_len > 0:
|
||||
self.write_cstr(frb_read_all(buf), buf_len)
|
||||
|
||||
cdef write_cstr(self, const char *data, ssize_t len):
|
||||
self._check_readonly()
|
||||
self._ensure_alloced(len)
|
||||
|
||||
memcpy(self._buf + self._length, <void*>data, <size_t>len)
|
||||
self._length += len
|
||||
|
||||
cdef write_int16(self, int16_t i):
|
||||
self._check_readonly()
|
||||
self._ensure_alloced(2)
|
||||
|
||||
hton.pack_int16(&self._buf[self._length], i)
|
||||
self._length += 2
|
||||
|
||||
cdef write_int32(self, int32_t i):
|
||||
self._check_readonly()
|
||||
self._ensure_alloced(4)
|
||||
|
||||
hton.pack_int32(&self._buf[self._length], i)
|
||||
self._length += 4
|
||||
|
||||
cdef write_int64(self, int64_t i):
|
||||
self._check_readonly()
|
||||
self._ensure_alloced(8)
|
||||
|
||||
hton.pack_int64(&self._buf[self._length], i)
|
||||
self._length += 8
|
||||
|
||||
cdef write_float(self, float f):
|
||||
self._check_readonly()
|
||||
self._ensure_alloced(4)
|
||||
|
||||
hton.pack_float(&self._buf[self._length], f)
|
||||
self._length += 4
|
||||
|
||||
cdef write_double(self, double d):
|
||||
self._check_readonly()
|
||||
self._ensure_alloced(8)
|
||||
|
||||
hton.pack_double(&self._buf[self._length], d)
|
||||
self._length += 8
|
||||
|
||||
@staticmethod
|
||||
cdef WriteBuffer new_message(char type):
|
||||
cdef WriteBuffer buf
|
||||
buf = WriteBuffer.__new__(WriteBuffer)
|
||||
buf.start_message(type)
|
||||
return buf
|
||||
|
||||
@staticmethod
|
||||
cdef WriteBuffer new():
|
||||
cdef WriteBuffer buf
|
||||
buf = WriteBuffer.__new__(WriteBuffer)
|
||||
return buf
|
||||
|
||||
|
||||
@cython.no_gc_clear
|
||||
@cython.final
|
||||
@cython.freelist(_BUFFER_FREELIST_SIZE)
|
||||
cdef class ReadBuffer:
|
||||
|
||||
def __cinit__(self):
|
||||
self._bufs = collections.deque()
|
||||
self._bufs_append = self._bufs.append
|
||||
self._bufs_popleft = self._bufs.popleft
|
||||
self._bufs_len = 0
|
||||
self._buf0 = None
|
||||
self._buf0_prev = None
|
||||
self._pos0 = 0
|
||||
self._len0 = 0
|
||||
self._length = 0
|
||||
|
||||
self._current_message_type = 0
|
||||
self._current_message_len = 0
|
||||
self._current_message_len_unread = 0
|
||||
self._current_message_ready = 0
|
||||
|
||||
cdef feed_data(self, data):
|
||||
cdef:
|
||||
ssize_t dlen
|
||||
bytes data_bytes
|
||||
|
||||
if not cpython.PyBytes_CheckExact(data):
|
||||
if cpythonx.PyByteArray_CheckExact(data):
|
||||
# ProactorEventLoop in Python 3.10+ seems to be sending
|
||||
# bytearray objects instead of bytes. Handle this here
|
||||
# to avoid duplicating this check in every data_received().
|
||||
data = bytes(data)
|
||||
else:
|
||||
raise BufferError(
|
||||
'feed_data: a bytes or bytearray object expected')
|
||||
|
||||
# Uncomment the below code to test code paths that
|
||||
# read single int/str/bytes sequences are split over
|
||||
# multiple received buffers.
|
||||
#
|
||||
# ll = 107
|
||||
# if len(data) > ll:
|
||||
# self.feed_data(data[:ll])
|
||||
# self.feed_data(data[ll:])
|
||||
# return
|
||||
|
||||
data_bytes = <bytes>data
|
||||
|
||||
dlen = cpython.Py_SIZE(data_bytes)
|
||||
if dlen == 0:
|
||||
# EOF?
|
||||
return
|
||||
|
||||
self._bufs_append(data_bytes)
|
||||
self._length += dlen
|
||||
|
||||
if self._bufs_len == 0:
|
||||
# First buffer
|
||||
self._len0 = dlen
|
||||
self._buf0 = data_bytes
|
||||
|
||||
self._bufs_len += 1
|
||||
|
||||
cdef inline _ensure_first_buf(self):
|
||||
if PG_DEBUG:
|
||||
if self._len0 == 0:
|
||||
raise BufferError('empty first buffer')
|
||||
if self._length == 0:
|
||||
raise BufferError('empty buffer')
|
||||
|
||||
if self._pos0 == self._len0:
|
||||
self._switch_to_next_buf()
|
||||
|
||||
cdef _switch_to_next_buf(self):
|
||||
# The first buffer is fully read, discard it
|
||||
self._bufs_popleft()
|
||||
self._bufs_len -= 1
|
||||
|
||||
# Shouldn't fail, since we've checked that `_length >= 1`
|
||||
# in _ensure_first_buf()
|
||||
self._buf0_prev = self._buf0
|
||||
self._buf0 = <bytes>self._bufs[0]
|
||||
|
||||
self._pos0 = 0
|
||||
self._len0 = len(self._buf0)
|
||||
|
||||
if PG_DEBUG:
|
||||
if self._len0 < 1:
|
||||
raise BufferError(
|
||||
'debug: second buffer of ReadBuffer is empty')
|
||||
|
||||
cdef inline const char* _try_read_bytes(self, ssize_t nbytes):
|
||||
# Try to read *nbytes* from the first buffer.
|
||||
#
|
||||
# Returns pointer to data if there is at least *nbytes*
|
||||
# in the buffer, NULL otherwise.
|
||||
#
|
||||
# Important: caller must call _ensure_first_buf() prior
|
||||
# to calling try_read_bytes, and must not overread
|
||||
|
||||
cdef:
|
||||
const char *result
|
||||
|
||||
if PG_DEBUG:
|
||||
if nbytes > self._length:
|
||||
return NULL
|
||||
|
||||
if self._current_message_ready:
|
||||
if self._current_message_len_unread < nbytes:
|
||||
return NULL
|
||||
|
||||
if self._pos0 + nbytes <= self._len0:
|
||||
result = cpython.PyBytes_AS_STRING(self._buf0)
|
||||
result += self._pos0
|
||||
self._pos0 += nbytes
|
||||
self._length -= nbytes
|
||||
if self._current_message_ready:
|
||||
self._current_message_len_unread -= nbytes
|
||||
return result
|
||||
else:
|
||||
return NULL
|
||||
|
||||
cdef inline _read_into(self, char *buf, ssize_t nbytes):
|
||||
cdef:
|
||||
ssize_t nread
|
||||
char *buf0
|
||||
|
||||
while True:
|
||||
buf0 = cpython.PyBytes_AS_STRING(self._buf0)
|
||||
|
||||
if self._pos0 + nbytes > self._len0:
|
||||
nread = self._len0 - self._pos0
|
||||
memcpy(buf, buf0 + self._pos0, <size_t>nread)
|
||||
self._pos0 = self._len0
|
||||
self._length -= nread
|
||||
nbytes -= nread
|
||||
buf += nread
|
||||
self._ensure_first_buf()
|
||||
|
||||
else:
|
||||
memcpy(buf, buf0 + self._pos0, <size_t>nbytes)
|
||||
self._pos0 += nbytes
|
||||
self._length -= nbytes
|
||||
break
|
||||
|
||||
cdef inline _read_and_discard(self, ssize_t nbytes):
|
||||
cdef:
|
||||
ssize_t nread
|
||||
|
||||
self._ensure_first_buf()
|
||||
while True:
|
||||
if self._pos0 + nbytes > self._len0:
|
||||
nread = self._len0 - self._pos0
|
||||
self._pos0 = self._len0
|
||||
self._length -= nread
|
||||
nbytes -= nread
|
||||
self._ensure_first_buf()
|
||||
|
||||
else:
|
||||
self._pos0 += nbytes
|
||||
self._length -= nbytes
|
||||
break
|
||||
|
||||
cdef bytes read_bytes(self, ssize_t nbytes):
|
||||
cdef:
|
||||
bytes result
|
||||
ssize_t nread
|
||||
const char *cbuf
|
||||
char *buf
|
||||
|
||||
self._ensure_first_buf()
|
||||
cbuf = self._try_read_bytes(nbytes)
|
||||
if cbuf != NULL:
|
||||
return cpython.PyBytes_FromStringAndSize(cbuf, nbytes)
|
||||
|
||||
if nbytes > self._length:
|
||||
raise BufferError(
|
||||
'not enough data to read {} bytes'.format(nbytes))
|
||||
|
||||
if self._current_message_ready:
|
||||
self._current_message_len_unread -= nbytes
|
||||
if self._current_message_len_unread < 0:
|
||||
raise BufferError('buffer overread')
|
||||
|
||||
result = cpython.PyBytes_FromStringAndSize(NULL, nbytes)
|
||||
buf = cpython.PyBytes_AS_STRING(result)
|
||||
self._read_into(buf, nbytes)
|
||||
return result
|
||||
|
||||
cdef bytes read_len_prefixed_bytes(self):
|
||||
cdef int32_t size = self.read_int32()
|
||||
if size < 0:
|
||||
raise BufferError(
|
||||
'negative length for a len-prefixed bytes value')
|
||||
if size == 0:
|
||||
return b''
|
||||
return self.read_bytes(size)
|
||||
|
||||
cdef str read_len_prefixed_utf8(self):
|
||||
cdef:
|
||||
int32_t size
|
||||
const char *cbuf
|
||||
|
||||
size = self.read_int32()
|
||||
if size < 0:
|
||||
raise BufferError(
|
||||
'negative length for a len-prefixed bytes value')
|
||||
|
||||
if size == 0:
|
||||
return ''
|
||||
|
||||
self._ensure_first_buf()
|
||||
cbuf = self._try_read_bytes(size)
|
||||
if cbuf != NULL:
|
||||
return cpython.PyUnicode_DecodeUTF8(cbuf, size, NULL)
|
||||
else:
|
||||
return self.read_bytes(size).decode('utf-8')
|
||||
|
||||
cdef read_uuid(self):
|
||||
cdef:
|
||||
bytes mem
|
||||
const char *cbuf
|
||||
|
||||
self._ensure_first_buf()
|
||||
cbuf = self._try_read_bytes(16)
|
||||
if cbuf != NULL:
|
||||
return pg_uuid_from_buf(cbuf)
|
||||
else:
|
||||
return pg_UUID(self.read_bytes(16))
|
||||
|
||||
cdef inline char read_byte(self) except? -1:
|
||||
cdef const char *first_byte
|
||||
|
||||
if PG_DEBUG:
|
||||
if not self._buf0:
|
||||
raise BufferError(
|
||||
'debug: first buffer of ReadBuffer is empty')
|
||||
|
||||
self._ensure_first_buf()
|
||||
first_byte = self._try_read_bytes(1)
|
||||
if first_byte is NULL:
|
||||
raise BufferError('not enough data to read one byte')
|
||||
|
||||
return first_byte[0]
|
||||
|
||||
cdef inline int64_t read_int64(self) except? -1:
|
||||
cdef:
|
||||
bytes mem
|
||||
const char *cbuf
|
||||
|
||||
self._ensure_first_buf()
|
||||
cbuf = self._try_read_bytes(8)
|
||||
if cbuf != NULL:
|
||||
return hton.unpack_int64(cbuf)
|
||||
else:
|
||||
mem = self.read_bytes(8)
|
||||
return hton.unpack_int64(cpython.PyBytes_AS_STRING(mem))
|
||||
|
||||
cdef inline int32_t read_int32(self) except? -1:
|
||||
cdef:
|
||||
bytes mem
|
||||
const char *cbuf
|
||||
|
||||
self._ensure_first_buf()
|
||||
cbuf = self._try_read_bytes(4)
|
||||
if cbuf != NULL:
|
||||
return hton.unpack_int32(cbuf)
|
||||
else:
|
||||
mem = self.read_bytes(4)
|
||||
return hton.unpack_int32(cpython.PyBytes_AS_STRING(mem))
|
||||
|
||||
cdef inline int16_t read_int16(self) except? -1:
|
||||
cdef:
|
||||
bytes mem
|
||||
const char *cbuf
|
||||
|
||||
self._ensure_first_buf()
|
||||
cbuf = self._try_read_bytes(2)
|
||||
if cbuf != NULL:
|
||||
return hton.unpack_int16(cbuf)
|
||||
else:
|
||||
mem = self.read_bytes(2)
|
||||
return hton.unpack_int16(cpython.PyBytes_AS_STRING(mem))
|
||||
|
||||
cdef inline read_null_str(self):
|
||||
if not self._current_message_ready:
|
||||
raise BufferError(
|
||||
'read_null_str only works when the message guaranteed '
|
||||
'to be in the buffer')
|
||||
|
||||
cdef:
|
||||
ssize_t pos
|
||||
ssize_t nread
|
||||
bytes result
|
||||
const char *buf
|
||||
const char *buf_start
|
||||
|
||||
self._ensure_first_buf()
|
||||
|
||||
buf_start = cpython.PyBytes_AS_STRING(self._buf0)
|
||||
buf = buf_start + self._pos0
|
||||
while buf - buf_start < self._len0:
|
||||
if buf[0] == 0:
|
||||
pos = buf - buf_start
|
||||
nread = pos - self._pos0
|
||||
buf = self._try_read_bytes(nread + 1)
|
||||
if buf != NULL:
|
||||
return cpython.PyBytes_FromStringAndSize(buf, nread)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
buf += 1
|
||||
|
||||
result = b''
|
||||
while True:
|
||||
pos = self._buf0.find(b'\x00', self._pos0)
|
||||
if pos >= 0:
|
||||
result += self._buf0[self._pos0 : pos]
|
||||
nread = pos - self._pos0 + 1
|
||||
self._pos0 = pos + 1
|
||||
self._length -= nread
|
||||
|
||||
self._current_message_len_unread -= nread
|
||||
if self._current_message_len_unread < 0:
|
||||
raise BufferError(
|
||||
'read_null_str: buffer overread')
|
||||
|
||||
return result
|
||||
|
||||
else:
|
||||
result += self._buf0[self._pos0:]
|
||||
nread = self._len0 - self._pos0
|
||||
self._pos0 = self._len0
|
||||
self._length -= nread
|
||||
|
||||
self._current_message_len_unread -= nread
|
||||
if self._current_message_len_unread < 0:
|
||||
raise BufferError(
|
||||
'read_null_str: buffer overread')
|
||||
|
||||
self._ensure_first_buf()
|
||||
|
||||
cdef int32_t take_message(self) except -1:
|
||||
cdef:
|
||||
const char *cbuf
|
||||
|
||||
if self._current_message_ready:
|
||||
return 1
|
||||
|
||||
if self._current_message_type == 0:
|
||||
if self._length < 1:
|
||||
return 0
|
||||
self._ensure_first_buf()
|
||||
cbuf = self._try_read_bytes(1)
|
||||
if cbuf == NULL:
|
||||
raise BufferError(
|
||||
'failed to read one byte on a non-empty buffer')
|
||||
self._current_message_type = cbuf[0]
|
||||
|
||||
if self._current_message_len == 0:
|
||||
if self._length < 4:
|
||||
return 0
|
||||
|
||||
self._ensure_first_buf()
|
||||
cbuf = self._try_read_bytes(4)
|
||||
if cbuf != NULL:
|
||||
self._current_message_len = hton.unpack_int32(cbuf)
|
||||
else:
|
||||
self._current_message_len = self.read_int32()
|
||||
|
||||
self._current_message_len_unread = self._current_message_len - 4
|
||||
|
||||
if self._length < self._current_message_len_unread:
|
||||
return 0
|
||||
|
||||
self._current_message_ready = 1
|
||||
return 1
|
||||
|
||||
cdef inline int32_t take_message_type(self, char mtype) except -1:
|
||||
cdef const char *buf0
|
||||
|
||||
if self._current_message_ready:
|
||||
return self._current_message_type == mtype
|
||||
elif self._length >= 1:
|
||||
self._ensure_first_buf()
|
||||
buf0 = cpython.PyBytes_AS_STRING(self._buf0)
|
||||
|
||||
return buf0[self._pos0] == mtype and self.take_message()
|
||||
else:
|
||||
return 0
|
||||
|
||||
cdef int32_t put_message(self) except -1:
|
||||
if not self._current_message_ready:
|
||||
raise BufferError(
|
||||
'cannot put message: no message taken')
|
||||
self._current_message_ready = False
|
||||
return 0
|
||||
|
||||
cdef inline const char* try_consume_message(self, ssize_t* len):
|
||||
cdef:
|
||||
ssize_t buf_len
|
||||
const char *buf
|
||||
|
||||
if not self._current_message_ready:
|
||||
return NULL
|
||||
|
||||
self._ensure_first_buf()
|
||||
buf_len = self._current_message_len_unread
|
||||
buf = self._try_read_bytes(buf_len)
|
||||
if buf != NULL:
|
||||
len[0] = buf_len
|
||||
self._finish_message()
|
||||
return buf
|
||||
|
||||
cdef discard_message(self):
|
||||
if not self._current_message_ready:
|
||||
raise BufferError('no message to discard')
|
||||
if self._current_message_len_unread > 0:
|
||||
self._read_and_discard(self._current_message_len_unread)
|
||||
self._current_message_len_unread = 0
|
||||
self._finish_message()
|
||||
|
||||
cdef bytes consume_message(self):
|
||||
if not self._current_message_ready:
|
||||
raise BufferError('no message to consume')
|
||||
if self._current_message_len_unread > 0:
|
||||
mem = self.read_bytes(self._current_message_len_unread)
|
||||
else:
|
||||
mem = b''
|
||||
self._finish_message()
|
||||
return mem
|
||||
|
||||
cdef redirect_messages(self, WriteBuffer buf, char mtype,
|
||||
int stop_at=0):
|
||||
if not self._current_message_ready:
|
||||
raise BufferError(
|
||||
'consume_full_messages called on a buffer without a '
|
||||
'complete first message')
|
||||
if mtype != self._current_message_type:
|
||||
raise BufferError(
|
||||
'consume_full_messages called with a wrong mtype')
|
||||
if self._current_message_len_unread != self._current_message_len - 4:
|
||||
raise BufferError(
|
||||
'consume_full_messages called on a partially read message')
|
||||
|
||||
cdef:
|
||||
const char* cbuf
|
||||
ssize_t cbuf_len
|
||||
int32_t msg_len
|
||||
ssize_t new_pos0
|
||||
ssize_t pos_delta
|
||||
int32_t done
|
||||
|
||||
while True:
|
||||
buf.write_byte(mtype)
|
||||
buf.write_int32(self._current_message_len)
|
||||
|
||||
cbuf = self.try_consume_message(&cbuf_len)
|
||||
if cbuf != NULL:
|
||||
buf.write_cstr(cbuf, cbuf_len)
|
||||
else:
|
||||
buf.write_bytes(self.consume_message())
|
||||
|
||||
if self._length > 0:
|
||||
self._ensure_first_buf()
|
||||
else:
|
||||
return
|
||||
|
||||
if stop_at and buf._length >= stop_at:
|
||||
return
|
||||
|
||||
# Fast path: exhaust buf0 as efficiently as possible.
|
||||
if self._pos0 + 5 <= self._len0:
|
||||
cbuf = cpython.PyBytes_AS_STRING(self._buf0)
|
||||
new_pos0 = self._pos0
|
||||
cbuf_len = self._len0
|
||||
|
||||
done = 0
|
||||
# Scan the first buffer and find the position of the
|
||||
# end of the last "mtype" message.
|
||||
while new_pos0 + 5 <= cbuf_len:
|
||||
if (cbuf + new_pos0)[0] != mtype:
|
||||
done = 1
|
||||
break
|
||||
if (stop_at and
|
||||
(buf._length + new_pos0 - self._pos0) > stop_at):
|
||||
done = 1
|
||||
break
|
||||
msg_len = hton.unpack_int32(cbuf + new_pos0 + 1) + 1
|
||||
if new_pos0 + msg_len > cbuf_len:
|
||||
break
|
||||
new_pos0 += msg_len
|
||||
|
||||
if new_pos0 != self._pos0:
|
||||
assert self._pos0 < new_pos0 <= self._len0
|
||||
|
||||
pos_delta = new_pos0 - self._pos0
|
||||
buf.write_cstr(
|
||||
cbuf + self._pos0,
|
||||
pos_delta)
|
||||
|
||||
self._pos0 = new_pos0
|
||||
self._length -= pos_delta
|
||||
|
||||
assert self._length >= 0
|
||||
|
||||
if done:
|
||||
# The next message is of a different type.
|
||||
return
|
||||
|
||||
# Back to slow path.
|
||||
if not self.take_message_type(mtype):
|
||||
return
|
||||
|
||||
cdef bytearray consume_messages(self, char mtype):
|
||||
"""Consume consecutive messages of the same type."""
|
||||
cdef:
|
||||
char *buf
|
||||
ssize_t nbytes
|
||||
ssize_t total_bytes = 0
|
||||
bytearray result
|
||||
|
||||
if not self.take_message_type(mtype):
|
||||
return None
|
||||
|
||||
# consume_messages is a volume-oriented method, so
|
||||
# we assume that the remainder of the buffer will contain
|
||||
# messages of the requested type.
|
||||
result = cpythonx.PyByteArray_FromStringAndSize(NULL, self._length)
|
||||
buf = cpythonx.PyByteArray_AsString(result)
|
||||
|
||||
while self.take_message_type(mtype):
|
||||
self._ensure_first_buf()
|
||||
nbytes = self._current_message_len_unread
|
||||
self._read_into(buf, nbytes)
|
||||
buf += nbytes
|
||||
total_bytes += nbytes
|
||||
self._finish_message()
|
||||
|
||||
# Clamp the result to an actual size read.
|
||||
cpythonx.PyByteArray_Resize(result, total_bytes)
|
||||
|
||||
return result
|
||||
|
||||
cdef finish_message(self):
|
||||
if self._current_message_type == 0 or not self._current_message_ready:
|
||||
# The message has already been finished (e.g by consume_message()),
|
||||
# or has been put back by put_message().
|
||||
return
|
||||
|
||||
if self._current_message_len_unread:
|
||||
if PG_DEBUG:
|
||||
mtype = chr(self._current_message_type)
|
||||
|
||||
discarded = self.consume_message()
|
||||
|
||||
if PG_DEBUG:
|
||||
print('!!! discarding message {!r} unread data: {!r}'.format(
|
||||
mtype,
|
||||
discarded))
|
||||
|
||||
self._finish_message()
|
||||
|
||||
cdef inline _finish_message(self):
|
||||
self._current_message_type = 0
|
||||
self._current_message_len = 0
|
||||
self._current_message_ready = 0
|
||||
self._current_message_len_unread = 0
|
||||
|
||||
@staticmethod
|
||||
cdef ReadBuffer new_message_parser(object data):
|
||||
cdef ReadBuffer buf
|
||||
|
||||
buf = ReadBuffer.__new__(ReadBuffer)
|
||||
buf.feed_data(data)
|
||||
|
||||
buf._current_message_ready = 1
|
||||
buf._current_message_len_unread = buf._len0
|
||||
|
||||
return buf
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef class CodecContext:
|
||||
|
||||
cpdef get_text_codec(self)
|
||||
cdef is_encoding_utf8(self)
|
||||
cpdef get_json_decoder(self)
|
||||
cdef is_decoding_json(self)
|
||||
cpdef get_json_encoder(self)
|
||||
cdef is_encoding_json(self)
|
||||
|
||||
|
||||
ctypedef object (*encode_func)(CodecContext settings,
|
||||
WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
ctypedef object (*decode_func)(CodecContext settings,
|
||||
FRBuffer *buf)
|
||||
|
||||
|
||||
# Datetime
|
||||
cdef date_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef date_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef date_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef date_decode_tuple(CodecContext settings, FRBuffer * buf)
|
||||
cdef timestamp_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef timestamp_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef timestamp_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef timestamp_decode_tuple(CodecContext settings, FRBuffer * buf)
|
||||
cdef timestamptz_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef timestamptz_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef time_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef time_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef time_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef time_decode_tuple(CodecContext settings, FRBuffer * buf)
|
||||
cdef timetz_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef timetz_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef timetz_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef timetz_decode_tuple(CodecContext settings, FRBuffer * buf)
|
||||
cdef interval_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef interval_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef interval_encode_tuple(CodecContext settings, WriteBuffer buf, tuple obj)
|
||||
cdef interval_decode_tuple(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# Bits
|
||||
cdef bits_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef bits_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# Bools
|
||||
cdef bool_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef bool_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# Geometry
|
||||
cdef box_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef box_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef line_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef line_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef lseg_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef lseg_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef point_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef point_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef path_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef path_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef poly_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef poly_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef circle_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef circle_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# Hstore
|
||||
cdef hstore_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef hstore_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# Ints
|
||||
cdef int2_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef int2_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef int4_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef int4_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef uint4_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef uint4_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef int8_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef int8_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef uint8_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef uint8_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# Floats
|
||||
cdef float4_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef float4_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef float8_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef float8_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# JSON
|
||||
cdef jsonb_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef jsonb_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# JSON path
|
||||
cdef jsonpath_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef jsonpath_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# Text
|
||||
cdef as_pg_string_and_size(
|
||||
CodecContext settings, obj, char **cstr, ssize_t *size)
|
||||
cdef text_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef text_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
# Bytea
|
||||
cdef bytea_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef bytea_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# UUID
|
||||
cdef uuid_encode(CodecContext settings, WriteBuffer wbuf, obj)
|
||||
cdef uuid_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# Numeric
|
||||
cdef numeric_encode_text(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef numeric_decode_text(CodecContext settings, FRBuffer * buf)
|
||||
cdef numeric_encode_binary(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef numeric_decode_binary(CodecContext settings, FRBuffer * buf)
|
||||
cdef numeric_decode_binary_ex(CodecContext settings, FRBuffer * buf,
|
||||
bint trail_fract_zero)
|
||||
|
||||
|
||||
# Void
|
||||
cdef void_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef void_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# tid
|
||||
cdef tid_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef tid_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# Network
|
||||
cdef cidr_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef cidr_decode(CodecContext settings, FRBuffer * buf)
|
||||
cdef inet_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef inet_decode(CodecContext settings, FRBuffer * buf)
|
||||
|
||||
|
||||
# pg_snapshot
|
||||
cdef pg_snapshot_encode(CodecContext settings, WriteBuffer buf, obj)
|
||||
cdef pg_snapshot_decode(CodecContext settings, FRBuffer * buf)
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef bits_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
cdef:
|
||||
Py_buffer pybuf
|
||||
bint pybuf_used = False
|
||||
char *buf
|
||||
ssize_t len
|
||||
ssize_t bitlen
|
||||
|
||||
if cpython.PyBytes_CheckExact(obj):
|
||||
buf = cpython.PyBytes_AS_STRING(obj)
|
||||
len = cpython.Py_SIZE(obj)
|
||||
bitlen = len * 8
|
||||
elif isinstance(obj, pgproto_types.BitString):
|
||||
cpython.PyBytes_AsStringAndSize(obj.bytes, &buf, &len)
|
||||
bitlen = obj.__len__()
|
||||
else:
|
||||
cpython.PyObject_GetBuffer(obj, &pybuf, cpython.PyBUF_SIMPLE)
|
||||
pybuf_used = True
|
||||
buf = <char*>pybuf.buf
|
||||
len = pybuf.len
|
||||
bitlen = len * 8
|
||||
|
||||
try:
|
||||
if bitlen > _MAXINT32:
|
||||
raise ValueError('bit value too long')
|
||||
wbuf.write_int32(4 + <int32_t>len)
|
||||
wbuf.write_int32(<int32_t>bitlen)
|
||||
wbuf.write_cstr(buf, len)
|
||||
finally:
|
||||
if pybuf_used:
|
||||
cpython.PyBuffer_Release(&pybuf)
|
||||
|
||||
|
||||
cdef bits_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int32_t bitlen = hton.unpack_int32(frb_read(buf, 4))
|
||||
ssize_t buf_len = buf.len
|
||||
|
||||
bytes_ = cpython.PyBytes_FromStringAndSize(frb_read_all(buf), buf_len)
|
||||
return pgproto_types.BitString.frombytes(bytes_, bitlen)
|
||||
@@ -0,0 +1,34 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef bytea_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
cdef:
|
||||
Py_buffer pybuf
|
||||
bint pybuf_used = False
|
||||
char *buf
|
||||
ssize_t len
|
||||
|
||||
if cpython.PyBytes_CheckExact(obj):
|
||||
buf = cpython.PyBytes_AS_STRING(obj)
|
||||
len = cpython.Py_SIZE(obj)
|
||||
else:
|
||||
cpython.PyObject_GetBuffer(obj, &pybuf, cpython.PyBUF_SIMPLE)
|
||||
pybuf_used = True
|
||||
buf = <char*>pybuf.buf
|
||||
len = pybuf.len
|
||||
|
||||
try:
|
||||
wbuf.write_int32(<int32_t>len)
|
||||
wbuf.write_cstr(buf, len)
|
||||
finally:
|
||||
if pybuf_used:
|
||||
cpython.PyBuffer_Release(&pybuf)
|
||||
|
||||
|
||||
cdef bytea_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef ssize_t buf_len = buf.len
|
||||
return cpython.PyBytes_FromStringAndSize(frb_read_all(buf), buf_len)
|
||||
@@ -0,0 +1,26 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef class CodecContext:
|
||||
|
||||
cpdef get_text_codec(self):
|
||||
raise NotImplementedError
|
||||
|
||||
cdef is_encoding_utf8(self):
|
||||
raise NotImplementedError
|
||||
|
||||
cpdef get_json_decoder(self):
|
||||
raise NotImplementedError
|
||||
|
||||
cdef is_decoding_json(self):
|
||||
return False
|
||||
|
||||
cpdef get_json_encoder(self):
|
||||
raise NotImplementedError
|
||||
|
||||
cdef is_encoding_json(self):
|
||||
return False
|
||||
@@ -0,0 +1,423 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cimport cpython.datetime
|
||||
import datetime
|
||||
|
||||
cpython.datetime.import_datetime()
|
||||
|
||||
utc = datetime.timezone.utc
|
||||
date_from_ordinal = datetime.date.fromordinal
|
||||
timedelta = datetime.timedelta
|
||||
|
||||
pg_epoch_datetime = datetime.datetime(2000, 1, 1)
|
||||
cdef int32_t pg_epoch_datetime_ts = \
|
||||
<int32_t>cpython.PyLong_AsLong(int(pg_epoch_datetime.timestamp()))
|
||||
|
||||
pg_epoch_datetime_utc = datetime.datetime(2000, 1, 1, tzinfo=utc)
|
||||
cdef int32_t pg_epoch_datetime_utc_ts = \
|
||||
<int32_t>cpython.PyLong_AsLong(int(pg_epoch_datetime_utc.timestamp()))
|
||||
|
||||
pg_epoch_date = datetime.date(2000, 1, 1)
|
||||
cdef int32_t pg_date_offset_ord = \
|
||||
<int32_t>cpython.PyLong_AsLong(pg_epoch_date.toordinal())
|
||||
|
||||
# Binary representations of infinity for datetimes.
|
||||
cdef int64_t pg_time64_infinity = 0x7fffffffffffffff
|
||||
cdef int64_t pg_time64_negative_infinity = <int64_t>0x8000000000000000
|
||||
cdef int32_t pg_date_infinity = 0x7fffffff
|
||||
cdef int32_t pg_date_negative_infinity = <int32_t>0x80000000
|
||||
|
||||
infinity_datetime = datetime.datetime(
|
||||
datetime.MAXYEAR, 12, 31, 23, 59, 59, 999999)
|
||||
|
||||
cdef int32_t infinity_datetime_ord = <int32_t>cpython.PyLong_AsLong(
|
||||
infinity_datetime.toordinal())
|
||||
|
||||
cdef int64_t infinity_datetime_ts = 252455615999999999
|
||||
|
||||
negative_infinity_datetime = datetime.datetime(
|
||||
datetime.MINYEAR, 1, 1, 0, 0, 0, 0)
|
||||
|
||||
cdef int32_t negative_infinity_datetime_ord = <int32_t>cpython.PyLong_AsLong(
|
||||
negative_infinity_datetime.toordinal())
|
||||
|
||||
cdef int64_t negative_infinity_datetime_ts = -63082281600000000
|
||||
|
||||
infinity_date = datetime.date(datetime.MAXYEAR, 12, 31)
|
||||
|
||||
cdef int32_t infinity_date_ord = <int32_t>cpython.PyLong_AsLong(
|
||||
infinity_date.toordinal())
|
||||
|
||||
negative_infinity_date = datetime.date(datetime.MINYEAR, 1, 1)
|
||||
|
||||
cdef int32_t negative_infinity_date_ord = <int32_t>cpython.PyLong_AsLong(
|
||||
negative_infinity_date.toordinal())
|
||||
|
||||
|
||||
cdef inline _local_timezone():
|
||||
d = datetime.datetime.now(datetime.timezone.utc).astimezone()
|
||||
return datetime.timezone(d.utcoffset())
|
||||
|
||||
|
||||
cdef inline _encode_time(WriteBuffer buf, int64_t seconds,
|
||||
int32_t microseconds):
|
||||
# XXX: add support for double timestamps
|
||||
# int64 timestamps,
|
||||
cdef int64_t ts = seconds * 1000000 + microseconds
|
||||
|
||||
if ts == infinity_datetime_ts:
|
||||
buf.write_int64(pg_time64_infinity)
|
||||
elif ts == negative_infinity_datetime_ts:
|
||||
buf.write_int64(pg_time64_negative_infinity)
|
||||
else:
|
||||
buf.write_int64(ts)
|
||||
|
||||
|
||||
cdef inline int32_t _decode_time(FRBuffer *buf, int64_t *seconds,
|
||||
int32_t *microseconds):
|
||||
cdef int64_t ts = hton.unpack_int64(frb_read(buf, 8))
|
||||
|
||||
if ts == pg_time64_infinity:
|
||||
return 1
|
||||
elif ts == pg_time64_negative_infinity:
|
||||
return -1
|
||||
else:
|
||||
seconds[0] = ts // 1000000
|
||||
microseconds[0] = <int32_t>(ts % 1000000)
|
||||
return 0
|
||||
|
||||
|
||||
cdef date_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
int32_t ordinal = <int32_t>cpython.PyLong_AsLong(obj.toordinal())
|
||||
int32_t pg_ordinal
|
||||
|
||||
if ordinal == infinity_date_ord:
|
||||
pg_ordinal = pg_date_infinity
|
||||
elif ordinal == negative_infinity_date_ord:
|
||||
pg_ordinal = pg_date_negative_infinity
|
||||
else:
|
||||
pg_ordinal = ordinal - pg_date_offset_ord
|
||||
|
||||
buf.write_int32(4)
|
||||
buf.write_int32(pg_ordinal)
|
||||
|
||||
|
||||
cdef date_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
int32_t pg_ordinal
|
||||
|
||||
if len(obj) != 1:
|
||||
raise ValueError(
|
||||
'date tuple encoder: expecting 1 element '
|
||||
'in tuple, got {}'.format(len(obj)))
|
||||
|
||||
pg_ordinal = obj[0]
|
||||
buf.write_int32(4)
|
||||
buf.write_int32(pg_ordinal)
|
||||
|
||||
|
||||
cdef date_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef int32_t pg_ordinal = hton.unpack_int32(frb_read(buf, 4))
|
||||
|
||||
if pg_ordinal == pg_date_infinity:
|
||||
return infinity_date
|
||||
elif pg_ordinal == pg_date_negative_infinity:
|
||||
return negative_infinity_date
|
||||
else:
|
||||
return date_from_ordinal(pg_ordinal + pg_date_offset_ord)
|
||||
|
||||
|
||||
cdef date_decode_tuple(CodecContext settings, FRBuffer *buf):
|
||||
cdef int32_t pg_ordinal = hton.unpack_int32(frb_read(buf, 4))
|
||||
|
||||
return (pg_ordinal,)
|
||||
|
||||
|
||||
cdef timestamp_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
if not cpython.datetime.PyDateTime_Check(obj):
|
||||
if cpython.datetime.PyDate_Check(obj):
|
||||
obj = datetime.datetime(obj.year, obj.month, obj.day)
|
||||
else:
|
||||
raise TypeError(
|
||||
'expected a datetime.date or datetime.datetime instance, '
|
||||
'got {!r}'.format(type(obj).__name__)
|
||||
)
|
||||
|
||||
delta = obj - pg_epoch_datetime
|
||||
cdef:
|
||||
int64_t seconds = cpython.PyLong_AsLongLong(delta.days) * 86400 + \
|
||||
cpython.PyLong_AsLong(delta.seconds)
|
||||
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(
|
||||
delta.microseconds)
|
||||
|
||||
buf.write_int32(8)
|
||||
_encode_time(buf, seconds, microseconds)
|
||||
|
||||
|
||||
cdef timestamp_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
int64_t microseconds
|
||||
|
||||
if len(obj) != 1:
|
||||
raise ValueError(
|
||||
'timestamp tuple encoder: expecting 1 element '
|
||||
'in tuple, got {}'.format(len(obj)))
|
||||
|
||||
microseconds = obj[0]
|
||||
|
||||
buf.write_int32(8)
|
||||
buf.write_int64(microseconds)
|
||||
|
||||
|
||||
cdef timestamp_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int64_t seconds = 0
|
||||
int32_t microseconds = 0
|
||||
int32_t inf = _decode_time(buf, &seconds, µseconds)
|
||||
|
||||
if inf > 0:
|
||||
# positive infinity
|
||||
return infinity_datetime
|
||||
elif inf < 0:
|
||||
# negative infinity
|
||||
return negative_infinity_datetime
|
||||
else:
|
||||
return pg_epoch_datetime.__add__(
|
||||
timedelta(0, seconds, microseconds))
|
||||
|
||||
|
||||
cdef timestamp_decode_tuple(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int64_t ts = hton.unpack_int64(frb_read(buf, 8))
|
||||
|
||||
return (ts,)
|
||||
|
||||
|
||||
cdef timestamptz_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
if not cpython.datetime.PyDateTime_Check(obj):
|
||||
if cpython.datetime.PyDate_Check(obj):
|
||||
obj = datetime.datetime(obj.year, obj.month, obj.day,
|
||||
tzinfo=_local_timezone())
|
||||
else:
|
||||
raise TypeError(
|
||||
'expected a datetime.date or datetime.datetime instance, '
|
||||
'got {!r}'.format(type(obj).__name__)
|
||||
)
|
||||
|
||||
buf.write_int32(8)
|
||||
|
||||
if obj == infinity_datetime:
|
||||
buf.write_int64(pg_time64_infinity)
|
||||
return
|
||||
elif obj == negative_infinity_datetime:
|
||||
buf.write_int64(pg_time64_negative_infinity)
|
||||
return
|
||||
|
||||
utc_dt = obj.astimezone(utc)
|
||||
|
||||
delta = utc_dt - pg_epoch_datetime_utc
|
||||
cdef:
|
||||
int64_t seconds = cpython.PyLong_AsLongLong(delta.days) * 86400 + \
|
||||
cpython.PyLong_AsLong(delta.seconds)
|
||||
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(
|
||||
delta.microseconds)
|
||||
|
||||
_encode_time(buf, seconds, microseconds)
|
||||
|
||||
|
||||
cdef timestamptz_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int64_t seconds = 0
|
||||
int32_t microseconds = 0
|
||||
int32_t inf = _decode_time(buf, &seconds, µseconds)
|
||||
|
||||
if inf > 0:
|
||||
# positive infinity
|
||||
return infinity_datetime
|
||||
elif inf < 0:
|
||||
# negative infinity
|
||||
return negative_infinity_datetime
|
||||
else:
|
||||
return pg_epoch_datetime_utc.__add__(
|
||||
timedelta(0, seconds, microseconds))
|
||||
|
||||
|
||||
cdef time_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
int64_t seconds = cpython.PyLong_AsLong(obj.hour) * 3600 + \
|
||||
cpython.PyLong_AsLong(obj.minute) * 60 + \
|
||||
cpython.PyLong_AsLong(obj.second)
|
||||
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microsecond)
|
||||
|
||||
buf.write_int32(8)
|
||||
_encode_time(buf, seconds, microseconds)
|
||||
|
||||
|
||||
cdef time_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
int64_t microseconds
|
||||
|
||||
if len(obj) != 1:
|
||||
raise ValueError(
|
||||
'time tuple encoder: expecting 1 element '
|
||||
'in tuple, got {}'.format(len(obj)))
|
||||
|
||||
microseconds = obj[0]
|
||||
|
||||
buf.write_int32(8)
|
||||
buf.write_int64(microseconds)
|
||||
|
||||
|
||||
cdef time_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int64_t seconds = 0
|
||||
int32_t microseconds = 0
|
||||
|
||||
_decode_time(buf, &seconds, µseconds)
|
||||
|
||||
cdef:
|
||||
int64_t minutes = <int64_t>(seconds / 60)
|
||||
int64_t sec = seconds % 60
|
||||
int64_t hours = <int64_t>(minutes / 60)
|
||||
int64_t min = minutes % 60
|
||||
|
||||
return datetime.time(hours, min, sec, microseconds)
|
||||
|
||||
|
||||
cdef time_decode_tuple(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int64_t ts = hton.unpack_int64(frb_read(buf, 8))
|
||||
|
||||
return (ts,)
|
||||
|
||||
|
||||
cdef timetz_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
offset = obj.tzinfo.utcoffset(None)
|
||||
|
||||
cdef:
|
||||
int32_t offset_sec = \
|
||||
<int32_t>cpython.PyLong_AsLong(offset.days) * 24 * 60 * 60 + \
|
||||
<int32_t>cpython.PyLong_AsLong(offset.seconds)
|
||||
|
||||
int64_t seconds = cpython.PyLong_AsLong(obj.hour) * 3600 + \
|
||||
cpython.PyLong_AsLong(obj.minute) * 60 + \
|
||||
cpython.PyLong_AsLong(obj.second)
|
||||
|
||||
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microsecond)
|
||||
|
||||
buf.write_int32(12)
|
||||
_encode_time(buf, seconds, microseconds)
|
||||
# In Python utcoffset() is the difference between the local time
|
||||
# and the UTC, whereas in PostgreSQL it's the opposite,
|
||||
# so we need to flip the sign.
|
||||
buf.write_int32(-offset_sec)
|
||||
|
||||
|
||||
cdef timetz_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
int64_t microseconds
|
||||
int32_t offset_sec
|
||||
|
||||
if len(obj) != 2:
|
||||
raise ValueError(
|
||||
'time tuple encoder: expecting 2 elements2 '
|
||||
'in tuple, got {}'.format(len(obj)))
|
||||
|
||||
microseconds = obj[0]
|
||||
offset_sec = obj[1]
|
||||
|
||||
buf.write_int32(12)
|
||||
buf.write_int64(microseconds)
|
||||
buf.write_int32(offset_sec)
|
||||
|
||||
|
||||
cdef timetz_decode(CodecContext settings, FRBuffer *buf):
|
||||
time = time_decode(settings, buf)
|
||||
cdef int32_t offset = <int32_t>(hton.unpack_int32(frb_read(buf, 4)) / 60)
|
||||
# See the comment in the `timetz_encode` method.
|
||||
return time.replace(tzinfo=datetime.timezone(timedelta(minutes=-offset)))
|
||||
|
||||
|
||||
cdef timetz_decode_tuple(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int64_t microseconds = hton.unpack_int64(frb_read(buf, 8))
|
||||
int32_t offset_sec = hton.unpack_int32(frb_read(buf, 4))
|
||||
|
||||
return (microseconds, offset_sec)
|
||||
|
||||
|
||||
cdef interval_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
int32_t days = <int32_t>cpython.PyLong_AsLong(obj.days)
|
||||
int64_t seconds = cpython.PyLong_AsLongLong(obj.seconds)
|
||||
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microseconds)
|
||||
|
||||
buf.write_int32(16)
|
||||
_encode_time(buf, seconds, microseconds)
|
||||
buf.write_int32(days)
|
||||
buf.write_int32(0) # Months
|
||||
|
||||
|
||||
cdef interval_encode_tuple(CodecContext settings, WriteBuffer buf,
|
||||
tuple obj):
|
||||
cdef:
|
||||
int32_t months
|
||||
int32_t days
|
||||
int64_t microseconds
|
||||
|
||||
if len(obj) != 3:
|
||||
raise ValueError(
|
||||
'interval tuple encoder: expecting 3 elements '
|
||||
'in tuple, got {}'.format(len(obj)))
|
||||
|
||||
months = obj[0]
|
||||
days = obj[1]
|
||||
microseconds = obj[2]
|
||||
|
||||
buf.write_int32(16)
|
||||
buf.write_int64(microseconds)
|
||||
buf.write_int32(days)
|
||||
buf.write_int32(months)
|
||||
|
||||
|
||||
cdef interval_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int32_t days
|
||||
int32_t months
|
||||
int32_t years
|
||||
int64_t seconds = 0
|
||||
int32_t microseconds = 0
|
||||
|
||||
_decode_time(buf, &seconds, µseconds)
|
||||
|
||||
days = hton.unpack_int32(frb_read(buf, 4))
|
||||
months = hton.unpack_int32(frb_read(buf, 4))
|
||||
|
||||
if months < 0:
|
||||
years = -<int32_t>(-months // 12)
|
||||
months = -<int32_t>(-months % 12)
|
||||
else:
|
||||
years = <int32_t>(months // 12)
|
||||
months = <int32_t>(months % 12)
|
||||
|
||||
return datetime.timedelta(days=days + months * 30 + years * 365,
|
||||
seconds=seconds, microseconds=microseconds)
|
||||
|
||||
|
||||
cdef interval_decode_tuple(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int32_t days
|
||||
int32_t months
|
||||
int64_t microseconds
|
||||
|
||||
microseconds = hton.unpack_int64(frb_read(buf, 8))
|
||||
days = hton.unpack_int32(frb_read(buf, 4))
|
||||
months = hton.unpack_int32(frb_read(buf, 4))
|
||||
|
||||
return (months, days, microseconds)
|
||||
@@ -0,0 +1,34 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from libc cimport math
|
||||
|
||||
|
||||
cdef float4_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef double dval = cpython.PyFloat_AsDouble(obj)
|
||||
cdef float fval = <float>dval
|
||||
if math.isinf(fval) and not math.isinf(dval):
|
||||
raise ValueError('value out of float32 range')
|
||||
|
||||
buf.write_int32(4)
|
||||
buf.write_float(fval)
|
||||
|
||||
|
||||
cdef float4_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef float f = hton.unpack_float(frb_read(buf, 4))
|
||||
return cpython.PyFloat_FromDouble(f)
|
||||
|
||||
|
||||
cdef float8_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef double dval = cpython.PyFloat_AsDouble(obj)
|
||||
buf.write_int32(8)
|
||||
buf.write_double(dval)
|
||||
|
||||
|
||||
cdef float8_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef double f = hton.unpack_double(frb_read(buf, 8))
|
||||
return cpython.PyFloat_FromDouble(f)
|
||||
@@ -0,0 +1,164 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef inline _encode_points(WriteBuffer wbuf, object points):
|
||||
cdef object point
|
||||
|
||||
for point in points:
|
||||
wbuf.write_double(point[0])
|
||||
wbuf.write_double(point[1])
|
||||
|
||||
|
||||
cdef inline _decode_points(FRBuffer *buf):
|
||||
cdef:
|
||||
int32_t npts = hton.unpack_int32(frb_read(buf, 4))
|
||||
pts = cpython.PyTuple_New(npts)
|
||||
int32_t i
|
||||
object point
|
||||
double x
|
||||
double y
|
||||
|
||||
for i in range(npts):
|
||||
x = hton.unpack_double(frb_read(buf, 8))
|
||||
y = hton.unpack_double(frb_read(buf, 8))
|
||||
point = pgproto_types.Point(x, y)
|
||||
cpython.Py_INCREF(point)
|
||||
cpython.PyTuple_SET_ITEM(pts, i, point)
|
||||
|
||||
return pts
|
||||
|
||||
|
||||
cdef box_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
wbuf.write_int32(32)
|
||||
_encode_points(wbuf, (obj[0], obj[1]))
|
||||
|
||||
|
||||
cdef box_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
double high_x = hton.unpack_double(frb_read(buf, 8))
|
||||
double high_y = hton.unpack_double(frb_read(buf, 8))
|
||||
double low_x = hton.unpack_double(frb_read(buf, 8))
|
||||
double low_y = hton.unpack_double(frb_read(buf, 8))
|
||||
|
||||
return pgproto_types.Box(
|
||||
pgproto_types.Point(high_x, high_y),
|
||||
pgproto_types.Point(low_x, low_y))
|
||||
|
||||
|
||||
cdef line_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
wbuf.write_int32(24)
|
||||
wbuf.write_double(obj[0])
|
||||
wbuf.write_double(obj[1])
|
||||
wbuf.write_double(obj[2])
|
||||
|
||||
|
||||
cdef line_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
double A = hton.unpack_double(frb_read(buf, 8))
|
||||
double B = hton.unpack_double(frb_read(buf, 8))
|
||||
double C = hton.unpack_double(frb_read(buf, 8))
|
||||
|
||||
return pgproto_types.Line(A, B, C)
|
||||
|
||||
|
||||
cdef lseg_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
wbuf.write_int32(32)
|
||||
_encode_points(wbuf, (obj[0], obj[1]))
|
||||
|
||||
|
||||
cdef lseg_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
double p1_x = hton.unpack_double(frb_read(buf, 8))
|
||||
double p1_y = hton.unpack_double(frb_read(buf, 8))
|
||||
double p2_x = hton.unpack_double(frb_read(buf, 8))
|
||||
double p2_y = hton.unpack_double(frb_read(buf, 8))
|
||||
|
||||
return pgproto_types.LineSegment((p1_x, p1_y), (p2_x, p2_y))
|
||||
|
||||
|
||||
cdef point_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
wbuf.write_int32(16)
|
||||
wbuf.write_double(obj[0])
|
||||
wbuf.write_double(obj[1])
|
||||
|
||||
|
||||
cdef point_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
double x = hton.unpack_double(frb_read(buf, 8))
|
||||
double y = hton.unpack_double(frb_read(buf, 8))
|
||||
|
||||
return pgproto_types.Point(x, y)
|
||||
|
||||
|
||||
cdef path_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
cdef:
|
||||
int8_t is_closed = 0
|
||||
ssize_t npts
|
||||
ssize_t encoded_len
|
||||
int32_t i
|
||||
|
||||
if cpython.PyTuple_Check(obj):
|
||||
is_closed = 1
|
||||
elif cpython.PyList_Check(obj):
|
||||
is_closed = 0
|
||||
elif isinstance(obj, pgproto_types.Path):
|
||||
is_closed = obj.is_closed
|
||||
|
||||
npts = len(obj)
|
||||
encoded_len = 1 + 4 + 16 * npts
|
||||
if encoded_len > _MAXINT32:
|
||||
raise ValueError('path value too long')
|
||||
|
||||
wbuf.write_int32(<int32_t>encoded_len)
|
||||
wbuf.write_byte(is_closed)
|
||||
wbuf.write_int32(<int32_t>npts)
|
||||
|
||||
_encode_points(wbuf, obj)
|
||||
|
||||
|
||||
cdef path_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int8_t is_closed = <int8_t>(frb_read(buf, 1)[0])
|
||||
|
||||
return pgproto_types.Path(*_decode_points(buf), is_closed=is_closed == 1)
|
||||
|
||||
|
||||
cdef poly_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
cdef:
|
||||
bint is_closed
|
||||
ssize_t npts
|
||||
ssize_t encoded_len
|
||||
int32_t i
|
||||
|
||||
npts = len(obj)
|
||||
encoded_len = 4 + 16 * npts
|
||||
if encoded_len > _MAXINT32:
|
||||
raise ValueError('polygon value too long')
|
||||
|
||||
wbuf.write_int32(<int32_t>encoded_len)
|
||||
wbuf.write_int32(<int32_t>npts)
|
||||
_encode_points(wbuf, obj)
|
||||
|
||||
|
||||
cdef poly_decode(CodecContext settings, FRBuffer *buf):
|
||||
return pgproto_types.Polygon(*_decode_points(buf))
|
||||
|
||||
|
||||
cdef circle_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
wbuf.write_int32(24)
|
||||
wbuf.write_double(obj[0][0])
|
||||
wbuf.write_double(obj[0][1])
|
||||
wbuf.write_double(obj[1])
|
||||
|
||||
|
||||
cdef circle_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
double center_x = hton.unpack_double(frb_read(buf, 8))
|
||||
double center_y = hton.unpack_double(frb_read(buf, 8))
|
||||
double radius = hton.unpack_double(frb_read(buf, 8))
|
||||
|
||||
return pgproto_types.Circle((center_x, center_y), radius)
|
||||
@@ -0,0 +1,73 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef hstore_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
char *str
|
||||
ssize_t size
|
||||
ssize_t count
|
||||
object items
|
||||
WriteBuffer item_buf = WriteBuffer.new()
|
||||
|
||||
count = len(obj)
|
||||
if count > _MAXINT32:
|
||||
raise ValueError('hstore value is too large')
|
||||
item_buf.write_int32(<int32_t>count)
|
||||
|
||||
if hasattr(obj, 'items'):
|
||||
items = obj.items()
|
||||
else:
|
||||
items = obj
|
||||
|
||||
for k, v in items:
|
||||
if k is None:
|
||||
raise ValueError('null value not allowed in hstore key')
|
||||
as_pg_string_and_size(settings, k, &str, &size)
|
||||
item_buf.write_int32(<int32_t>size)
|
||||
item_buf.write_cstr(str, size)
|
||||
if v is None:
|
||||
item_buf.write_int32(<int32_t>-1)
|
||||
else:
|
||||
as_pg_string_and_size(settings, v, &str, &size)
|
||||
item_buf.write_int32(<int32_t>size)
|
||||
item_buf.write_cstr(str, size)
|
||||
|
||||
buf.write_int32(item_buf.len())
|
||||
buf.write_buffer(item_buf)
|
||||
|
||||
|
||||
cdef hstore_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
dict result
|
||||
uint32_t elem_count
|
||||
int32_t elem_len
|
||||
uint32_t i
|
||||
str k
|
||||
str v
|
||||
|
||||
result = {}
|
||||
|
||||
elem_count = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
|
||||
if elem_count == 0:
|
||||
return result
|
||||
|
||||
for i in range(elem_count):
|
||||
elem_len = hton.unpack_int32(frb_read(buf, 4))
|
||||
if elem_len < 0:
|
||||
raise ValueError('null value not allowed in hstore key')
|
||||
|
||||
k = decode_pg_string(settings, frb_read(buf, elem_len), elem_len)
|
||||
|
||||
elem_len = hton.unpack_int32(frb_read(buf, 4))
|
||||
if elem_len < 0:
|
||||
v = None
|
||||
else:
|
||||
v = decode_pg_string(settings, frb_read(buf, elem_len), elem_len)
|
||||
|
||||
result[k] = v
|
||||
|
||||
return result
|
||||
144
venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/int.pyx
Normal file
144
venv/lib/python3.12/site-packages/asyncpg/pgproto/codecs/int.pyx
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef bool_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
if not cpython.PyBool_Check(obj):
|
||||
raise TypeError('a boolean is required (got type {})'.format(
|
||||
type(obj).__name__))
|
||||
|
||||
buf.write_int32(1)
|
||||
buf.write_byte(b'\x01' if obj is True else b'\x00')
|
||||
|
||||
|
||||
cdef bool_decode(CodecContext settings, FRBuffer *buf):
|
||||
return frb_read(buf, 1)[0] is b'\x01'
|
||||
|
||||
|
||||
cdef int2_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef int overflow = 0
|
||||
cdef long val
|
||||
|
||||
try:
|
||||
if type(obj) is not int and hasattr(type(obj), '__int__'):
|
||||
# Silence a Python warning about implicit __int__
|
||||
# conversion.
|
||||
obj = int(obj)
|
||||
val = cpython.PyLong_AsLong(obj)
|
||||
except OverflowError:
|
||||
overflow = 1
|
||||
|
||||
if overflow or val < INT16_MIN or val > INT16_MAX:
|
||||
raise OverflowError('value out of int16 range')
|
||||
|
||||
buf.write_int32(2)
|
||||
buf.write_int16(<int16_t>val)
|
||||
|
||||
|
||||
cdef int2_decode(CodecContext settings, FRBuffer *buf):
|
||||
return cpython.PyLong_FromLong(hton.unpack_int16(frb_read(buf, 2)))
|
||||
|
||||
|
||||
cdef int4_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef int overflow = 0
|
||||
cdef long val = 0
|
||||
|
||||
try:
|
||||
if type(obj) is not int and hasattr(type(obj), '__int__'):
|
||||
# Silence a Python warning about implicit __int__
|
||||
# conversion.
|
||||
obj = int(obj)
|
||||
val = cpython.PyLong_AsLong(obj)
|
||||
except OverflowError:
|
||||
overflow = 1
|
||||
|
||||
# "long" and "long long" have the same size for x86_64, need an extra check
|
||||
if overflow or (sizeof(val) > 4 and (val < INT32_MIN or val > INT32_MAX)):
|
||||
raise OverflowError('value out of int32 range')
|
||||
|
||||
buf.write_int32(4)
|
||||
buf.write_int32(<int32_t>val)
|
||||
|
||||
|
||||
cdef int4_decode(CodecContext settings, FRBuffer *buf):
|
||||
return cpython.PyLong_FromLong(hton.unpack_int32(frb_read(buf, 4)))
|
||||
|
||||
|
||||
cdef uint4_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef int overflow = 0
|
||||
cdef unsigned long val = 0
|
||||
|
||||
try:
|
||||
if type(obj) is not int and hasattr(type(obj), '__int__'):
|
||||
# Silence a Python warning about implicit __int__
|
||||
# conversion.
|
||||
obj = int(obj)
|
||||
val = cpython.PyLong_AsUnsignedLong(obj)
|
||||
except OverflowError:
|
||||
overflow = 1
|
||||
|
||||
# "long" and "long long" have the same size for x86_64, need an extra check
|
||||
if overflow or (sizeof(val) > 4 and val > UINT32_MAX):
|
||||
raise OverflowError('value out of uint32 range')
|
||||
|
||||
buf.write_int32(4)
|
||||
buf.write_int32(<int32_t>val)
|
||||
|
||||
|
||||
cdef uint4_decode(CodecContext settings, FRBuffer *buf):
|
||||
return cpython.PyLong_FromUnsignedLong(
|
||||
<uint32_t>hton.unpack_int32(frb_read(buf, 4)))
|
||||
|
||||
|
||||
cdef int8_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef int overflow = 0
|
||||
cdef long long val
|
||||
|
||||
try:
|
||||
if type(obj) is not int and hasattr(type(obj), '__int__'):
|
||||
# Silence a Python warning about implicit __int__
|
||||
# conversion.
|
||||
obj = int(obj)
|
||||
val = cpython.PyLong_AsLongLong(obj)
|
||||
except OverflowError:
|
||||
overflow = 1
|
||||
|
||||
# Just in case for systems with "long long" bigger than 8 bytes
|
||||
if overflow or (sizeof(val) > 8 and (val < INT64_MIN or val > INT64_MAX)):
|
||||
raise OverflowError('value out of int64 range')
|
||||
|
||||
buf.write_int32(8)
|
||||
buf.write_int64(<int64_t>val)
|
||||
|
||||
|
||||
cdef int8_decode(CodecContext settings, FRBuffer *buf):
|
||||
return cpython.PyLong_FromLongLong(hton.unpack_int64(frb_read(buf, 8)))
|
||||
|
||||
|
||||
cdef uint8_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef int overflow = 0
|
||||
cdef unsigned long long val = 0
|
||||
|
||||
try:
|
||||
if type(obj) is not int and hasattr(type(obj), '__int__'):
|
||||
# Silence a Python warning about implicit __int__
|
||||
# conversion.
|
||||
obj = int(obj)
|
||||
val = cpython.PyLong_AsUnsignedLongLong(obj)
|
||||
except OverflowError:
|
||||
overflow = 1
|
||||
|
||||
# Just in case for systems with "long long" bigger than 8 bytes
|
||||
if overflow or (sizeof(val) > 8 and val > UINT64_MAX):
|
||||
raise OverflowError('value out of uint64 range')
|
||||
|
||||
buf.write_int32(8)
|
||||
buf.write_int64(<int64_t>val)
|
||||
|
||||
|
||||
cdef uint8_decode(CodecContext settings, FRBuffer *buf):
|
||||
return cpython.PyLong_FromUnsignedLongLong(
|
||||
<uint64_t>hton.unpack_int64(frb_read(buf, 8)))
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef jsonb_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
char *str
|
||||
ssize_t size
|
||||
|
||||
if settings.is_encoding_json():
|
||||
obj = settings.get_json_encoder().encode(obj)
|
||||
|
||||
as_pg_string_and_size(settings, obj, &str, &size)
|
||||
|
||||
if size > 0x7fffffff - 1:
|
||||
raise ValueError('string too long')
|
||||
|
||||
buf.write_int32(<int32_t>size + 1)
|
||||
buf.write_byte(1) # JSONB format version
|
||||
buf.write_cstr(str, size)
|
||||
|
||||
|
||||
cdef jsonb_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef uint8_t format = <uint8_t>(frb_read(buf, 1)[0])
|
||||
|
||||
if format != 1:
|
||||
raise ValueError('unexpected JSONB format: {}'.format(format))
|
||||
|
||||
rv = text_decode(settings, buf)
|
||||
|
||||
if settings.is_decoding_json():
|
||||
rv = settings.get_json_decoder().decode(rv)
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
cdef json_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
char *str
|
||||
ssize_t size
|
||||
|
||||
if settings.is_encoding_json():
|
||||
obj = settings.get_json_encoder().encode(obj)
|
||||
|
||||
text_encode(settings, buf, obj)
|
||||
|
||||
|
||||
cdef json_decode(CodecContext settings, FRBuffer *buf):
|
||||
rv = text_decode(settings, buf)
|
||||
|
||||
if settings.is_decoding_json():
|
||||
rv = settings.get_json_decoder().decode(rv)
|
||||
|
||||
return rv
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef jsonpath_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
char *str
|
||||
ssize_t size
|
||||
|
||||
as_pg_string_and_size(settings, obj, &str, &size)
|
||||
|
||||
if size > 0x7fffffff - 1:
|
||||
raise ValueError('string too long')
|
||||
|
||||
buf.write_int32(<int32_t>size + 1)
|
||||
buf.write_byte(1) # jsonpath format version
|
||||
buf.write_cstr(str, size)
|
||||
|
||||
|
||||
cdef jsonpath_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef uint8_t format = <uint8_t>(frb_read(buf, 1)[0])
|
||||
|
||||
if format != 1:
|
||||
raise ValueError('unexpected jsonpath format: {}'.format(format))
|
||||
|
||||
return text_decode(settings, buf)
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef void_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
# Void is zero bytes
|
||||
buf.write_int32(0)
|
||||
|
||||
|
||||
cdef void_decode(CodecContext settings, FRBuffer *buf):
|
||||
# Do nothing; void will be passed as NULL so this function
|
||||
# will never be called.
|
||||
pass
|
||||
@@ -0,0 +1,139 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import ipaddress
|
||||
|
||||
|
||||
# defined in postgresql/src/include/inet.h
|
||||
#
|
||||
DEF PGSQL_AF_INET = 2 # AF_INET
|
||||
DEF PGSQL_AF_INET6 = 3 # AF_INET + 1
|
||||
|
||||
|
||||
_ipaddr = ipaddress.ip_address
|
||||
_ipiface = ipaddress.ip_interface
|
||||
_ipnet = ipaddress.ip_network
|
||||
|
||||
|
||||
cdef inline uint8_t _ip_max_prefix_len(int32_t family):
|
||||
# Maximum number of bits in the network prefix of the specified
|
||||
# IP protocol version.
|
||||
if family == PGSQL_AF_INET:
|
||||
return 32
|
||||
else:
|
||||
return 128
|
||||
|
||||
|
||||
cdef inline int32_t _ip_addr_len(int32_t family):
|
||||
# Length of address in bytes for the specified IP protocol version.
|
||||
if family == PGSQL_AF_INET:
|
||||
return 4
|
||||
else:
|
||||
return 16
|
||||
|
||||
|
||||
cdef inline int8_t _ver_to_family(int32_t version):
|
||||
if version == 4:
|
||||
return PGSQL_AF_INET
|
||||
else:
|
||||
return PGSQL_AF_INET6
|
||||
|
||||
|
||||
cdef inline _net_encode(WriteBuffer buf, int8_t family, uint32_t bits,
|
||||
int8_t is_cidr, bytes addr):
|
||||
|
||||
cdef:
|
||||
char *addrbytes
|
||||
ssize_t addrlen
|
||||
|
||||
cpython.PyBytes_AsStringAndSize(addr, &addrbytes, &addrlen)
|
||||
|
||||
buf.write_int32(4 + <int32_t>addrlen)
|
||||
buf.write_byte(family)
|
||||
buf.write_byte(<int8_t>bits)
|
||||
buf.write_byte(is_cidr)
|
||||
buf.write_byte(<int8_t>addrlen)
|
||||
buf.write_cstr(addrbytes, addrlen)
|
||||
|
||||
|
||||
cdef net_decode(CodecContext settings, FRBuffer *buf, bint as_cidr):
|
||||
cdef:
|
||||
int32_t family = <int32_t>frb_read(buf, 1)[0]
|
||||
uint8_t bits = <uint8_t>frb_read(buf, 1)[0]
|
||||
int prefix_len
|
||||
int32_t is_cidr = <int32_t>frb_read(buf, 1)[0]
|
||||
int32_t addrlen = <int32_t>frb_read(buf, 1)[0]
|
||||
bytes addr
|
||||
uint8_t max_prefix_len = _ip_max_prefix_len(family)
|
||||
|
||||
if is_cidr != as_cidr:
|
||||
raise ValueError('unexpected CIDR flag set in non-cidr value')
|
||||
|
||||
if family != PGSQL_AF_INET and family != PGSQL_AF_INET6:
|
||||
raise ValueError('invalid address family in "{}" value'.format(
|
||||
'cidr' if is_cidr else 'inet'
|
||||
))
|
||||
|
||||
max_prefix_len = _ip_max_prefix_len(family)
|
||||
|
||||
if bits > max_prefix_len:
|
||||
raise ValueError('invalid network prefix length in "{}" value'.format(
|
||||
'cidr' if is_cidr else 'inet'
|
||||
))
|
||||
|
||||
if addrlen != _ip_addr_len(family):
|
||||
raise ValueError('invalid address length in "{}" value'.format(
|
||||
'cidr' if is_cidr else 'inet'
|
||||
))
|
||||
|
||||
addr = cpython.PyBytes_FromStringAndSize(frb_read(buf, addrlen), addrlen)
|
||||
|
||||
if as_cidr or bits != max_prefix_len:
|
||||
prefix_len = cpython.PyLong_FromLong(bits)
|
||||
|
||||
if as_cidr:
|
||||
return _ipnet((addr, prefix_len))
|
||||
else:
|
||||
return _ipiface((addr, prefix_len))
|
||||
else:
|
||||
return _ipaddr(addr)
|
||||
|
||||
|
||||
cdef cidr_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
object ipnet
|
||||
int8_t family
|
||||
|
||||
ipnet = _ipnet(obj)
|
||||
family = _ver_to_family(ipnet.version)
|
||||
_net_encode(buf, family, ipnet.prefixlen, 1, ipnet.network_address.packed)
|
||||
|
||||
|
||||
cdef cidr_decode(CodecContext settings, FRBuffer *buf):
|
||||
return net_decode(settings, buf, True)
|
||||
|
||||
|
||||
cdef inet_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
object ipaddr
|
||||
int8_t family
|
||||
|
||||
try:
|
||||
ipaddr = _ipaddr(obj)
|
||||
except ValueError:
|
||||
# PostgreSQL accepts *both* CIDR and host values
|
||||
# for the host datatype.
|
||||
ipaddr = _ipiface(obj)
|
||||
family = _ver_to_family(ipaddr.version)
|
||||
_net_encode(buf, family, ipaddr.network.prefixlen, 1, ipaddr.packed)
|
||||
else:
|
||||
family = _ver_to_family(ipaddr.version)
|
||||
_net_encode(buf, family, _ip_max_prefix_len(family), 0, ipaddr.packed)
|
||||
|
||||
|
||||
cdef inet_decode(CodecContext settings, FRBuffer *buf):
|
||||
return net_decode(settings, buf, False)
|
||||
@@ -0,0 +1,356 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from libc.math cimport abs, log10
|
||||
from libc.stdio cimport snprintf
|
||||
|
||||
import decimal
|
||||
|
||||
# defined in postgresql/src/backend/utils/adt/numeric.c
|
||||
DEF DEC_DIGITS = 4
|
||||
DEF MAX_DSCALE = 0x3FFF
|
||||
DEF NUMERIC_POS = 0x0000
|
||||
DEF NUMERIC_NEG = 0x4000
|
||||
DEF NUMERIC_NAN = 0xC000
|
||||
DEF NUMERIC_PINF = 0xD000
|
||||
DEF NUMERIC_NINF = 0xF000
|
||||
|
||||
_Dec = decimal.Decimal
|
||||
|
||||
|
||||
cdef numeric_encode_text(CodecContext settings, WriteBuffer buf, obj):
|
||||
text_encode(settings, buf, str(obj))
|
||||
|
||||
|
||||
cdef numeric_decode_text(CodecContext settings, FRBuffer *buf):
|
||||
return _Dec(text_decode(settings, buf))
|
||||
|
||||
|
||||
cdef numeric_encode_binary(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
object dec
|
||||
object dt
|
||||
int64_t exponent
|
||||
int64_t i
|
||||
int64_t j
|
||||
tuple pydigits
|
||||
int64_t num_pydigits
|
||||
int16_t pgdigit
|
||||
int64_t num_pgdigits
|
||||
int16_t dscale
|
||||
int64_t dweight
|
||||
int64_t weight
|
||||
uint16_t sign
|
||||
int64_t padding_size = 0
|
||||
|
||||
if isinstance(obj, _Dec):
|
||||
dec = obj
|
||||
else:
|
||||
dec = _Dec(obj)
|
||||
|
||||
dt = dec.as_tuple()
|
||||
|
||||
if dt.exponent == 'n' or dt.exponent == 'N':
|
||||
# NaN
|
||||
sign = NUMERIC_NAN
|
||||
num_pgdigits = 0
|
||||
weight = 0
|
||||
dscale = 0
|
||||
elif dt.exponent == 'F':
|
||||
# Infinity
|
||||
if dt.sign:
|
||||
sign = NUMERIC_NINF
|
||||
else:
|
||||
sign = NUMERIC_PINF
|
||||
num_pgdigits = 0
|
||||
weight = 0
|
||||
dscale = 0
|
||||
else:
|
||||
exponent = dt.exponent
|
||||
if exponent < 0 and -exponent > MAX_DSCALE:
|
||||
raise ValueError(
|
||||
'cannot encode Decimal value into numeric: '
|
||||
'exponent is too small')
|
||||
|
||||
if dt.sign:
|
||||
sign = NUMERIC_NEG
|
||||
else:
|
||||
sign = NUMERIC_POS
|
||||
|
||||
pydigits = dt.digits
|
||||
num_pydigits = len(pydigits)
|
||||
|
||||
dweight = num_pydigits + exponent - 1
|
||||
if dweight >= 0:
|
||||
weight = (dweight + DEC_DIGITS) // DEC_DIGITS - 1
|
||||
else:
|
||||
weight = -((-dweight - 1) // DEC_DIGITS + 1)
|
||||
|
||||
if weight > 2 ** 16 - 1:
|
||||
raise ValueError(
|
||||
'cannot encode Decimal value into numeric: '
|
||||
'exponent is too large')
|
||||
|
||||
padding_size = \
|
||||
(weight + 1) * DEC_DIGITS - (dweight + 1)
|
||||
num_pgdigits = \
|
||||
(num_pydigits + padding_size + DEC_DIGITS - 1) // DEC_DIGITS
|
||||
|
||||
if num_pgdigits > 2 ** 16 - 1:
|
||||
raise ValueError(
|
||||
'cannot encode Decimal value into numeric: '
|
||||
'number of digits is too large')
|
||||
|
||||
# Pad decimal digits to provide room for correct Postgres
|
||||
# digit alignment in the digit computation loop.
|
||||
pydigits = (0,) * DEC_DIGITS + pydigits + (0,) * DEC_DIGITS
|
||||
|
||||
if exponent < 0:
|
||||
if -exponent > MAX_DSCALE:
|
||||
raise ValueError(
|
||||
'cannot encode Decimal value into numeric: '
|
||||
'exponent is too small')
|
||||
dscale = <int16_t>-exponent
|
||||
else:
|
||||
dscale = 0
|
||||
|
||||
buf.write_int32(2 + 2 + 2 + 2 + 2 * <uint16_t>num_pgdigits)
|
||||
buf.write_int16(<int16_t>num_pgdigits)
|
||||
buf.write_int16(<int16_t>weight)
|
||||
buf.write_int16(<int16_t>sign)
|
||||
buf.write_int16(dscale)
|
||||
|
||||
j = DEC_DIGITS - padding_size
|
||||
|
||||
for i in range(num_pgdigits):
|
||||
pgdigit = (pydigits[j] * 1000 + pydigits[j + 1] * 100 +
|
||||
pydigits[j + 2] * 10 + pydigits[j + 3])
|
||||
j += DEC_DIGITS
|
||||
buf.write_int16(pgdigit)
|
||||
|
||||
|
||||
# The decoding strategy here is to form a string representation of
|
||||
# the numeric var, as it is faster than passing an iterable of digits.
|
||||
# For this reason the below code is pure overhead and is ~25% slower
|
||||
# than the simple text decoder above. That said, we need the binary
|
||||
# decoder to support binary COPY with numeric values.
|
||||
cdef numeric_decode_binary_ex(
|
||||
CodecContext settings,
|
||||
FRBuffer *buf,
|
||||
bint trail_fract_zero,
|
||||
):
|
||||
cdef:
|
||||
uint16_t num_pgdigits = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
|
||||
int16_t weight = hton.unpack_int16(frb_read(buf, 2))
|
||||
uint16_t sign = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
|
||||
uint16_t dscale = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
|
||||
int16_t pgdigit0
|
||||
ssize_t i
|
||||
int16_t pgdigit
|
||||
object pydigits
|
||||
ssize_t num_pydigits
|
||||
ssize_t actual_num_pydigits
|
||||
ssize_t buf_size
|
||||
int64_t exponent
|
||||
int64_t abs_exponent
|
||||
ssize_t exponent_chars
|
||||
ssize_t front_padding = 0
|
||||
ssize_t num_fract_digits
|
||||
ssize_t trailing_fract_zeros_adj
|
||||
char smallbuf[_NUMERIC_DECODER_SMALLBUF_SIZE]
|
||||
char *charbuf
|
||||
char *bufptr
|
||||
bint buf_allocated = False
|
||||
|
||||
if sign == NUMERIC_NAN:
|
||||
# Not-a-number
|
||||
return _Dec('NaN')
|
||||
elif sign == NUMERIC_PINF:
|
||||
# +Infinity
|
||||
return _Dec('Infinity')
|
||||
elif sign == NUMERIC_NINF:
|
||||
# -Infinity
|
||||
return _Dec('-Infinity')
|
||||
|
||||
if num_pgdigits == 0:
|
||||
# Zero
|
||||
return _Dec('0e-' + str(dscale))
|
||||
|
||||
pgdigit0 = hton.unpack_int16(frb_read(buf, 2))
|
||||
if weight >= 0:
|
||||
if pgdigit0 < 10:
|
||||
front_padding = 3
|
||||
elif pgdigit0 < 100:
|
||||
front_padding = 2
|
||||
elif pgdigit0 < 1000:
|
||||
front_padding = 1
|
||||
|
||||
# The number of fractional decimal digits actually encoded in
|
||||
# base-DEC_DEIGITS digits sent by Postgres.
|
||||
num_fract_digits = (num_pgdigits - weight - 1) * DEC_DIGITS
|
||||
|
||||
# The trailing zero adjustment necessary to obtain exactly
|
||||
# dscale number of fractional digits in output. May be negative,
|
||||
# which indicates that trailing zeros in the last input digit
|
||||
# should be discarded.
|
||||
trailing_fract_zeros_adj = dscale - num_fract_digits
|
||||
|
||||
# Maximum possible number of decimal digits in base 10.
|
||||
# The actual number might be up to 3 digits smaller due to
|
||||
# leading zeros in first input digit.
|
||||
num_pydigits = num_pgdigits * DEC_DIGITS
|
||||
if trailing_fract_zeros_adj > 0:
|
||||
num_pydigits += trailing_fract_zeros_adj
|
||||
|
||||
# Exponent.
|
||||
exponent = (weight + 1) * DEC_DIGITS - front_padding
|
||||
abs_exponent = abs(exponent)
|
||||
if abs_exponent != 0:
|
||||
# Number of characters required to render absolute exponent value
|
||||
# in decimal.
|
||||
exponent_chars = <ssize_t>log10(<double>abs_exponent) + 1
|
||||
else:
|
||||
exponent_chars = 0
|
||||
|
||||
# Output buffer size.
|
||||
buf_size = (
|
||||
1 + # sign
|
||||
1 + # leading zero
|
||||
1 + # decimal dot
|
||||
num_pydigits + # digits
|
||||
1 + # possible trailing zero padding
|
||||
2 + # exponent indicator (E-,E+)
|
||||
exponent_chars + # exponent
|
||||
1 # null terminator char
|
||||
)
|
||||
|
||||
if buf_size > _NUMERIC_DECODER_SMALLBUF_SIZE:
|
||||
charbuf = <char *>cpython.PyMem_Malloc(<size_t>buf_size)
|
||||
buf_allocated = True
|
||||
else:
|
||||
charbuf = smallbuf
|
||||
|
||||
try:
|
||||
bufptr = charbuf
|
||||
|
||||
if sign == NUMERIC_NEG:
|
||||
bufptr[0] = b'-'
|
||||
bufptr += 1
|
||||
|
||||
bufptr[0] = b'0'
|
||||
bufptr[1] = b'.'
|
||||
bufptr += 2
|
||||
|
||||
if weight >= 0:
|
||||
bufptr = _unpack_digit_stripping_lzeros(bufptr, pgdigit0)
|
||||
else:
|
||||
bufptr = _unpack_digit(bufptr, pgdigit0)
|
||||
|
||||
for i in range(1, num_pgdigits):
|
||||
pgdigit = hton.unpack_int16(frb_read(buf, 2))
|
||||
bufptr = _unpack_digit(bufptr, pgdigit)
|
||||
|
||||
if dscale:
|
||||
if trailing_fract_zeros_adj > 0:
|
||||
for i in range(trailing_fract_zeros_adj):
|
||||
bufptr[i] = <char>b'0'
|
||||
|
||||
# If display scale is _less_ than the number of rendered digits,
|
||||
# trailing_fract_zeros_adj will be negative and this will strip
|
||||
# the excess trailing zeros.
|
||||
bufptr += trailing_fract_zeros_adj
|
||||
|
||||
if trail_fract_zero:
|
||||
# Check if the number of rendered digits matches the exponent,
|
||||
# and if so, add another trailing zero, so the result always
|
||||
# appears with a decimal point.
|
||||
actual_num_pydigits = bufptr - charbuf - 2
|
||||
if sign == NUMERIC_NEG:
|
||||
actual_num_pydigits -= 1
|
||||
|
||||
if actual_num_pydigits == abs_exponent:
|
||||
bufptr[0] = <char>b'0'
|
||||
bufptr += 1
|
||||
|
||||
if exponent != 0:
|
||||
bufptr[0] = b'E'
|
||||
if exponent < 0:
|
||||
bufptr[1] = b'-'
|
||||
else:
|
||||
bufptr[1] = b'+'
|
||||
bufptr += 2
|
||||
snprintf(bufptr, <size_t>exponent_chars + 1, '%d',
|
||||
<int>abs_exponent)
|
||||
bufptr += exponent_chars
|
||||
|
||||
bufptr[0] = 0
|
||||
|
||||
pydigits = cpythonx.PyUnicode_FromString(charbuf)
|
||||
|
||||
return _Dec(pydigits)
|
||||
|
||||
finally:
|
||||
if buf_allocated:
|
||||
cpython.PyMem_Free(charbuf)
|
||||
|
||||
|
||||
cdef numeric_decode_binary(CodecContext settings, FRBuffer *buf):
|
||||
return numeric_decode_binary_ex(settings, buf, False)
|
||||
|
||||
|
||||
cdef inline char *_unpack_digit_stripping_lzeros(char *buf, int64_t pgdigit):
|
||||
cdef:
|
||||
int64_t d
|
||||
bint significant
|
||||
|
||||
d = pgdigit // 1000
|
||||
significant = (d > 0)
|
||||
if significant:
|
||||
pgdigit -= d * 1000
|
||||
buf[0] = <char>(d + <int32_t>b'0')
|
||||
buf += 1
|
||||
|
||||
d = pgdigit // 100
|
||||
significant |= (d > 0)
|
||||
if significant:
|
||||
pgdigit -= d * 100
|
||||
buf[0] = <char>(d + <int32_t>b'0')
|
||||
buf += 1
|
||||
|
||||
d = pgdigit // 10
|
||||
significant |= (d > 0)
|
||||
if significant:
|
||||
pgdigit -= d * 10
|
||||
buf[0] = <char>(d + <int32_t>b'0')
|
||||
buf += 1
|
||||
|
||||
buf[0] = <char>(pgdigit + <int32_t>b'0')
|
||||
buf += 1
|
||||
|
||||
return buf
|
||||
|
||||
|
||||
cdef inline char *_unpack_digit(char *buf, int64_t pgdigit):
|
||||
cdef:
|
||||
int64_t d
|
||||
|
||||
d = pgdigit // 1000
|
||||
pgdigit -= d * 1000
|
||||
buf[0] = <char>(d + <int32_t>b'0')
|
||||
|
||||
d = pgdigit // 100
|
||||
pgdigit -= d * 100
|
||||
buf[1] = <char>(d + <int32_t>b'0')
|
||||
|
||||
d = pgdigit // 10
|
||||
pgdigit -= d * 10
|
||||
buf[2] = <char>(d + <int32_t>b'0')
|
||||
|
||||
buf[3] = <char>(pgdigit + <int32_t>b'0')
|
||||
buf += 4
|
||||
|
||||
return buf
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef pg_snapshot_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
ssize_t nxip
|
||||
uint64_t xmin
|
||||
uint64_t xmax
|
||||
int i
|
||||
WriteBuffer xip_buf = WriteBuffer.new()
|
||||
|
||||
if not (cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj)):
|
||||
raise TypeError(
|
||||
'list or tuple expected (got type {})'.format(type(obj)))
|
||||
|
||||
if len(obj) != 3:
|
||||
raise ValueError(
|
||||
'invalid number of elements in txid_snapshot tuple, expecting 4')
|
||||
|
||||
nxip = len(obj[2])
|
||||
if nxip > _MAXINT32:
|
||||
raise ValueError('txid_snapshot value is too long')
|
||||
|
||||
xmin = obj[0]
|
||||
xmax = obj[1]
|
||||
|
||||
for i in range(nxip):
|
||||
xip_buf.write_int64(
|
||||
<int64_t>cpython.PyLong_AsUnsignedLongLong(obj[2][i]))
|
||||
|
||||
buf.write_int32(20 + xip_buf.len())
|
||||
|
||||
buf.write_int32(<int32_t>nxip)
|
||||
buf.write_int64(<int64_t>xmin)
|
||||
buf.write_int64(<int64_t>xmax)
|
||||
buf.write_buffer(xip_buf)
|
||||
|
||||
|
||||
cdef pg_snapshot_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
int32_t nxip
|
||||
uint64_t xmin
|
||||
uint64_t xmax
|
||||
tuple xip_tup
|
||||
int32_t i
|
||||
object xip
|
||||
|
||||
nxip = hton.unpack_int32(frb_read(buf, 4))
|
||||
xmin = <uint64_t>hton.unpack_int64(frb_read(buf, 8))
|
||||
xmax = <uint64_t>hton.unpack_int64(frb_read(buf, 8))
|
||||
|
||||
xip_tup = cpython.PyTuple_New(nxip)
|
||||
for i in range(nxip):
|
||||
xip = cpython.PyLong_FromUnsignedLongLong(
|
||||
<uint64_t>hton.unpack_int64(frb_read(buf, 8)))
|
||||
cpython.Py_INCREF(xip)
|
||||
cpython.PyTuple_SET_ITEM(xip_tup, i, xip)
|
||||
|
||||
return (xmin, xmax, xip_tup)
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef inline as_pg_string_and_size(
|
||||
CodecContext settings, obj, char **cstr, ssize_t *size):
|
||||
|
||||
if not cpython.PyUnicode_Check(obj):
|
||||
raise TypeError('expected str, got {}'.format(type(obj).__name__))
|
||||
|
||||
if settings.is_encoding_utf8():
|
||||
cstr[0] = <char*>cpythonx.PyUnicode_AsUTF8AndSize(obj, size)
|
||||
else:
|
||||
encoded = settings.get_text_codec().encode(obj)[0]
|
||||
cpython.PyBytes_AsStringAndSize(encoded, cstr, size)
|
||||
|
||||
if size[0] > 0x7fffffff:
|
||||
raise ValueError('string too long')
|
||||
|
||||
|
||||
cdef text_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef:
|
||||
char *str
|
||||
ssize_t size
|
||||
|
||||
as_pg_string_and_size(settings, obj, &str, &size)
|
||||
|
||||
buf.write_int32(<int32_t>size)
|
||||
buf.write_cstr(str, size)
|
||||
|
||||
|
||||
cdef inline decode_pg_string(CodecContext settings, const char* data,
|
||||
ssize_t len):
|
||||
|
||||
if settings.is_encoding_utf8():
|
||||
# decode UTF-8 in strict mode
|
||||
return cpython.PyUnicode_DecodeUTF8(data, len, NULL)
|
||||
else:
|
||||
bytes = cpython.PyBytes_FromStringAndSize(data, len)
|
||||
return settings.get_text_codec().decode(bytes)[0]
|
||||
|
||||
|
||||
cdef text_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef ssize_t buf_len = buf.len
|
||||
return decode_pg_string(settings, frb_read_all(buf), buf_len)
|
||||
@@ -0,0 +1,51 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef tid_encode(CodecContext settings, WriteBuffer buf, obj):
|
||||
cdef int overflow = 0
|
||||
cdef unsigned long block, offset
|
||||
|
||||
if not (cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj)):
|
||||
raise TypeError(
|
||||
'list or tuple expected (got type {})'.format(type(obj)))
|
||||
|
||||
if len(obj) != 2:
|
||||
raise ValueError(
|
||||
'invalid number of elements in tid tuple, expecting 2')
|
||||
|
||||
try:
|
||||
block = cpython.PyLong_AsUnsignedLong(obj[0])
|
||||
except OverflowError:
|
||||
overflow = 1
|
||||
|
||||
# "long" and "long long" have the same size for x86_64, need an extra check
|
||||
if overflow or (sizeof(block) > 4 and block > UINT32_MAX):
|
||||
raise OverflowError('tuple id block value out of uint32 range')
|
||||
|
||||
try:
|
||||
offset = cpython.PyLong_AsUnsignedLong(obj[1])
|
||||
overflow = 0
|
||||
except OverflowError:
|
||||
overflow = 1
|
||||
|
||||
if overflow or offset > 65535:
|
||||
raise OverflowError('tuple id offset value out of uint16 range')
|
||||
|
||||
buf.write_int32(6)
|
||||
buf.write_int32(<int32_t>block)
|
||||
buf.write_int16(<int16_t>offset)
|
||||
|
||||
|
||||
cdef tid_decode(CodecContext settings, FRBuffer *buf):
|
||||
cdef:
|
||||
uint32_t block
|
||||
uint16_t offset
|
||||
|
||||
block = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
|
||||
offset = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
|
||||
|
||||
return (block, offset)
|
||||
@@ -0,0 +1,27 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef uuid_encode(CodecContext settings, WriteBuffer wbuf, obj):
|
||||
cdef:
|
||||
char buf[16]
|
||||
|
||||
if type(obj) is pg_UUID:
|
||||
wbuf.write_int32(<int32_t>16)
|
||||
wbuf.write_cstr((<UUID>obj)._data, 16)
|
||||
elif cpython.PyUnicode_Check(obj):
|
||||
pg_uuid_bytes_from_str(obj, buf)
|
||||
wbuf.write_int32(<int32_t>16)
|
||||
wbuf.write_cstr(buf, 16)
|
||||
else:
|
||||
bytea_encode(settings, wbuf, obj.bytes)
|
||||
|
||||
|
||||
cdef uuid_decode(CodecContext settings, FRBuffer *buf):
|
||||
if buf.len != 16:
|
||||
raise TypeError(
|
||||
f'cannot decode UUID, expected 16 bytes, got {buf.len}')
|
||||
return pg_uuid_from_buf(frb_read_all(buf))
|
||||
12
venv/lib/python3.12/site-packages/asyncpg/pgproto/consts.pxi
Normal file
12
venv/lib/python3.12/site-packages/asyncpg/pgproto/consts.pxi
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
DEF _BUFFER_INITIAL_SIZE = 1024
|
||||
DEF _BUFFER_MAX_GROW = 65536
|
||||
DEF _BUFFER_FREELIST_SIZE = 256
|
||||
DEF _MAXINT32 = 2**31 - 1
|
||||
DEF _NUMERIC_DECODER_SMALLBUF_SIZE = 256
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from cpython cimport Py_buffer
|
||||
|
||||
cdef extern from "Python.h":
|
||||
int PyUnicode_1BYTE_KIND
|
||||
|
||||
int PyByteArray_CheckExact(object)
|
||||
int PyByteArray_Resize(object, ssize_t) except -1
|
||||
object PyByteArray_FromStringAndSize(const char *, ssize_t)
|
||||
char* PyByteArray_AsString(object)
|
||||
|
||||
object PyUnicode_FromString(const char *u)
|
||||
const char* PyUnicode_AsUTF8AndSize(
|
||||
object unicode, ssize_t *size) except NULL
|
||||
|
||||
object PyUnicode_FromKindAndData(
|
||||
int kind, const void *buffer, Py_ssize_t size)
|
||||
10
venv/lib/python3.12/site-packages/asyncpg/pgproto/debug.pxd
Normal file
10
venv/lib/python3.12/site-packages/asyncpg/pgproto/debug.pxd
Normal file
@@ -0,0 +1,10 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef extern from "debug.h":
|
||||
|
||||
cdef int PG_DEBUG
|
||||
48
venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pxd
Normal file
48
venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pxd
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef:
|
||||
|
||||
struct FRBuffer:
|
||||
const char* buf
|
||||
ssize_t len
|
||||
|
||||
inline ssize_t frb_get_len(FRBuffer *frb):
|
||||
return frb.len
|
||||
|
||||
inline void frb_set_len(FRBuffer *frb, ssize_t new_len):
|
||||
frb.len = new_len
|
||||
|
||||
inline void frb_init(FRBuffer *frb, const char *buf, ssize_t len):
|
||||
frb.buf = buf
|
||||
frb.len = len
|
||||
|
||||
inline const char* frb_read(FRBuffer *frb, ssize_t n) except NULL:
|
||||
cdef const char *result
|
||||
|
||||
frb_check(frb, n)
|
||||
|
||||
result = frb.buf
|
||||
frb.buf += n
|
||||
frb.len -= n
|
||||
|
||||
return result
|
||||
|
||||
inline const char* frb_read_all(FRBuffer *frb):
|
||||
cdef const char *result
|
||||
result = frb.buf
|
||||
frb.buf += frb.len
|
||||
frb.len = 0
|
||||
return result
|
||||
|
||||
inline FRBuffer *frb_slice_from(FRBuffer *frb,
|
||||
FRBuffer* source, ssize_t len):
|
||||
frb.buf = frb_read(source, len)
|
||||
frb.len = len
|
||||
return frb
|
||||
|
||||
object frb_check(FRBuffer *frb, ssize_t n)
|
||||
12
venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pyx
Normal file
12
venv/lib/python3.12/site-packages/asyncpg/pgproto/frb.pyx
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef object frb_check(FRBuffer *frb, ssize_t n):
|
||||
if n > frb.len:
|
||||
raise AssertionError(
|
||||
f'insufficient data in buffer: requested {n} '
|
||||
f'remaining {frb.len}')
|
||||
24
venv/lib/python3.12/site-packages/asyncpg/pgproto/hton.pxd
Normal file
24
venv/lib/python3.12/site-packages/asyncpg/pgproto/hton.pxd
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from libc.stdint cimport int16_t, int32_t, uint16_t, uint32_t, int64_t, uint64_t
|
||||
|
||||
|
||||
cdef extern from "./hton.h":
|
||||
cdef void pack_int16(char *buf, int16_t x);
|
||||
cdef void pack_int32(char *buf, int32_t x);
|
||||
cdef void pack_int64(char *buf, int64_t x);
|
||||
cdef void pack_float(char *buf, float f);
|
||||
cdef void pack_double(char *buf, double f);
|
||||
cdef int16_t unpack_int16(const char *buf);
|
||||
cdef uint16_t unpack_uint16(const char *buf);
|
||||
cdef int32_t unpack_int32(const char *buf);
|
||||
cdef uint32_t unpack_uint32(const char *buf);
|
||||
cdef int64_t unpack_int64(const char *buf);
|
||||
cdef uint64_t unpack_uint64(const char *buf);
|
||||
cdef float unpack_float(const char *buf);
|
||||
cdef double unpack_double(const char *buf);
|
||||
Binary file not shown.
@@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cimport cython
|
||||
cimport cpython
|
||||
|
||||
from libc.stdint cimport int16_t, int32_t, uint16_t, uint32_t, int64_t, uint64_t
|
||||
|
||||
|
||||
include "./consts.pxi"
|
||||
include "./frb.pxd"
|
||||
include "./buffer.pxd"
|
||||
|
||||
|
||||
include "./codecs/__init__.pxd"
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cimport cython
|
||||
cimport cpython
|
||||
|
||||
from . cimport cpythonx
|
||||
|
||||
from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \
|
||||
int32_t, uint32_t, int64_t, uint64_t, \
|
||||
INT16_MIN, INT16_MAX, INT32_MIN, INT32_MAX, \
|
||||
UINT32_MAX, INT64_MIN, INT64_MAX, UINT64_MAX
|
||||
|
||||
|
||||
from . cimport hton
|
||||
from . cimport tohex
|
||||
|
||||
from .debug cimport PG_DEBUG
|
||||
from . import types as pgproto_types
|
||||
|
||||
|
||||
include "./consts.pxi"
|
||||
include "./frb.pyx"
|
||||
include "./buffer.pyx"
|
||||
include "./uuid.pyx"
|
||||
|
||||
include "./codecs/context.pyx"
|
||||
|
||||
include "./codecs/bytea.pyx"
|
||||
include "./codecs/text.pyx"
|
||||
|
||||
include "./codecs/datetime.pyx"
|
||||
include "./codecs/float.pyx"
|
||||
include "./codecs/int.pyx"
|
||||
include "./codecs/json.pyx"
|
||||
include "./codecs/jsonpath.pyx"
|
||||
include "./codecs/uuid.pyx"
|
||||
include "./codecs/numeric.pyx"
|
||||
include "./codecs/bits.pyx"
|
||||
include "./codecs/geometry.pyx"
|
||||
include "./codecs/hstore.pyx"
|
||||
include "./codecs/misc.pyx"
|
||||
include "./codecs/network.pyx"
|
||||
include "./codecs/tid.pyx"
|
||||
include "./codecs/pg_snapshot.pyx"
|
||||
10
venv/lib/python3.12/site-packages/asyncpg/pgproto/tohex.pxd
Normal file
10
venv/lib/python3.12/site-packages/asyncpg/pgproto/tohex.pxd
Normal file
@@ -0,0 +1,10 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef extern from "./tohex.h":
|
||||
cdef void uuid_to_str(const char *source, char *dest)
|
||||
cdef void uuid_to_hex(const char *source, char *dest)
|
||||
423
venv/lib/python3.12/site-packages/asyncpg/pgproto/types.py
Normal file
423
venv/lib/python3.12/site-packages/asyncpg/pgproto/types.py
Normal file
@@ -0,0 +1,423 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import builtins
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal, SupportsIndex
|
||||
else:
|
||||
from typing_extensions import Literal, SupportsIndex
|
||||
|
||||
|
||||
__all__ = (
|
||||
'BitString', 'Point', 'Path', 'Polygon',
|
||||
'Box', 'Line', 'LineSegment', 'Circle',
|
||||
)
|
||||
|
||||
_BitString = typing.TypeVar('_BitString', bound='BitString')
|
||||
_BitOrderType = Literal['big', 'little']
|
||||
|
||||
|
||||
class BitString:
|
||||
"""Immutable representation of PostgreSQL `bit` and `varbit` types."""
|
||||
|
||||
__slots__ = '_bytes', '_bitlength'
|
||||
|
||||
def __init__(self,
|
||||
bitstring: typing.Optional[builtins.bytes] = None) -> None:
|
||||
if not bitstring:
|
||||
self._bytes = bytes()
|
||||
self._bitlength = 0
|
||||
else:
|
||||
bytelen = len(bitstring) // 8 + 1
|
||||
bytes_ = bytearray(bytelen)
|
||||
byte = 0
|
||||
byte_pos = 0
|
||||
bit_pos = 0
|
||||
|
||||
for i, bit in enumerate(bitstring):
|
||||
if bit == ' ': # type: ignore
|
||||
continue
|
||||
bit = int(bit)
|
||||
if bit != 0 and bit != 1:
|
||||
raise ValueError(
|
||||
'invalid bit value at position {}'.format(i))
|
||||
|
||||
byte |= bit << (8 - bit_pos - 1)
|
||||
bit_pos += 1
|
||||
if bit_pos == 8:
|
||||
bytes_[byte_pos] = byte
|
||||
byte = 0
|
||||
byte_pos += 1
|
||||
bit_pos = 0
|
||||
|
||||
if bit_pos != 0:
|
||||
bytes_[byte_pos] = byte
|
||||
|
||||
bitlen = byte_pos * 8 + bit_pos
|
||||
bytelen = byte_pos + (1 if bit_pos else 0)
|
||||
|
||||
self._bytes = bytes(bytes_[:bytelen])
|
||||
self._bitlength = bitlen
|
||||
|
||||
@classmethod
|
||||
def frombytes(cls: typing.Type[_BitString],
|
||||
bytes_: typing.Optional[builtins.bytes] = None,
|
||||
bitlength: typing.Optional[int] = None) -> _BitString:
|
||||
if bitlength is None:
|
||||
if bytes_ is None:
|
||||
bytes_ = bytes()
|
||||
bitlength = 0
|
||||
else:
|
||||
bitlength = len(bytes_) * 8
|
||||
else:
|
||||
if bytes_ is None:
|
||||
bytes_ = bytes(bitlength // 8 + 1)
|
||||
bitlength = bitlength
|
||||
else:
|
||||
bytes_len = len(bytes_) * 8
|
||||
|
||||
if bytes_len == 0 and bitlength != 0:
|
||||
raise ValueError('invalid bit length specified')
|
||||
|
||||
if bytes_len != 0 and bitlength == 0:
|
||||
raise ValueError('invalid bit length specified')
|
||||
|
||||
if bitlength < bytes_len - 8:
|
||||
raise ValueError('invalid bit length specified')
|
||||
|
||||
if bitlength > bytes_len:
|
||||
raise ValueError('invalid bit length specified')
|
||||
|
||||
result = cls()
|
||||
result._bytes = bytes_
|
||||
result._bitlength = bitlength
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def bytes(self) -> builtins.bytes:
|
||||
return self._bytes
|
||||
|
||||
def as_string(self) -> str:
|
||||
s = ''
|
||||
|
||||
for i in range(self._bitlength):
|
||||
s += str(self._getitem(i))
|
||||
if i % 4 == 3:
|
||||
s += ' '
|
||||
|
||||
return s.strip()
|
||||
|
||||
def to_int(self, bitorder: _BitOrderType = 'big',
|
||||
*, signed: bool = False) -> int:
|
||||
"""Interpret the BitString as a Python int.
|
||||
Acts similarly to int.from_bytes.
|
||||
|
||||
:param bitorder:
|
||||
Determines the bit order used to interpret the BitString. By
|
||||
default, this function uses Postgres conventions for casting bits
|
||||
to ints. If bitorder is 'big', the most significant bit is at the
|
||||
start of the string (this is the same as the default). If bitorder
|
||||
is 'little', the most significant bit is at the end of the string.
|
||||
|
||||
:param bool signed:
|
||||
Determines whether two's complement is used to interpret the
|
||||
BitString. If signed is False, the returned value is always
|
||||
non-negative.
|
||||
|
||||
:return int: An integer representing the BitString. Information about
|
||||
the BitString's exact length is lost.
|
||||
|
||||
.. versionadded:: 0.18.0
|
||||
"""
|
||||
x = int.from_bytes(self._bytes, byteorder='big')
|
||||
x >>= -self._bitlength % 8
|
||||
if bitorder == 'big':
|
||||
pass
|
||||
elif bitorder == 'little':
|
||||
x = int(bin(x)[:1:-1].ljust(self._bitlength, '0'), 2)
|
||||
else:
|
||||
raise ValueError("bitorder must be either 'big' or 'little'")
|
||||
|
||||
if signed and self._bitlength > 0 and x & (1 << (self._bitlength - 1)):
|
||||
x -= 1 << self._bitlength
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_int(cls: typing.Type[_BitString], x: int, length: int,
|
||||
bitorder: _BitOrderType = 'big', *, signed: bool = False) \
|
||||
-> _BitString:
|
||||
"""Represent the Python int x as a BitString.
|
||||
Acts similarly to int.to_bytes.
|
||||
|
||||
:param int x:
|
||||
An integer to represent. Negative integers are represented in two's
|
||||
complement form, unless the argument signed is False, in which case
|
||||
negative integers raise an OverflowError.
|
||||
|
||||
:param int length:
|
||||
The length of the resulting BitString. An OverflowError is raised
|
||||
if the integer is not representable in this many bits.
|
||||
|
||||
:param bitorder:
|
||||
Determines the bit order used in the BitString representation. By
|
||||
default, this function uses Postgres conventions for casting ints
|
||||
to bits. If bitorder is 'big', the most significant bit is at the
|
||||
start of the string (this is the same as the default). If bitorder
|
||||
is 'little', the most significant bit is at the end of the string.
|
||||
|
||||
:param bool signed:
|
||||
Determines whether two's complement is used in the BitString
|
||||
representation. If signed is False and a negative integer is given,
|
||||
an OverflowError is raised.
|
||||
|
||||
:return BitString: A BitString representing the input integer, in the
|
||||
form specified by the other input args.
|
||||
|
||||
.. versionadded:: 0.18.0
|
||||
"""
|
||||
# Exception types are by analogy to int.to_bytes
|
||||
if length < 0:
|
||||
raise ValueError("length argument must be non-negative")
|
||||
elif length < x.bit_length():
|
||||
raise OverflowError("int too big to convert")
|
||||
|
||||
if x < 0:
|
||||
if not signed:
|
||||
raise OverflowError("can't convert negative int to unsigned")
|
||||
x &= (1 << length) - 1
|
||||
|
||||
if bitorder == 'big':
|
||||
pass
|
||||
elif bitorder == 'little':
|
||||
x = int(bin(x)[:1:-1].ljust(length, '0'), 2)
|
||||
else:
|
||||
raise ValueError("bitorder must be either 'big' or 'little'")
|
||||
|
||||
x <<= (-length % 8)
|
||||
bytes_ = x.to_bytes((length + 7) // 8, byteorder='big')
|
||||
return cls.frombytes(bytes_, length)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '<BitString {}>'.format(self.as_string())
|
||||
|
||||
__str__: typing.Callable[['BitString'], str] = __repr__
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, BitString):
|
||||
return NotImplemented
|
||||
|
||||
return (self._bytes == other._bytes and
|
||||
self._bitlength == other._bitlength)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._bytes, self._bitlength))
|
||||
|
||||
def _getitem(self, i: int) -> int:
|
||||
byte = self._bytes[i // 8]
|
||||
shift = 8 - i % 8 - 1
|
||||
return (byte >> shift) & 0x1
|
||||
|
||||
def __getitem__(self, i: int) -> int:
|
||||
if isinstance(i, slice):
|
||||
raise NotImplementedError('BitString does not support slices')
|
||||
|
||||
if i >= self._bitlength:
|
||||
raise IndexError('index out of range')
|
||||
|
||||
return self._getitem(i)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._bitlength
|
||||
|
||||
|
||||
class Point(typing.Tuple[float, float]):
|
||||
"""Immutable representation of PostgreSQL `point` type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls,
|
||||
x: typing.Union[typing.SupportsFloat,
|
||||
SupportsIndex,
|
||||
typing.Text,
|
||||
builtins.bytes,
|
||||
builtins.bytearray],
|
||||
y: typing.Union[typing.SupportsFloat,
|
||||
SupportsIndex,
|
||||
typing.Text,
|
||||
builtins.bytes,
|
||||
builtins.bytearray]) -> 'Point':
|
||||
return super().__new__(cls,
|
||||
typing.cast(typing.Any, (float(x), float(y))))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '{}.{}({})'.format(
|
||||
type(self).__module__,
|
||||
type(self).__name__,
|
||||
tuple.__repr__(self)
|
||||
)
|
||||
|
||||
@property
|
||||
def x(self) -> float:
|
||||
return self[0]
|
||||
|
||||
@property
|
||||
def y(self) -> float:
|
||||
return self[1]
|
||||
|
||||
|
||||
class Box(typing.Tuple[Point, Point]):
|
||||
"""Immutable representation of PostgreSQL `box` type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, high: typing.Sequence[float],
|
||||
low: typing.Sequence[float]) -> 'Box':
|
||||
return super().__new__(cls,
|
||||
typing.cast(typing.Any, (Point(*high),
|
||||
Point(*low))))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '{}.{}({})'.format(
|
||||
type(self).__module__,
|
||||
type(self).__name__,
|
||||
tuple.__repr__(self)
|
||||
)
|
||||
|
||||
@property
|
||||
def high(self) -> Point:
|
||||
return self[0]
|
||||
|
||||
@property
|
||||
def low(self) -> Point:
|
||||
return self[1]
|
||||
|
||||
|
||||
class Line(typing.Tuple[float, float, float]):
|
||||
"""Immutable representation of PostgreSQL `line` type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, A: float, B: float, C: float) -> 'Line':
|
||||
return super().__new__(cls, typing.cast(typing.Any, (A, B, C)))
|
||||
|
||||
@property
|
||||
def A(self) -> float:
|
||||
return self[0]
|
||||
|
||||
@property
|
||||
def B(self) -> float:
|
||||
return self[1]
|
||||
|
||||
@property
|
||||
def C(self) -> float:
|
||||
return self[2]
|
||||
|
||||
|
||||
class LineSegment(typing.Tuple[Point, Point]):
|
||||
"""Immutable representation of PostgreSQL `lseg` type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, p1: typing.Sequence[float],
|
||||
p2: typing.Sequence[float]) -> 'LineSegment':
|
||||
return super().__new__(cls,
|
||||
typing.cast(typing.Any, (Point(*p1),
|
||||
Point(*p2))))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '{}.{}({})'.format(
|
||||
type(self).__module__,
|
||||
type(self).__name__,
|
||||
tuple.__repr__(self)
|
||||
)
|
||||
|
||||
@property
|
||||
def p1(self) -> Point:
|
||||
return self[0]
|
||||
|
||||
@property
|
||||
def p2(self) -> Point:
|
||||
return self[1]
|
||||
|
||||
|
||||
class Path:
|
||||
"""Immutable representation of PostgreSQL `path` type."""
|
||||
|
||||
__slots__ = '_is_closed', 'points'
|
||||
|
||||
points: typing.Tuple[Point, ...]
|
||||
|
||||
def __init__(self, *points: typing.Sequence[float],
|
||||
is_closed: bool = False) -> None:
|
||||
self.points = tuple(Point(*p) for p in points)
|
||||
self._is_closed = is_closed
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Path):
|
||||
return NotImplemented
|
||||
|
||||
return (self.points == other.points and
|
||||
self._is_closed == other._is_closed)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.points, self.is_closed))
|
||||
|
||||
def __iter__(self) -> typing.Iterator[Point]:
|
||||
return iter(self.points)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.points)
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(self, i: int) -> Point:
|
||||
...
|
||||
|
||||
@typing.overload
|
||||
def __getitem__(self, i: slice) -> typing.Tuple[Point, ...]:
|
||||
...
|
||||
|
||||
def __getitem__(self, i: typing.Union[int, slice]) \
|
||||
-> typing.Union[Point, typing.Tuple[Point, ...]]:
|
||||
return self.points[i]
|
||||
|
||||
def __contains__(self, point: object) -> bool:
|
||||
return point in self.points
|
||||
|
||||
|
||||
class Polygon(Path):
|
||||
"""Immutable representation of PostgreSQL `polygon` type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *points: typing.Sequence[float]) -> None:
|
||||
# polygon is always closed
|
||||
super().__init__(*points, is_closed=True)
|
||||
|
||||
|
||||
class Circle(typing.Tuple[Point, float]):
|
||||
"""Immutable representation of PostgreSQL `circle` type."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, center: Point, radius: float) -> 'Circle':
|
||||
return super().__new__(cls, typing.cast(typing.Any, (center, radius)))
|
||||
|
||||
@property
|
||||
def center(self) -> Point:
|
||||
return self[0]
|
||||
|
||||
@property
|
||||
def radius(self) -> float:
|
||||
return self[1]
|
||||
353
venv/lib/python3.12/site-packages/asyncpg/pgproto/uuid.pyx
Normal file
353
venv/lib/python3.12/site-packages/asyncpg/pgproto/uuid.pyx
Normal file
@@ -0,0 +1,353 @@
|
||||
import functools
|
||||
import uuid
|
||||
|
||||
cimport cython
|
||||
cimport cpython
|
||||
|
||||
from libc.stdint cimport uint8_t, int8_t
|
||||
from libc.string cimport memcpy, memcmp
|
||||
|
||||
|
||||
cdef extern from "Python.h":
|
||||
int PyUnicode_1BYTE_KIND
|
||||
const char* PyUnicode_AsUTF8AndSize(
|
||||
object unicode, Py_ssize_t *size) except NULL
|
||||
object PyUnicode_FromKindAndData(
|
||||
int kind, const void *buffer, Py_ssize_t size)
|
||||
|
||||
|
||||
cdef extern from "./tohex.h":
|
||||
cdef void uuid_to_str(const char *source, char *dest)
|
||||
cdef void uuid_to_hex(const char *source, char *dest)
|
||||
|
||||
|
||||
# A more efficient UUID type implementation
|
||||
# (6-7x faster than the starndard uuid.UUID):
|
||||
#
|
||||
# -= Benchmark results (less is better): =-
|
||||
#
|
||||
# std_UUID(bytes): 1.2368
|
||||
# c_UUID(bytes): * 0.1645 (7.52x)
|
||||
# object(): 0.1483
|
||||
#
|
||||
# std_UUID(str): 1.8038
|
||||
# c_UUID(str): * 0.2313 (7.80x)
|
||||
#
|
||||
# str(std_UUID()): 1.4625
|
||||
# str(c_UUID()): * 0.2681 (5.46x)
|
||||
# str(object()): 0.5975
|
||||
#
|
||||
# std_UUID().bytes: 0.3508
|
||||
# c_UUID().bytes: * 0.1068 (3.28x)
|
||||
#
|
||||
# std_UUID().int: 0.0871
|
||||
# c_UUID().int: * 0.0856
|
||||
#
|
||||
# std_UUID().hex: 0.4871
|
||||
# c_UUID().hex: * 0.1405
|
||||
#
|
||||
# hash(std_UUID()): 0.3635
|
||||
# hash(c_UUID()): * 0.1564 (2.32x)
|
||||
#
|
||||
# dct[std_UUID()]: 0.3319
|
||||
# dct[c_UUID()]: * 0.1570 (2.11x)
|
||||
#
|
||||
# std_UUID() ==: 0.3478
|
||||
# c_UUID() ==: * 0.0915 (3.80x)
|
||||
|
||||
|
||||
cdef char _hextable[256]
|
||||
_hextable[:] = [
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
|
||||
-1,-1, 0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,-1,10,11,12,13,14,15,-1,
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
|
||||
-1,-1,10,11,12,13,14,15,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
|
||||
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
|
||||
]
|
||||
|
||||
|
||||
cdef std_UUID = uuid.UUID
|
||||
|
||||
|
||||
cdef pg_uuid_bytes_from_str(str u, char *out):
|
||||
cdef:
|
||||
const char *orig_buf
|
||||
Py_ssize_t size
|
||||
unsigned char ch
|
||||
uint8_t acc, part, acc_set
|
||||
int i, j
|
||||
|
||||
orig_buf = PyUnicode_AsUTF8AndSize(u, &size)
|
||||
if size > 36 or size < 32:
|
||||
raise ValueError(
|
||||
f'invalid UUID {u!r}: '
|
||||
f'length must be between 32..36 characters, got {size}')
|
||||
|
||||
acc_set = 0
|
||||
j = 0
|
||||
for i in range(size):
|
||||
ch = <unsigned char>orig_buf[i]
|
||||
if ch == <unsigned char>b'-':
|
||||
continue
|
||||
|
||||
part = <uint8_t><int8_t>_hextable[ch]
|
||||
if part == <uint8_t>-1:
|
||||
if ch >= 0x20 and ch <= 0x7e:
|
||||
raise ValueError(
|
||||
f'invalid UUID {u!r}: unexpected character {chr(ch)!r}')
|
||||
else:
|
||||
raise ValueError('invalid UUID {u!r}: unexpected character')
|
||||
|
||||
if acc_set:
|
||||
acc |= part
|
||||
out[j] = <char>acc
|
||||
acc_set = 0
|
||||
j += 1
|
||||
else:
|
||||
acc = <uint8_t>(part << 4)
|
||||
acc_set = 1
|
||||
|
||||
if j > 16 or (j == 16 and acc_set):
|
||||
raise ValueError(
|
||||
f'invalid UUID {u!r}: decodes to more than 16 bytes')
|
||||
|
||||
if j != 16:
|
||||
raise ValueError(
|
||||
f'invalid UUID {u!r}: decodes to less than 16 bytes')
|
||||
|
||||
|
||||
cdef class __UUIDReplaceMe:
|
||||
pass
|
||||
|
||||
|
||||
cdef pg_uuid_from_buf(const char *buf):
|
||||
cdef:
|
||||
UUID u = UUID.__new__(UUID)
|
||||
memcpy(u._data, buf, 16)
|
||||
return u
|
||||
|
||||
|
||||
@cython.final
|
||||
@cython.no_gc_clear
|
||||
cdef class UUID(__UUIDReplaceMe):
|
||||
|
||||
cdef:
|
||||
char _data[16]
|
||||
object _int
|
||||
object _hash
|
||||
object __weakref__
|
||||
|
||||
def __cinit__(self):
|
||||
self._int = None
|
||||
self._hash = None
|
||||
|
||||
def __init__(self, inp):
|
||||
cdef:
|
||||
char *buf
|
||||
Py_ssize_t size
|
||||
|
||||
if cpython.PyBytes_Check(inp):
|
||||
cpython.PyBytes_AsStringAndSize(inp, &buf, &size)
|
||||
if size != 16:
|
||||
raise ValueError(f'16 bytes were expected, got {size}')
|
||||
memcpy(self._data, buf, 16)
|
||||
|
||||
elif cpython.PyUnicode_Check(inp):
|
||||
pg_uuid_bytes_from_str(inp, self._data)
|
||||
else:
|
||||
raise TypeError(f'a bytes or str object expected, got {inp!r}')
|
||||
|
||||
@property
|
||||
def bytes(self):
|
||||
return cpython.PyBytes_FromStringAndSize(self._data, 16)
|
||||
|
||||
@property
|
||||
def int(self):
|
||||
if self._int is None:
|
||||
# The cache is important because `self.int` can be
|
||||
# used multiple times by __hash__ etc.
|
||||
self._int = int.from_bytes(self.bytes, 'big')
|
||||
return self._int
|
||||
|
||||
@property
|
||||
def is_safe(self):
|
||||
return uuid.SafeUUID.unknown
|
||||
|
||||
def __str__(self):
|
||||
cdef char out[36]
|
||||
uuid_to_str(self._data, out)
|
||||
return PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, <void*>out, 36)
|
||||
|
||||
@property
|
||||
def hex(self):
|
||||
cdef char out[32]
|
||||
uuid_to_hex(self._data, out)
|
||||
return PyUnicode_FromKindAndData(PyUnicode_1BYTE_KIND, <void*>out, 32)
|
||||
|
||||
def __repr__(self):
|
||||
return f"UUID('{self}')"
|
||||
|
||||
def __reduce__(self):
|
||||
return (type(self), (self.bytes,))
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(other) is UUID:
|
||||
return memcmp(self._data, (<UUID>other)._data, 16) == 0
|
||||
if isinstance(other, std_UUID):
|
||||
return self.int == other.int
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
if type(other) is UUID:
|
||||
return memcmp(self._data, (<UUID>other)._data, 16) != 0
|
||||
if isinstance(other, std_UUID):
|
||||
return self.int != other.int
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other):
|
||||
if type(other) is UUID:
|
||||
return memcmp(self._data, (<UUID>other)._data, 16) < 0
|
||||
if isinstance(other, std_UUID):
|
||||
return self.int < other.int
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other):
|
||||
if type(other) is UUID:
|
||||
return memcmp(self._data, (<UUID>other)._data, 16) > 0
|
||||
if isinstance(other, std_UUID):
|
||||
return self.int > other.int
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other):
|
||||
if type(other) is UUID:
|
||||
return memcmp(self._data, (<UUID>other)._data, 16) <= 0
|
||||
if isinstance(other, std_UUID):
|
||||
return self.int <= other.int
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other):
|
||||
if type(other) is UUID:
|
||||
return memcmp(self._data, (<UUID>other)._data, 16) >= 0
|
||||
if isinstance(other, std_UUID):
|
||||
return self.int >= other.int
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
# In EdgeDB every schema object has a uuid and there are
|
||||
# huge hash-maps of them. We want UUID.__hash__ to be
|
||||
# as fast as possible.
|
||||
if self._hash is not None:
|
||||
return self._hash
|
||||
|
||||
self._hash = hash(self.int)
|
||||
return self._hash
|
||||
|
||||
def __int__(self):
|
||||
return self.int
|
||||
|
||||
@property
|
||||
def bytes_le(self):
|
||||
bytes = self.bytes
|
||||
return (bytes[4-1::-1] + bytes[6-1:4-1:-1] + bytes[8-1:6-1:-1] +
|
||||
bytes[8:])
|
||||
|
||||
@property
|
||||
def fields(self):
|
||||
return (self.time_low, self.time_mid, self.time_hi_version,
|
||||
self.clock_seq_hi_variant, self.clock_seq_low, self.node)
|
||||
|
||||
@property
|
||||
def time_low(self):
|
||||
return self.int >> 96
|
||||
|
||||
@property
|
||||
def time_mid(self):
|
||||
return (self.int >> 80) & 0xffff
|
||||
|
||||
@property
|
||||
def time_hi_version(self):
|
||||
return (self.int >> 64) & 0xffff
|
||||
|
||||
@property
|
||||
def clock_seq_hi_variant(self):
|
||||
return (self.int >> 56) & 0xff
|
||||
|
||||
@property
|
||||
def clock_seq_low(self):
|
||||
return (self.int >> 48) & 0xff
|
||||
|
||||
@property
|
||||
def time(self):
|
||||
return (((self.time_hi_version & 0x0fff) << 48) |
|
||||
(self.time_mid << 32) | self.time_low)
|
||||
|
||||
@property
|
||||
def clock_seq(self):
|
||||
return (((self.clock_seq_hi_variant & 0x3f) << 8) |
|
||||
self.clock_seq_low)
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
return self.int & 0xffffffffffff
|
||||
|
||||
@property
|
||||
def urn(self):
|
||||
return 'urn:uuid:' + str(self)
|
||||
|
||||
@property
|
||||
def variant(self):
|
||||
if not self.int & (0x8000 << 48):
|
||||
return uuid.RESERVED_NCS
|
||||
elif not self.int & (0x4000 << 48):
|
||||
return uuid.RFC_4122
|
||||
elif not self.int & (0x2000 << 48):
|
||||
return uuid.RESERVED_MICROSOFT
|
||||
else:
|
||||
return uuid.RESERVED_FUTURE
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
# The version bits are only meaningful for RFC 4122 UUIDs.
|
||||
if self.variant == uuid.RFC_4122:
|
||||
return int((self.int >> 76) & 0xf)
|
||||
|
||||
|
||||
# <hack>
|
||||
# In order for `isinstance(pgproto.UUID, uuid.UUID)` to work,
|
||||
# patch __bases__ and __mro__ by injecting `uuid.UUID`.
|
||||
#
|
||||
# We apply brute-force here because the following pattern stopped
|
||||
# working with Python 3.8:
|
||||
#
|
||||
# cdef class OurUUID:
|
||||
# ...
|
||||
#
|
||||
# class UUID(OurUUID, uuid.UUID):
|
||||
# ...
|
||||
#
|
||||
# With Python 3.8 it now produces
|
||||
#
|
||||
# "TypeError: multiple bases have instance lay-out conflict"
|
||||
#
|
||||
# error. Maybe it's possible to fix this some other way, but
|
||||
# the best solution possible would be to just contribute our
|
||||
# faster UUID to the standard library and not have this problem
|
||||
# at all. For now this hack is pretty safe and should be
|
||||
# compatible with future Pythons for long enough.
|
||||
#
|
||||
assert UUID.__bases__[0] is __UUIDReplaceMe
|
||||
assert UUID.__mro__[1] is __UUIDReplaceMe
|
||||
cpython.Py_INCREF(std_UUID)
|
||||
cpython.PyTuple_SET_ITEM(UUID.__bases__, 0, std_UUID)
|
||||
cpython.Py_INCREF(std_UUID)
|
||||
cpython.PyTuple_SET_ITEM(UUID.__mro__, 1, std_UUID)
|
||||
# </hack>
|
||||
|
||||
|
||||
cdef pg_UUID = UUID
|
||||
1130
venv/lib/python3.12/site-packages/asyncpg/pool.py
Normal file
1130
venv/lib/python3.12/site-packages/asyncpg/pool.py
Normal file
File diff suppressed because it is too large
Load Diff
259
venv/lib/python3.12/site-packages/asyncpg/prepared_stmt.py
Normal file
259
venv/lib/python3.12/site-packages/asyncpg/prepared_stmt.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import json
|
||||
|
||||
from . import connresource
|
||||
from . import cursor
|
||||
from . import exceptions
|
||||
|
||||
|
||||
class PreparedStatement(connresource.ConnectionResource):
|
||||
"""A representation of a prepared statement."""
|
||||
|
||||
__slots__ = ('_state', '_query', '_last_status')
|
||||
|
||||
def __init__(self, connection, query, state):
|
||||
super().__init__(connection)
|
||||
self._state = state
|
||||
self._query = query
|
||||
state.attach()
|
||||
self._last_status = None
|
||||
|
||||
@connresource.guarded
|
||||
def get_name(self) -> str:
|
||||
"""Return the name of this prepared statement.
|
||||
|
||||
.. versionadded:: 0.25.0
|
||||
"""
|
||||
return self._state.name
|
||||
|
||||
@connresource.guarded
|
||||
def get_query(self) -> str:
|
||||
"""Return the text of the query for this prepared statement.
|
||||
|
||||
Example::
|
||||
|
||||
stmt = await connection.prepare('SELECT $1::int')
|
||||
assert stmt.get_query() == "SELECT $1::int"
|
||||
"""
|
||||
return self._query
|
||||
|
||||
@connresource.guarded
|
||||
def get_statusmsg(self) -> str:
|
||||
"""Return the status of the executed command.
|
||||
|
||||
Example::
|
||||
|
||||
stmt = await connection.prepare('CREATE TABLE mytab (a int)')
|
||||
await stmt.fetch()
|
||||
assert stmt.get_statusmsg() == "CREATE TABLE"
|
||||
"""
|
||||
if self._last_status is None:
|
||||
return self._last_status
|
||||
return self._last_status.decode()
|
||||
|
||||
@connresource.guarded
|
||||
def get_parameters(self):
|
||||
"""Return a description of statement parameters types.
|
||||
|
||||
:return: A tuple of :class:`asyncpg.types.Type`.
|
||||
|
||||
Example::
|
||||
|
||||
stmt = await connection.prepare('SELECT ($1::int, $2::text)')
|
||||
print(stmt.get_parameters())
|
||||
|
||||
# Will print:
|
||||
# (Type(oid=23, name='int4', kind='scalar', schema='pg_catalog'),
|
||||
# Type(oid=25, name='text', kind='scalar', schema='pg_catalog'))
|
||||
"""
|
||||
return self._state._get_parameters()
|
||||
|
||||
@connresource.guarded
|
||||
def get_attributes(self):
|
||||
"""Return a description of relation attributes (columns).
|
||||
|
||||
:return: A tuple of :class:`asyncpg.types.Attribute`.
|
||||
|
||||
Example::
|
||||
|
||||
st = await self.con.prepare('''
|
||||
SELECT typname, typnamespace FROM pg_type
|
||||
''')
|
||||
print(st.get_attributes())
|
||||
|
||||
# Will print:
|
||||
# (Attribute(
|
||||
# name='typname',
|
||||
# type=Type(oid=19, name='name', kind='scalar',
|
||||
# schema='pg_catalog')),
|
||||
# Attribute(
|
||||
# name='typnamespace',
|
||||
# type=Type(oid=26, name='oid', kind='scalar',
|
||||
# schema='pg_catalog')))
|
||||
"""
|
||||
return self._state._get_attributes()
|
||||
|
||||
@connresource.guarded
|
||||
def cursor(self, *args, prefetch=None,
|
||||
timeout=None) -> cursor.CursorFactory:
|
||||
"""Return a *cursor factory* for the prepared statement.
|
||||
|
||||
:param args: Query arguments.
|
||||
:param int prefetch: The number of rows the *cursor iterator*
|
||||
will prefetch (defaults to ``50``.)
|
||||
:param float timeout: Optional timeout in seconds.
|
||||
|
||||
:return: A :class:`~cursor.CursorFactory` object.
|
||||
"""
|
||||
return cursor.CursorFactory(
|
||||
self._connection,
|
||||
self._query,
|
||||
self._state,
|
||||
args,
|
||||
prefetch,
|
||||
timeout,
|
||||
self._state.record_class,
|
||||
)
|
||||
|
||||
@connresource.guarded
|
||||
async def explain(self, *args, analyze=False):
|
||||
"""Return the execution plan of the statement.
|
||||
|
||||
:param args: Query arguments.
|
||||
:param analyze: If ``True``, the statement will be executed and
|
||||
the run time statitics added to the return value.
|
||||
|
||||
:return: An object representing the execution plan. This value
|
||||
is actually a deserialized JSON output of the SQL
|
||||
``EXPLAIN`` command.
|
||||
"""
|
||||
query = 'EXPLAIN (FORMAT JSON, VERBOSE'
|
||||
if analyze:
|
||||
query += ', ANALYZE) '
|
||||
else:
|
||||
query += ') '
|
||||
query += self._state.query
|
||||
|
||||
if analyze:
|
||||
# From PostgreSQL docs:
|
||||
# Important: Keep in mind that the statement is actually
|
||||
# executed when the ANALYZE option is used. Although EXPLAIN
|
||||
# will discard any output that a SELECT would return, other
|
||||
# side effects of the statement will happen as usual. If you
|
||||
# wish to use EXPLAIN ANALYZE on an INSERT, UPDATE, DELETE,
|
||||
# CREATE TABLE AS, or EXECUTE statement without letting the
|
||||
# command affect your data, use this approach:
|
||||
# BEGIN;
|
||||
# EXPLAIN ANALYZE ...;
|
||||
# ROLLBACK;
|
||||
tr = self._connection.transaction()
|
||||
await tr.start()
|
||||
try:
|
||||
data = await self._connection.fetchval(query, *args)
|
||||
finally:
|
||||
await tr.rollback()
|
||||
else:
|
||||
data = await self._connection.fetchval(query, *args)
|
||||
|
||||
return json.loads(data)
|
||||
|
||||
@connresource.guarded
|
||||
async def fetch(self, *args, timeout=None):
|
||||
r"""Execute the statement and return a list of :class:`Record` objects.
|
||||
|
||||
:param str query: Query text
|
||||
:param args: Query arguments
|
||||
:param float timeout: Optional timeout value in seconds.
|
||||
|
||||
:return: A list of :class:`Record` instances.
|
||||
"""
|
||||
data = await self.__bind_execute(args, 0, timeout)
|
||||
return data
|
||||
|
||||
@connresource.guarded
|
||||
async def fetchval(self, *args, column=0, timeout=None):
|
||||
"""Execute the statement and return a value in the first row.
|
||||
|
||||
:param args: Query arguments.
|
||||
:param int column: Numeric index within the record of the value to
|
||||
return (defaults to 0).
|
||||
:param float timeout: Optional timeout value in seconds.
|
||||
If not specified, defaults to the value of
|
||||
``command_timeout`` argument to the ``Connection``
|
||||
instance constructor.
|
||||
|
||||
:return: The value of the specified column of the first record.
|
||||
"""
|
||||
data = await self.__bind_execute(args, 1, timeout)
|
||||
if not data:
|
||||
return None
|
||||
return data[0][column]
|
||||
|
||||
@connresource.guarded
|
||||
async def fetchrow(self, *args, timeout=None):
|
||||
"""Execute the statement and return the first row.
|
||||
|
||||
:param str query: Query text
|
||||
:param args: Query arguments
|
||||
:param float timeout: Optional timeout value in seconds.
|
||||
|
||||
:return: The first row as a :class:`Record` instance.
|
||||
"""
|
||||
data = await self.__bind_execute(args, 1, timeout)
|
||||
if not data:
|
||||
return None
|
||||
return data[0]
|
||||
|
||||
@connresource.guarded
|
||||
async def executemany(self, args, *, timeout: float=None):
|
||||
"""Execute the statement for each sequence of arguments in *args*.
|
||||
|
||||
:param args: An iterable containing sequences of arguments.
|
||||
:param float timeout: Optional timeout value in seconds.
|
||||
:return None: This method discards the results of the operations.
|
||||
|
||||
.. versionadded:: 0.22.0
|
||||
"""
|
||||
return await self.__do_execute(
|
||||
lambda protocol: protocol.bind_execute_many(
|
||||
self._state, args, '', timeout))
|
||||
|
||||
async def __do_execute(self, executor):
|
||||
protocol = self._connection._protocol
|
||||
try:
|
||||
return await executor(protocol)
|
||||
except exceptions.OutdatedSchemaCacheError:
|
||||
await self._connection.reload_schema_state()
|
||||
# We can not find all manually created prepared statements, so just
|
||||
# drop known cached ones in the `self._connection`.
|
||||
# Other manually created prepared statements will fail and
|
||||
# invalidate themselves (unfortunately, clearing caches again).
|
||||
self._state.mark_closed()
|
||||
raise
|
||||
|
||||
async def __bind_execute(self, args, limit, timeout):
|
||||
data, status, _ = await self.__do_execute(
|
||||
lambda protocol: protocol.bind_execute(
|
||||
self._state, args, '', limit, True, timeout))
|
||||
self._last_status = status
|
||||
return data
|
||||
|
||||
def _check_open(self, meth_name):
|
||||
if self._state.closed:
|
||||
raise exceptions.InterfaceError(
|
||||
'cannot call PreparedStmt.{}(): '
|
||||
'the prepared statement is closed'.format(meth_name))
|
||||
|
||||
def _check_conn_validity(self, meth_name):
|
||||
self._check_open(meth_name)
|
||||
super()._check_conn_validity(meth_name)
|
||||
|
||||
def __del__(self):
|
||||
self._state.detach()
|
||||
self._connection._maybe_gc_stmt(self._state)
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# flake8: NOQA
|
||||
|
||||
from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP
|
||||
@@ -0,0 +1,875 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from collections.abc import (Iterable as IterableABC,
|
||||
Mapping as MappingABC,
|
||||
Sized as SizedABC)
|
||||
|
||||
from asyncpg import exceptions
|
||||
|
||||
|
||||
DEF ARRAY_MAXDIM = 6 # defined in postgresql/src/includes/c.h
|
||||
|
||||
# "NULL"
|
||||
cdef Py_UCS4 *APG_NULL = [0x004E, 0x0055, 0x004C, 0x004C, 0x0000]
|
||||
|
||||
|
||||
ctypedef object (*encode_func_ex)(ConnectionSettings settings,
|
||||
WriteBuffer buf,
|
||||
object obj,
|
||||
const void *arg)
|
||||
|
||||
|
||||
ctypedef object (*decode_func_ex)(ConnectionSettings settings,
|
||||
FRBuffer *buf,
|
||||
const void *arg)
|
||||
|
||||
|
||||
cdef inline bint _is_trivial_container(object obj):
|
||||
return cpython.PyUnicode_Check(obj) or cpython.PyBytes_Check(obj) or \
|
||||
cpythonx.PyByteArray_Check(obj) or cpythonx.PyMemoryView_Check(obj)
|
||||
|
||||
|
||||
cdef inline _is_array_iterable(object obj):
|
||||
return (
|
||||
isinstance(obj, IterableABC) and
|
||||
isinstance(obj, SizedABC) and
|
||||
not _is_trivial_container(obj) and
|
||||
not isinstance(obj, MappingABC)
|
||||
)
|
||||
|
||||
|
||||
cdef inline _is_sub_array_iterable(object obj):
|
||||
# Sub-arrays have a specialized check, because we treat
|
||||
# nested tuples as records.
|
||||
return _is_array_iterable(obj) and not cpython.PyTuple_Check(obj)
|
||||
|
||||
|
||||
cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims):
|
||||
cdef:
|
||||
ssize_t mylen = len(obj)
|
||||
ssize_t elemlen = -2
|
||||
object it
|
||||
|
||||
if mylen > _MAXINT32:
|
||||
raise ValueError('too many elements in array value')
|
||||
|
||||
if ndims[0] > ARRAY_MAXDIM:
|
||||
raise ValueError(
|
||||
'number of array dimensions ({}) exceed the maximum expected ({})'.
|
||||
format(ndims[0], ARRAY_MAXDIM))
|
||||
|
||||
dims[ndims[0] - 1] = <int32_t>mylen
|
||||
|
||||
for elem in obj:
|
||||
if _is_sub_array_iterable(elem):
|
||||
if elemlen == -2:
|
||||
elemlen = len(elem)
|
||||
if elemlen > _MAXINT32:
|
||||
raise ValueError('too many elements in array value')
|
||||
ndims[0] += 1
|
||||
_get_array_shape(elem, dims, ndims)
|
||||
else:
|
||||
if len(elem) != elemlen:
|
||||
raise ValueError('non-homogeneous array')
|
||||
else:
|
||||
if elemlen >= 0:
|
||||
raise ValueError('non-homogeneous array')
|
||||
else:
|
||||
elemlen = -1
|
||||
|
||||
|
||||
cdef _write_array_data(ConnectionSettings settings, object obj, int32_t ndims,
|
||||
int32_t dim, WriteBuffer elem_data,
|
||||
encode_func_ex encoder, const void *encoder_arg):
|
||||
if dim < ndims - 1:
|
||||
for item in obj:
|
||||
_write_array_data(settings, item, ndims, dim + 1, elem_data,
|
||||
encoder, encoder_arg)
|
||||
else:
|
||||
for item in obj:
|
||||
if item is None:
|
||||
elem_data.write_int32(-1)
|
||||
else:
|
||||
try:
|
||||
encoder(settings, elem_data, item, encoder_arg)
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
'invalid array element: {}'.format(e.args[0])) from None
|
||||
|
||||
|
||||
cdef inline array_encode(ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj, uint32_t elem_oid,
|
||||
encode_func_ex encoder, const void *encoder_arg):
|
||||
cdef:
|
||||
WriteBuffer elem_data
|
||||
int32_t dims[ARRAY_MAXDIM]
|
||||
int32_t ndims = 1
|
||||
int32_t i
|
||||
|
||||
if not _is_array_iterable(obj):
|
||||
raise TypeError(
|
||||
'a sized iterable container expected (got type {!r})'.format(
|
||||
type(obj).__name__))
|
||||
|
||||
_get_array_shape(obj, dims, &ndims)
|
||||
|
||||
elem_data = WriteBuffer.new()
|
||||
|
||||
if ndims > 1:
|
||||
_write_array_data(settings, obj, ndims, 0, elem_data,
|
||||
encoder, encoder_arg)
|
||||
else:
|
||||
for i, item in enumerate(obj):
|
||||
if item is None:
|
||||
elem_data.write_int32(-1)
|
||||
else:
|
||||
try:
|
||||
encoder(settings, elem_data, item, encoder_arg)
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
'invalid array element at index {}: {}'.format(
|
||||
i, e.args[0])) from None
|
||||
|
||||
buf.write_int32(12 + 8 * ndims + elem_data.len())
|
||||
# Number of dimensions
|
||||
buf.write_int32(ndims)
|
||||
# flags
|
||||
buf.write_int32(0)
|
||||
# element type
|
||||
buf.write_int32(<int32_t>elem_oid)
|
||||
# upper / lower bounds
|
||||
for i in range(ndims):
|
||||
buf.write_int32(dims[i])
|
||||
buf.write_int32(1)
|
||||
# element data
|
||||
buf.write_buffer(elem_data)
|
||||
|
||||
|
||||
cdef _write_textarray_data(ConnectionSettings settings, object obj,
|
||||
int32_t ndims, int32_t dim, WriteBuffer array_data,
|
||||
encode_func_ex encoder, const void *encoder_arg,
|
||||
Py_UCS4 typdelim):
|
||||
cdef:
|
||||
ssize_t i = 0
|
||||
int8_t delim = <int8_t>typdelim
|
||||
WriteBuffer elem_data
|
||||
Py_buffer pybuf
|
||||
const char *elem_str
|
||||
char ch
|
||||
ssize_t elem_len
|
||||
ssize_t quoted_elem_len
|
||||
bint need_quoting
|
||||
|
||||
array_data.write_byte(b'{')
|
||||
|
||||
if dim < ndims - 1:
|
||||
for item in obj:
|
||||
if i > 0:
|
||||
array_data.write_byte(delim)
|
||||
array_data.write_byte(b' ')
|
||||
_write_textarray_data(settings, item, ndims, dim + 1, array_data,
|
||||
encoder, encoder_arg, typdelim)
|
||||
i += 1
|
||||
else:
|
||||
for item in obj:
|
||||
elem_data = WriteBuffer.new()
|
||||
|
||||
if i > 0:
|
||||
array_data.write_byte(delim)
|
||||
array_data.write_byte(b' ')
|
||||
|
||||
if item is None:
|
||||
array_data.write_bytes(b'NULL')
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
encoder(settings, elem_data, item, encoder_arg)
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
'invalid array element: {}'.format(
|
||||
e.args[0])) from None
|
||||
|
||||
# element string length (first four bytes are the encoded length.)
|
||||
elem_len = elem_data.len() - 4
|
||||
|
||||
if elem_len == 0:
|
||||
# Empty string
|
||||
array_data.write_bytes(b'""')
|
||||
else:
|
||||
cpython.PyObject_GetBuffer(
|
||||
elem_data, &pybuf, cpython.PyBUF_SIMPLE)
|
||||
|
||||
elem_str = <const char*>(pybuf.buf) + 4
|
||||
|
||||
try:
|
||||
if not apg_strcasecmp_char(elem_str, b'NULL'):
|
||||
array_data.write_byte(b'"')
|
||||
array_data.write_cstr(elem_str, 4)
|
||||
array_data.write_byte(b'"')
|
||||
else:
|
||||
quoted_elem_len = elem_len
|
||||
need_quoting = False
|
||||
|
||||
for i in range(elem_len):
|
||||
ch = elem_str[i]
|
||||
if ch == b'"' or ch == b'\\':
|
||||
# Quotes and backslashes need escaping.
|
||||
quoted_elem_len += 1
|
||||
need_quoting = True
|
||||
elif (ch == b'{' or ch == b'}' or ch == delim or
|
||||
apg_ascii_isspace(<uint32_t>ch)):
|
||||
need_quoting = True
|
||||
|
||||
if need_quoting:
|
||||
array_data.write_byte(b'"')
|
||||
|
||||
if quoted_elem_len == elem_len:
|
||||
array_data.write_cstr(elem_str, elem_len)
|
||||
else:
|
||||
# Escaping required.
|
||||
for i in range(elem_len):
|
||||
ch = elem_str[i]
|
||||
if ch == b'"' or ch == b'\\':
|
||||
array_data.write_byte(b'\\')
|
||||
array_data.write_byte(ch)
|
||||
|
||||
array_data.write_byte(b'"')
|
||||
else:
|
||||
array_data.write_cstr(elem_str, elem_len)
|
||||
finally:
|
||||
cpython.PyBuffer_Release(&pybuf)
|
||||
|
||||
i += 1
|
||||
|
||||
array_data.write_byte(b'}')
|
||||
|
||||
|
||||
cdef inline textarray_encode(ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj, encode_func_ex encoder,
|
||||
const void *encoder_arg, Py_UCS4 typdelim):
|
||||
cdef:
|
||||
WriteBuffer array_data
|
||||
int32_t dims[ARRAY_MAXDIM]
|
||||
int32_t ndims = 1
|
||||
int32_t i
|
||||
|
||||
if not _is_array_iterable(obj):
|
||||
raise TypeError(
|
||||
'a sized iterable container expected (got type {!r})'.format(
|
||||
type(obj).__name__))
|
||||
|
||||
_get_array_shape(obj, dims, &ndims)
|
||||
|
||||
array_data = WriteBuffer.new()
|
||||
_write_textarray_data(settings, obj, ndims, 0, array_data,
|
||||
encoder, encoder_arg, typdelim)
|
||||
buf.write_int32(array_data.len())
|
||||
buf.write_buffer(array_data)
|
||||
|
||||
|
||||
cdef inline array_decode(ConnectionSettings settings, FRBuffer *buf,
|
||||
decode_func_ex decoder, const void *decoder_arg):
|
||||
cdef:
|
||||
int32_t ndims = hton.unpack_int32(frb_read(buf, 4))
|
||||
int32_t flags = hton.unpack_int32(frb_read(buf, 4))
|
||||
uint32_t elem_oid = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
|
||||
list result
|
||||
int i
|
||||
int32_t elem_len
|
||||
int32_t elem_count = 1
|
||||
FRBuffer elem_buf
|
||||
int32_t dims[ARRAY_MAXDIM]
|
||||
Codec elem_codec
|
||||
|
||||
if ndims == 0:
|
||||
return []
|
||||
|
||||
if ndims > ARRAY_MAXDIM:
|
||||
raise exceptions.ProtocolError(
|
||||
'number of array dimensions ({}) exceed the maximum expected ({})'.
|
||||
format(ndims, ARRAY_MAXDIM))
|
||||
elif ndims < 0:
|
||||
raise exceptions.ProtocolError(
|
||||
'unexpected array dimensions value: {}'.format(ndims))
|
||||
|
||||
for i in range(ndims):
|
||||
dims[i] = hton.unpack_int32(frb_read(buf, 4))
|
||||
if dims[i] < 0:
|
||||
raise exceptions.ProtocolError(
|
||||
'unexpected array dimension size: {}'.format(dims[i]))
|
||||
# Ignore the lower bound information
|
||||
frb_read(buf, 4)
|
||||
|
||||
if ndims == 1:
|
||||
# Fast path for flat arrays
|
||||
elem_count = dims[0]
|
||||
result = cpython.PyList_New(elem_count)
|
||||
|
||||
for i in range(elem_count):
|
||||
elem_len = hton.unpack_int32(frb_read(buf, 4))
|
||||
if elem_len == -1:
|
||||
elem = None
|
||||
else:
|
||||
frb_slice_from(&elem_buf, buf, elem_len)
|
||||
elem = decoder(settings, &elem_buf, decoder_arg)
|
||||
|
||||
cpython.Py_INCREF(elem)
|
||||
cpython.PyList_SET_ITEM(result, i, elem)
|
||||
|
||||
else:
|
||||
result = _nested_array_decode(settings, buf,
|
||||
decoder, decoder_arg, ndims, dims,
|
||||
&elem_buf)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
cdef _nested_array_decode(ConnectionSettings settings,
|
||||
FRBuffer *buf,
|
||||
decode_func_ex decoder,
|
||||
const void *decoder_arg,
|
||||
int32_t ndims, int32_t *dims,
|
||||
FRBuffer *elem_buf):
|
||||
|
||||
cdef:
|
||||
int32_t elem_len
|
||||
int64_t i, j
|
||||
int64_t array_len = 1
|
||||
object elem, stride
|
||||
# An array of pointers to lists for each current array level.
|
||||
void *strides[ARRAY_MAXDIM]
|
||||
# An array of current positions at each array level.
|
||||
int32_t indexes[ARRAY_MAXDIM]
|
||||
|
||||
for i in range(ndims):
|
||||
array_len *= dims[i]
|
||||
indexes[i] = 0
|
||||
strides[i] = NULL
|
||||
|
||||
if array_len == 0:
|
||||
# A multidimensional array with a zero-sized dimension?
|
||||
return []
|
||||
|
||||
elif array_len < 0:
|
||||
# Array length overflow
|
||||
raise exceptions.ProtocolError('array length overflow')
|
||||
|
||||
for i in range(array_len):
|
||||
# Decode the element.
|
||||
elem_len = hton.unpack_int32(frb_read(buf, 4))
|
||||
if elem_len == -1:
|
||||
elem = None
|
||||
else:
|
||||
elem = decoder(settings,
|
||||
frb_slice_from(elem_buf, buf, elem_len),
|
||||
decoder_arg)
|
||||
|
||||
# Take an explicit reference for PyList_SET_ITEM in the below
|
||||
# loop expects this.
|
||||
cpython.Py_INCREF(elem)
|
||||
|
||||
# Iterate over array dimentions and put the element in
|
||||
# the correctly nested sublist.
|
||||
for j in reversed(range(ndims)):
|
||||
if indexes[j] == 0:
|
||||
# Allocate the list for this array level.
|
||||
stride = cpython.PyList_New(dims[j])
|
||||
|
||||
strides[j] = <void*><cpython.PyObject>stride
|
||||
# Take an explicit reference for PyList_SET_ITEM below
|
||||
# expects this.
|
||||
cpython.Py_INCREF(stride)
|
||||
|
||||
stride = <object><cpython.PyObject*>strides[j]
|
||||
cpython.PyList_SET_ITEM(stride, indexes[j], elem)
|
||||
indexes[j] += 1
|
||||
|
||||
if indexes[j] == dims[j] and j != 0:
|
||||
# This array level is full, continue the
|
||||
# ascent in the dimensions so that this level
|
||||
# sublist will be appened to the parent list.
|
||||
elem = stride
|
||||
# Reset the index, this will cause the
|
||||
# new list to be allocated on the next
|
||||
# iteration on this array axis.
|
||||
indexes[j] = 0
|
||||
else:
|
||||
break
|
||||
|
||||
stride = <object><cpython.PyObject*>strides[0]
|
||||
# Since each element in strides has a refcount of 1,
|
||||
# returning strides[0] will increment it to 2, so
|
||||
# balance that.
|
||||
cpython.Py_DECREF(stride)
|
||||
return stride
|
||||
|
||||
|
||||
cdef textarray_decode(ConnectionSettings settings, FRBuffer *buf,
|
||||
decode_func_ex decoder, const void *decoder_arg,
|
||||
Py_UCS4 typdelim):
|
||||
cdef:
|
||||
Py_UCS4 *array_text
|
||||
str s
|
||||
|
||||
# Make a copy of array data since we will be mutating it for
|
||||
# the purposes of element decoding.
|
||||
s = pgproto.text_decode(settings, buf)
|
||||
array_text = cpythonx.PyUnicode_AsUCS4Copy(s)
|
||||
|
||||
try:
|
||||
return _textarray_decode(
|
||||
settings, array_text, decoder, decoder_arg, typdelim)
|
||||
except ValueError as e:
|
||||
raise exceptions.ProtocolError(
|
||||
'malformed array literal {!r}: {}'.format(s, e.args[0]))
|
||||
finally:
|
||||
cpython.PyMem_Free(array_text)
|
||||
|
||||
|
||||
cdef _textarray_decode(ConnectionSettings settings,
|
||||
Py_UCS4 *array_text,
|
||||
decode_func_ex decoder,
|
||||
const void *decoder_arg,
|
||||
Py_UCS4 typdelim):
|
||||
|
||||
cdef:
|
||||
bytearray array_bytes
|
||||
list result
|
||||
list new_stride
|
||||
Py_UCS4 *ptr
|
||||
int32_t ndims = 0
|
||||
int32_t ubound = 0
|
||||
int32_t lbound = 0
|
||||
int32_t dims[ARRAY_MAXDIM]
|
||||
int32_t inferred_dims[ARRAY_MAXDIM]
|
||||
int32_t inferred_ndims = 0
|
||||
void *strides[ARRAY_MAXDIM]
|
||||
int32_t indexes[ARRAY_MAXDIM]
|
||||
int32_t nest_level = 0
|
||||
int32_t item_level = 0
|
||||
bint end_of_array = False
|
||||
|
||||
bint end_of_item = False
|
||||
bint has_quoting = False
|
||||
bint strip_spaces = False
|
||||
bint in_quotes = False
|
||||
Py_UCS4 *item_start
|
||||
Py_UCS4 *item_ptr
|
||||
Py_UCS4 *item_end
|
||||
|
||||
int i
|
||||
object item
|
||||
str item_text
|
||||
FRBuffer item_buf
|
||||
char *pg_item_str
|
||||
ssize_t pg_item_len
|
||||
|
||||
ptr = array_text
|
||||
|
||||
while True:
|
||||
while apg_ascii_isspace(ptr[0]):
|
||||
ptr += 1
|
||||
|
||||
if ptr[0] != '[':
|
||||
# Finished parsing dimensions spec.
|
||||
break
|
||||
|
||||
ptr += 1 # '['
|
||||
|
||||
if ndims > ARRAY_MAXDIM:
|
||||
raise ValueError(
|
||||
'number of array dimensions ({}) exceed the '
|
||||
'maximum expected ({})'.format(ndims, ARRAY_MAXDIM))
|
||||
|
||||
ptr = apg_parse_int32(ptr, &ubound)
|
||||
if ptr == NULL:
|
||||
raise ValueError('missing array dimension value')
|
||||
|
||||
if ptr[0] == ':':
|
||||
ptr += 1
|
||||
lbound = ubound
|
||||
|
||||
# [lower:upper] spec. We disregard the lbound for decoding.
|
||||
ptr = apg_parse_int32(ptr, &ubound)
|
||||
if ptr == NULL:
|
||||
raise ValueError('missing array dimension value')
|
||||
else:
|
||||
lbound = 1
|
||||
|
||||
if ptr[0] != ']':
|
||||
raise ValueError('missing \']\' after array dimensions')
|
||||
|
||||
ptr += 1 # ']'
|
||||
|
||||
dims[ndims] = ubound - lbound + 1
|
||||
ndims += 1
|
||||
|
||||
if ndims != 0:
|
||||
# If dimensions were given, the '=' token is expected.
|
||||
if ptr[0] != '=':
|
||||
raise ValueError('missing \'=\' after array dimensions')
|
||||
|
||||
ptr += 1 # '='
|
||||
|
||||
# Skip any whitespace after the '=', whitespace
|
||||
# before was consumed in the above loop.
|
||||
while apg_ascii_isspace(ptr[0]):
|
||||
ptr += 1
|
||||
|
||||
# Infer the dimensions from the brace structure in the
|
||||
# array literal body, and check that it matches the explicit
|
||||
# spec. This also validates that the array literal is sane.
|
||||
_infer_array_dims(ptr, typdelim, inferred_dims, &inferred_ndims)
|
||||
|
||||
if inferred_ndims != ndims:
|
||||
raise ValueError(
|
||||
'specified array dimensions do not match array content')
|
||||
|
||||
for i in range(ndims):
|
||||
if inferred_dims[i] != dims[i]:
|
||||
raise ValueError(
|
||||
'specified array dimensions do not match array content')
|
||||
else:
|
||||
# Infer the dimensions from the brace structure in the array literal
|
||||
# body. This also validates that the array literal is sane.
|
||||
_infer_array_dims(ptr, typdelim, dims, &ndims)
|
||||
|
||||
while not end_of_array:
|
||||
# We iterate over the literal character by character
|
||||
# and modify the string in-place removing the array-specific
|
||||
# quoting and determining the boundaries of each element.
|
||||
end_of_item = has_quoting = in_quotes = False
|
||||
strip_spaces = True
|
||||
|
||||
# Pointers to array element start, end, and the current pointer
|
||||
# tracking the position where characters are written when
|
||||
# escaping is folded.
|
||||
item_start = item_end = item_ptr = ptr
|
||||
item_level = 0
|
||||
|
||||
while not end_of_item:
|
||||
if ptr[0] == '"':
|
||||
in_quotes = not in_quotes
|
||||
if in_quotes:
|
||||
strip_spaces = False
|
||||
else:
|
||||
item_end = item_ptr
|
||||
has_quoting = True
|
||||
|
||||
elif ptr[0] == '\\':
|
||||
# Quoted character, collapse the backslash.
|
||||
ptr += 1
|
||||
has_quoting = True
|
||||
item_ptr[0] = ptr[0]
|
||||
item_ptr += 1
|
||||
strip_spaces = False
|
||||
item_end = item_ptr
|
||||
|
||||
elif in_quotes:
|
||||
# Consume the string until we see the closing quote.
|
||||
item_ptr[0] = ptr[0]
|
||||
item_ptr += 1
|
||||
|
||||
elif ptr[0] == '{':
|
||||
# Nesting level increase.
|
||||
nest_level += 1
|
||||
|
||||
indexes[nest_level - 1] = 0
|
||||
new_stride = cpython.PyList_New(dims[nest_level - 1])
|
||||
strides[nest_level - 1] = \
|
||||
<void*>(<cpython.PyObject>new_stride)
|
||||
|
||||
if nest_level > 1:
|
||||
cpython.Py_INCREF(new_stride)
|
||||
cpython.PyList_SET_ITEM(
|
||||
<object><cpython.PyObject*>strides[nest_level - 2],
|
||||
indexes[nest_level - 2],
|
||||
new_stride)
|
||||
else:
|
||||
result = new_stride
|
||||
|
||||
elif ptr[0] == '}':
|
||||
if item_level == 0:
|
||||
# Make sure we keep track of which nesting
|
||||
# level the item belongs to, as the loop
|
||||
# will continue to consume closing braces
|
||||
# until the delimiter or the end of input.
|
||||
item_level = nest_level
|
||||
|
||||
nest_level -= 1
|
||||
|
||||
if nest_level == 0:
|
||||
end_of_array = end_of_item = True
|
||||
|
||||
elif ptr[0] == typdelim:
|
||||
# Array element delimiter,
|
||||
end_of_item = True
|
||||
if item_level == 0:
|
||||
item_level = nest_level
|
||||
|
||||
elif apg_ascii_isspace(ptr[0]):
|
||||
if not strip_spaces:
|
||||
item_ptr[0] = ptr[0]
|
||||
item_ptr += 1
|
||||
# Ignore the leading literal whitespace.
|
||||
|
||||
else:
|
||||
item_ptr[0] = ptr[0]
|
||||
item_ptr += 1
|
||||
strip_spaces = False
|
||||
item_end = item_ptr
|
||||
|
||||
ptr += 1
|
||||
|
||||
# end while not end_of_item
|
||||
|
||||
if item_end == item_start:
|
||||
# Empty array
|
||||
continue
|
||||
|
||||
item_end[0] = '\0'
|
||||
|
||||
if not has_quoting and apg_strcasecmp(item_start, APG_NULL) == 0:
|
||||
# NULL element.
|
||||
item = None
|
||||
else:
|
||||
# XXX: find a way to avoid the redundant encode/decode
|
||||
# cycle here.
|
||||
item_text = cpythonx.PyUnicode_FromKindAndData(
|
||||
cpythonx.PyUnicode_4BYTE_KIND,
|
||||
<void *>item_start,
|
||||
item_end - item_start)
|
||||
|
||||
# Prepare the element buffer and call the text decoder
|
||||
# for the element type.
|
||||
pgproto.as_pg_string_and_size(
|
||||
settings, item_text, &pg_item_str, &pg_item_len)
|
||||
frb_init(&item_buf, pg_item_str, pg_item_len)
|
||||
item = decoder(settings, &item_buf, decoder_arg)
|
||||
|
||||
# Place the decoded element in the array.
|
||||
cpython.Py_INCREF(item)
|
||||
cpython.PyList_SET_ITEM(
|
||||
<object><cpython.PyObject*>strides[item_level - 1],
|
||||
indexes[item_level - 1],
|
||||
item)
|
||||
|
||||
if nest_level > 0:
|
||||
indexes[nest_level - 1] += 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
cdef enum _ArrayParseState:
|
||||
APS_START = 1
|
||||
APS_STRIDE_STARTED = 2
|
||||
APS_STRIDE_DONE = 3
|
||||
APS_STRIDE_DELIMITED = 4
|
||||
APS_ELEM_STARTED = 5
|
||||
APS_ELEM_DELIMITED = 6
|
||||
|
||||
|
||||
cdef _UnexpectedCharacter(const Py_UCS4 *array_text, const Py_UCS4 *ptr):
|
||||
return ValueError('unexpected character {!r} at position {}'.format(
|
||||
cpython.PyUnicode_FromOrdinal(<int>ptr[0]), ptr - array_text + 1))
|
||||
|
||||
|
||||
cdef _infer_array_dims(const Py_UCS4 *array_text,
|
||||
Py_UCS4 typdelim,
|
||||
int32_t *dims,
|
||||
int32_t *ndims):
|
||||
cdef:
|
||||
const Py_UCS4 *ptr = array_text
|
||||
int i
|
||||
int nest_level = 0
|
||||
bint end_of_array = False
|
||||
bint end_of_item = False
|
||||
bint in_quotes = False
|
||||
bint array_is_empty = True
|
||||
int stride_len[ARRAY_MAXDIM]
|
||||
int prev_stride_len[ARRAY_MAXDIM]
|
||||
_ArrayParseState parse_state = APS_START
|
||||
|
||||
for i in range(ARRAY_MAXDIM):
|
||||
dims[i] = prev_stride_len[i] = 0
|
||||
stride_len[i] = 1
|
||||
|
||||
while not end_of_array:
|
||||
end_of_item = False
|
||||
|
||||
while not end_of_item:
|
||||
if ptr[0] == '\0':
|
||||
raise ValueError('unexpected end of string')
|
||||
|
||||
elif ptr[0] == '"':
|
||||
if (parse_state not in (APS_STRIDE_STARTED,
|
||||
APS_ELEM_DELIMITED) and
|
||||
not (parse_state == APS_ELEM_STARTED and in_quotes)):
|
||||
raise _UnexpectedCharacter(array_text, ptr)
|
||||
|
||||
in_quotes = not in_quotes
|
||||
if in_quotes:
|
||||
parse_state = APS_ELEM_STARTED
|
||||
array_is_empty = False
|
||||
|
||||
elif ptr[0] == '\\':
|
||||
if parse_state not in (APS_STRIDE_STARTED,
|
||||
APS_ELEM_STARTED,
|
||||
APS_ELEM_DELIMITED):
|
||||
raise _UnexpectedCharacter(array_text, ptr)
|
||||
|
||||
parse_state = APS_ELEM_STARTED
|
||||
array_is_empty = False
|
||||
|
||||
if ptr[1] != '\0':
|
||||
ptr += 1
|
||||
else:
|
||||
raise ValueError('unexpected end of string')
|
||||
|
||||
elif in_quotes:
|
||||
# Ignore everything inside the quotes.
|
||||
pass
|
||||
|
||||
elif ptr[0] == '{':
|
||||
if parse_state not in (APS_START,
|
||||
APS_STRIDE_STARTED,
|
||||
APS_STRIDE_DELIMITED):
|
||||
raise _UnexpectedCharacter(array_text, ptr)
|
||||
|
||||
parse_state = APS_STRIDE_STARTED
|
||||
if nest_level >= ARRAY_MAXDIM:
|
||||
raise ValueError(
|
||||
'number of array dimensions ({}) exceed the '
|
||||
'maximum expected ({})'.format(
|
||||
nest_level, ARRAY_MAXDIM))
|
||||
|
||||
dims[nest_level] = 0
|
||||
nest_level += 1
|
||||
if ndims[0] < nest_level:
|
||||
ndims[0] = nest_level
|
||||
|
||||
elif ptr[0] == '}':
|
||||
if (parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE) and
|
||||
not (nest_level == 1 and
|
||||
parse_state == APS_STRIDE_STARTED)):
|
||||
raise _UnexpectedCharacter(array_text, ptr)
|
||||
|
||||
parse_state = APS_STRIDE_DONE
|
||||
|
||||
if nest_level == 0:
|
||||
raise _UnexpectedCharacter(array_text, ptr)
|
||||
|
||||
nest_level -= 1
|
||||
|
||||
if (prev_stride_len[nest_level] != 0 and
|
||||
stride_len[nest_level] != prev_stride_len[nest_level]):
|
||||
raise ValueError(
|
||||
'inconsistent sub-array dimensions'
|
||||
' at position {}'.format(
|
||||
ptr - array_text + 1))
|
||||
|
||||
prev_stride_len[nest_level] = stride_len[nest_level]
|
||||
stride_len[nest_level] = 1
|
||||
if nest_level == 0:
|
||||
end_of_array = end_of_item = True
|
||||
else:
|
||||
dims[nest_level - 1] += 1
|
||||
|
||||
elif ptr[0] == typdelim:
|
||||
if parse_state not in (APS_ELEM_STARTED, APS_STRIDE_DONE):
|
||||
raise _UnexpectedCharacter(array_text, ptr)
|
||||
|
||||
if parse_state == APS_STRIDE_DONE:
|
||||
parse_state = APS_STRIDE_DELIMITED
|
||||
else:
|
||||
parse_state = APS_ELEM_DELIMITED
|
||||
end_of_item = True
|
||||
stride_len[nest_level - 1] += 1
|
||||
|
||||
elif not apg_ascii_isspace(ptr[0]):
|
||||
if parse_state not in (APS_STRIDE_STARTED,
|
||||
APS_ELEM_STARTED,
|
||||
APS_ELEM_DELIMITED):
|
||||
raise _UnexpectedCharacter(array_text, ptr)
|
||||
|
||||
parse_state = APS_ELEM_STARTED
|
||||
array_is_empty = False
|
||||
|
||||
if not end_of_item:
|
||||
ptr += 1
|
||||
|
||||
if not array_is_empty:
|
||||
dims[ndims[0] - 1] += 1
|
||||
|
||||
ptr += 1
|
||||
|
||||
# only whitespace is allowed after the closing brace
|
||||
while ptr[0] != '\0':
|
||||
if not apg_ascii_isspace(ptr[0]):
|
||||
raise _UnexpectedCharacter(array_text, ptr)
|
||||
|
||||
ptr += 1
|
||||
|
||||
if array_is_empty:
|
||||
ndims[0] = 0
|
||||
|
||||
|
||||
cdef uint4_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj,
|
||||
const void *arg):
|
||||
return pgproto.uint4_encode(settings, buf, obj)
|
||||
|
||||
|
||||
cdef uint4_decode_ex(ConnectionSettings settings, FRBuffer *buf,
|
||||
const void *arg):
|
||||
return pgproto.uint4_decode(settings, buf)
|
||||
|
||||
|
||||
cdef arrayoid_encode(ConnectionSettings settings, WriteBuffer buf, items):
|
||||
array_encode(settings, buf, items, OIDOID,
|
||||
<encode_func_ex>&uint4_encode_ex, NULL)
|
||||
|
||||
|
||||
cdef arrayoid_decode(ConnectionSettings settings, FRBuffer *buf):
|
||||
return array_decode(settings, buf, <decode_func_ex>&uint4_decode_ex, NULL)
|
||||
|
||||
|
||||
cdef text_encode_ex(ConnectionSettings settings, WriteBuffer buf, object obj,
|
||||
const void *arg):
|
||||
return pgproto.text_encode(settings, buf, obj)
|
||||
|
||||
|
||||
cdef text_decode_ex(ConnectionSettings settings, FRBuffer *buf,
|
||||
const void *arg):
|
||||
return pgproto.text_decode(settings, buf)
|
||||
|
||||
|
||||
cdef arraytext_encode(ConnectionSettings settings, WriteBuffer buf, items):
|
||||
array_encode(settings, buf, items, TEXTOID,
|
||||
<encode_func_ex>&text_encode_ex, NULL)
|
||||
|
||||
|
||||
cdef arraytext_decode(ConnectionSettings settings, FRBuffer *buf):
|
||||
return array_decode(settings, buf, <decode_func_ex>&text_decode_ex, NULL)
|
||||
|
||||
|
||||
cdef init_array_codecs():
|
||||
# oid[] and text[] are registered as core codecs
|
||||
# to make type introspection query work
|
||||
#
|
||||
register_core_codec(_OIDOID,
|
||||
<encode_func>&arrayoid_encode,
|
||||
<decode_func>&arrayoid_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(_TEXTOID,
|
||||
<encode_func>&arraytext_encode,
|
||||
<decode_func>&arraytext_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
init_array_codecs()
|
||||
@@ -0,0 +1,187 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
ctypedef object (*encode_func)(ConnectionSettings settings,
|
||||
WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
ctypedef object (*decode_func)(ConnectionSettings settings,
|
||||
FRBuffer *buf)
|
||||
|
||||
ctypedef object (*codec_encode_func)(Codec codec,
|
||||
ConnectionSettings settings,
|
||||
WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
ctypedef object (*codec_decode_func)(Codec codec,
|
||||
ConnectionSettings settings,
|
||||
FRBuffer *buf)
|
||||
|
||||
|
||||
cdef enum CodecType:
|
||||
CODEC_UNDEFINED = 0
|
||||
CODEC_C = 1
|
||||
CODEC_PY = 2
|
||||
CODEC_ARRAY = 3
|
||||
CODEC_COMPOSITE = 4
|
||||
CODEC_RANGE = 5
|
||||
CODEC_MULTIRANGE = 6
|
||||
|
||||
|
||||
cdef enum ServerDataFormat:
|
||||
PG_FORMAT_ANY = -1
|
||||
PG_FORMAT_TEXT = 0
|
||||
PG_FORMAT_BINARY = 1
|
||||
|
||||
|
||||
cdef enum ClientExchangeFormat:
|
||||
PG_XFORMAT_OBJECT = 1
|
||||
PG_XFORMAT_TUPLE = 2
|
||||
|
||||
|
||||
cdef class Codec:
|
||||
cdef:
|
||||
uint32_t oid
|
||||
|
||||
str name
|
||||
str schema
|
||||
str kind
|
||||
|
||||
CodecType type
|
||||
ServerDataFormat format
|
||||
ClientExchangeFormat xformat
|
||||
|
||||
encode_func c_encoder
|
||||
decode_func c_decoder
|
||||
Codec base_codec
|
||||
|
||||
object py_encoder
|
||||
object py_decoder
|
||||
|
||||
# arrays
|
||||
Codec element_codec
|
||||
Py_UCS4 element_delimiter
|
||||
|
||||
# composite types
|
||||
tuple element_type_oids
|
||||
object element_names
|
||||
object record_desc
|
||||
list element_codecs
|
||||
|
||||
# Pointers to actual encoder/decoder functions for this codec
|
||||
codec_encode_func encoder
|
||||
codec_decode_func decoder
|
||||
|
||||
cdef init(self, str name, str schema, str kind,
|
||||
CodecType type, ServerDataFormat format,
|
||||
ClientExchangeFormat xformat,
|
||||
encode_func c_encoder, decode_func c_decoder,
|
||||
Codec base_codec,
|
||||
object py_encoder, object py_decoder,
|
||||
Codec element_codec, tuple element_type_oids,
|
||||
object element_names, list element_codecs,
|
||||
Py_UCS4 element_delimiter)
|
||||
|
||||
cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf)
|
||||
|
||||
cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf)
|
||||
|
||||
cdef decode_array_text(self, ConnectionSettings settings, FRBuffer *buf)
|
||||
|
||||
cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf)
|
||||
|
||||
cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf)
|
||||
|
||||
cdef decode_composite(self, ConnectionSettings settings, FRBuffer *buf)
|
||||
|
||||
cdef decode_in_python(self, ConnectionSettings settings, FRBuffer *buf)
|
||||
|
||||
cdef inline encode(self,
|
||||
ConnectionSettings settings,
|
||||
WriteBuffer buf,
|
||||
object obj)
|
||||
|
||||
cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf)
|
||||
|
||||
cdef has_encoder(self)
|
||||
cdef has_decoder(self)
|
||||
cdef is_binary(self)
|
||||
|
||||
cdef inline Codec copy(self)
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_array_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
Codec element_codec,
|
||||
Py_UCS4 element_delimiter)
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_range_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
Codec element_codec)
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_multirange_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
Codec element_codec)
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_composite_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
ServerDataFormat format,
|
||||
list element_codecs,
|
||||
tuple element_type_oids,
|
||||
object element_names)
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_python_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
str kind,
|
||||
object encoder,
|
||||
object decoder,
|
||||
encode_func c_encoder,
|
||||
decode_func c_decoder,
|
||||
Codec base_codec,
|
||||
ServerDataFormat format,
|
||||
ClientExchangeFormat xformat)
|
||||
|
||||
|
||||
cdef class DataCodecConfig:
|
||||
cdef:
|
||||
dict _derived_type_codecs
|
||||
dict _custom_type_codecs
|
||||
|
||||
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
|
||||
bint ignore_custom_codec=*)
|
||||
cdef inline Codec get_custom_codec(self, uint32_t oid,
|
||||
ServerDataFormat format)
|
||||
@@ -0,0 +1,895 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from collections.abc import Mapping as MappingABC
|
||||
|
||||
import asyncpg
|
||||
from asyncpg import exceptions
|
||||
|
||||
|
||||
cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2]
|
||||
cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2]
|
||||
cdef dict EXTRA_CODECS = {}
|
||||
|
||||
|
||||
@cython.final
|
||||
cdef class Codec:
|
||||
|
||||
def __cinit__(self, uint32_t oid):
|
||||
self.oid = oid
|
||||
self.type = CODEC_UNDEFINED
|
||||
|
||||
cdef init(
|
||||
self,
|
||||
str name,
|
||||
str schema,
|
||||
str kind,
|
||||
CodecType type,
|
||||
ServerDataFormat format,
|
||||
ClientExchangeFormat xformat,
|
||||
encode_func c_encoder,
|
||||
decode_func c_decoder,
|
||||
Codec base_codec,
|
||||
object py_encoder,
|
||||
object py_decoder,
|
||||
Codec element_codec,
|
||||
tuple element_type_oids,
|
||||
object element_names,
|
||||
list element_codecs,
|
||||
Py_UCS4 element_delimiter,
|
||||
):
|
||||
|
||||
self.name = name
|
||||
self.schema = schema
|
||||
self.kind = kind
|
||||
self.type = type
|
||||
self.format = format
|
||||
self.xformat = xformat
|
||||
self.c_encoder = c_encoder
|
||||
self.c_decoder = c_decoder
|
||||
self.base_codec = base_codec
|
||||
self.py_encoder = py_encoder
|
||||
self.py_decoder = py_decoder
|
||||
self.element_codec = element_codec
|
||||
self.element_type_oids = element_type_oids
|
||||
self.element_codecs = element_codecs
|
||||
self.element_delimiter = element_delimiter
|
||||
self.element_names = element_names
|
||||
|
||||
if base_codec is not None:
|
||||
if c_encoder != NULL or c_decoder != NULL:
|
||||
raise exceptions.InternalClientError(
|
||||
'base_codec is mutually exclusive with c_encoder/c_decoder'
|
||||
)
|
||||
|
||||
if element_names is not None:
|
||||
self.record_desc = record.ApgRecordDesc_New(
|
||||
element_names, tuple(element_names))
|
||||
else:
|
||||
self.record_desc = None
|
||||
|
||||
if type == CODEC_C:
|
||||
self.encoder = <codec_encode_func>&self.encode_scalar
|
||||
self.decoder = <codec_decode_func>&self.decode_scalar
|
||||
elif type == CODEC_ARRAY:
|
||||
if format == PG_FORMAT_BINARY:
|
||||
self.encoder = <codec_encode_func>&self.encode_array
|
||||
self.decoder = <codec_decode_func>&self.decode_array
|
||||
else:
|
||||
self.encoder = <codec_encode_func>&self.encode_array_text
|
||||
self.decoder = <codec_decode_func>&self.decode_array_text
|
||||
elif type == CODEC_RANGE:
|
||||
if format != PG_FORMAT_BINARY:
|
||||
raise exceptions.UnsupportedClientFeatureError(
|
||||
'cannot decode type "{}"."{}": text encoding of '
|
||||
'range types is not supported'.format(schema, name))
|
||||
self.encoder = <codec_encode_func>&self.encode_range
|
||||
self.decoder = <codec_decode_func>&self.decode_range
|
||||
elif type == CODEC_MULTIRANGE:
|
||||
if format != PG_FORMAT_BINARY:
|
||||
raise exceptions.UnsupportedClientFeatureError(
|
||||
'cannot decode type "{}"."{}": text encoding of '
|
||||
'range types is not supported'.format(schema, name))
|
||||
self.encoder = <codec_encode_func>&self.encode_multirange
|
||||
self.decoder = <codec_decode_func>&self.decode_multirange
|
||||
elif type == CODEC_COMPOSITE:
|
||||
if format != PG_FORMAT_BINARY:
|
||||
raise exceptions.UnsupportedClientFeatureError(
|
||||
'cannot decode type "{}"."{}": text encoding of '
|
||||
'composite types is not supported'.format(schema, name))
|
||||
self.encoder = <codec_encode_func>&self.encode_composite
|
||||
self.decoder = <codec_decode_func>&self.decode_composite
|
||||
elif type == CODEC_PY:
|
||||
self.encoder = <codec_encode_func>&self.encode_in_python
|
||||
self.decoder = <codec_decode_func>&self.decode_in_python
|
||||
else:
|
||||
raise exceptions.InternalClientError(
|
||||
'unexpected codec type: {}'.format(type))
|
||||
|
||||
cdef Codec copy(self):
|
||||
cdef Codec codec
|
||||
|
||||
codec = Codec(self.oid)
|
||||
codec.init(self.name, self.schema, self.kind,
|
||||
self.type, self.format, self.xformat,
|
||||
self.c_encoder, self.c_decoder, self.base_codec,
|
||||
self.py_encoder, self.py_decoder,
|
||||
self.element_codec,
|
||||
self.element_type_oids, self.element_names,
|
||||
self.element_codecs, self.element_delimiter)
|
||||
|
||||
return codec
|
||||
|
||||
cdef encode_scalar(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj):
|
||||
self.c_encoder(settings, buf, obj)
|
||||
|
||||
cdef encode_array(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj):
|
||||
array_encode(settings, buf, obj, self.element_codec.oid,
|
||||
codec_encode_func_ex,
|
||||
<void*>(<cpython.PyObject>self.element_codec))
|
||||
|
||||
cdef encode_array_text(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj):
|
||||
return textarray_encode(settings, buf, obj,
|
||||
codec_encode_func_ex,
|
||||
<void*>(<cpython.PyObject>self.element_codec),
|
||||
self.element_delimiter)
|
||||
|
||||
cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj):
|
||||
range_encode(settings, buf, obj, self.element_codec.oid,
|
||||
codec_encode_func_ex,
|
||||
<void*>(<cpython.PyObject>self.element_codec))
|
||||
|
||||
cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj):
|
||||
multirange_encode(settings, buf, obj, self.element_codec.oid,
|
||||
codec_encode_func_ex,
|
||||
<void*>(<cpython.PyObject>self.element_codec))
|
||||
|
||||
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj):
|
||||
cdef:
|
||||
WriteBuffer elem_data
|
||||
int i
|
||||
list elem_codecs = self.element_codecs
|
||||
ssize_t count
|
||||
ssize_t composite_size
|
||||
tuple rec
|
||||
|
||||
if isinstance(obj, MappingABC):
|
||||
# Input is dict-like, form a tuple
|
||||
composite_size = len(self.element_type_oids)
|
||||
rec = cpython.PyTuple_New(composite_size)
|
||||
|
||||
for i in range(composite_size):
|
||||
cpython.Py_INCREF(None)
|
||||
cpython.PyTuple_SET_ITEM(rec, i, None)
|
||||
|
||||
for field in obj:
|
||||
try:
|
||||
i = self.element_names[field]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
'{!r} is not a valid element of composite '
|
||||
'type {}'.format(field, self.name)) from None
|
||||
|
||||
item = obj[field]
|
||||
cpython.Py_INCREF(item)
|
||||
cpython.PyTuple_SET_ITEM(rec, i, item)
|
||||
|
||||
obj = rec
|
||||
|
||||
count = len(obj)
|
||||
if count > _MAXINT32:
|
||||
raise ValueError('too many elements in composite type record')
|
||||
|
||||
elem_data = WriteBuffer.new()
|
||||
i = 0
|
||||
for item in obj:
|
||||
elem_data.write_int32(<int32_t>self.element_type_oids[i])
|
||||
if item is None:
|
||||
elem_data.write_int32(-1)
|
||||
else:
|
||||
(<Codec>elem_codecs[i]).encode(settings, elem_data, item)
|
||||
i += 1
|
||||
|
||||
record_encode_frame(settings, buf, elem_data, <int32_t>count)
|
||||
|
||||
cdef encode_in_python(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj):
|
||||
data = self.py_encoder(obj)
|
||||
if self.xformat == PG_XFORMAT_OBJECT:
|
||||
if self.format == PG_FORMAT_BINARY:
|
||||
pgproto.bytea_encode(settings, buf, data)
|
||||
elif self.format == PG_FORMAT_TEXT:
|
||||
pgproto.text_encode(settings, buf, data)
|
||||
else:
|
||||
raise exceptions.InternalClientError(
|
||||
'unexpected data format: {}'.format(self.format))
|
||||
elif self.xformat == PG_XFORMAT_TUPLE:
|
||||
if self.base_codec is not None:
|
||||
self.base_codec.encode(settings, buf, data)
|
||||
else:
|
||||
self.c_encoder(settings, buf, data)
|
||||
else:
|
||||
raise exceptions.InternalClientError(
|
||||
'unexpected exchange format: {}'.format(self.xformat))
|
||||
|
||||
cdef encode(self, ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj):
|
||||
return self.encoder(self, settings, buf, obj)
|
||||
|
||||
cdef decode_scalar(self, ConnectionSettings settings, FRBuffer *buf):
|
||||
return self.c_decoder(settings, buf)
|
||||
|
||||
cdef decode_array(self, ConnectionSettings settings, FRBuffer *buf):
|
||||
return array_decode(settings, buf, codec_decode_func_ex,
|
||||
<void*>(<cpython.PyObject>self.element_codec))
|
||||
|
||||
cdef decode_array_text(self, ConnectionSettings settings,
|
||||
FRBuffer *buf):
|
||||
return textarray_decode(settings, buf, codec_decode_func_ex,
|
||||
<void*>(<cpython.PyObject>self.element_codec),
|
||||
self.element_delimiter)
|
||||
|
||||
cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf):
|
||||
return range_decode(settings, buf, codec_decode_func_ex,
|
||||
<void*>(<cpython.PyObject>self.element_codec))
|
||||
|
||||
cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf):
|
||||
return multirange_decode(settings, buf, codec_decode_func_ex,
|
||||
<void*>(<cpython.PyObject>self.element_codec))
|
||||
|
||||
cdef decode_composite(self, ConnectionSettings settings,
|
||||
FRBuffer *buf):
|
||||
cdef:
|
||||
object result
|
||||
ssize_t elem_count
|
||||
ssize_t i
|
||||
int32_t elem_len
|
||||
uint32_t elem_typ
|
||||
uint32_t received_elem_typ
|
||||
Codec elem_codec
|
||||
FRBuffer elem_buf
|
||||
|
||||
elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4))
|
||||
if elem_count != len(self.element_type_oids):
|
||||
raise exceptions.OutdatedSchemaCacheError(
|
||||
'unexpected number of attributes of composite type: '
|
||||
'{}, expected {}'
|
||||
.format(
|
||||
elem_count,
|
||||
len(self.element_type_oids),
|
||||
),
|
||||
schema=self.schema,
|
||||
data_type=self.name,
|
||||
)
|
||||
result = record.ApgRecord_New(asyncpg.Record, self.record_desc, elem_count)
|
||||
for i in range(elem_count):
|
||||
elem_typ = self.element_type_oids[i]
|
||||
received_elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
|
||||
|
||||
if received_elem_typ != elem_typ:
|
||||
raise exceptions.OutdatedSchemaCacheError(
|
||||
'unexpected data type of composite type attribute {}: '
|
||||
'{!r}, expected {!r}'
|
||||
.format(
|
||||
i,
|
||||
BUILTIN_TYPE_OID_MAP.get(
|
||||
received_elem_typ, received_elem_typ),
|
||||
BUILTIN_TYPE_OID_MAP.get(
|
||||
elem_typ, elem_typ)
|
||||
),
|
||||
schema=self.schema,
|
||||
data_type=self.name,
|
||||
position=i,
|
||||
)
|
||||
|
||||
elem_len = hton.unpack_int32(frb_read(buf, 4))
|
||||
if elem_len == -1:
|
||||
elem = None
|
||||
else:
|
||||
elem_codec = self.element_codecs[i]
|
||||
elem = elem_codec.decode(
|
||||
settings, frb_slice_from(&elem_buf, buf, elem_len))
|
||||
|
||||
cpython.Py_INCREF(elem)
|
||||
record.ApgRecord_SET_ITEM(result, i, elem)
|
||||
|
||||
return result
|
||||
|
||||
cdef decode_in_python(self, ConnectionSettings settings,
|
||||
FRBuffer *buf):
|
||||
if self.xformat == PG_XFORMAT_OBJECT:
|
||||
if self.format == PG_FORMAT_BINARY:
|
||||
data = pgproto.bytea_decode(settings, buf)
|
||||
elif self.format == PG_FORMAT_TEXT:
|
||||
data = pgproto.text_decode(settings, buf)
|
||||
else:
|
||||
raise exceptions.InternalClientError(
|
||||
'unexpected data format: {}'.format(self.format))
|
||||
elif self.xformat == PG_XFORMAT_TUPLE:
|
||||
if self.base_codec is not None:
|
||||
data = self.base_codec.decode(settings, buf)
|
||||
else:
|
||||
data = self.c_decoder(settings, buf)
|
||||
else:
|
||||
raise exceptions.InternalClientError(
|
||||
'unexpected exchange format: {}'.format(self.xformat))
|
||||
|
||||
return self.py_decoder(data)
|
||||
|
||||
cdef inline decode(self, ConnectionSettings settings, FRBuffer *buf):
|
||||
return self.decoder(self, settings, buf)
|
||||
|
||||
cdef inline has_encoder(self):
|
||||
cdef Codec elem_codec
|
||||
|
||||
if self.c_encoder is not NULL or self.py_encoder is not None:
|
||||
return True
|
||||
|
||||
elif (
|
||||
self.type == CODEC_ARRAY
|
||||
or self.type == CODEC_RANGE
|
||||
or self.type == CODEC_MULTIRANGE
|
||||
):
|
||||
return self.element_codec.has_encoder()
|
||||
|
||||
elif self.type == CODEC_COMPOSITE:
|
||||
for elem_codec in self.element_codecs:
|
||||
if not elem_codec.has_encoder():
|
||||
return False
|
||||
return True
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
cdef has_decoder(self):
|
||||
cdef Codec elem_codec
|
||||
|
||||
if self.c_decoder is not NULL or self.py_decoder is not None:
|
||||
return True
|
||||
|
||||
elif (
|
||||
self.type == CODEC_ARRAY
|
||||
or self.type == CODEC_RANGE
|
||||
or self.type == CODEC_MULTIRANGE
|
||||
):
|
||||
return self.element_codec.has_decoder()
|
||||
|
||||
elif self.type == CODEC_COMPOSITE:
|
||||
for elem_codec in self.element_codecs:
|
||||
if not elem_codec.has_decoder():
|
||||
return False
|
||||
return True
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
cdef is_binary(self):
|
||||
return self.format == PG_FORMAT_BINARY
|
||||
|
||||
def __repr__(self):
|
||||
return '<Codec oid={} elem_oid={} core={}>'.format(
|
||||
self.oid,
|
||||
'NA' if self.element_codec is None else self.element_codec.oid,
|
||||
has_core_codec(self.oid))
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_array_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
Codec element_codec,
|
||||
Py_UCS4 element_delimiter):
|
||||
cdef Codec codec
|
||||
codec = Codec(oid)
|
||||
codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format,
|
||||
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
|
||||
element_codec, None, None, None, element_delimiter)
|
||||
return codec
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_range_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
Codec element_codec):
|
||||
cdef Codec codec
|
||||
codec = Codec(oid)
|
||||
codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format,
|
||||
PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
|
||||
element_codec, None, None, None, 0)
|
||||
return codec
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_multirange_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
Codec element_codec):
|
||||
cdef Codec codec
|
||||
codec = Codec(oid)
|
||||
codec.init(name, schema, 'multirange', CODEC_MULTIRANGE,
|
||||
element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None,
|
||||
None, None, element_codec, None, None, None, 0)
|
||||
return codec
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_composite_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
ServerDataFormat format,
|
||||
list element_codecs,
|
||||
tuple element_type_oids,
|
||||
object element_names):
|
||||
cdef Codec codec
|
||||
codec = Codec(oid)
|
||||
codec.init(name, schema, 'composite', CODEC_COMPOSITE,
|
||||
format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
|
||||
None, element_type_oids, element_names, element_codecs, 0)
|
||||
return codec
|
||||
|
||||
@staticmethod
|
||||
cdef Codec new_python_codec(uint32_t oid,
|
||||
str name,
|
||||
str schema,
|
||||
str kind,
|
||||
object encoder,
|
||||
object decoder,
|
||||
encode_func c_encoder,
|
||||
decode_func c_decoder,
|
||||
Codec base_codec,
|
||||
ServerDataFormat format,
|
||||
ClientExchangeFormat xformat):
|
||||
cdef Codec codec
|
||||
codec = Codec(oid)
|
||||
codec.init(name, schema, kind, CODEC_PY, format, xformat,
|
||||
c_encoder, c_decoder, base_codec, encoder, decoder,
|
||||
None, None, None, None, 0)
|
||||
return codec
|
||||
|
||||
|
||||
# Encode callback for arrays
|
||||
cdef codec_encode_func_ex(ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj, const void *arg):
|
||||
return (<Codec>arg).encode(settings, buf, obj)
|
||||
|
||||
|
||||
# Decode callback for arrays
|
||||
cdef codec_decode_func_ex(ConnectionSettings settings, FRBuffer *buf,
|
||||
const void *arg):
|
||||
return (<Codec>arg).decode(settings, buf)
|
||||
|
||||
|
||||
cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
|
||||
cdef:
|
||||
int64_t oid = 0
|
||||
bint overflow = False
|
||||
|
||||
try:
|
||||
oid = cpython.PyLong_AsLongLong(val)
|
||||
except OverflowError:
|
||||
overflow = True
|
||||
|
||||
if overflow or (oid < 0 or oid > UINT32_MAX):
|
||||
raise OverflowError('OID value too large: {!r}'.format(val))
|
||||
|
||||
return <uint32_t>val
|
||||
|
||||
|
||||
cdef class DataCodecConfig:
|
||||
def __init__(self, cache_key):
|
||||
# Codec instance cache for derived types:
|
||||
# composites, arrays, ranges, domains and their combinations.
|
||||
self._derived_type_codecs = {}
|
||||
# Codec instances set up by the user for the connection.
|
||||
self._custom_type_codecs = {}
|
||||
|
||||
def add_types(self, types):
|
||||
cdef:
|
||||
Codec elem_codec
|
||||
list comp_elem_codecs
|
||||
ServerDataFormat format
|
||||
ServerDataFormat elem_format
|
||||
bint has_text_elements
|
||||
Py_UCS4 elem_delim
|
||||
|
||||
for ti in types:
|
||||
oid = ti['oid']
|
||||
|
||||
if self.get_codec(oid, PG_FORMAT_ANY) is not None:
|
||||
continue
|
||||
|
||||
name = ti['name']
|
||||
schema = ti['ns']
|
||||
array_element_oid = ti['elemtype']
|
||||
range_subtype_oid = ti['range_subtype']
|
||||
if ti['attrtypoids']:
|
||||
comp_type_attrs = tuple(ti['attrtypoids'])
|
||||
else:
|
||||
comp_type_attrs = None
|
||||
base_type = ti['basetype']
|
||||
|
||||
if array_element_oid:
|
||||
# Array type (note, there is no separate 'kind' for arrays)
|
||||
|
||||
# Canonicalize type name to "elemtype[]"
|
||||
if name.startswith('_'):
|
||||
name = name[1:]
|
||||
name = '{}[]'.format(name)
|
||||
|
||||
elem_codec = self.get_codec(array_element_oid, PG_FORMAT_ANY)
|
||||
if elem_codec is None:
|
||||
elem_codec = self.declare_fallback_codec(
|
||||
array_element_oid, ti['elemtype_name'], schema)
|
||||
|
||||
elem_delim = <Py_UCS4>ti['elemdelim'][0]
|
||||
|
||||
self._derived_type_codecs[oid, elem_codec.format] = \
|
||||
Codec.new_array_codec(
|
||||
oid, name, schema, elem_codec, elem_delim)
|
||||
|
||||
elif ti['kind'] == b'c':
|
||||
# Composite type
|
||||
|
||||
if not comp_type_attrs:
|
||||
raise exceptions.InternalClientError(
|
||||
f'type record missing field types for composite {oid}')
|
||||
|
||||
comp_elem_codecs = []
|
||||
has_text_elements = False
|
||||
|
||||
for typoid in comp_type_attrs:
|
||||
elem_codec = self.get_codec(typoid, PG_FORMAT_ANY)
|
||||
if elem_codec is None:
|
||||
raise exceptions.InternalClientError(
|
||||
f'no codec for composite attribute type {typoid}')
|
||||
if elem_codec.format is PG_FORMAT_TEXT:
|
||||
has_text_elements = True
|
||||
comp_elem_codecs.append(elem_codec)
|
||||
|
||||
element_names = collections.OrderedDict()
|
||||
for i, attrname in enumerate(ti['attrnames']):
|
||||
element_names[attrname] = i
|
||||
|
||||
# If at least one element is text-encoded, we must
|
||||
# encode the whole composite as text.
|
||||
if has_text_elements:
|
||||
elem_format = PG_FORMAT_TEXT
|
||||
else:
|
||||
elem_format = PG_FORMAT_BINARY
|
||||
|
||||
self._derived_type_codecs[oid, elem_format] = \
|
||||
Codec.new_composite_codec(
|
||||
oid, name, schema, elem_format, comp_elem_codecs,
|
||||
comp_type_attrs, element_names)
|
||||
|
||||
elif ti['kind'] == b'd':
|
||||
# Domain type
|
||||
|
||||
if not base_type:
|
||||
raise exceptions.InternalClientError(
|
||||
f'type record missing base type for domain {oid}')
|
||||
|
||||
elem_codec = self.get_codec(base_type, PG_FORMAT_ANY)
|
||||
if elem_codec is None:
|
||||
elem_codec = self.declare_fallback_codec(
|
||||
base_type, ti['basetype_name'], schema)
|
||||
|
||||
self._derived_type_codecs[oid, elem_codec.format] = elem_codec
|
||||
|
||||
elif ti['kind'] == b'r':
|
||||
# Range type
|
||||
|
||||
if not range_subtype_oid:
|
||||
raise exceptions.InternalClientError(
|
||||
f'type record missing base type for range {oid}')
|
||||
|
||||
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
|
||||
if elem_codec is None:
|
||||
elem_codec = self.declare_fallback_codec(
|
||||
range_subtype_oid, ti['range_subtype_name'], schema)
|
||||
|
||||
self._derived_type_codecs[oid, elem_codec.format] = \
|
||||
Codec.new_range_codec(oid, name, schema, elem_codec)
|
||||
|
||||
elif ti['kind'] == b'm':
|
||||
# Multirange type
|
||||
|
||||
if not range_subtype_oid:
|
||||
raise exceptions.InternalClientError(
|
||||
f'type record missing base type for multirange {oid}')
|
||||
|
||||
elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
|
||||
if elem_codec is None:
|
||||
elem_codec = self.declare_fallback_codec(
|
||||
range_subtype_oid, ti['range_subtype_name'], schema)
|
||||
|
||||
self._derived_type_codecs[oid, elem_codec.format] = \
|
||||
Codec.new_multirange_codec(oid, name, schema, elem_codec)
|
||||
|
||||
elif ti['kind'] == b'e':
|
||||
# Enum types are essentially text
|
||||
self._set_builtin_type_codec(oid, name, schema, 'scalar',
|
||||
TEXTOID, PG_FORMAT_ANY)
|
||||
else:
|
||||
self.declare_fallback_codec(oid, name, schema)
|
||||
|
||||
def add_python_codec(self, typeoid, typename, typeschema, typekind,
|
||||
typeinfos, encoder, decoder, format, xformat):
|
||||
cdef:
|
||||
Codec core_codec = None
|
||||
encode_func c_encoder = NULL
|
||||
decode_func c_decoder = NULL
|
||||
Codec base_codec = None
|
||||
uint32_t oid = pylong_as_oid(typeoid)
|
||||
bint codec_set = False
|
||||
|
||||
# Clear all previous overrides (this also clears type cache).
|
||||
self.remove_python_codec(typeoid, typename, typeschema)
|
||||
|
||||
if typeinfos:
|
||||
self.add_types(typeinfos)
|
||||
|
||||
if format == PG_FORMAT_ANY:
|
||||
formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY)
|
||||
else:
|
||||
formats = (format,)
|
||||
|
||||
for fmt in formats:
|
||||
if xformat == PG_XFORMAT_TUPLE:
|
||||
if typekind == "scalar":
|
||||
core_codec = get_core_codec(oid, fmt, xformat)
|
||||
if core_codec is None:
|
||||
continue
|
||||
c_encoder = core_codec.c_encoder
|
||||
c_decoder = core_codec.c_decoder
|
||||
elif typekind == "composite":
|
||||
base_codec = self.get_codec(oid, fmt)
|
||||
if base_codec is None:
|
||||
continue
|
||||
|
||||
self._custom_type_codecs[typeoid, fmt] = \
|
||||
Codec.new_python_codec(oid, typename, typeschema, typekind,
|
||||
encoder, decoder, c_encoder, c_decoder,
|
||||
base_codec, fmt, xformat)
|
||||
codec_set = True
|
||||
|
||||
if not codec_set:
|
||||
raise exceptions.InterfaceError(
|
||||
"{} type does not support the 'tuple' exchange format".format(
|
||||
typename))
|
||||
|
||||
def remove_python_codec(self, typeoid, typename, typeschema):
|
||||
for fmt in (PG_FORMAT_BINARY, PG_FORMAT_TEXT):
|
||||
self._custom_type_codecs.pop((typeoid, fmt), None)
|
||||
self.clear_type_cache()
|
||||
|
||||
def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
|
||||
alias_to, format=PG_FORMAT_ANY):
|
||||
cdef:
|
||||
Codec codec
|
||||
Codec target_codec
|
||||
uint32_t oid = pylong_as_oid(typeoid)
|
||||
uint32_t alias_oid = 0
|
||||
bint codec_set = False
|
||||
|
||||
if format == PG_FORMAT_ANY:
|
||||
formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT)
|
||||
else:
|
||||
formats = (format,)
|
||||
|
||||
if isinstance(alias_to, int):
|
||||
alias_oid = pylong_as_oid(alias_to)
|
||||
else:
|
||||
alias_oid = BUILTIN_TYPE_NAME_MAP.get(alias_to, 0)
|
||||
|
||||
for format in formats:
|
||||
if alias_oid != 0:
|
||||
target_codec = self.get_codec(alias_oid, format)
|
||||
else:
|
||||
target_codec = get_extra_codec(alias_to, format)
|
||||
|
||||
if target_codec is None:
|
||||
continue
|
||||
|
||||
codec = target_codec.copy()
|
||||
codec.oid = typeoid
|
||||
codec.name = typename
|
||||
codec.schema = typeschema
|
||||
codec.kind = typekind
|
||||
|
||||
self._custom_type_codecs[typeoid, format] = codec
|
||||
codec_set = True
|
||||
|
||||
if not codec_set:
|
||||
if format == PG_FORMAT_BINARY:
|
||||
codec_str = 'binary'
|
||||
elif format == PG_FORMAT_TEXT:
|
||||
codec_str = 'text'
|
||||
else:
|
||||
codec_str = 'text or binary'
|
||||
|
||||
raise exceptions.InterfaceError(
|
||||
f'cannot alias {typename} to {alias_to}: '
|
||||
f'there is no {codec_str} codec for {alias_to}')
|
||||
|
||||
def set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
|
||||
alias_to, format=PG_FORMAT_ANY):
|
||||
self._set_builtin_type_codec(typeoid, typename, typeschema, typekind,
|
||||
alias_to, format)
|
||||
self.clear_type_cache()
|
||||
|
||||
def clear_type_cache(self):
|
||||
self._derived_type_codecs.clear()
|
||||
|
||||
def declare_fallback_codec(self, uint32_t oid, str name, str schema):
|
||||
cdef Codec codec
|
||||
|
||||
if oid <= MAXBUILTINOID:
|
||||
# This is a BKI type, for which asyncpg has no
|
||||
# defined codec. This should only happen for newly
|
||||
# added builtin types, for which this version of
|
||||
# asyncpg is lacking support.
|
||||
#
|
||||
raise exceptions.UnsupportedClientFeatureError(
|
||||
f'unhandled standard data type {name!r} (OID {oid})')
|
||||
else:
|
||||
# This is a non-BKI type, and as such, has no
|
||||
# stable OID, so no possibility of a builtin codec.
|
||||
# In this case, fallback to text format. Applications
|
||||
# can avoid this by specifying a codec for this type
|
||||
# using Connection.set_type_codec().
|
||||
#
|
||||
self._set_builtin_type_codec(oid, name, schema, 'scalar',
|
||||
TEXTOID, PG_FORMAT_TEXT)
|
||||
|
||||
codec = self.get_codec(oid, PG_FORMAT_TEXT)
|
||||
|
||||
return codec
|
||||
|
||||
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format,
|
||||
bint ignore_custom_codec=False):
|
||||
cdef Codec codec
|
||||
|
||||
if format == PG_FORMAT_ANY:
|
||||
codec = self.get_codec(
|
||||
oid, PG_FORMAT_BINARY, ignore_custom_codec)
|
||||
if codec is None:
|
||||
codec = self.get_codec(
|
||||
oid, PG_FORMAT_TEXT, ignore_custom_codec)
|
||||
return codec
|
||||
else:
|
||||
if not ignore_custom_codec:
|
||||
codec = self.get_custom_codec(oid, PG_FORMAT_ANY)
|
||||
if codec is not None:
|
||||
if codec.format != format:
|
||||
# The codec for this OID has been overridden by
|
||||
# set_{builtin}_type_codec with a different format.
|
||||
# We must respect that and not return a core codec.
|
||||
return None
|
||||
else:
|
||||
return codec
|
||||
|
||||
codec = get_core_codec(oid, format)
|
||||
if codec is not None:
|
||||
return codec
|
||||
else:
|
||||
try:
|
||||
return self._derived_type_codecs[oid, format]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
cdef inline Codec get_custom_codec(
|
||||
self,
|
||||
uint32_t oid,
|
||||
ServerDataFormat format
|
||||
):
|
||||
cdef Codec codec
|
||||
|
||||
if format == PG_FORMAT_ANY:
|
||||
codec = self.get_custom_codec(oid, PG_FORMAT_BINARY)
|
||||
if codec is None:
|
||||
codec = self.get_custom_codec(oid, PG_FORMAT_TEXT)
|
||||
else:
|
||||
codec = self._custom_type_codecs.get((oid, format))
|
||||
|
||||
return codec
|
||||
|
||||
|
||||
cdef inline Codec get_core_codec(
|
||||
uint32_t oid, ServerDataFormat format,
|
||||
ClientExchangeFormat xformat=PG_XFORMAT_OBJECT):
|
||||
cdef:
|
||||
void *ptr = NULL
|
||||
|
||||
if oid > MAXSUPPORTEDOID:
|
||||
return None
|
||||
if format == PG_FORMAT_BINARY:
|
||||
ptr = binary_codec_map[oid * xformat]
|
||||
elif format == PG_FORMAT_TEXT:
|
||||
ptr = text_codec_map[oid * xformat]
|
||||
|
||||
if ptr is NULL:
|
||||
return None
|
||||
else:
|
||||
return <Codec>ptr
|
||||
|
||||
|
||||
cdef inline Codec get_any_core_codec(
|
||||
uint32_t oid, ServerDataFormat format,
|
||||
ClientExchangeFormat xformat=PG_XFORMAT_OBJECT):
|
||||
"""A version of get_core_codec that accepts PG_FORMAT_ANY."""
|
||||
cdef:
|
||||
Codec codec
|
||||
|
||||
if format == PG_FORMAT_ANY:
|
||||
codec = get_core_codec(oid, PG_FORMAT_BINARY, xformat)
|
||||
if codec is None:
|
||||
codec = get_core_codec(oid, PG_FORMAT_TEXT, xformat)
|
||||
else:
|
||||
codec = get_core_codec(oid, format, xformat)
|
||||
|
||||
return codec
|
||||
|
||||
|
||||
cdef inline int has_core_codec(uint32_t oid):
|
||||
return binary_codec_map[oid] != NULL or text_codec_map[oid] != NULL
|
||||
|
||||
|
||||
cdef register_core_codec(uint32_t oid,
|
||||
encode_func encode,
|
||||
decode_func decode,
|
||||
ServerDataFormat format,
|
||||
ClientExchangeFormat xformat=PG_XFORMAT_OBJECT):
|
||||
|
||||
if oid > MAXSUPPORTEDOID:
|
||||
raise exceptions.InternalClientError(
|
||||
'cannot register core codec for OID {}: it is greater '
|
||||
'than MAXSUPPORTEDOID ({})'.format(oid, MAXSUPPORTEDOID))
|
||||
|
||||
cdef:
|
||||
Codec codec
|
||||
str name
|
||||
str kind
|
||||
|
||||
name = BUILTIN_TYPE_OID_MAP[oid]
|
||||
kind = 'array' if oid in ARRAY_TYPES else 'scalar'
|
||||
|
||||
codec = Codec(oid)
|
||||
codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat,
|
||||
encode, decode, None, None, None, None, None, None, None, 0)
|
||||
cpython.Py_INCREF(codec) # immortalize
|
||||
|
||||
if format == PG_FORMAT_BINARY:
|
||||
binary_codec_map[oid * xformat] = <void*>codec
|
||||
elif format == PG_FORMAT_TEXT:
|
||||
text_codec_map[oid * xformat] = <void*>codec
|
||||
else:
|
||||
raise exceptions.InternalClientError(
|
||||
'invalid data format: {}'.format(format))
|
||||
|
||||
|
||||
cdef register_extra_codec(str name,
|
||||
encode_func encode,
|
||||
decode_func decode,
|
||||
ServerDataFormat format):
|
||||
cdef:
|
||||
Codec codec
|
||||
str kind
|
||||
|
||||
kind = 'scalar'
|
||||
|
||||
codec = Codec(INVALIDOID)
|
||||
codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT,
|
||||
encode, decode, None, None, None, None, None, None, None, 0)
|
||||
EXTRA_CODECS[name, format] = codec
|
||||
|
||||
|
||||
cdef inline Codec get_extra_codec(str name, ServerDataFormat format):
|
||||
return EXTRA_CODECS.get((name, format))
|
||||
@@ -0,0 +1,484 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef init_bits_codecs():
|
||||
register_core_codec(BITOID,
|
||||
<encode_func>pgproto.bits_encode,
|
||||
<decode_func>pgproto.bits_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(VARBITOID,
|
||||
<encode_func>pgproto.bits_encode,
|
||||
<decode_func>pgproto.bits_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_bytea_codecs():
|
||||
register_core_codec(BYTEAOID,
|
||||
<encode_func>pgproto.bytea_encode,
|
||||
<decode_func>pgproto.bytea_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(CHAROID,
|
||||
<encode_func>pgproto.bytea_encode,
|
||||
<decode_func>pgproto.bytea_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_datetime_codecs():
|
||||
register_core_codec(DATEOID,
|
||||
<encode_func>pgproto.date_encode,
|
||||
<decode_func>pgproto.date_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(DATEOID,
|
||||
<encode_func>pgproto.date_encode_tuple,
|
||||
<decode_func>pgproto.date_decode_tuple,
|
||||
PG_FORMAT_BINARY,
|
||||
PG_XFORMAT_TUPLE)
|
||||
|
||||
register_core_codec(TIMEOID,
|
||||
<encode_func>pgproto.time_encode,
|
||||
<decode_func>pgproto.time_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(TIMEOID,
|
||||
<encode_func>pgproto.time_encode_tuple,
|
||||
<decode_func>pgproto.time_decode_tuple,
|
||||
PG_FORMAT_BINARY,
|
||||
PG_XFORMAT_TUPLE)
|
||||
|
||||
register_core_codec(TIMETZOID,
|
||||
<encode_func>pgproto.timetz_encode,
|
||||
<decode_func>pgproto.timetz_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(TIMETZOID,
|
||||
<encode_func>pgproto.timetz_encode_tuple,
|
||||
<decode_func>pgproto.timetz_decode_tuple,
|
||||
PG_FORMAT_BINARY,
|
||||
PG_XFORMAT_TUPLE)
|
||||
|
||||
register_core_codec(TIMESTAMPOID,
|
||||
<encode_func>pgproto.timestamp_encode,
|
||||
<decode_func>pgproto.timestamp_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(TIMESTAMPOID,
|
||||
<encode_func>pgproto.timestamp_encode_tuple,
|
||||
<decode_func>pgproto.timestamp_decode_tuple,
|
||||
PG_FORMAT_BINARY,
|
||||
PG_XFORMAT_TUPLE)
|
||||
|
||||
register_core_codec(TIMESTAMPTZOID,
|
||||
<encode_func>pgproto.timestamptz_encode,
|
||||
<decode_func>pgproto.timestamptz_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(TIMESTAMPTZOID,
|
||||
<encode_func>pgproto.timestamp_encode_tuple,
|
||||
<decode_func>pgproto.timestamp_decode_tuple,
|
||||
PG_FORMAT_BINARY,
|
||||
PG_XFORMAT_TUPLE)
|
||||
|
||||
register_core_codec(INTERVALOID,
|
||||
<encode_func>pgproto.interval_encode,
|
||||
<decode_func>pgproto.interval_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(INTERVALOID,
|
||||
<encode_func>pgproto.interval_encode_tuple,
|
||||
<decode_func>pgproto.interval_decode_tuple,
|
||||
PG_FORMAT_BINARY,
|
||||
PG_XFORMAT_TUPLE)
|
||||
|
||||
# For obsolete abstime/reltime/tinterval, we do not bother to
|
||||
# interpret the value, and simply return and pass it as text.
|
||||
#
|
||||
register_core_codec(ABSTIMEOID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
register_core_codec(RELTIMEOID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
register_core_codec(TINTERVALOID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
|
||||
cdef init_float_codecs():
|
||||
register_core_codec(FLOAT4OID,
|
||||
<encode_func>pgproto.float4_encode,
|
||||
<decode_func>pgproto.float4_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(FLOAT8OID,
|
||||
<encode_func>pgproto.float8_encode,
|
||||
<decode_func>pgproto.float8_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_geometry_codecs():
|
||||
register_core_codec(BOXOID,
|
||||
<encode_func>pgproto.box_encode,
|
||||
<decode_func>pgproto.box_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(LINEOID,
|
||||
<encode_func>pgproto.line_encode,
|
||||
<decode_func>pgproto.line_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(LSEGOID,
|
||||
<encode_func>pgproto.lseg_encode,
|
||||
<decode_func>pgproto.lseg_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(POINTOID,
|
||||
<encode_func>pgproto.point_encode,
|
||||
<decode_func>pgproto.point_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(PATHOID,
|
||||
<encode_func>pgproto.path_encode,
|
||||
<decode_func>pgproto.path_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(POLYGONOID,
|
||||
<encode_func>pgproto.poly_encode,
|
||||
<decode_func>pgproto.poly_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(CIRCLEOID,
|
||||
<encode_func>pgproto.circle_encode,
|
||||
<decode_func>pgproto.circle_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_hstore_codecs():
|
||||
register_extra_codec('pg_contrib.hstore',
|
||||
<encode_func>pgproto.hstore_encode,
|
||||
<decode_func>pgproto.hstore_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_json_codecs():
|
||||
register_core_codec(JSONOID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
register_core_codec(JSONBOID,
|
||||
<encode_func>pgproto.jsonb_encode,
|
||||
<decode_func>pgproto.jsonb_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
register_core_codec(JSONPATHOID,
|
||||
<encode_func>pgproto.jsonpath_encode,
|
||||
<decode_func>pgproto.jsonpath_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_int_codecs():
|
||||
|
||||
register_core_codec(BOOLOID,
|
||||
<encode_func>pgproto.bool_encode,
|
||||
<decode_func>pgproto.bool_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(INT2OID,
|
||||
<encode_func>pgproto.int2_encode,
|
||||
<decode_func>pgproto.int2_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(INT4OID,
|
||||
<encode_func>pgproto.int4_encode,
|
||||
<decode_func>pgproto.int4_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(INT8OID,
|
||||
<encode_func>pgproto.int8_encode,
|
||||
<decode_func>pgproto.int8_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_pseudo_codecs():
|
||||
# Void type is returned by SELECT void_returning_function()
|
||||
register_core_codec(VOIDOID,
|
||||
<encode_func>pgproto.void_encode,
|
||||
<decode_func>pgproto.void_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
# Unknown type, always decoded as text
|
||||
register_core_codec(UNKNOWNOID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
# OID and friends
|
||||
oid_types = [
|
||||
OIDOID, XIDOID, CIDOID
|
||||
]
|
||||
|
||||
for oid_type in oid_types:
|
||||
register_core_codec(oid_type,
|
||||
<encode_func>pgproto.uint4_encode,
|
||||
<decode_func>pgproto.uint4_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
# 64-bit OID types
|
||||
oid8_types = [
|
||||
XID8OID,
|
||||
]
|
||||
|
||||
for oid_type in oid8_types:
|
||||
register_core_codec(oid_type,
|
||||
<encode_func>pgproto.uint8_encode,
|
||||
<decode_func>pgproto.uint8_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
# reg* types -- these are really system catalog OIDs, but
|
||||
# allow the catalog object name as an input. We could just
|
||||
# decode these as OIDs, but handling them as text seems more
|
||||
# useful.
|
||||
#
|
||||
reg_types = [
|
||||
REGPROCOID, REGPROCEDUREOID, REGOPEROID, REGOPERATOROID,
|
||||
REGCLASSOID, REGTYPEOID, REGCONFIGOID, REGDICTIONARYOID,
|
||||
REGNAMESPACEOID, REGROLEOID, REFCURSOROID, REGCOLLATIONOID,
|
||||
]
|
||||
|
||||
for reg_type in reg_types:
|
||||
register_core_codec(reg_type,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
# cstring type is used by Postgres' I/O functions
|
||||
register_core_codec(CSTRINGOID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
# various system pseudotypes with no I/O
|
||||
no_io_types = [
|
||||
ANYOID, TRIGGEROID, EVENT_TRIGGEROID, LANGUAGE_HANDLEROID,
|
||||
FDW_HANDLEROID, TSM_HANDLEROID, INTERNALOID, OPAQUEOID,
|
||||
ANYELEMENTOID, ANYNONARRAYOID, ANYCOMPATIBLEOID,
|
||||
ANYCOMPATIBLEARRAYOID, ANYCOMPATIBLENONARRAYOID,
|
||||
ANYCOMPATIBLERANGEOID, ANYCOMPATIBLEMULTIRANGEOID,
|
||||
ANYRANGEOID, ANYMULTIRANGEOID, ANYARRAYOID,
|
||||
PG_DDL_COMMANDOID, INDEX_AM_HANDLEROID, TABLE_AM_HANDLEROID,
|
||||
]
|
||||
|
||||
register_core_codec(ANYENUMOID,
|
||||
NULL,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
for no_io_type in no_io_types:
|
||||
register_core_codec(no_io_type,
|
||||
NULL,
|
||||
NULL,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
# ACL specification string
|
||||
register_core_codec(ACLITEMOID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
# Postgres' serialized expression tree type
|
||||
register_core_codec(PG_NODE_TREEOID,
|
||||
NULL,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
# pg_lsn type -- a pointer to a location in the XLOG.
|
||||
register_core_codec(PG_LSNOID,
|
||||
<encode_func>pgproto.int8_encode,
|
||||
<decode_func>pgproto.int8_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(SMGROID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
# pg_dependencies and pg_ndistinct are special types
|
||||
# used in pg_statistic_ext columns.
|
||||
register_core_codec(PG_DEPENDENCIESOID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
register_core_codec(PG_NDISTINCTOID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
# pg_mcv_list is a special type used in pg_statistic_ext_data
|
||||
# system catalog
|
||||
register_core_codec(PG_MCV_LISTOID,
|
||||
<encode_func>pgproto.bytea_encode,
|
||||
<decode_func>pgproto.bytea_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
# These two are internal to BRIN index support and are unlikely
|
||||
# to be sent, but since I/O functions for these exist, add decoders
|
||||
# nonetheless.
|
||||
register_core_codec(PG_BRIN_BLOOM_SUMMARYOID,
|
||||
NULL,
|
||||
<decode_func>pgproto.bytea_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(PG_BRIN_MINMAX_MULTI_SUMMARYOID,
|
||||
NULL,
|
||||
<decode_func>pgproto.bytea_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_text_codecs():
|
||||
textoids = [
|
||||
NAMEOID,
|
||||
BPCHAROID,
|
||||
VARCHAROID,
|
||||
TEXTOID,
|
||||
XMLOID
|
||||
]
|
||||
|
||||
for oid in textoids:
|
||||
register_core_codec(oid,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(oid,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
|
||||
cdef init_tid_codecs():
|
||||
register_core_codec(TIDOID,
|
||||
<encode_func>pgproto.tid_encode,
|
||||
<decode_func>pgproto.tid_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_txid_codecs():
|
||||
register_core_codec(TXID_SNAPSHOTOID,
|
||||
<encode_func>pgproto.pg_snapshot_encode,
|
||||
<decode_func>pgproto.pg_snapshot_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(PG_SNAPSHOTOID,
|
||||
<encode_func>pgproto.pg_snapshot_encode,
|
||||
<decode_func>pgproto.pg_snapshot_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_tsearch_codecs():
|
||||
ts_oids = [
|
||||
TSQUERYOID,
|
||||
TSVECTOROID,
|
||||
]
|
||||
|
||||
for oid in ts_oids:
|
||||
register_core_codec(oid,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
register_core_codec(GTSVECTOROID,
|
||||
NULL,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
|
||||
cdef init_uuid_codecs():
|
||||
register_core_codec(UUIDOID,
|
||||
<encode_func>pgproto.uuid_encode,
|
||||
<decode_func>pgproto.uuid_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_numeric_codecs():
|
||||
register_core_codec(NUMERICOID,
|
||||
<encode_func>pgproto.numeric_encode_text,
|
||||
<decode_func>pgproto.numeric_decode_text,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
register_core_codec(NUMERICOID,
|
||||
<encode_func>pgproto.numeric_encode_binary,
|
||||
<decode_func>pgproto.numeric_decode_binary,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
|
||||
cdef init_network_codecs():
|
||||
register_core_codec(CIDROID,
|
||||
<encode_func>pgproto.cidr_encode,
|
||||
<decode_func>pgproto.cidr_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(INETOID,
|
||||
<encode_func>pgproto.inet_encode,
|
||||
<decode_func>pgproto.inet_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
register_core_codec(MACADDROID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
register_core_codec(MACADDR8OID,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
|
||||
cdef init_monetary_codecs():
|
||||
moneyoids = [
|
||||
MONEYOID,
|
||||
]
|
||||
|
||||
for oid in moneyoids:
|
||||
register_core_codec(oid,
|
||||
<encode_func>pgproto.text_encode,
|
||||
<decode_func>pgproto.text_decode,
|
||||
PG_FORMAT_TEXT)
|
||||
|
||||
|
||||
cdef init_all_pgproto_codecs():
|
||||
# Builtin types, in lexicographical order.
|
||||
init_bits_codecs()
|
||||
init_bytea_codecs()
|
||||
init_datetime_codecs()
|
||||
init_float_codecs()
|
||||
init_geometry_codecs()
|
||||
init_int_codecs()
|
||||
init_json_codecs()
|
||||
init_monetary_codecs()
|
||||
init_network_codecs()
|
||||
init_numeric_codecs()
|
||||
init_text_codecs()
|
||||
init_tid_codecs()
|
||||
init_tsearch_codecs()
|
||||
init_txid_codecs()
|
||||
init_uuid_codecs()
|
||||
|
||||
# Various pseudotypes and system types
|
||||
init_pseudo_codecs()
|
||||
|
||||
# contrib
|
||||
init_hstore_codecs()
|
||||
|
||||
|
||||
init_all_pgproto_codecs()
|
||||
@@ -0,0 +1,207 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from asyncpg import types as apg_types
|
||||
|
||||
from collections.abc import Sequence as SequenceABC
|
||||
|
||||
# defined in postgresql/src/include/utils/rangetypes.h
|
||||
DEF RANGE_EMPTY = 0x01 # range is empty
|
||||
DEF RANGE_LB_INC = 0x02 # lower bound is inclusive
|
||||
DEF RANGE_UB_INC = 0x04 # upper bound is inclusive
|
||||
DEF RANGE_LB_INF = 0x08 # lower bound is -infinity
|
||||
DEF RANGE_UB_INF = 0x10 # upper bound is +infinity
|
||||
|
||||
|
||||
cdef enum _RangeArgumentType:
|
||||
_RANGE_ARGUMENT_INVALID = 0
|
||||
_RANGE_ARGUMENT_TUPLE = 1
|
||||
_RANGE_ARGUMENT_RANGE = 2
|
||||
|
||||
|
||||
cdef inline bint _range_has_lbound(uint8_t flags):
|
||||
return not (flags & (RANGE_EMPTY | RANGE_LB_INF))
|
||||
|
||||
|
||||
cdef inline bint _range_has_ubound(uint8_t flags):
|
||||
return not (flags & (RANGE_EMPTY | RANGE_UB_INF))
|
||||
|
||||
|
||||
cdef inline _RangeArgumentType _range_type(object obj):
|
||||
if cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj):
|
||||
return _RANGE_ARGUMENT_TUPLE
|
||||
elif isinstance(obj, apg_types.Range):
|
||||
return _RANGE_ARGUMENT_RANGE
|
||||
else:
|
||||
return _RANGE_ARGUMENT_INVALID
|
||||
|
||||
|
||||
cdef range_encode(ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj, uint32_t elem_oid,
|
||||
encode_func_ex encoder, const void *encoder_arg):
|
||||
cdef:
|
||||
ssize_t obj_len
|
||||
uint8_t flags = 0
|
||||
object lower = None
|
||||
object upper = None
|
||||
WriteBuffer bounds_data = WriteBuffer.new()
|
||||
_RangeArgumentType arg_type = _range_type(obj)
|
||||
|
||||
if arg_type == _RANGE_ARGUMENT_INVALID:
|
||||
raise TypeError(
|
||||
'list, tuple or Range object expected (got type {})'.format(
|
||||
type(obj)))
|
||||
|
||||
elif arg_type == _RANGE_ARGUMENT_TUPLE:
|
||||
obj_len = len(obj)
|
||||
if obj_len == 2:
|
||||
lower = obj[0]
|
||||
upper = obj[1]
|
||||
|
||||
if lower is None:
|
||||
flags |= RANGE_LB_INF
|
||||
|
||||
if upper is None:
|
||||
flags |= RANGE_UB_INF
|
||||
|
||||
flags |= RANGE_LB_INC | RANGE_UB_INC
|
||||
|
||||
elif obj_len == 1:
|
||||
lower = obj[0]
|
||||
flags |= RANGE_LB_INC | RANGE_UB_INF
|
||||
|
||||
elif obj_len == 0:
|
||||
flags |= RANGE_EMPTY
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
'expected 0, 1 or 2 elements in range (got {})'.format(
|
||||
obj_len))
|
||||
|
||||
else:
|
||||
if obj.isempty:
|
||||
flags |= RANGE_EMPTY
|
||||
else:
|
||||
lower = obj.lower
|
||||
upper = obj.upper
|
||||
|
||||
if obj.lower_inc:
|
||||
flags |= RANGE_LB_INC
|
||||
elif lower is None:
|
||||
flags |= RANGE_LB_INF
|
||||
|
||||
if obj.upper_inc:
|
||||
flags |= RANGE_UB_INC
|
||||
elif upper is None:
|
||||
flags |= RANGE_UB_INF
|
||||
|
||||
if _range_has_lbound(flags):
|
||||
encoder(settings, bounds_data, lower, encoder_arg)
|
||||
|
||||
if _range_has_ubound(flags):
|
||||
encoder(settings, bounds_data, upper, encoder_arg)
|
||||
|
||||
buf.write_int32(1 + bounds_data.len())
|
||||
buf.write_byte(<int8_t>flags)
|
||||
buf.write_buffer(bounds_data)
|
||||
|
||||
|
||||
cdef range_decode(ConnectionSettings settings, FRBuffer *buf,
|
||||
decode_func_ex decoder, const void *decoder_arg):
|
||||
cdef:
|
||||
uint8_t flags = <uint8_t>frb_read(buf, 1)[0]
|
||||
int32_t bound_len
|
||||
object lower = None
|
||||
object upper = None
|
||||
FRBuffer bound_buf
|
||||
|
||||
if _range_has_lbound(flags):
|
||||
bound_len = hton.unpack_int32(frb_read(buf, 4))
|
||||
if bound_len == -1:
|
||||
lower = None
|
||||
else:
|
||||
frb_slice_from(&bound_buf, buf, bound_len)
|
||||
lower = decoder(settings, &bound_buf, decoder_arg)
|
||||
|
||||
if _range_has_ubound(flags):
|
||||
bound_len = hton.unpack_int32(frb_read(buf, 4))
|
||||
if bound_len == -1:
|
||||
upper = None
|
||||
else:
|
||||
frb_slice_from(&bound_buf, buf, bound_len)
|
||||
upper = decoder(settings, &bound_buf, decoder_arg)
|
||||
|
||||
return apg_types.Range(lower=lower, upper=upper,
|
||||
lower_inc=(flags & RANGE_LB_INC) != 0,
|
||||
upper_inc=(flags & RANGE_UB_INC) != 0,
|
||||
empty=(flags & RANGE_EMPTY) != 0)
|
||||
|
||||
|
||||
cdef multirange_encode(ConnectionSettings settings, WriteBuffer buf,
|
||||
object obj, uint32_t elem_oid,
|
||||
encode_func_ex encoder, const void *encoder_arg):
|
||||
cdef:
|
||||
WriteBuffer elem_data
|
||||
ssize_t elem_data_len
|
||||
ssize_t elem_count
|
||||
|
||||
if not isinstance(obj, SequenceABC):
|
||||
raise TypeError(
|
||||
'expected a sequence (got type {!r})'.format(type(obj).__name__)
|
||||
)
|
||||
|
||||
elem_data = WriteBuffer.new()
|
||||
|
||||
for elem in obj:
|
||||
range_encode(settings, elem_data, elem, elem_oid, encoder, encoder_arg)
|
||||
|
||||
elem_count = len(obj)
|
||||
if elem_count > INT32_MAX:
|
||||
raise OverflowError(f'too many elements in multirange value')
|
||||
|
||||
elem_data_len = elem_data.len()
|
||||
if elem_data_len > INT32_MAX - 4:
|
||||
raise OverflowError(
|
||||
f'size of encoded multirange datum exceeds the maximum allowed'
|
||||
f' {INT32_MAX - 4} bytes')
|
||||
|
||||
# Datum length
|
||||
buf.write_int32(4 + <int32_t>elem_data_len)
|
||||
# Number of elements in multirange
|
||||
buf.write_int32(<int32_t>elem_count)
|
||||
buf.write_buffer(elem_data)
|
||||
|
||||
|
||||
cdef multirange_decode(ConnectionSettings settings, FRBuffer *buf,
|
||||
decode_func_ex decoder, const void *decoder_arg):
|
||||
cdef:
|
||||
int32_t nelems = hton.unpack_int32(frb_read(buf, 4))
|
||||
FRBuffer elem_buf
|
||||
int32_t elem_len
|
||||
int i
|
||||
list result
|
||||
|
||||
if nelems == 0:
|
||||
return []
|
||||
|
||||
if nelems < 0:
|
||||
raise exceptions.ProtocolError(
|
||||
'unexpected multirange size value: {}'.format(nelems))
|
||||
|
||||
result = cpython.PyList_New(nelems)
|
||||
for i in range(nelems):
|
||||
elem_len = hton.unpack_int32(frb_read(buf, 4))
|
||||
if elem_len == -1:
|
||||
raise exceptions.ProtocolError(
|
||||
'unexpected NULL element in multirange value')
|
||||
else:
|
||||
frb_slice_from(&elem_buf, buf, elem_len)
|
||||
elem = range_decode(settings, &elem_buf, decoder, decoder_arg)
|
||||
cpython.Py_INCREF(elem)
|
||||
cpython.PyList_SET_ITEM(result, i, elem)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from asyncpg import exceptions
|
||||
|
||||
|
||||
cdef inline record_encode_frame(ConnectionSettings settings, WriteBuffer buf,
|
||||
WriteBuffer elem_data, int32_t elem_count):
|
||||
buf.write_int32(4 + elem_data.len())
|
||||
# attribute count
|
||||
buf.write_int32(elem_count)
|
||||
# encoded attribute data
|
||||
buf.write_buffer(elem_data)
|
||||
|
||||
|
||||
cdef anonymous_record_decode(ConnectionSettings settings, FRBuffer *buf):
|
||||
cdef:
|
||||
tuple result
|
||||
ssize_t elem_count
|
||||
ssize_t i
|
||||
int32_t elem_len
|
||||
uint32_t elem_typ
|
||||
Codec elem_codec
|
||||
FRBuffer elem_buf
|
||||
|
||||
elem_count = <ssize_t><uint32_t>hton.unpack_int32(frb_read(buf, 4))
|
||||
result = cpython.PyTuple_New(elem_count)
|
||||
|
||||
for i in range(elem_count):
|
||||
elem_typ = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
|
||||
elem_len = hton.unpack_int32(frb_read(buf, 4))
|
||||
|
||||
if elem_len == -1:
|
||||
elem = None
|
||||
else:
|
||||
elem_codec = settings.get_data_codec(elem_typ)
|
||||
if elem_codec is None or not elem_codec.has_decoder():
|
||||
raise exceptions.InternalClientError(
|
||||
'no decoder for composite type element in '
|
||||
'position {} of type OID {}'.format(i, elem_typ))
|
||||
elem = elem_codec.decode(settings,
|
||||
frb_slice_from(&elem_buf, buf, elem_len))
|
||||
|
||||
cpython.Py_INCREF(elem)
|
||||
cpython.PyTuple_SET_ITEM(result, i, elem)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
cdef anonymous_record_encode(ConnectionSettings settings, WriteBuffer buf, obj):
|
||||
raise exceptions.UnsupportedClientFeatureError(
|
||||
'input of anonymous composite types is not supported',
|
||||
hint=(
|
||||
'Consider declaring an explicit composite type and '
|
||||
'using it to cast the argument.'
|
||||
),
|
||||
detail='PostgreSQL does not implement anonymous composite type input.'
|
||||
)
|
||||
|
||||
|
||||
cdef init_record_codecs():
|
||||
register_core_codec(RECORDOID,
|
||||
<encode_func>anonymous_record_encode,
|
||||
<decode_func>anonymous_record_decode,
|
||||
PG_FORMAT_BINARY)
|
||||
|
||||
init_record_codecs()
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef inline uint32_t _apg_tolower(uint32_t c):
|
||||
if c >= <uint32_t><Py_UCS4>'A' and c <= <uint32_t><Py_UCS4>'Z':
|
||||
return c + <uint32_t><Py_UCS4>'a' - <uint32_t><Py_UCS4>'A'
|
||||
else:
|
||||
return c
|
||||
|
||||
|
||||
cdef int apg_strcasecmp(const Py_UCS4 *s1, const Py_UCS4 *s2):
|
||||
cdef:
|
||||
uint32_t c1
|
||||
uint32_t c2
|
||||
int i = 0
|
||||
|
||||
while True:
|
||||
c1 = s1[i]
|
||||
c2 = s2[i]
|
||||
|
||||
if c1 != c2:
|
||||
c1 = _apg_tolower(c1)
|
||||
c2 = _apg_tolower(c2)
|
||||
if c1 != c2:
|
||||
return <int32_t>c1 - <int32_t>c2
|
||||
|
||||
if c1 == 0 or c2 == 0:
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
cdef int apg_strcasecmp_char(const char *s1, const char *s2):
|
||||
cdef:
|
||||
uint8_t c1
|
||||
uint8_t c2
|
||||
int i = 0
|
||||
|
||||
while True:
|
||||
c1 = <uint8_t>s1[i]
|
||||
c2 = <uint8_t>s2[i]
|
||||
|
||||
if c1 != c2:
|
||||
c1 = <uint8_t>_apg_tolower(c1)
|
||||
c2 = <uint8_t>_apg_tolower(c2)
|
||||
if c1 != c2:
|
||||
return <int8_t>c1 - <int8_t>c2
|
||||
|
||||
if c1 == 0 or c2 == 0:
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
cdef inline bint apg_ascii_isspace(Py_UCS4 ch):
|
||||
return (
|
||||
ch == ' ' or
|
||||
ch == '\n' or
|
||||
ch == '\r' or
|
||||
ch == '\t' or
|
||||
ch == '\v' or
|
||||
ch == '\f'
|
||||
)
|
||||
|
||||
|
||||
cdef Py_UCS4 *apg_parse_int32(Py_UCS4 *buf, int32_t *num):
|
||||
cdef:
|
||||
Py_UCS4 *p
|
||||
int32_t n = 0
|
||||
int32_t neg = 0
|
||||
|
||||
if buf[0] == '-':
|
||||
neg = 1
|
||||
buf += 1
|
||||
elif buf[0] == '+':
|
||||
buf += 1
|
||||
|
||||
p = buf
|
||||
while <int>p[0] >= <int><Py_UCS4>'0' and <int>p[0] <= <int><Py_UCS4>'9':
|
||||
n = 10 * n - (<int>p[0] - <int32_t><Py_UCS4>'0')
|
||||
p += 1
|
||||
|
||||
if p == buf:
|
||||
return NULL
|
||||
|
||||
if not neg:
|
||||
n = -n
|
||||
|
||||
num[0] = n
|
||||
|
||||
return p
|
||||
@@ -0,0 +1,12 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
DEF _MAXINT32 = 2**31 - 1
|
||||
DEF _COPY_BUFFER_SIZE = 524288
|
||||
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
|
||||
DEF _EXECUTE_MANY_BUF_NUM = 4
|
||||
DEF _EXECUTE_MANY_BUF_SIZE = 32768
|
||||
195
venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd
Normal file
195
venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
include "scram.pxd"
|
||||
|
||||
|
||||
cdef enum ConnectionStatus:
|
||||
CONNECTION_OK = 1
|
||||
CONNECTION_BAD = 2
|
||||
CONNECTION_STARTED = 3 # Waiting for connection to be made.
|
||||
|
||||
|
||||
cdef enum ProtocolState:
|
||||
PROTOCOL_IDLE = 0
|
||||
|
||||
PROTOCOL_FAILED = 1
|
||||
PROTOCOL_ERROR_CONSUME = 2
|
||||
PROTOCOL_CANCELLED = 3
|
||||
PROTOCOL_TERMINATING = 4
|
||||
|
||||
PROTOCOL_AUTH = 10
|
||||
PROTOCOL_PREPARE = 11
|
||||
PROTOCOL_BIND_EXECUTE = 12
|
||||
PROTOCOL_BIND_EXECUTE_MANY = 13
|
||||
PROTOCOL_CLOSE_STMT_PORTAL = 14
|
||||
PROTOCOL_SIMPLE_QUERY = 15
|
||||
PROTOCOL_EXECUTE = 16
|
||||
PROTOCOL_BIND = 17
|
||||
PROTOCOL_COPY_OUT = 18
|
||||
PROTOCOL_COPY_OUT_DATA = 19
|
||||
PROTOCOL_COPY_OUT_DONE = 20
|
||||
PROTOCOL_COPY_IN = 21
|
||||
PROTOCOL_COPY_IN_DATA = 22
|
||||
|
||||
|
||||
cdef enum AuthenticationMessage:
|
||||
AUTH_SUCCESSFUL = 0
|
||||
AUTH_REQUIRED_KERBEROS = 2
|
||||
AUTH_REQUIRED_PASSWORD = 3
|
||||
AUTH_REQUIRED_PASSWORDMD5 = 5
|
||||
AUTH_REQUIRED_SCMCRED = 6
|
||||
AUTH_REQUIRED_GSS = 7
|
||||
AUTH_REQUIRED_GSS_CONTINUE = 8
|
||||
AUTH_REQUIRED_SSPI = 9
|
||||
AUTH_REQUIRED_SASL = 10
|
||||
AUTH_SASL_CONTINUE = 11
|
||||
AUTH_SASL_FINAL = 12
|
||||
|
||||
|
||||
AUTH_METHOD_NAME = {
|
||||
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
|
||||
AUTH_REQUIRED_PASSWORD: 'password',
|
||||
AUTH_REQUIRED_PASSWORDMD5: 'md5',
|
||||
AUTH_REQUIRED_GSS: 'gss',
|
||||
AUTH_REQUIRED_SASL: 'scram-sha-256',
|
||||
AUTH_REQUIRED_SSPI: 'sspi',
|
||||
}
|
||||
|
||||
|
||||
cdef enum ResultType:
|
||||
RESULT_OK = 1
|
||||
RESULT_FAILED = 2
|
||||
|
||||
|
||||
cdef enum TransactionStatus:
|
||||
PQTRANS_IDLE = 0 # connection idle
|
||||
PQTRANS_ACTIVE = 1 # command in progress
|
||||
PQTRANS_INTRANS = 2 # idle, within transaction block
|
||||
PQTRANS_INERROR = 3 # idle, within failed transaction
|
||||
PQTRANS_UNKNOWN = 4 # cannot determine status
|
||||
|
||||
|
||||
ctypedef object (*decode_row_method)(object, const char*, ssize_t)
|
||||
|
||||
|
||||
cdef class CoreProtocol:
|
||||
cdef:
|
||||
ReadBuffer buffer
|
||||
bint _skip_discard
|
||||
bint _discard_data
|
||||
|
||||
# executemany support data
|
||||
object _execute_iter
|
||||
str _execute_portal_name
|
||||
str _execute_stmt_name
|
||||
|
||||
ConnectionStatus con_status
|
||||
ProtocolState state
|
||||
TransactionStatus xact_status
|
||||
|
||||
str encoding
|
||||
|
||||
object transport
|
||||
|
||||
# Instance of _ConnectionParameters
|
||||
object con_params
|
||||
# Instance of SCRAMAuthentication
|
||||
SCRAMAuthentication scram
|
||||
|
||||
readonly int32_t backend_pid
|
||||
readonly int32_t backend_secret
|
||||
|
||||
## Result
|
||||
ResultType result_type
|
||||
object result
|
||||
bytes result_param_desc
|
||||
bytes result_row_desc
|
||||
bytes result_status_msg
|
||||
|
||||
# True - completed, False - suspended
|
||||
bint result_execute_completed
|
||||
|
||||
cpdef is_in_transaction(self)
|
||||
cdef _process__auth(self, char mtype)
|
||||
cdef _process__prepare(self, char mtype)
|
||||
cdef _process__bind_execute(self, char mtype)
|
||||
cdef _process__bind_execute_many(self, char mtype)
|
||||
cdef _process__close_stmt_portal(self, char mtype)
|
||||
cdef _process__simple_query(self, char mtype)
|
||||
cdef _process__bind(self, char mtype)
|
||||
cdef _process__copy_out(self, char mtype)
|
||||
cdef _process__copy_out_data(self, char mtype)
|
||||
cdef _process__copy_in(self, char mtype)
|
||||
cdef _process__copy_in_data(self, char mtype)
|
||||
|
||||
cdef _parse_msg_authentication(self)
|
||||
cdef _parse_msg_parameter_status(self)
|
||||
cdef _parse_msg_notification(self)
|
||||
cdef _parse_msg_backend_key_data(self)
|
||||
cdef _parse_msg_ready_for_query(self)
|
||||
cdef _parse_data_msgs(self)
|
||||
cdef _parse_copy_data_msgs(self)
|
||||
cdef _parse_msg_error_response(self, is_error)
|
||||
cdef _parse_msg_command_complete(self)
|
||||
|
||||
cdef _write_copy_data_msg(self, object data)
|
||||
cdef _write_copy_done_msg(self)
|
||||
cdef _write_copy_fail_msg(self, str cause)
|
||||
|
||||
cdef _auth_password_message_cleartext(self)
|
||||
cdef _auth_password_message_md5(self, bytes salt)
|
||||
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
|
||||
cdef _auth_password_message_sasl_continue(self, bytes server_response)
|
||||
|
||||
cdef _write(self, buf)
|
||||
cdef _writelines(self, list buffers)
|
||||
|
||||
cdef _read_server_messages(self)
|
||||
|
||||
cdef _push_result(self)
|
||||
cdef _reset_result(self)
|
||||
cdef _set_state(self, ProtocolState new_state)
|
||||
|
||||
cdef _ensure_connected(self)
|
||||
|
||||
cdef WriteBuffer _build_parse_message(self, str stmt_name, str query)
|
||||
cdef WriteBuffer _build_bind_message(self, str portal_name,
|
||||
str stmt_name,
|
||||
WriteBuffer bind_data)
|
||||
cdef WriteBuffer _build_empty_bind_data(self)
|
||||
cdef WriteBuffer _build_execute_message(self, str portal_name,
|
||||
int32_t limit)
|
||||
|
||||
|
||||
cdef _connect(self)
|
||||
cdef _prepare_and_describe(self, str stmt_name, str query)
|
||||
cdef _send_parse_message(self, str stmt_name, str query)
|
||||
cdef _send_bind_message(self, str portal_name, str stmt_name,
|
||||
WriteBuffer bind_data, int32_t limit)
|
||||
cdef _bind_execute(self, str portal_name, str stmt_name,
|
||||
WriteBuffer bind_data, int32_t limit)
|
||||
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
|
||||
object bind_data)
|
||||
cdef bint _bind_execute_many_more(self, bint first=*)
|
||||
cdef _bind_execute_many_fail(self, object error, bint first=*)
|
||||
cdef _bind(self, str portal_name, str stmt_name,
|
||||
WriteBuffer bind_data)
|
||||
cdef _execute(self, str portal_name, int32_t limit)
|
||||
cdef _close(self, str name, bint is_portal)
|
||||
cdef _simple_query(self, str query)
|
||||
cdef _copy_out(self, str copy_stmt)
|
||||
cdef _copy_in(self, str copy_stmt)
|
||||
cdef _terminate(self)
|
||||
|
||||
cdef _decode_row(self, const char* buf, ssize_t buf_len)
|
||||
|
||||
cdef _on_result(self)
|
||||
cdef _on_notification(self, pid, channel, payload)
|
||||
cdef _on_notice(self, parsed)
|
||||
cdef _set_server_parameter(self, name, val)
|
||||
cdef _on_connection_lost(self, exc)
|
||||
1153
venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx
Normal file
1153
venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef extern from "Python.h":
|
||||
int PyByteArray_Check(object)
|
||||
|
||||
int PyMemoryView_Check(object)
|
||||
Py_buffer *PyMemoryView_GET_BUFFER(object)
|
||||
object PyMemoryView_GetContiguous(object, int buffertype, char order)
|
||||
|
||||
Py_UCS4* PyUnicode_AsUCS4Copy(object) except NULL
|
||||
object PyUnicode_FromKindAndData(
|
||||
int kind, const void *buffer, Py_ssize_t size)
|
||||
|
||||
int PyUnicode_4BYTE_KIND
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
'''Map PostgreSQL encoding names to Python encoding names
|
||||
|
||||
https://www.postgresql.org/docs/current/static/multibyte.html#CHARSET-TABLE
|
||||
'''
|
||||
|
||||
cdef dict ENCODINGS_MAP = {
|
||||
'abc': 'cp1258',
|
||||
'alt': 'cp866',
|
||||
'euc_cn': 'euccn',
|
||||
'euc_jp': 'eucjp',
|
||||
'euc_kr': 'euckr',
|
||||
'koi8r': 'koi8_r',
|
||||
'koi8u': 'koi8_u',
|
||||
'shift_jis_2004': 'euc_jis_2004',
|
||||
'sjis': 'shift_jis',
|
||||
'sql_ascii': 'ascii',
|
||||
'vscii': 'cp1258',
|
||||
'tcvn': 'cp1258',
|
||||
'tcvn5712': 'cp1258',
|
||||
'unicode': 'utf_8',
|
||||
'win': 'cp1521',
|
||||
'win1250': 'cp1250',
|
||||
'win1251': 'cp1251',
|
||||
'win1252': 'cp1252',
|
||||
'win1253': 'cp1253',
|
||||
'win1254': 'cp1254',
|
||||
'win1255': 'cp1255',
|
||||
'win1256': 'cp1256',
|
||||
'win1257': 'cp1257',
|
||||
'win1258': 'cp1258',
|
||||
'win866': 'cp866',
|
||||
'win874': 'cp874',
|
||||
'win932': 'cp932',
|
||||
'win936': 'cp936',
|
||||
'win949': 'cp949',
|
||||
'win950': 'cp950',
|
||||
'windows1250': 'cp1250',
|
||||
'windows1251': 'cp1251',
|
||||
'windows1252': 'cp1252',
|
||||
'windows1253': 'cp1253',
|
||||
'windows1254': 'cp1254',
|
||||
'windows1255': 'cp1255',
|
||||
'windows1256': 'cp1256',
|
||||
'windows1257': 'cp1257',
|
||||
'windows1258': 'cp1258',
|
||||
'windows866': 'cp866',
|
||||
'windows874': 'cp874',
|
||||
'windows932': 'cp932',
|
||||
'windows936': 'cp936',
|
||||
'windows949': 'cp949',
|
||||
'windows950': 'cp950',
|
||||
}
|
||||
|
||||
|
||||
cdef get_python_encoding(pg_encoding):
|
||||
return ENCODINGS_MAP.get(pg_encoding.lower(), pg_encoding.lower())
|
||||
266
venv/lib/python3.12/site-packages/asyncpg/protocol/pgtypes.pxi
Normal file
266
venv/lib/python3.12/site-packages/asyncpg/protocol/pgtypes.pxi
Normal file
@@ -0,0 +1,266 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
# GENERATED FROM pg_catalog.pg_type
|
||||
# DO NOT MODIFY, use tools/generate_type_map.py to update
|
||||
|
||||
DEF INVALIDOID = 0
|
||||
DEF MAXBUILTINOID = 9999
|
||||
DEF MAXSUPPORTEDOID = 5080
|
||||
|
||||
DEF BOOLOID = 16
|
||||
DEF BYTEAOID = 17
|
||||
DEF CHAROID = 18
|
||||
DEF NAMEOID = 19
|
||||
DEF INT8OID = 20
|
||||
DEF INT2OID = 21
|
||||
DEF INT4OID = 23
|
||||
DEF REGPROCOID = 24
|
||||
DEF TEXTOID = 25
|
||||
DEF OIDOID = 26
|
||||
DEF TIDOID = 27
|
||||
DEF XIDOID = 28
|
||||
DEF CIDOID = 29
|
||||
DEF PG_DDL_COMMANDOID = 32
|
||||
DEF JSONOID = 114
|
||||
DEF XMLOID = 142
|
||||
DEF PG_NODE_TREEOID = 194
|
||||
DEF SMGROID = 210
|
||||
DEF TABLE_AM_HANDLEROID = 269
|
||||
DEF INDEX_AM_HANDLEROID = 325
|
||||
DEF POINTOID = 600
|
||||
DEF LSEGOID = 601
|
||||
DEF PATHOID = 602
|
||||
DEF BOXOID = 603
|
||||
DEF POLYGONOID = 604
|
||||
DEF LINEOID = 628
|
||||
DEF CIDROID = 650
|
||||
DEF FLOAT4OID = 700
|
||||
DEF FLOAT8OID = 701
|
||||
DEF ABSTIMEOID = 702
|
||||
DEF RELTIMEOID = 703
|
||||
DEF TINTERVALOID = 704
|
||||
DEF UNKNOWNOID = 705
|
||||
DEF CIRCLEOID = 718
|
||||
DEF MACADDR8OID = 774
|
||||
DEF MONEYOID = 790
|
||||
DEF MACADDROID = 829
|
||||
DEF INETOID = 869
|
||||
DEF _TEXTOID = 1009
|
||||
DEF _OIDOID = 1028
|
||||
DEF ACLITEMOID = 1033
|
||||
DEF BPCHAROID = 1042
|
||||
DEF VARCHAROID = 1043
|
||||
DEF DATEOID = 1082
|
||||
DEF TIMEOID = 1083
|
||||
DEF TIMESTAMPOID = 1114
|
||||
DEF TIMESTAMPTZOID = 1184
|
||||
DEF INTERVALOID = 1186
|
||||
DEF TIMETZOID = 1266
|
||||
DEF BITOID = 1560
|
||||
DEF VARBITOID = 1562
|
||||
DEF NUMERICOID = 1700
|
||||
DEF REFCURSOROID = 1790
|
||||
DEF REGPROCEDUREOID = 2202
|
||||
DEF REGOPEROID = 2203
|
||||
DEF REGOPERATOROID = 2204
|
||||
DEF REGCLASSOID = 2205
|
||||
DEF REGTYPEOID = 2206
|
||||
DEF RECORDOID = 2249
|
||||
DEF CSTRINGOID = 2275
|
||||
DEF ANYOID = 2276
|
||||
DEF ANYARRAYOID = 2277
|
||||
DEF VOIDOID = 2278
|
||||
DEF TRIGGEROID = 2279
|
||||
DEF LANGUAGE_HANDLEROID = 2280
|
||||
DEF INTERNALOID = 2281
|
||||
DEF OPAQUEOID = 2282
|
||||
DEF ANYELEMENTOID = 2283
|
||||
DEF ANYNONARRAYOID = 2776
|
||||
DEF UUIDOID = 2950
|
||||
DEF TXID_SNAPSHOTOID = 2970
|
||||
DEF FDW_HANDLEROID = 3115
|
||||
DEF PG_LSNOID = 3220
|
||||
DEF TSM_HANDLEROID = 3310
|
||||
DEF PG_NDISTINCTOID = 3361
|
||||
DEF PG_DEPENDENCIESOID = 3402
|
||||
DEF ANYENUMOID = 3500
|
||||
DEF TSVECTOROID = 3614
|
||||
DEF TSQUERYOID = 3615
|
||||
DEF GTSVECTOROID = 3642
|
||||
DEF REGCONFIGOID = 3734
|
||||
DEF REGDICTIONARYOID = 3769
|
||||
DEF JSONBOID = 3802
|
||||
DEF ANYRANGEOID = 3831
|
||||
DEF EVENT_TRIGGEROID = 3838
|
||||
DEF JSONPATHOID = 4072
|
||||
DEF REGNAMESPACEOID = 4089
|
||||
DEF REGROLEOID = 4096
|
||||
DEF REGCOLLATIONOID = 4191
|
||||
DEF ANYMULTIRANGEOID = 4537
|
||||
DEF ANYCOMPATIBLEMULTIRANGEOID = 4538
|
||||
DEF PG_BRIN_BLOOM_SUMMARYOID = 4600
|
||||
DEF PG_BRIN_MINMAX_MULTI_SUMMARYOID = 4601
|
||||
DEF PG_MCV_LISTOID = 5017
|
||||
DEF PG_SNAPSHOTOID = 5038
|
||||
DEF XID8OID = 5069
|
||||
DEF ANYCOMPATIBLEOID = 5077
|
||||
DEF ANYCOMPATIBLEARRAYOID = 5078
|
||||
DEF ANYCOMPATIBLENONARRAYOID = 5079
|
||||
DEF ANYCOMPATIBLERANGEOID = 5080
|
||||
|
||||
cdef ARRAY_TYPES = (_TEXTOID, _OIDOID,)
|
||||
|
||||
BUILTIN_TYPE_OID_MAP = {
|
||||
ABSTIMEOID: 'abstime',
|
||||
ACLITEMOID: 'aclitem',
|
||||
ANYARRAYOID: 'anyarray',
|
||||
ANYCOMPATIBLEARRAYOID: 'anycompatiblearray',
|
||||
ANYCOMPATIBLEMULTIRANGEOID: 'anycompatiblemultirange',
|
||||
ANYCOMPATIBLENONARRAYOID: 'anycompatiblenonarray',
|
||||
ANYCOMPATIBLEOID: 'anycompatible',
|
||||
ANYCOMPATIBLERANGEOID: 'anycompatiblerange',
|
||||
ANYELEMENTOID: 'anyelement',
|
||||
ANYENUMOID: 'anyenum',
|
||||
ANYMULTIRANGEOID: 'anymultirange',
|
||||
ANYNONARRAYOID: 'anynonarray',
|
||||
ANYOID: 'any',
|
||||
ANYRANGEOID: 'anyrange',
|
||||
BITOID: 'bit',
|
||||
BOOLOID: 'bool',
|
||||
BOXOID: 'box',
|
||||
BPCHAROID: 'bpchar',
|
||||
BYTEAOID: 'bytea',
|
||||
CHAROID: 'char',
|
||||
CIDOID: 'cid',
|
||||
CIDROID: 'cidr',
|
||||
CIRCLEOID: 'circle',
|
||||
CSTRINGOID: 'cstring',
|
||||
DATEOID: 'date',
|
||||
EVENT_TRIGGEROID: 'event_trigger',
|
||||
FDW_HANDLEROID: 'fdw_handler',
|
||||
FLOAT4OID: 'float4',
|
||||
FLOAT8OID: 'float8',
|
||||
GTSVECTOROID: 'gtsvector',
|
||||
INDEX_AM_HANDLEROID: 'index_am_handler',
|
||||
INETOID: 'inet',
|
||||
INT2OID: 'int2',
|
||||
INT4OID: 'int4',
|
||||
INT8OID: 'int8',
|
||||
INTERNALOID: 'internal',
|
||||
INTERVALOID: 'interval',
|
||||
JSONBOID: 'jsonb',
|
||||
JSONOID: 'json',
|
||||
JSONPATHOID: 'jsonpath',
|
||||
LANGUAGE_HANDLEROID: 'language_handler',
|
||||
LINEOID: 'line',
|
||||
LSEGOID: 'lseg',
|
||||
MACADDR8OID: 'macaddr8',
|
||||
MACADDROID: 'macaddr',
|
||||
MONEYOID: 'money',
|
||||
NAMEOID: 'name',
|
||||
NUMERICOID: 'numeric',
|
||||
OIDOID: 'oid',
|
||||
OPAQUEOID: 'opaque',
|
||||
PATHOID: 'path',
|
||||
PG_BRIN_BLOOM_SUMMARYOID: 'pg_brin_bloom_summary',
|
||||
PG_BRIN_MINMAX_MULTI_SUMMARYOID: 'pg_brin_minmax_multi_summary',
|
||||
PG_DDL_COMMANDOID: 'pg_ddl_command',
|
||||
PG_DEPENDENCIESOID: 'pg_dependencies',
|
||||
PG_LSNOID: 'pg_lsn',
|
||||
PG_MCV_LISTOID: 'pg_mcv_list',
|
||||
PG_NDISTINCTOID: 'pg_ndistinct',
|
||||
PG_NODE_TREEOID: 'pg_node_tree',
|
||||
PG_SNAPSHOTOID: 'pg_snapshot',
|
||||
POINTOID: 'point',
|
||||
POLYGONOID: 'polygon',
|
||||
RECORDOID: 'record',
|
||||
REFCURSOROID: 'refcursor',
|
||||
REGCLASSOID: 'regclass',
|
||||
REGCOLLATIONOID: 'regcollation',
|
||||
REGCONFIGOID: 'regconfig',
|
||||
REGDICTIONARYOID: 'regdictionary',
|
||||
REGNAMESPACEOID: 'regnamespace',
|
||||
REGOPERATOROID: 'regoperator',
|
||||
REGOPEROID: 'regoper',
|
||||
REGPROCEDUREOID: 'regprocedure',
|
||||
REGPROCOID: 'regproc',
|
||||
REGROLEOID: 'regrole',
|
||||
REGTYPEOID: 'regtype',
|
||||
RELTIMEOID: 'reltime',
|
||||
SMGROID: 'smgr',
|
||||
TABLE_AM_HANDLEROID: 'table_am_handler',
|
||||
TEXTOID: 'text',
|
||||
TIDOID: 'tid',
|
||||
TIMEOID: 'time',
|
||||
TIMESTAMPOID: 'timestamp',
|
||||
TIMESTAMPTZOID: 'timestamptz',
|
||||
TIMETZOID: 'timetz',
|
||||
TINTERVALOID: 'tinterval',
|
||||
TRIGGEROID: 'trigger',
|
||||
TSM_HANDLEROID: 'tsm_handler',
|
||||
TSQUERYOID: 'tsquery',
|
||||
TSVECTOROID: 'tsvector',
|
||||
TXID_SNAPSHOTOID: 'txid_snapshot',
|
||||
UNKNOWNOID: 'unknown',
|
||||
UUIDOID: 'uuid',
|
||||
VARBITOID: 'varbit',
|
||||
VARCHAROID: 'varchar',
|
||||
VOIDOID: 'void',
|
||||
XID8OID: 'xid8',
|
||||
XIDOID: 'xid',
|
||||
XMLOID: 'xml',
|
||||
_OIDOID: 'oid[]',
|
||||
_TEXTOID: 'text[]'
|
||||
}
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP = {v: k for k, v in BUILTIN_TYPE_OID_MAP.items()}
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['smallint'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['int2']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['int'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['int4']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['integer'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['int4']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['bigint'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['int8']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['decimal'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['numeric']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['real'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['float4']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['double precision'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['float8']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['timestamptz']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['timestamp without timezone'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['timestamp']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['time with timezone'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['timetz']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['time without timezone'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['time']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['char'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['bpchar']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['character'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['bpchar']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['character varying'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['varchar']
|
||||
|
||||
BUILTIN_TYPE_NAME_MAP['bit varying'] = \
|
||||
BUILTIN_TYPE_NAME_MAP['varbit']
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef class PreparedStatementState:
|
||||
cdef:
|
||||
readonly str name
|
||||
readonly str query
|
||||
readonly bint closed
|
||||
readonly bint prepared
|
||||
readonly int refs
|
||||
readonly type record_class
|
||||
readonly bint ignore_custom_codec
|
||||
|
||||
|
||||
list row_desc
|
||||
list parameters_desc
|
||||
|
||||
ConnectionSettings settings
|
||||
|
||||
int16_t args_num
|
||||
bint have_text_args
|
||||
tuple args_codecs
|
||||
|
||||
int16_t cols_num
|
||||
object cols_desc
|
||||
bint have_text_cols
|
||||
tuple rows_codecs
|
||||
|
||||
cdef _encode_bind_msg(self, args, int seqno = ?)
|
||||
cpdef _init_codecs(self)
|
||||
cdef _ensure_rows_decoder(self)
|
||||
cdef _ensure_args_encoder(self)
|
||||
cdef _set_row_desc(self, object desc)
|
||||
cdef _set_args_desc(self, object desc)
|
||||
cdef _decode_row(self, const char* cbuf, ssize_t buf_len)
|
||||
@@ -0,0 +1,395 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from asyncpg import exceptions
|
||||
|
||||
|
||||
@cython.final
|
||||
cdef class PreparedStatementState:
|
||||
|
||||
def __cinit__(
|
||||
self,
|
||||
str name,
|
||||
str query,
|
||||
BaseProtocol protocol,
|
||||
type record_class,
|
||||
bint ignore_custom_codec
|
||||
):
|
||||
self.name = name
|
||||
self.query = query
|
||||
self.settings = protocol.settings
|
||||
self.row_desc = self.parameters_desc = None
|
||||
self.args_codecs = self.rows_codecs = None
|
||||
self.args_num = self.cols_num = 0
|
||||
self.cols_desc = None
|
||||
self.closed = False
|
||||
self.prepared = True
|
||||
self.refs = 0
|
||||
self.record_class = record_class
|
||||
self.ignore_custom_codec = ignore_custom_codec
|
||||
|
||||
def _get_parameters(self):
|
||||
cdef Codec codec
|
||||
|
||||
result = []
|
||||
for oid in self.parameters_desc:
|
||||
codec = self.settings.get_data_codec(oid)
|
||||
if codec is None:
|
||||
raise exceptions.InternalClientError(
|
||||
'missing codec information for OID {}'.format(oid))
|
||||
result.append(apg_types.Type(
|
||||
oid, codec.name, codec.kind, codec.schema))
|
||||
|
||||
return tuple(result)
|
||||
|
||||
def _get_attributes(self):
|
||||
cdef Codec codec
|
||||
|
||||
if not self.row_desc:
|
||||
return ()
|
||||
|
||||
result = []
|
||||
for d in self.row_desc:
|
||||
name = d[0]
|
||||
oid = d[3]
|
||||
|
||||
codec = self.settings.get_data_codec(oid)
|
||||
if codec is None:
|
||||
raise exceptions.InternalClientError(
|
||||
'missing codec information for OID {}'.format(oid))
|
||||
|
||||
name = name.decode(self.settings._encoding)
|
||||
|
||||
result.append(
|
||||
apg_types.Attribute(name,
|
||||
apg_types.Type(oid, codec.name, codec.kind, codec.schema)))
|
||||
|
||||
return tuple(result)
|
||||
|
||||
def _init_types(self):
|
||||
cdef:
|
||||
Codec codec
|
||||
set missing = set()
|
||||
|
||||
if self.parameters_desc:
|
||||
for p_oid in self.parameters_desc:
|
||||
codec = self.settings.get_data_codec(<uint32_t>p_oid)
|
||||
if codec is None or not codec.has_encoder():
|
||||
missing.add(p_oid)
|
||||
|
||||
if self.row_desc:
|
||||
for rdesc in self.row_desc:
|
||||
codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
|
||||
if codec is None or not codec.has_decoder():
|
||||
missing.add(rdesc[3])
|
||||
|
||||
return missing
|
||||
|
||||
cpdef _init_codecs(self):
|
||||
self._ensure_args_encoder()
|
||||
self._ensure_rows_decoder()
|
||||
|
||||
def attach(self):
|
||||
self.refs += 1
|
||||
|
||||
def detach(self):
|
||||
self.refs -= 1
|
||||
|
||||
def mark_closed(self):
|
||||
self.closed = True
|
||||
|
||||
def mark_unprepared(self):
|
||||
if self.name:
|
||||
raise exceptions.InternalClientError(
|
||||
"named prepared statements cannot be marked unprepared")
|
||||
self.prepared = False
|
||||
|
||||
cdef _encode_bind_msg(self, args, int seqno = -1):
|
||||
cdef:
|
||||
int idx
|
||||
WriteBuffer writer
|
||||
Codec codec
|
||||
|
||||
if not cpython.PySequence_Check(args):
|
||||
if seqno >= 0:
|
||||
raise exceptions.DataError(
|
||||
f'invalid input in executemany() argument sequence '
|
||||
f'element #{seqno}: expected a sequence, got '
|
||||
f'{type(args).__name__}'
|
||||
)
|
||||
else:
|
||||
# Non executemany() callers do not pass user input directly,
|
||||
# so bad input is a bug.
|
||||
raise exceptions.InternalClientError(
|
||||
f'Bind: expected a sequence, got {type(args).__name__}')
|
||||
|
||||
if len(args) > 32767:
|
||||
raise exceptions.InterfaceError(
|
||||
'the number of query arguments cannot exceed 32767')
|
||||
|
||||
writer = WriteBuffer.new()
|
||||
|
||||
num_args_passed = len(args)
|
||||
if self.args_num != num_args_passed:
|
||||
hint = 'Check the query against the passed list of arguments.'
|
||||
|
||||
if self.args_num == 0:
|
||||
# If the server was expecting zero arguments, it is likely
|
||||
# that the user tried to parametrize a statement that does
|
||||
# not support parameters.
|
||||
hint += (r' Note that parameters are supported only in'
|
||||
r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
|
||||
r' statements, and will *not* work in statements '
|
||||
r' like CREATE VIEW or DECLARE CURSOR.')
|
||||
|
||||
raise exceptions.InterfaceError(
|
||||
'the server expects {x} argument{s} for this query, '
|
||||
'{y} {w} passed'.format(
|
||||
x=self.args_num, s='s' if self.args_num != 1 else '',
|
||||
y=num_args_passed,
|
||||
w='was' if num_args_passed == 1 else 'were'),
|
||||
hint=hint)
|
||||
|
||||
if self.have_text_args:
|
||||
writer.write_int16(self.args_num)
|
||||
for idx in range(self.args_num):
|
||||
codec = <Codec>(self.args_codecs[idx])
|
||||
writer.write_int16(<int16_t>codec.format)
|
||||
else:
|
||||
# All arguments are in binary format
|
||||
writer.write_int32(0x00010001)
|
||||
|
||||
writer.write_int16(self.args_num)
|
||||
|
||||
for idx in range(self.args_num):
|
||||
arg = args[idx]
|
||||
if arg is None:
|
||||
writer.write_int32(-1)
|
||||
else:
|
||||
codec = <Codec>(self.args_codecs[idx])
|
||||
try:
|
||||
codec.encode(self.settings, writer, arg)
|
||||
except (AssertionError, exceptions.InternalClientError):
|
||||
# These are internal errors and should raise as-is.
|
||||
raise
|
||||
except exceptions.InterfaceError as e:
|
||||
# This is already a descriptive error, but annotate
|
||||
# with argument name for clarity.
|
||||
pos = f'${idx + 1}'
|
||||
if seqno >= 0:
|
||||
pos = (
|
||||
f'{pos} in element #{seqno} of'
|
||||
f' executemany() sequence'
|
||||
)
|
||||
raise e.with_msg(
|
||||
f'query argument {pos}: {e.args[0]}'
|
||||
) from None
|
||||
except Exception as e:
|
||||
# Everything else is assumed to be an encoding error
|
||||
# due to invalid input.
|
||||
pos = f'${idx + 1}'
|
||||
if seqno >= 0:
|
||||
pos = (
|
||||
f'{pos} in element #{seqno} of'
|
||||
f' executemany() sequence'
|
||||
)
|
||||
value_repr = repr(arg)
|
||||
if len(value_repr) > 40:
|
||||
value_repr = value_repr[:40] + '...'
|
||||
|
||||
raise exceptions.DataError(
|
||||
f'invalid input for query argument'
|
||||
f' {pos}: {value_repr} ({e})'
|
||||
) from e
|
||||
|
||||
if self.have_text_cols:
|
||||
writer.write_int16(self.cols_num)
|
||||
for idx in range(self.cols_num):
|
||||
codec = <Codec>(self.rows_codecs[idx])
|
||||
writer.write_int16(<int16_t>codec.format)
|
||||
else:
|
||||
# All columns are in binary format
|
||||
writer.write_int32(0x00010001)
|
||||
|
||||
return writer
|
||||
|
||||
cdef _ensure_rows_decoder(self):
|
||||
cdef:
|
||||
list cols_names
|
||||
object cols_mapping
|
||||
tuple row
|
||||
uint32_t oid
|
||||
Codec codec
|
||||
list codecs
|
||||
|
||||
if self.cols_desc is not None:
|
||||
return
|
||||
|
||||
if self.cols_num == 0:
|
||||
self.cols_desc = record.ApgRecordDesc_New({}, ())
|
||||
return
|
||||
|
||||
cols_mapping = collections.OrderedDict()
|
||||
cols_names = []
|
||||
codecs = []
|
||||
for i from 0 <= i < self.cols_num:
|
||||
row = self.row_desc[i]
|
||||
col_name = row[0].decode(self.settings._encoding)
|
||||
cols_mapping[col_name] = i
|
||||
cols_names.append(col_name)
|
||||
oid = row[3]
|
||||
codec = self.settings.get_data_codec(
|
||||
oid, ignore_custom_codec=self.ignore_custom_codec)
|
||||
if codec is None or not codec.has_decoder():
|
||||
raise exceptions.InternalClientError(
|
||||
'no decoder for OID {}'.format(oid))
|
||||
if not codec.is_binary():
|
||||
self.have_text_cols = True
|
||||
|
||||
codecs.append(codec)
|
||||
|
||||
self.cols_desc = record.ApgRecordDesc_New(
|
||||
cols_mapping, tuple(cols_names))
|
||||
|
||||
self.rows_codecs = tuple(codecs)
|
||||
|
||||
cdef _ensure_args_encoder(self):
|
||||
cdef:
|
||||
uint32_t p_oid
|
||||
Codec codec
|
||||
list codecs = []
|
||||
|
||||
if self.args_num == 0 or self.args_codecs is not None:
|
||||
return
|
||||
|
||||
for i from 0 <= i < self.args_num:
|
||||
p_oid = self.parameters_desc[i]
|
||||
codec = self.settings.get_data_codec(
|
||||
p_oid, ignore_custom_codec=self.ignore_custom_codec)
|
||||
if codec is None or not codec.has_encoder():
|
||||
raise exceptions.InternalClientError(
|
||||
'no encoder for OID {}'.format(p_oid))
|
||||
if codec.type not in {}:
|
||||
self.have_text_args = True
|
||||
|
||||
codecs.append(codec)
|
||||
|
||||
self.args_codecs = tuple(codecs)
|
||||
|
||||
cdef _set_row_desc(self, object desc):
|
||||
self.row_desc = _decode_row_desc(desc)
|
||||
self.cols_num = <int16_t>(len(self.row_desc))
|
||||
|
||||
cdef _set_args_desc(self, object desc):
|
||||
self.parameters_desc = _decode_parameters_desc(desc)
|
||||
self.args_num = <int16_t>(len(self.parameters_desc))
|
||||
|
||||
cdef _decode_row(self, const char* cbuf, ssize_t buf_len):
|
||||
cdef:
|
||||
Codec codec
|
||||
int16_t fnum
|
||||
int32_t flen
|
||||
object dec_row
|
||||
tuple rows_codecs = self.rows_codecs
|
||||
ConnectionSettings settings = self.settings
|
||||
int32_t i
|
||||
FRBuffer rbuf
|
||||
ssize_t bl
|
||||
|
||||
frb_init(&rbuf, cbuf, buf_len)
|
||||
|
||||
fnum = hton.unpack_int16(frb_read(&rbuf, 2))
|
||||
|
||||
if fnum != self.cols_num:
|
||||
raise exceptions.ProtocolError(
|
||||
'the number of columns in the result row ({}) is '
|
||||
'different from what was described ({})'.format(
|
||||
fnum, self.cols_num))
|
||||
|
||||
dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum)
|
||||
for i in range(fnum):
|
||||
flen = hton.unpack_int32(frb_read(&rbuf, 4))
|
||||
|
||||
if flen == -1:
|
||||
val = None
|
||||
else:
|
||||
# Clamp buffer size to that of the reported field length
|
||||
# to make sure that codecs can rely on read_all() working
|
||||
# properly.
|
||||
bl = frb_get_len(&rbuf)
|
||||
if flen > bl:
|
||||
frb_check(&rbuf, flen)
|
||||
frb_set_len(&rbuf, flen)
|
||||
codec = <Codec>cpython.PyTuple_GET_ITEM(rows_codecs, i)
|
||||
val = codec.decode(settings, &rbuf)
|
||||
if frb_get_len(&rbuf) != 0:
|
||||
raise BufferError(
|
||||
'unexpected trailing {} bytes in buffer'.format(
|
||||
frb_get_len(&rbuf)))
|
||||
frb_set_len(&rbuf, bl - flen)
|
||||
|
||||
cpython.Py_INCREF(val)
|
||||
record.ApgRecord_SET_ITEM(dec_row, i, val)
|
||||
|
||||
if frb_get_len(&rbuf) != 0:
|
||||
raise BufferError('unexpected trailing {} bytes in buffer'.format(
|
||||
frb_get_len(&rbuf)))
|
||||
|
||||
return dec_row
|
||||
|
||||
|
||||
cdef _decode_parameters_desc(object desc):
|
||||
cdef:
|
||||
ReadBuffer reader
|
||||
int16_t nparams
|
||||
uint32_t p_oid
|
||||
list result = []
|
||||
|
||||
reader = ReadBuffer.new_message_parser(desc)
|
||||
nparams = reader.read_int16()
|
||||
|
||||
for i from 0 <= i < nparams:
|
||||
p_oid = <uint32_t>reader.read_int32()
|
||||
result.append(p_oid)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
cdef _decode_row_desc(object desc):
|
||||
cdef:
|
||||
ReadBuffer reader
|
||||
|
||||
int16_t nfields
|
||||
|
||||
bytes f_name
|
||||
uint32_t f_table_oid
|
||||
int16_t f_column_num
|
||||
uint32_t f_dt_oid
|
||||
int16_t f_dt_size
|
||||
int32_t f_dt_mod
|
||||
int16_t f_format
|
||||
|
||||
list result
|
||||
|
||||
reader = ReadBuffer.new_message_parser(desc)
|
||||
nfields = reader.read_int16()
|
||||
result = []
|
||||
|
||||
for i from 0 <= i < nfields:
|
||||
f_name = reader.read_null_str()
|
||||
f_table_oid = <uint32_t>reader.read_int32()
|
||||
f_column_num = reader.read_int16()
|
||||
f_dt_oid = <uint32_t>reader.read_int32()
|
||||
f_dt_size = reader.read_int16()
|
||||
f_dt_mod = reader.read_int32()
|
||||
f_format = reader.read_int16()
|
||||
|
||||
result.append(
|
||||
(f_name, f_table_oid, f_column_num, f_dt_oid,
|
||||
f_dt_size, f_dt_mod, f_format))
|
||||
|
||||
return result
|
||||
Binary file not shown.
@@ -0,0 +1,78 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from libc.stdint cimport int16_t, int32_t, uint16_t, \
|
||||
uint32_t, int64_t, uint64_t
|
||||
|
||||
from asyncpg.pgproto.debug cimport PG_DEBUG
|
||||
|
||||
from asyncpg.pgproto.pgproto cimport (
|
||||
WriteBuffer,
|
||||
ReadBuffer,
|
||||
FRBuffer,
|
||||
)
|
||||
|
||||
from asyncpg.pgproto cimport pgproto
|
||||
|
||||
include "consts.pxi"
|
||||
include "pgtypes.pxi"
|
||||
|
||||
include "codecs/base.pxd"
|
||||
include "settings.pxd"
|
||||
include "coreproto.pxd"
|
||||
include "prepared_stmt.pxd"
|
||||
|
||||
|
||||
cdef class BaseProtocol(CoreProtocol):
|
||||
|
||||
cdef:
|
||||
object loop
|
||||
object address
|
||||
ConnectionSettings settings
|
||||
object cancel_sent_waiter
|
||||
object cancel_waiter
|
||||
object waiter
|
||||
bint return_extra
|
||||
object create_future
|
||||
object timeout_handle
|
||||
object conref
|
||||
type record_class
|
||||
bint is_reading
|
||||
|
||||
str last_query
|
||||
|
||||
bint writing_paused
|
||||
bint closing
|
||||
|
||||
readonly uint64_t queries_count
|
||||
|
||||
bint _is_ssl
|
||||
|
||||
PreparedStatementState statement
|
||||
|
||||
cdef get_connection(self)
|
||||
|
||||
cdef _get_timeout_impl(self, timeout)
|
||||
cdef _check_state(self)
|
||||
cdef _new_waiter(self, timeout)
|
||||
cdef _coreproto_error(self)
|
||||
|
||||
cdef _on_result__connect(self, object waiter)
|
||||
cdef _on_result__prepare(self, object waiter)
|
||||
cdef _on_result__bind_and_exec(self, object waiter)
|
||||
cdef _on_result__close_stmt_or_portal(self, object waiter)
|
||||
cdef _on_result__simple_query(self, object waiter)
|
||||
cdef _on_result__bind(self, object waiter)
|
||||
cdef _on_result__copy_out(self, object waiter)
|
||||
cdef _on_result__copy_in(self, object waiter)
|
||||
|
||||
cdef _handle_waiter_on_connection_lost(self, cause)
|
||||
|
||||
cdef _dispatch_result(self)
|
||||
|
||||
cdef inline resume_reading(self)
|
||||
cdef inline pause_reading(self)
|
||||
1064
venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx
Normal file
1064
venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,19 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cimport cpython
|
||||
|
||||
|
||||
cdef extern from "record/recordobj.h":
|
||||
|
||||
cpython.PyTypeObject *ApgRecord_InitTypes() except NULL
|
||||
|
||||
int ApgRecord_CheckExact(object)
|
||||
object ApgRecord_New(type, object, int)
|
||||
void ApgRecord_SET_ITEM(object, int, object)
|
||||
|
||||
object ApgRecordDesc_New(object, object)
|
||||
31
venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pxd
Normal file
31
venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pxd
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef class SCRAMAuthentication:
|
||||
cdef:
|
||||
readonly bytes authentication_method
|
||||
readonly bytes authorization_message
|
||||
readonly bytes client_channel_binding
|
||||
readonly bytes client_first_message_bare
|
||||
readonly bytes client_nonce
|
||||
readonly bytes client_proof
|
||||
readonly bytes password_salt
|
||||
readonly int password_iterations
|
||||
readonly bytes server_first_message
|
||||
# server_key is an instance of hmac.HAMC
|
||||
readonly object server_key
|
||||
readonly bytes server_nonce
|
||||
|
||||
cdef create_client_first_message(self, str username)
|
||||
cdef create_client_final_message(self, str password)
|
||||
cdef parse_server_first_message(self, bytes server_response)
|
||||
cdef verify_server_final_message(self, bytes server_final_message)
|
||||
cdef _bytes_xor(self, bytes a, bytes b)
|
||||
cdef _generate_client_nonce(self, int num_bytes)
|
||||
cdef _generate_client_proof(self, str password)
|
||||
cdef _generate_salted_password(self, str password, bytes salt, int iterations)
|
||||
cdef _normalize_password(self, str original_password)
|
||||
341
venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pyx
Normal file
341
venv/lib/python3.12/site-packages/asyncpg/protocol/scram.pyx
Normal file
@@ -0,0 +1,341 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import re
|
||||
import secrets
|
||||
import stringprep
|
||||
import unicodedata
|
||||
|
||||
|
||||
@cython.final
|
||||
cdef class SCRAMAuthentication:
|
||||
"""Contains the protocol for generating and a SCRAM hashed password.
|
||||
|
||||
Since PostgreSQL 10, the option to hash passwords using the SCRAM-SHA-256
|
||||
method was added. This module follows the defined protocol, which can be
|
||||
referenced from here:
|
||||
|
||||
https://www.postgresql.org/docs/current/sasl-authentication.html#SASL-SCRAM-SHA-256
|
||||
|
||||
libpq references the following RFCs that it uses for implementation:
|
||||
|
||||
* RFC 5802
|
||||
* RFC 5803
|
||||
* RFC 7677
|
||||
|
||||
The protocol works as such:
|
||||
|
||||
- A client connets to the server. The server requests the client to begin
|
||||
SASL authentication using SCRAM and presents a client with the methods it
|
||||
supports. At present, those are SCRAM-SHA-256, and, on servers that are
|
||||
built with OpenSSL and
|
||||
are PG11+, SCRAM-SHA-256-PLUS (which supports channel binding, more on that
|
||||
below)
|
||||
|
||||
- The client sends a "first message" to the server, where it chooses which
|
||||
method to authenticate with, and sends, along with the method, an indication
|
||||
of channel binding (we disable for now), a nonce, and the username.
|
||||
(Technically, PostgreSQL ignores the username as it already has it from the
|
||||
initical connection, but we add it for completeness)
|
||||
|
||||
- The server responds with a "first message" in which it extends the nonce,
|
||||
as well as a password salt and the number of iterations to hash the password
|
||||
with. The client validates that the new nonce contains the first part of the
|
||||
client's original nonce
|
||||
|
||||
- The client generates a salted password, but does not sent this up to the
|
||||
server. Instead, the client follows the SCRAM algorithm (RFC5802) to
|
||||
generate a proof. This proof is sent aspart of a client "final message" to
|
||||
the server for it to validate.
|
||||
|
||||
- The server validates the proof. If it is valid, the server sends a
|
||||
verification code for the client to verify that the server came to the same
|
||||
proof the client did. PostgreSQL immediately sends an AuthenticationOK
|
||||
response right after a valid negotiation. If the password the client
|
||||
provided was invalid, then authentication fails.
|
||||
|
||||
(The beauty of this is that the salted password is never transmitted over
|
||||
the wire!)
|
||||
|
||||
PostgreSQL 11 added support for the channel binding (i.e.
|
||||
SCRAM-SHA-256-PLUS) but to do some ongoing discussion, there is a conscious
|
||||
decision by several driver authors to not support it as of yet. As such, the
|
||||
channel binding parameter is hard-coded to "n" for now, but can be updated
|
||||
to support other channel binding methos in the future
|
||||
"""
|
||||
AUTHENTICATION_METHODS = [b"SCRAM-SHA-256"]
|
||||
DEFAULT_CLIENT_NONCE_BYTES = 24
|
||||
DIGEST = hashlib.sha256
|
||||
REQUIREMENTS_CLIENT_FINAL_MESSAGE = ['client_channel_binding',
|
||||
'server_nonce']
|
||||
REQUIREMENTS_CLIENT_PROOF = ['password_iterations', 'password_salt',
|
||||
'server_first_message', 'server_nonce']
|
||||
SASLPREP_PROHIBITED = (
|
||||
stringprep.in_table_a1, # PostgreSQL treats this as prohibited
|
||||
stringprep.in_table_c12,
|
||||
stringprep.in_table_c21_c22,
|
||||
stringprep.in_table_c3,
|
||||
stringprep.in_table_c4,
|
||||
stringprep.in_table_c5,
|
||||
stringprep.in_table_c6,
|
||||
stringprep.in_table_c7,
|
||||
stringprep.in_table_c8,
|
||||
stringprep.in_table_c9,
|
||||
)
|
||||
|
||||
def __cinit__(self, bytes authentication_method):
|
||||
self.authentication_method = authentication_method
|
||||
self.authorization_message = None
|
||||
# channel binding is turned off for the time being
|
||||
self.client_channel_binding = b"n,,"
|
||||
self.client_first_message_bare = None
|
||||
self.client_nonce = None
|
||||
self.client_proof = None
|
||||
self.password_salt = None
|
||||
# self.password_iterations = None
|
||||
self.server_first_message = None
|
||||
self.server_key = None
|
||||
self.server_nonce = None
|
||||
|
||||
cdef create_client_first_message(self, str username):
|
||||
"""Create the initial client message for SCRAM authentication"""
|
||||
cdef:
|
||||
bytes msg
|
||||
bytes client_first_message
|
||||
|
||||
self.client_nonce = \
|
||||
self._generate_client_nonce(self.DEFAULT_CLIENT_NONCE_BYTES)
|
||||
# set the client first message bare here, as it's used in a later step
|
||||
self.client_first_message_bare = b"n=" + username.encode("utf-8") + \
|
||||
b",r=" + self.client_nonce
|
||||
# put together the full message here
|
||||
msg = bytes()
|
||||
msg += self.authentication_method + b"\0"
|
||||
client_first_message = self.client_channel_binding + \
|
||||
self.client_first_message_bare
|
||||
msg += (len(client_first_message)).to_bytes(4, byteorder='big') + \
|
||||
client_first_message
|
||||
return msg
|
||||
|
||||
cdef create_client_final_message(self, str password):
|
||||
"""Create the final client message as part of SCRAM authentication"""
|
||||
cdef:
|
||||
bytes msg
|
||||
|
||||
if any([getattr(self, val) is None for val in
|
||||
self.REQUIREMENTS_CLIENT_FINAL_MESSAGE]):
|
||||
raise Exception(
|
||||
"you need values from server to generate a client proof")
|
||||
|
||||
# normalize the password using the SASLprep algorithm in RFC 4013
|
||||
password = self._normalize_password(password)
|
||||
|
||||
# generate the client proof
|
||||
self.client_proof = self._generate_client_proof(password=password)
|
||||
msg = bytes()
|
||||
msg += b"c=" + base64.b64encode(self.client_channel_binding) + \
|
||||
b",r=" + self.server_nonce + \
|
||||
b",p=" + base64.b64encode(self.client_proof)
|
||||
return msg
|
||||
|
||||
cdef parse_server_first_message(self, bytes server_response):
|
||||
"""Parse the response from the first message from the server"""
|
||||
self.server_first_message = server_response
|
||||
try:
|
||||
self.server_nonce = re.search(b'r=([^,]+),',
|
||||
self.server_first_message).group(1)
|
||||
except IndexError:
|
||||
raise Exception("could not get nonce")
|
||||
if not self.server_nonce.startswith(self.client_nonce):
|
||||
raise Exception("invalid nonce")
|
||||
try:
|
||||
self.password_salt = re.search(b',s=([^,]+),',
|
||||
self.server_first_message).group(1)
|
||||
except IndexError:
|
||||
raise Exception("could not get salt")
|
||||
try:
|
||||
self.password_iterations = int(re.search(b',i=(\d+),?',
|
||||
self.server_first_message).group(1))
|
||||
except (IndexError, TypeError, ValueError):
|
||||
raise Exception("could not get iterations")
|
||||
|
||||
cdef verify_server_final_message(self, bytes server_final_message):
|
||||
"""Verify the final message from the server"""
|
||||
cdef:
|
||||
bytes server_signature
|
||||
|
||||
try:
|
||||
server_signature = re.search(b'v=([^,]+)',
|
||||
server_final_message).group(1)
|
||||
except IndexError:
|
||||
raise Exception("could not get server signature")
|
||||
|
||||
verify_server_signature = hmac.new(self.server_key.digest(),
|
||||
self.authorization_message, self.DIGEST)
|
||||
# validate the server signature against the verifier
|
||||
return server_signature == base64.b64encode(
|
||||
verify_server_signature.digest())
|
||||
|
||||
cdef _bytes_xor(self, bytes a, bytes b):
|
||||
"""XOR two bytestrings together"""
|
||||
return bytes(a_i ^ b_i for a_i, b_i in zip(a, b))
|
||||
|
||||
cdef _generate_client_nonce(self, int num_bytes):
|
||||
cdef:
|
||||
bytes token
|
||||
|
||||
token = secrets.token_bytes(num_bytes)
|
||||
|
||||
return base64.b64encode(token)
|
||||
|
||||
cdef _generate_client_proof(self, str password):
|
||||
"""need to ensure a server response exists, i.e. """
|
||||
cdef:
|
||||
bytes salted_password
|
||||
|
||||
if any([getattr(self, val) is None for val in
|
||||
self.REQUIREMENTS_CLIENT_PROOF]):
|
||||
raise Exception(
|
||||
"you need values from server to generate a client proof")
|
||||
# generate a salt password
|
||||
salted_password = self._generate_salted_password(password,
|
||||
self.password_salt, self.password_iterations)
|
||||
# client key is derived from the salted password
|
||||
client_key = hmac.new(salted_password, b"Client Key", self.DIGEST)
|
||||
# this allows us to compute the stored key that is residing on the server
|
||||
stored_key = self.DIGEST(client_key.digest())
|
||||
# as well as compute the server key
|
||||
self.server_key = hmac.new(salted_password, b"Server Key", self.DIGEST)
|
||||
# build the authorization message that will be used in the
|
||||
# client signature
|
||||
# the "c=" portion is for the channel binding, but this is not
|
||||
# presently implemented
|
||||
self.authorization_message = self.client_first_message_bare + b"," + \
|
||||
self.server_first_message + b",c=" + \
|
||||
base64.b64encode(self.client_channel_binding) + \
|
||||
b",r=" + self.server_nonce
|
||||
# sign!
|
||||
client_signature = hmac.new(stored_key.digest(),
|
||||
self.authorization_message, self.DIGEST)
|
||||
# and the proof
|
||||
return self._bytes_xor(client_key.digest(), client_signature.digest())
|
||||
|
||||
cdef _generate_salted_password(self, str password, bytes salt, int iterations):
|
||||
"""This follows the "Hi" algorithm specified in RFC5802"""
|
||||
cdef:
|
||||
bytes p
|
||||
bytes s
|
||||
bytes u
|
||||
|
||||
# convert the password to a binary string - UTF8 is safe for SASL
|
||||
# (though there are SASLPrep rules)
|
||||
p = password.encode("utf8")
|
||||
# the salt needs to be base64 decoded -- full binary must be used
|
||||
s = base64.b64decode(salt)
|
||||
# the initial signature is the salt with a terminator of a 32-bit string
|
||||
# ending in 1
|
||||
ui = hmac.new(p, s + b'\x00\x00\x00\x01', self.DIGEST)
|
||||
# grab the initial digest
|
||||
u = ui.digest()
|
||||
# for X number of iterations, recompute the HMAC signature against the
|
||||
# password and the latest iteration of the hash, and XOR it with the
|
||||
# previous version
|
||||
for x in range(iterations - 1):
|
||||
ui = hmac.new(p, ui.digest(), hashlib.sha256)
|
||||
# this is a fancy way of XORing two byte strings together
|
||||
u = self._bytes_xor(u, ui.digest())
|
||||
return u
|
||||
|
||||
cdef _normalize_password(self, str original_password):
|
||||
"""Normalize the password using the SASLprep from RFC4013"""
|
||||
cdef:
|
||||
str normalized_password
|
||||
|
||||
# Note: Per the PostgreSQL documentation, PostgreSWL does not require
|
||||
# UTF-8 to be used for the password, but will perform SASLprep on the
|
||||
# password regardless.
|
||||
# If the password is not valid UTF-8, PostgreSQL will then **not** use
|
||||
# SASLprep processing.
|
||||
# If the password fails SASLprep, the password should still be sent
|
||||
# See: https://www.postgresql.org/docs/current/sasl-authentication.html
|
||||
# and
|
||||
# https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/common/saslprep.c
|
||||
# using the `pg_saslprep` function
|
||||
normalized_password = original_password
|
||||
# if the original password is an ASCII string or fails to encode as a
|
||||
# UTF-8 string, then no further action is needed
|
||||
try:
|
||||
original_password.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
pass
|
||||
else:
|
||||
return original_password
|
||||
|
||||
# Step 1 of SASLPrep: Map. Per the algorithm, we map non-ascii space
|
||||
# characters to ASCII spaces (\x20 or \u0020, but we will use ' ') and
|
||||
# commonly mapped to nothing characters are removed
|
||||
# Table C.1.2 -- non-ASCII spaces
|
||||
# Table B.1 -- "Commonly mapped to nothing"
|
||||
normalized_password = u"".join(
|
||||
' ' if stringprep.in_table_c12(c) else c
|
||||
for c in tuple(normalized_password) if not stringprep.in_table_b1(c)
|
||||
)
|
||||
|
||||
# If at this point the password is empty, PostgreSQL uses the original
|
||||
# password
|
||||
if not normalized_password:
|
||||
return original_password
|
||||
|
||||
# Step 2 of SASLPrep: Normalize. Normalize the password using the
|
||||
# Unicode normalization algorithm to NFKC form
|
||||
normalized_password = unicodedata.normalize('NFKC', normalized_password)
|
||||
|
||||
# If the password is not empty, PostgreSQL uses the original password
|
||||
if not normalized_password:
|
||||
return original_password
|
||||
|
||||
normalized_password_tuple = tuple(normalized_password)
|
||||
|
||||
# Step 3 of SASLPrep: Prohobited characters. If PostgreSQL detects any
|
||||
# of the prohibited characters in SASLPrep, it will use the original
|
||||
# password
|
||||
# We also include "unassigned code points" in the prohibited character
|
||||
# category as PostgreSQL does the same
|
||||
for c in normalized_password_tuple:
|
||||
if any(
|
||||
in_prohibited_table(c)
|
||||
for in_prohibited_table in self.SASLPREP_PROHIBITED
|
||||
):
|
||||
return original_password
|
||||
|
||||
# Step 4 of SASLPrep: Bi-directional characters. PostgreSQL follows the
|
||||
# rules for bi-directional characters laid on in RFC3454 Sec. 6 which
|
||||
# are:
|
||||
# 1. Characters in RFC 3454 Sec 5.8 are prohibited (C.8)
|
||||
# 2. If a string contains a RandALCat character, it cannot containy any
|
||||
# LCat character
|
||||
# 3. If the string contains any RandALCat character, an RandALCat
|
||||
# character must be the first and last character of the string
|
||||
# RandALCat characters are found in table D.1, whereas LCat are in D.2
|
||||
if any(stringprep.in_table_d1(c) for c in normalized_password_tuple):
|
||||
# if the first character or the last character are not in D.1,
|
||||
# return the original password
|
||||
if not (stringprep.in_table_d1(normalized_password_tuple[0]) and
|
||||
stringprep.in_table_d1(normalized_password_tuple[-1])):
|
||||
return original_password
|
||||
|
||||
# if any characters are in D.2, use the original password
|
||||
if any(
|
||||
stringprep.in_table_d2(c) for c in normalized_password_tuple
|
||||
):
|
||||
return original_password
|
||||
|
||||
# return the normalized password
|
||||
return normalized_password
|
||||
@@ -0,0 +1,30 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
cdef class ConnectionSettings(pgproto.CodecContext):
|
||||
cdef:
|
||||
str _encoding
|
||||
object _codec
|
||||
dict _settings
|
||||
bint _is_utf8
|
||||
DataCodecConfig _data_codecs
|
||||
|
||||
cdef add_setting(self, str name, str val)
|
||||
cdef is_encoding_utf8(self)
|
||||
cpdef get_text_codec(self)
|
||||
cpdef inline register_data_types(self, types)
|
||||
cpdef inline add_python_codec(
|
||||
self, typeoid, typename, typeschema, typeinfos, typekind, encoder,
|
||||
decoder, format)
|
||||
cpdef inline remove_python_codec(
|
||||
self, typeoid, typename, typeschema)
|
||||
cpdef inline clear_type_cache(self)
|
||||
cpdef inline set_builtin_type_codec(
|
||||
self, typeoid, typename, typeschema, typekind, alias_to, format)
|
||||
cpdef inline Codec get_data_codec(
|
||||
self, uint32_t oid, ServerDataFormat format=*,
|
||||
bint ignore_custom_codec=*)
|
||||
106
venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx
Normal file
106
venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
from asyncpg import exceptions
|
||||
|
||||
|
||||
@cython.final
|
||||
cdef class ConnectionSettings(pgproto.CodecContext):
|
||||
|
||||
def __cinit__(self, conn_key):
|
||||
self._encoding = 'utf-8'
|
||||
self._is_utf8 = True
|
||||
self._settings = {}
|
||||
self._codec = codecs.lookup('utf-8')
|
||||
self._data_codecs = DataCodecConfig(conn_key)
|
||||
|
||||
cdef add_setting(self, str name, str val):
|
||||
self._settings[name] = val
|
||||
if name == 'client_encoding':
|
||||
py_enc = get_python_encoding(val)
|
||||
self._codec = codecs.lookup(py_enc)
|
||||
self._encoding = self._codec.name
|
||||
self._is_utf8 = self._encoding == 'utf-8'
|
||||
|
||||
cdef is_encoding_utf8(self):
|
||||
return self._is_utf8
|
||||
|
||||
cpdef get_text_codec(self):
|
||||
return self._codec
|
||||
|
||||
cpdef inline register_data_types(self, types):
|
||||
self._data_codecs.add_types(types)
|
||||
|
||||
cpdef inline add_python_codec(self, typeoid, typename, typeschema,
|
||||
typeinfos, typekind, encoder, decoder,
|
||||
format):
|
||||
cdef:
|
||||
ServerDataFormat _format
|
||||
ClientExchangeFormat xformat
|
||||
|
||||
if format == 'binary':
|
||||
_format = PG_FORMAT_BINARY
|
||||
xformat = PG_XFORMAT_OBJECT
|
||||
elif format == 'text':
|
||||
_format = PG_FORMAT_TEXT
|
||||
xformat = PG_XFORMAT_OBJECT
|
||||
elif format == 'tuple':
|
||||
_format = PG_FORMAT_ANY
|
||||
xformat = PG_XFORMAT_TUPLE
|
||||
else:
|
||||
raise exceptions.InterfaceError(
|
||||
'invalid `format` argument, expected {}, got {!r}'.format(
|
||||
"'text', 'binary' or 'tuple'", format
|
||||
))
|
||||
|
||||
self._data_codecs.add_python_codec(typeoid, typename, typeschema,
|
||||
typekind, typeinfos,
|
||||
encoder, decoder,
|
||||
_format, xformat)
|
||||
|
||||
cpdef inline remove_python_codec(self, typeoid, typename, typeschema):
|
||||
self._data_codecs.remove_python_codec(typeoid, typename, typeschema)
|
||||
|
||||
cpdef inline clear_type_cache(self):
|
||||
self._data_codecs.clear_type_cache()
|
||||
|
||||
cpdef inline set_builtin_type_codec(self, typeoid, typename, typeschema,
|
||||
typekind, alias_to, format):
|
||||
cdef:
|
||||
ServerDataFormat _format
|
||||
|
||||
if format is None:
|
||||
_format = PG_FORMAT_ANY
|
||||
elif format == 'binary':
|
||||
_format = PG_FORMAT_BINARY
|
||||
elif format == 'text':
|
||||
_format = PG_FORMAT_TEXT
|
||||
else:
|
||||
raise exceptions.InterfaceError(
|
||||
'invalid `format` argument, expected {}, got {!r}'.format(
|
||||
"'text' or 'binary'", format
|
||||
))
|
||||
|
||||
self._data_codecs.set_builtin_type_codec(typeoid, typename, typeschema,
|
||||
typekind, alias_to, _format)
|
||||
|
||||
cpdef inline Codec get_data_codec(self, uint32_t oid,
|
||||
ServerDataFormat format=PG_FORMAT_ANY,
|
||||
bint ignore_custom_codec=False):
|
||||
return self._data_codecs.get_codec(oid, format, ignore_custom_codec)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if not name.startswith('_'):
|
||||
try:
|
||||
return self._settings[name]
|
||||
except KeyError:
|
||||
raise AttributeError(name) from None
|
||||
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
def __repr__(self):
|
||||
return '<ConnectionSettings {!r}>'.format(self._settings)
|
||||
60
venv/lib/python3.12/site-packages/asyncpg/serverversion.py
Normal file
60
venv/lib/python3.12/site-packages/asyncpg/serverversion.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import re
|
||||
|
||||
from .types import ServerVersion
|
||||
|
||||
version_regex = re.compile(
|
||||
r"(Postgre[^\s]*)?\s*"
|
||||
r"(?P<major>[0-9]+)\.?"
|
||||
r"((?P<minor>[0-9]+)\.?)?"
|
||||
r"(?P<micro>[0-9]+)?"
|
||||
r"(?P<releaselevel>[a-z]+)?"
|
||||
r"(?P<serial>[0-9]+)?"
|
||||
)
|
||||
|
||||
|
||||
def split_server_version_string(version_string):
|
||||
version_match = version_regex.search(version_string)
|
||||
|
||||
if version_match is None:
|
||||
raise ValueError(
|
||||
"Unable to parse Postgres "
|
||||
f'version from "{version_string}"'
|
||||
)
|
||||
|
||||
version = version_match.groupdict()
|
||||
for ver_key, ver_value in version.items():
|
||||
# Cast all possible versions parts to int
|
||||
try:
|
||||
version[ver_key] = int(ver_value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if version.get("major") < 10:
|
||||
return ServerVersion(
|
||||
version.get("major"),
|
||||
version.get("minor") or 0,
|
||||
version.get("micro") or 0,
|
||||
version.get("releaselevel") or "final",
|
||||
version.get("serial") or 0,
|
||||
)
|
||||
|
||||
# Since PostgreSQL 10 the versioning scheme has changed.
|
||||
# 10.x really means 10.0.x. While parsing 10.1
|
||||
# as (10, 1) may seem less confusing, in practice most
|
||||
# version checks are written as version[:2], and we
|
||||
# want to keep that behaviour consistent, i.e not fail
|
||||
# a major version check due to a bugfix release.
|
||||
return ServerVersion(
|
||||
version.get("major"),
|
||||
0,
|
||||
version.get("minor") or 0,
|
||||
version.get("releaselevel") or "final",
|
||||
version.get("serial") or 0,
|
||||
)
|
||||
246
venv/lib/python3.12/site-packages/asyncpg/transaction.py
Normal file
246
venv/lib/python3.12/site-packages/asyncpg/transaction.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import enum
|
||||
|
||||
from . import connresource
|
||||
from . import exceptions as apg_errors
|
||||
|
||||
|
||||
class TransactionState(enum.Enum):
|
||||
NEW = 0
|
||||
STARTED = 1
|
||||
COMMITTED = 2
|
||||
ROLLEDBACK = 3
|
||||
FAILED = 4
|
||||
|
||||
|
||||
ISOLATION_LEVELS = {
|
||||
'read_committed',
|
||||
'read_uncommitted',
|
||||
'serializable',
|
||||
'repeatable_read',
|
||||
}
|
||||
ISOLATION_LEVELS_BY_VALUE = {
|
||||
'read committed': 'read_committed',
|
||||
'read uncommitted': 'read_uncommitted',
|
||||
'serializable': 'serializable',
|
||||
'repeatable read': 'repeatable_read',
|
||||
}
|
||||
|
||||
|
||||
class Transaction(connresource.ConnectionResource):
|
||||
"""Represents a transaction or savepoint block.
|
||||
|
||||
Transactions are created by calling the
|
||||
:meth:`Connection.transaction() <connection.Connection.transaction>`
|
||||
function.
|
||||
"""
|
||||
|
||||
__slots__ = ('_connection', '_isolation', '_readonly', '_deferrable',
|
||||
'_state', '_nested', '_id', '_managed')
|
||||
|
||||
def __init__(self, connection, isolation, readonly, deferrable):
|
||||
super().__init__(connection)
|
||||
|
||||
if isolation and isolation not in ISOLATION_LEVELS:
|
||||
raise ValueError(
|
||||
'isolation is expected to be either of {}, '
|
||||
'got {!r}'.format(ISOLATION_LEVELS, isolation))
|
||||
|
||||
self._isolation = isolation
|
||||
self._readonly = readonly
|
||||
self._deferrable = deferrable
|
||||
self._state = TransactionState.NEW
|
||||
self._nested = False
|
||||
self._id = None
|
||||
self._managed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
if self._managed:
|
||||
raise apg_errors.InterfaceError(
|
||||
'cannot enter context: already in an `async with` block')
|
||||
self._managed = True
|
||||
await self.start()
|
||||
|
||||
async def __aexit__(self, extype, ex, tb):
|
||||
try:
|
||||
self._check_conn_validity('__aexit__')
|
||||
except apg_errors.InterfaceError:
|
||||
if extype is GeneratorExit:
|
||||
# When a PoolAcquireContext is being exited, and there
|
||||
# is an open transaction in an async generator that has
|
||||
# not been iterated fully, there is a possibility that
|
||||
# Pool.release() would race with this __aexit__(), since
|
||||
# both would be in concurrent tasks. In such case we
|
||||
# yield to Pool.release() to do the ROLLBACK for us.
|
||||
# See https://github.com/MagicStack/asyncpg/issues/232
|
||||
# for an example.
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
try:
|
||||
if extype is not None:
|
||||
await self.__rollback()
|
||||
else:
|
||||
await self.__commit()
|
||||
finally:
|
||||
self._managed = False
|
||||
|
||||
@connresource.guarded
|
||||
async def start(self):
|
||||
"""Enter the transaction or savepoint block."""
|
||||
self.__check_state_base('start')
|
||||
if self._state is TransactionState.STARTED:
|
||||
raise apg_errors.InterfaceError(
|
||||
'cannot start; the transaction is already started')
|
||||
|
||||
con = self._connection
|
||||
|
||||
if con._top_xact is None:
|
||||
if con._protocol.is_in_transaction():
|
||||
raise apg_errors.InterfaceError(
|
||||
'cannot use Connection.transaction() in '
|
||||
'a manually started transaction')
|
||||
con._top_xact = self
|
||||
else:
|
||||
# Nested transaction block
|
||||
if self._isolation:
|
||||
top_xact_isolation = con._top_xact._isolation
|
||||
if top_xact_isolation is None:
|
||||
top_xact_isolation = ISOLATION_LEVELS_BY_VALUE[
|
||||
await self._connection.fetchval(
|
||||
'SHOW transaction_isolation;')]
|
||||
if self._isolation != top_xact_isolation:
|
||||
raise apg_errors.InterfaceError(
|
||||
'nested transaction has a different isolation level: '
|
||||
'current {!r} != outer {!r}'.format(
|
||||
self._isolation, top_xact_isolation))
|
||||
self._nested = True
|
||||
|
||||
if self._nested:
|
||||
self._id = con._get_unique_id('savepoint')
|
||||
query = 'SAVEPOINT {};'.format(self._id)
|
||||
else:
|
||||
query = 'BEGIN'
|
||||
if self._isolation == 'read_committed':
|
||||
query += ' ISOLATION LEVEL READ COMMITTED'
|
||||
elif self._isolation == 'read_uncommitted':
|
||||
query += ' ISOLATION LEVEL READ UNCOMMITTED'
|
||||
elif self._isolation == 'repeatable_read':
|
||||
query += ' ISOLATION LEVEL REPEATABLE READ'
|
||||
elif self._isolation == 'serializable':
|
||||
query += ' ISOLATION LEVEL SERIALIZABLE'
|
||||
if self._readonly:
|
||||
query += ' READ ONLY'
|
||||
if self._deferrable:
|
||||
query += ' DEFERRABLE'
|
||||
query += ';'
|
||||
|
||||
try:
|
||||
await self._connection.execute(query)
|
||||
except BaseException:
|
||||
self._state = TransactionState.FAILED
|
||||
raise
|
||||
else:
|
||||
self._state = TransactionState.STARTED
|
||||
|
||||
def __check_state_base(self, opname):
|
||||
if self._state is TransactionState.COMMITTED:
|
||||
raise apg_errors.InterfaceError(
|
||||
'cannot {}; the transaction is already committed'.format(
|
||||
opname))
|
||||
if self._state is TransactionState.ROLLEDBACK:
|
||||
raise apg_errors.InterfaceError(
|
||||
'cannot {}; the transaction is already rolled back'.format(
|
||||
opname))
|
||||
if self._state is TransactionState.FAILED:
|
||||
raise apg_errors.InterfaceError(
|
||||
'cannot {}; the transaction is in error state'.format(
|
||||
opname))
|
||||
|
||||
def __check_state(self, opname):
|
||||
if self._state is not TransactionState.STARTED:
|
||||
if self._state is TransactionState.NEW:
|
||||
raise apg_errors.InterfaceError(
|
||||
'cannot {}; the transaction is not yet started'.format(
|
||||
opname))
|
||||
self.__check_state_base(opname)
|
||||
|
||||
async def __commit(self):
|
||||
self.__check_state('commit')
|
||||
|
||||
if self._connection._top_xact is self:
|
||||
self._connection._top_xact = None
|
||||
|
||||
if self._nested:
|
||||
query = 'RELEASE SAVEPOINT {};'.format(self._id)
|
||||
else:
|
||||
query = 'COMMIT;'
|
||||
|
||||
try:
|
||||
await self._connection.execute(query)
|
||||
except BaseException:
|
||||
self._state = TransactionState.FAILED
|
||||
raise
|
||||
else:
|
||||
self._state = TransactionState.COMMITTED
|
||||
|
||||
async def __rollback(self):
|
||||
self.__check_state('rollback')
|
||||
|
||||
if self._connection._top_xact is self:
|
||||
self._connection._top_xact = None
|
||||
|
||||
if self._nested:
|
||||
query = 'ROLLBACK TO {};'.format(self._id)
|
||||
else:
|
||||
query = 'ROLLBACK;'
|
||||
|
||||
try:
|
||||
await self._connection.execute(query)
|
||||
except BaseException:
|
||||
self._state = TransactionState.FAILED
|
||||
raise
|
||||
else:
|
||||
self._state = TransactionState.ROLLEDBACK
|
||||
|
||||
@connresource.guarded
|
||||
async def commit(self):
|
||||
"""Exit the transaction or savepoint block and commit changes."""
|
||||
if self._managed:
|
||||
raise apg_errors.InterfaceError(
|
||||
'cannot manually commit from within an `async with` block')
|
||||
await self.__commit()
|
||||
|
||||
@connresource.guarded
|
||||
async def rollback(self):
|
||||
"""Exit the transaction or savepoint block and rollback changes."""
|
||||
if self._managed:
|
||||
raise apg_errors.InterfaceError(
|
||||
'cannot manually rollback from within an `async with` block')
|
||||
await self.__rollback()
|
||||
|
||||
def __repr__(self):
|
||||
attrs = []
|
||||
attrs.append('state:{}'.format(self._state.name.lower()))
|
||||
|
||||
if self._isolation is not None:
|
||||
attrs.append(self._isolation)
|
||||
if self._readonly:
|
||||
attrs.append('readonly')
|
||||
if self._deferrable:
|
||||
attrs.append('deferrable')
|
||||
|
||||
if self.__class__.__module__.startswith('asyncpg.'):
|
||||
mod = 'asyncpg'
|
||||
else:
|
||||
mod = self.__class__.__module__
|
||||
|
||||
return '<{}.{} {} {:#x}>'.format(
|
||||
mod, self.__class__.__name__, ' '.join(attrs), id(self))
|
||||
177
venv/lib/python3.12/site-packages/asyncpg/types.py
Normal file
177
venv/lib/python3.12/site-packages/asyncpg/types.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import collections
|
||||
|
||||
from asyncpg.pgproto.types import (
|
||||
BitString, Point, Path, Polygon,
|
||||
Box, Line, LineSegment, Circle,
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
|
||||
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion',
|
||||
)
|
||||
|
||||
|
||||
Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema'])
|
||||
Type.__doc__ = 'Database data type.'
|
||||
Type.oid.__doc__ = 'OID of the type.'
|
||||
Type.name.__doc__ = 'Type name. For example "int2".'
|
||||
Type.kind.__doc__ = \
|
||||
'Type kind. Can be "scalar", "array", "composite" or "range".'
|
||||
Type.schema.__doc__ = 'Name of the database schema that defines the type.'
|
||||
|
||||
|
||||
Attribute = collections.namedtuple('Attribute', ['name', 'type'])
|
||||
Attribute.__doc__ = 'Database relation attribute.'
|
||||
Attribute.name.__doc__ = 'Attribute name.'
|
||||
Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.'
|
||||
|
||||
|
||||
ServerVersion = collections.namedtuple(
|
||||
'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial'])
|
||||
ServerVersion.__doc__ = 'PostgreSQL server version tuple.'
|
||||
|
||||
|
||||
class Range:
|
||||
"""Immutable representation of PostgreSQL `range` type."""
|
||||
|
||||
__slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty'
|
||||
|
||||
def __init__(self, lower=None, upper=None, *,
|
||||
lower_inc=True, upper_inc=False,
|
||||
empty=False):
|
||||
self._empty = empty
|
||||
if empty:
|
||||
self._lower = self._upper = None
|
||||
self._lower_inc = self._upper_inc = False
|
||||
else:
|
||||
self._lower = lower
|
||||
self._upper = upper
|
||||
self._lower_inc = lower is not None and lower_inc
|
||||
self._upper_inc = upper is not None and upper_inc
|
||||
|
||||
@property
|
||||
def lower(self):
|
||||
return self._lower
|
||||
|
||||
@property
|
||||
def lower_inc(self):
|
||||
return self._lower_inc
|
||||
|
||||
@property
|
||||
def lower_inf(self):
|
||||
return self._lower is None and not self._empty
|
||||
|
||||
@property
|
||||
def upper(self):
|
||||
return self._upper
|
||||
|
||||
@property
|
||||
def upper_inc(self):
|
||||
return self._upper_inc
|
||||
|
||||
@property
|
||||
def upper_inf(self):
|
||||
return self._upper is None and not self._empty
|
||||
|
||||
@property
|
||||
def isempty(self):
|
||||
return self._empty
|
||||
|
||||
def _issubset_lower(self, other):
|
||||
if other._lower is None:
|
||||
return True
|
||||
if self._lower is None:
|
||||
return False
|
||||
|
||||
return self._lower > other._lower or (
|
||||
self._lower == other._lower
|
||||
and (other._lower_inc or not self._lower_inc)
|
||||
)
|
||||
|
||||
def _issubset_upper(self, other):
|
||||
if other._upper is None:
|
||||
return True
|
||||
if self._upper is None:
|
||||
return False
|
||||
|
||||
return self._upper < other._upper or (
|
||||
self._upper == other._upper
|
||||
and (other._upper_inc or not self._upper_inc)
|
||||
)
|
||||
|
||||
def issubset(self, other):
|
||||
if self._empty:
|
||||
return True
|
||||
if other._empty:
|
||||
return False
|
||||
|
||||
return self._issubset_lower(other) and self._issubset_upper(other)
|
||||
|
||||
def issuperset(self, other):
|
||||
return other.issubset(self)
|
||||
|
||||
def __bool__(self):
|
||||
return not self._empty
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Range):
|
||||
return NotImplemented
|
||||
|
||||
return (
|
||||
self._lower,
|
||||
self._upper,
|
||||
self._lower_inc,
|
||||
self._upper_inc,
|
||||
self._empty
|
||||
) == (
|
||||
other._lower,
|
||||
other._upper,
|
||||
other._lower_inc,
|
||||
other._upper_inc,
|
||||
other._empty
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self._lower,
|
||||
self._upper,
|
||||
self._lower_inc,
|
||||
self._upper_inc,
|
||||
self._empty
|
||||
))
|
||||
|
||||
def __repr__(self):
|
||||
if self._empty:
|
||||
desc = 'empty'
|
||||
else:
|
||||
if self._lower is None or not self._lower_inc:
|
||||
lb = '('
|
||||
else:
|
||||
lb = '['
|
||||
|
||||
if self._lower is not None:
|
||||
lb += repr(self._lower)
|
||||
|
||||
if self._upper is not None:
|
||||
ub = repr(self._upper)
|
||||
else:
|
||||
ub = ''
|
||||
|
||||
if self._upper is None or not self._upper_inc:
|
||||
ub += ')'
|
||||
else:
|
||||
ub += ']'
|
||||
|
||||
desc = '{}, {}'.format(lb, ub)
|
||||
|
||||
return '<Range {}>'.format(desc)
|
||||
|
||||
__str__ = __repr__
|
||||
45
venv/lib/python3.12/site-packages/asyncpg/utils.py
Normal file
45
venv/lib/python3.12/site-packages/asyncpg/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (C) 2016-present the ayncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def _quote_ident(ident):
|
||||
return '"{}"'.format(ident.replace('"', '""'))
|
||||
|
||||
|
||||
def _quote_literal(string):
|
||||
return "'{}'".format(string.replace("'", "''"))
|
||||
|
||||
|
||||
async def _mogrify(conn, query, args):
|
||||
"""Safely inline arguments to query text."""
|
||||
# Introspect the target query for argument types and
|
||||
# build a list of safely-quoted fully-qualified type names.
|
||||
ps = await conn.prepare(query)
|
||||
paramtypes = []
|
||||
for t in ps.get_parameters():
|
||||
if t.name.endswith('[]'):
|
||||
pname = '_' + t.name[:-2]
|
||||
else:
|
||||
pname = t.name
|
||||
|
||||
paramtypes.append('{}.{}'.format(
|
||||
_quote_ident(t.schema), _quote_ident(pname)))
|
||||
del ps
|
||||
|
||||
# Use Postgres to convert arguments to text representation
|
||||
# by casting each value to text.
|
||||
cols = ['quote_literal(${}::{}::text)'.format(i, t)
|
||||
for i, t in enumerate(paramtypes, start=1)]
|
||||
|
||||
textified = await conn.fetchrow(
|
||||
'SELECT {cols}'.format(cols=', '.join(cols)), *args)
|
||||
|
||||
# Finally, replace $n references with text values.
|
||||
return re.sub(
|
||||
r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query)
|
||||
Reference in New Issue
Block a user