Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-09-25 15:51:48 +09:00
parent dd7349bb4c
commit ddce9f5125
5586 changed files with 1470941 additions and 0 deletions

View File

@@ -0,0 +1,527 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import atexit
import contextlib
import functools
import inspect
import logging
import os
import re
import textwrap
import time
import traceback
import unittest
import asyncpg
from asyncpg import cluster as pg_cluster
from asyncpg import connection as pg_connection
from asyncpg import pool as pg_pool
from . import fuzzer
@contextlib.contextmanager
def silence_asyncio_long_exec_warning():
def flt(log_record):
msg = log_record.getMessage()
return not msg.startswith('Executing ')
logger = logging.getLogger('asyncio')
logger.addFilter(flt)
try:
yield
finally:
logger.removeFilter(flt)
def with_timeout(timeout):
def wrap(func):
func.__timeout__ = timeout
return func
return wrap
class TestCaseMeta(type(unittest.TestCase)):
TEST_TIMEOUT = None
@staticmethod
def _iter_methods(bases, ns):
for base in bases:
for methname in dir(base):
if not methname.startswith('test_'):
continue
meth = getattr(base, methname)
if not inspect.iscoroutinefunction(meth):
continue
yield methname, meth
for methname, meth in ns.items():
if not methname.startswith('test_'):
continue
if not inspect.iscoroutinefunction(meth):
continue
yield methname, meth
def __new__(mcls, name, bases, ns):
for methname, meth in mcls._iter_methods(bases, ns):
@functools.wraps(meth)
def wrapper(self, *args, __meth__=meth, **kwargs):
coro = __meth__(self, *args, **kwargs)
timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT)
if timeout:
coro = asyncio.wait_for(coro, timeout)
try:
self.loop.run_until_complete(coro)
except asyncio.TimeoutError:
raise self.failureException(
'test timed out after {} seconds'.format(
timeout)) from None
else:
self.loop.run_until_complete(coro)
ns[methname] = wrapper
return super().__new__(mcls, name, bases, ns)
class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
@classmethod
def setUpClass(cls):
if os.environ.get('USE_UVLOOP'):
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
cls.loop = loop
@classmethod
def tearDownClass(cls):
cls.loop.close()
asyncio.set_event_loop(None)
def setUp(self):
self.loop.set_exception_handler(self.loop_exception_handler)
self.__unhandled_exceptions = []
def tearDown(self):
if self.__unhandled_exceptions:
formatted = []
for i, context in enumerate(self.__unhandled_exceptions):
formatted.append(self._format_loop_exception(context, i + 1))
self.fail(
'unexpected exceptions in asynchronous code:\n' +
'\n'.join(formatted))
@contextlib.contextmanager
def assertRunUnder(self, delta):
st = time.monotonic()
try:
yield
finally:
elapsed = time.monotonic() - st
if elapsed > delta:
raise AssertionError(
'running block took {:0.3f}s which is longer '
'than the expected maximum of {:0.3f}s'.format(
elapsed, delta))
@contextlib.contextmanager
def assertLoopErrorHandlerCalled(self, msg_re: str):
contexts = []
def handler(loop, ctx):
contexts.append(ctx)
old_handler = self.loop.get_exception_handler()
self.loop.set_exception_handler(handler)
try:
yield
for ctx in contexts:
msg = ctx.get('message')
if msg and re.search(msg_re, msg):
return
raise AssertionError(
'no message matching {!r} was logged with '
'loop.call_exception_handler()'.format(msg_re))
finally:
self.loop.set_exception_handler(old_handler)
def loop_exception_handler(self, loop, context):
self.__unhandled_exceptions.append(context)
loop.default_exception_handler(context)
def _format_loop_exception(self, context, n):
message = context.get('message', 'Unhandled exception in event loop')
exception = context.get('exception')
if exception is not None:
exc_info = (type(exception), exception, exception.__traceback__)
else:
exc_info = None
lines = []
for key in sorted(context):
if key in {'message', 'exception'}:
continue
value = context[key]
if key == 'source_traceback':
tb = ''.join(traceback.format_list(value))
value = 'Object created at (most recent call last):\n'
value += tb.rstrip()
else:
try:
value = repr(value)
except Exception as ex:
value = ('Exception in __repr__ {!r}; '
'value type: {!r}'.format(ex, type(value)))
lines.append('[{}]: {}\n\n'.format(key, value))
if exc_info is not None:
lines.append('[exception]:\n')
formatted_exc = textwrap.indent(
''.join(traceback.format_exception(*exc_info)), ' ')
lines.append(formatted_exc)
details = textwrap.indent(''.join(lines), ' ')
return '{:02d}. {}:\n{}\n'.format(n, message, details)
_default_cluster = None
def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
cluster = ClusterCls(**cluster_kwargs)
cluster.init(**(initdb_options or {}))
cluster.trust_local_connections()
atexit.register(_shutdown_cluster, cluster)
return cluster
def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
initdb_options=None):
cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
cluster.start(port='dynamic', server_settings=server_settings)
return cluster
def _get_initdb_options(initdb_options=None):
if not initdb_options:
initdb_options = {}
else:
initdb_options = dict(initdb_options)
# Make the default superuser name stable.
if 'username' not in initdb_options:
initdb_options['username'] = 'postgres'
return initdb_options
def _init_default_cluster(initdb_options=None):
global _default_cluster
if _default_cluster is None:
pg_host = os.environ.get('PGHOST')
if pg_host:
# Using existing cluster, assuming it is initialized and running
_default_cluster = pg_cluster.RunningCluster()
else:
_default_cluster = _init_cluster(
pg_cluster.TempCluster, cluster_kwargs={},
initdb_options=_get_initdb_options(initdb_options))
return _default_cluster
def _shutdown_cluster(cluster):
if cluster.get_status() == 'running':
cluster.stop()
if cluster.get_status() != 'not-initialized':
cluster.destroy()
def create_pool(dsn=None, *,
min_size=10,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=60.0,
setup=None,
init=None,
loop=None,
pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection,
record_class=asyncpg.Record,
**connect_kwargs):
return pool_class(
dsn,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
record_class=record_class,
**connect_kwargs)
class ClusterTestCase(TestCase):
@classmethod
def get_server_settings(cls):
settings = {
'log_connections': 'on'
}
if cls.cluster.get_pg_version() >= (11, 0):
# JITting messes up timing tests, and
# is not essential for testing.
settings['jit'] = 'off'
return settings
@classmethod
def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
cluster = _init_cluster(ClusterCls, cluster_kwargs,
_get_initdb_options(initdb_options))
cls._clusters.append(cluster)
return cluster
@classmethod
def start_cluster(cls, cluster, *, server_settings={}):
cluster.start(port='dynamic', server_settings=server_settings)
@classmethod
def setup_cluster(cls):
cls.cluster = _init_default_cluster()
if cls.cluster.get_status() != 'running':
cls.cluster.start(
port='dynamic', server_settings=cls.get_server_settings())
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._clusters = []
cls.setup_cluster()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
for cluster in cls._clusters:
if cluster is not _default_cluster:
cluster.stop()
cluster.destroy()
cls._clusters = []
@classmethod
def get_connection_spec(cls, kwargs={}):
conn_spec = cls.cluster.get_connection_spec()
if kwargs.get('dsn'):
conn_spec.pop('host')
conn_spec.update(kwargs)
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
conn_spec['user'] = 'postgres'
return conn_spec
@classmethod
def connect(cls, **kwargs):
conn_spec = cls.get_connection_spec(kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)
def setUp(self):
super().setUp()
self._pools = []
def tearDown(self):
super().tearDown()
for pool in self._pools:
pool.terminate()
self._pools = []
def create_pool(self, pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection, **kwargs):
conn_spec = self.get_connection_spec(kwargs)
pool = create_pool(loop=self.loop, pool_class=pool_class,
connection_class=connection_class, **conn_spec)
self._pools.append(pool)
return pool
class ProxiedClusterTestCase(ClusterTestCase):
@classmethod
def get_server_settings(cls):
settings = dict(super().get_server_settings())
settings['listen_addresses'] = '127.0.0.1'
return settings
@classmethod
def get_proxy_settings(cls):
return {'fuzzing-mode': None}
@classmethod
def setUpClass(cls):
super().setUpClass()
conn_spec = cls.cluster.get_connection_spec()
host = conn_spec.get('host')
if not host:
host = '127.0.0.1'
elif host.startswith('/'):
host = '127.0.0.1'
cls.proxy = fuzzer.TCPFuzzingProxy(
backend_host=host,
backend_port=conn_spec['port'],
)
cls.proxy.start()
@classmethod
def tearDownClass(cls):
cls.proxy.stop()
super().tearDownClass()
@classmethod
def get_connection_spec(cls, kwargs):
conn_spec = super().get_connection_spec(kwargs)
conn_spec['host'] = cls.proxy.listening_addr
conn_spec['port'] = cls.proxy.listening_port
return conn_spec
def tearDown(self):
self.proxy.reset()
super().tearDown()
def with_connection_options(**options):
if not options:
raise ValueError('no connection options were specified')
def wrap(func):
func.__connect_options__ = options
return func
return wrap
class ConnectedTestCase(ClusterTestCase):
def setUp(self):
super().setUp()
# Extract options set up with `with_connection_options`.
test_func = getattr(self, self._testMethodName).__func__
opts = getattr(test_func, '__connect_options__', {})
self.con = self.loop.run_until_complete(self.connect(**opts))
self.server_version = self.con.get_server_version()
def tearDown(self):
try:
self.loop.run_until_complete(self.con.close())
self.con = None
finally:
super().tearDown()
class HotStandbyTestCase(ClusterTestCase):
@classmethod
def setup_cluster(cls):
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
cls.start_cluster(
cls.master_cluster,
server_settings={
'max_wal_senders': 10,
'wal_level': 'hot_standby'
}
)
con = None
try:
con = cls.loop.run_until_complete(
cls.master_cluster.connect(
database='postgres', user='postgres', loop=cls.loop))
cls.loop.run_until_complete(
con.execute('''
CREATE ROLE replication WITH LOGIN REPLICATION
'''))
cls.master_cluster.trust_local_replication_by('replication')
conn_spec = cls.master_cluster.get_connection_spec()
cls.standby_cluster = cls.new_cluster(
pg_cluster.HotStandbyCluster,
cluster_kwargs={
'master': conn_spec,
'replication_user': 'replication'
}
)
cls.start_cluster(
cls.standby_cluster,
server_settings={
'hot_standby': True
}
)
finally:
if con is not None:
cls.loop.run_until_complete(con.close())
@classmethod
def get_cluster_connection_spec(cls, cluster, kwargs={}):
conn_spec = cluster.get_connection_spec()
if kwargs.get('dsn'):
conn_spec.pop('host')
conn_spec.update(kwargs)
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
conn_spec['user'] = 'postgres'
return conn_spec
@classmethod
def get_connection_spec(cls, kwargs={}):
primary_spec = cls.get_cluster_connection_spec(
cls.master_cluster, kwargs
)
standby_spec = cls.get_cluster_connection_spec(
cls.standby_cluster, kwargs
)
return {
'host': [primary_spec['host'], standby_spec['host']],
'port': [primary_spec['port'], standby_spec['port']],
'database': primary_spec['database'],
'user': primary_spec['user'],
**kwargs
}
@classmethod
def connect_primary(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)
@classmethod
def connect_standby(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(
cls.standby_cluster,
kwargs
)
return pg_connection.connect(**conn_spec, loop=cls.loop)

View File

@@ -0,0 +1,306 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import socket
import threading
import typing
from asyncpg import cluster
class StopServer(Exception):
pass
class TCPFuzzingProxy:
def __init__(self, *, listening_addr: str='127.0.0.1',
listening_port: typing.Optional[int]=None,
backend_host: str, backend_port: int,
settings: typing.Optional[dict]=None) -> None:
self.listening_addr = listening_addr
self.listening_port = listening_port
self.backend_host = backend_host
self.backend_port = backend_port
self.settings = settings or {}
self.loop = None
self.connectivity = None
self.connectivity_loss = None
self.stop_event = None
self.connections = {}
self.sock = None
self.listen_task = None
async def _wait(self, work):
work_task = asyncio.ensure_future(work)
stop_event_task = asyncio.ensure_future(self.stop_event.wait())
try:
await asyncio.wait(
[work_task, stop_event_task],
return_when=asyncio.FIRST_COMPLETED)
if self.stop_event.is_set():
raise StopServer()
else:
return work_task.result()
finally:
if not work_task.done():
work_task.cancel()
if not stop_event_task.done():
stop_event_task.cancel()
def start(self):
started = threading.Event()
self.thread = threading.Thread(
target=self._start_thread, args=(started,))
self.thread.start()
if not started.wait(timeout=2):
raise RuntimeError('fuzzer proxy failed to start')
def stop(self):
self.loop.call_soon_threadsafe(self._stop)
self.thread.join()
def _stop(self):
self.stop_event.set()
def _start_thread(self, started_event):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.connectivity = asyncio.Event()
self.connectivity.set()
self.connectivity_loss = asyncio.Event()
self.stop_event = asyncio.Event()
if self.listening_port is None:
self.listening_port = cluster.find_available_port()
self.sock = socket.socket()
self.sock.bind((self.listening_addr, self.listening_port))
self.sock.listen(50)
self.sock.setblocking(False)
try:
self.loop.run_until_complete(self._main(started_event))
finally:
self.loop.close()
async def _main(self, started_event):
self.listen_task = asyncio.ensure_future(self.listen())
# Notify the main thread that we are ready to go.
started_event.set()
try:
await self.listen_task
finally:
for c in list(self.connections):
c.close()
await asyncio.sleep(0.01)
if hasattr(self.loop, 'remove_reader'):
self.loop.remove_reader(self.sock.fileno())
self.sock.close()
async def listen(self):
while True:
try:
client_sock, _ = await self._wait(
self.loop.sock_accept(self.sock))
backend_sock = socket.socket()
backend_sock.setblocking(False)
await self._wait(self.loop.sock_connect(
backend_sock, (self.backend_host, self.backend_port)))
except StopServer:
break
conn = Connection(client_sock, backend_sock, self)
conn_task = self.loop.create_task(conn.handle())
self.connections[conn] = conn_task
def trigger_connectivity_loss(self):
self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)
def _trigger_connectivity_loss(self):
self.connectivity.clear()
self.connectivity_loss.set()
def restore_connectivity(self):
self.loop.call_soon_threadsafe(self._restore_connectivity)
def _restore_connectivity(self):
self.connectivity.set()
self.connectivity_loss.clear()
def reset(self):
self.restore_connectivity()
def _close_connection(self, connection):
conn_task = self.connections.pop(connection, None)
if conn_task is not None:
conn_task.cancel()
def close_all_connections(self):
for conn in list(self.connections):
self.loop.call_soon_threadsafe(self._close_connection, conn)
class Connection:
def __init__(self, client_sock, backend_sock, proxy):
self.client_sock = client_sock
self.backend_sock = backend_sock
self.proxy = proxy
self.loop = proxy.loop
self.connectivity = proxy.connectivity
self.connectivity_loss = proxy.connectivity_loss
self.proxy_to_backend_task = None
self.proxy_from_backend_task = None
self.is_closed = False
def close(self):
if self.is_closed:
return
self.is_closed = True
if self.proxy_to_backend_task is not None:
self.proxy_to_backend_task.cancel()
self.proxy_to_backend_task = None
if self.proxy_from_backend_task is not None:
self.proxy_from_backend_task.cancel()
self.proxy_from_backend_task = None
self.proxy._close_connection(self)
async def handle(self):
self.proxy_to_backend_task = asyncio.ensure_future(
self.proxy_to_backend())
self.proxy_from_backend_task = asyncio.ensure_future(
self.proxy_from_backend())
try:
await asyncio.wait(
[self.proxy_to_backend_task, self.proxy_from_backend_task],
return_when=asyncio.FIRST_COMPLETED)
finally:
if self.proxy_to_backend_task is not None:
self.proxy_to_backend_task.cancel()
if self.proxy_from_backend_task is not None:
self.proxy_from_backend_task.cancel()
# Asyncio fails to properly remove the readers and writers
# when the task doing recv() or send() is cancelled, so
# we must remove the readers and writers manually before
# closing the sockets.
self.loop.remove_reader(self.client_sock.fileno())
self.loop.remove_writer(self.client_sock.fileno())
self.loop.remove_reader(self.backend_sock.fileno())
self.loop.remove_writer(self.backend_sock.fileno())
self.client_sock.close()
self.backend_sock.close()
async def _read(self, sock, n):
read_task = asyncio.ensure_future(
self.loop.sock_recv(sock, n))
conn_event_task = asyncio.ensure_future(
self.connectivity_loss.wait())
try:
await asyncio.wait(
[read_task, conn_event_task],
return_when=asyncio.FIRST_COMPLETED)
if self.connectivity_loss.is_set():
return None
else:
return read_task.result()
finally:
if not self.loop.is_closed():
if not read_task.done():
read_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
async def _write(self, sock, data):
write_task = asyncio.ensure_future(
self.loop.sock_sendall(sock, data))
conn_event_task = asyncio.ensure_future(
self.connectivity_loss.wait())
try:
await asyncio.wait(
[write_task, conn_event_task],
return_when=asyncio.FIRST_COMPLETED)
if self.connectivity_loss.is_set():
return None
else:
return write_task.result()
finally:
if not self.loop.is_closed():
if not write_task.done():
write_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
async def proxy_to_backend(self):
buf = None
try:
while True:
await self.connectivity.wait()
if buf is not None:
data = buf
buf = None
else:
data = await self._read(self.client_sock, 4096)
if data == b'':
break
if self.connectivity_loss.is_set():
if data:
buf = data
continue
await self._write(self.backend_sock, data)
except ConnectionError:
pass
finally:
if not self.loop.is_closed():
self.loop.call_soon(self.close)
async def proxy_from_backend(self):
buf = None
try:
while True:
await self.connectivity.wait()
if buf is not None:
data = buf
buf = None
else:
data = await self._read(self.backend_sock, 4096)
if data == b'':
break
if self.connectivity_loss.is_set():
if data:
buf = data
continue
await self._write(self.client_sock, data)
except ConnectionError:
pass
finally:
if not self.loop.is_closed():
self.loop.call_soon(self.close)