Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
527
venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py
Normal file
527
venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
|
||||
import asyncpg
|
||||
from asyncpg import cluster as pg_cluster
|
||||
from asyncpg import connection as pg_connection
|
||||
from asyncpg import pool as pg_pool
|
||||
|
||||
from . import fuzzer
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def silence_asyncio_long_exec_warning():
|
||||
def flt(log_record):
|
||||
msg = log_record.getMessage()
|
||||
return not msg.startswith('Executing ')
|
||||
|
||||
logger = logging.getLogger('asyncio')
|
||||
logger.addFilter(flt)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.removeFilter(flt)
|
||||
|
||||
|
||||
def with_timeout(timeout):
|
||||
def wrap(func):
|
||||
func.__timeout__ = timeout
|
||||
return func
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class TestCaseMeta(type(unittest.TestCase)):
|
||||
TEST_TIMEOUT = None
|
||||
|
||||
@staticmethod
|
||||
def _iter_methods(bases, ns):
|
||||
for base in bases:
|
||||
for methname in dir(base):
|
||||
if not methname.startswith('test_'):
|
||||
continue
|
||||
|
||||
meth = getattr(base, methname)
|
||||
if not inspect.iscoroutinefunction(meth):
|
||||
continue
|
||||
|
||||
yield methname, meth
|
||||
|
||||
for methname, meth in ns.items():
|
||||
if not methname.startswith('test_'):
|
||||
continue
|
||||
|
||||
if not inspect.iscoroutinefunction(meth):
|
||||
continue
|
||||
|
||||
yield methname, meth
|
||||
|
||||
def __new__(mcls, name, bases, ns):
|
||||
for methname, meth in mcls._iter_methods(bases, ns):
|
||||
@functools.wraps(meth)
|
||||
def wrapper(self, *args, __meth__=meth, **kwargs):
|
||||
coro = __meth__(self, *args, **kwargs)
|
||||
timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT)
|
||||
if timeout:
|
||||
coro = asyncio.wait_for(coro, timeout)
|
||||
try:
|
||||
self.loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError:
|
||||
raise self.failureException(
|
||||
'test timed out after {} seconds'.format(
|
||||
timeout)) from None
|
||||
else:
|
||||
self.loop.run_until_complete(coro)
|
||||
ns[methname] = wrapper
|
||||
|
||||
return super().__new__(mcls, name, bases, ns)
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if os.environ.get('USE_UVLOOP'):
|
||||
import uvloop
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(None)
|
||||
cls.loop = loop
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def setUp(self):
|
||||
self.loop.set_exception_handler(self.loop_exception_handler)
|
||||
self.__unhandled_exceptions = []
|
||||
|
||||
def tearDown(self):
|
||||
if self.__unhandled_exceptions:
|
||||
formatted = []
|
||||
|
||||
for i, context in enumerate(self.__unhandled_exceptions):
|
||||
formatted.append(self._format_loop_exception(context, i + 1))
|
||||
|
||||
self.fail(
|
||||
'unexpected exceptions in asynchronous code:\n' +
|
||||
'\n'.join(formatted))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertRunUnder(self, delta):
|
||||
st = time.monotonic()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed = time.monotonic() - st
|
||||
if elapsed > delta:
|
||||
raise AssertionError(
|
||||
'running block took {:0.3f}s which is longer '
|
||||
'than the expected maximum of {:0.3f}s'.format(
|
||||
elapsed, delta))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assertLoopErrorHandlerCalled(self, msg_re: str):
|
||||
contexts = []
|
||||
|
||||
def handler(loop, ctx):
|
||||
contexts.append(ctx)
|
||||
|
||||
old_handler = self.loop.get_exception_handler()
|
||||
self.loop.set_exception_handler(handler)
|
||||
try:
|
||||
yield
|
||||
|
||||
for ctx in contexts:
|
||||
msg = ctx.get('message')
|
||||
if msg and re.search(msg_re, msg):
|
||||
return
|
||||
|
||||
raise AssertionError(
|
||||
'no message matching {!r} was logged with '
|
||||
'loop.call_exception_handler()'.format(msg_re))
|
||||
|
||||
finally:
|
||||
self.loop.set_exception_handler(old_handler)
|
||||
|
||||
def loop_exception_handler(self, loop, context):
|
||||
self.__unhandled_exceptions.append(context)
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
def _format_loop_exception(self, context, n):
|
||||
message = context.get('message', 'Unhandled exception in event loop')
|
||||
exception = context.get('exception')
|
||||
if exception is not None:
|
||||
exc_info = (type(exception), exception, exception.__traceback__)
|
||||
else:
|
||||
exc_info = None
|
||||
|
||||
lines = []
|
||||
for key in sorted(context):
|
||||
if key in {'message', 'exception'}:
|
||||
continue
|
||||
value = context[key]
|
||||
if key == 'source_traceback':
|
||||
tb = ''.join(traceback.format_list(value))
|
||||
value = 'Object created at (most recent call last):\n'
|
||||
value += tb.rstrip()
|
||||
else:
|
||||
try:
|
||||
value = repr(value)
|
||||
except Exception as ex:
|
||||
value = ('Exception in __repr__ {!r}; '
|
||||
'value type: {!r}'.format(ex, type(value)))
|
||||
lines.append('[{}]: {}\n\n'.format(key, value))
|
||||
|
||||
if exc_info is not None:
|
||||
lines.append('[exception]:\n')
|
||||
formatted_exc = textwrap.indent(
|
||||
''.join(traceback.format_exception(*exc_info)), ' ')
|
||||
lines.append(formatted_exc)
|
||||
|
||||
details = textwrap.indent(''.join(lines), ' ')
|
||||
return '{:02d}. {}:\n{}\n'.format(n, message, details)
|
||||
|
||||
|
||||
_default_cluster = None
|
||||
|
||||
|
||||
def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
|
||||
cluster = ClusterCls(**cluster_kwargs)
|
||||
cluster.init(**(initdb_options or {}))
|
||||
cluster.trust_local_connections()
|
||||
atexit.register(_shutdown_cluster, cluster)
|
||||
return cluster
|
||||
|
||||
|
||||
def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
|
||||
initdb_options=None):
|
||||
cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
|
||||
cluster.start(port='dynamic', server_settings=server_settings)
|
||||
return cluster
|
||||
|
||||
|
||||
def _get_initdb_options(initdb_options=None):
|
||||
if not initdb_options:
|
||||
initdb_options = {}
|
||||
else:
|
||||
initdb_options = dict(initdb_options)
|
||||
|
||||
# Make the default superuser name stable.
|
||||
if 'username' not in initdb_options:
|
||||
initdb_options['username'] = 'postgres'
|
||||
|
||||
return initdb_options
|
||||
|
||||
|
||||
def _init_default_cluster(initdb_options=None):
|
||||
global _default_cluster
|
||||
|
||||
if _default_cluster is None:
|
||||
pg_host = os.environ.get('PGHOST')
|
||||
if pg_host:
|
||||
# Using existing cluster, assuming it is initialized and running
|
||||
_default_cluster = pg_cluster.RunningCluster()
|
||||
else:
|
||||
_default_cluster = _init_cluster(
|
||||
pg_cluster.TempCluster, cluster_kwargs={},
|
||||
initdb_options=_get_initdb_options(initdb_options))
|
||||
|
||||
return _default_cluster
|
||||
|
||||
|
||||
def _shutdown_cluster(cluster):
|
||||
if cluster.get_status() == 'running':
|
||||
cluster.stop()
|
||||
if cluster.get_status() != 'not-initialized':
|
||||
cluster.destroy()
|
||||
|
||||
|
||||
def create_pool(dsn=None, *,
|
||||
min_size=10,
|
||||
max_size=10,
|
||||
max_queries=50000,
|
||||
max_inactive_connection_lifetime=60.0,
|
||||
setup=None,
|
||||
init=None,
|
||||
loop=None,
|
||||
pool_class=pg_pool.Pool,
|
||||
connection_class=pg_connection.Connection,
|
||||
record_class=asyncpg.Record,
|
||||
**connect_kwargs):
|
||||
return pool_class(
|
||||
dsn,
|
||||
min_size=min_size, max_size=max_size,
|
||||
max_queries=max_queries, loop=loop, setup=setup, init=init,
|
||||
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
|
||||
connection_class=connection_class,
|
||||
record_class=record_class,
|
||||
**connect_kwargs)
|
||||
|
||||
|
||||
class ClusterTestCase(TestCase):
|
||||
@classmethod
|
||||
def get_server_settings(cls):
|
||||
settings = {
|
||||
'log_connections': 'on'
|
||||
}
|
||||
|
||||
if cls.cluster.get_pg_version() >= (11, 0):
|
||||
# JITting messes up timing tests, and
|
||||
# is not essential for testing.
|
||||
settings['jit'] = 'off'
|
||||
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
|
||||
cluster = _init_cluster(ClusterCls, cluster_kwargs,
|
||||
_get_initdb_options(initdb_options))
|
||||
cls._clusters.append(cluster)
|
||||
return cluster
|
||||
|
||||
@classmethod
|
||||
def start_cluster(cls, cluster, *, server_settings={}):
|
||||
cluster.start(port='dynamic', server_settings=server_settings)
|
||||
|
||||
@classmethod
|
||||
def setup_cluster(cls):
|
||||
cls.cluster = _init_default_cluster()
|
||||
|
||||
if cls.cluster.get_status() != 'running':
|
||||
cls.cluster.start(
|
||||
port='dynamic', server_settings=cls.get_server_settings())
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._clusters = []
|
||||
cls.setup_cluster()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
for cluster in cls._clusters:
|
||||
if cluster is not _default_cluster:
|
||||
cluster.stop()
|
||||
cluster.destroy()
|
||||
cls._clusters = []
|
||||
|
||||
@classmethod
|
||||
def get_connection_spec(cls, kwargs={}):
|
||||
conn_spec = cls.cluster.get_connection_spec()
|
||||
if kwargs.get('dsn'):
|
||||
conn_spec.pop('host')
|
||||
conn_spec.update(kwargs)
|
||||
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
|
||||
if 'database' not in conn_spec:
|
||||
conn_spec['database'] = 'postgres'
|
||||
if 'user' not in conn_spec:
|
||||
conn_spec['user'] = 'postgres'
|
||||
return conn_spec
|
||||
|
||||
@classmethod
|
||||
def connect(cls, **kwargs):
|
||||
conn_spec = cls.get_connection_spec(kwargs)
|
||||
return pg_connection.connect(**conn_spec, loop=cls.loop)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._pools = []
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
for pool in self._pools:
|
||||
pool.terminate()
|
||||
self._pools = []
|
||||
|
||||
def create_pool(self, pool_class=pg_pool.Pool,
|
||||
connection_class=pg_connection.Connection, **kwargs):
|
||||
conn_spec = self.get_connection_spec(kwargs)
|
||||
pool = create_pool(loop=self.loop, pool_class=pool_class,
|
||||
connection_class=connection_class, **conn_spec)
|
||||
self._pools.append(pool)
|
||||
return pool
|
||||
|
||||
|
||||
class ProxiedClusterTestCase(ClusterTestCase):
|
||||
@classmethod
|
||||
def get_server_settings(cls):
|
||||
settings = dict(super().get_server_settings())
|
||||
settings['listen_addresses'] = '127.0.0.1'
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
def get_proxy_settings(cls):
|
||||
return {'fuzzing-mode': None}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
conn_spec = cls.cluster.get_connection_spec()
|
||||
host = conn_spec.get('host')
|
||||
if not host:
|
||||
host = '127.0.0.1'
|
||||
elif host.startswith('/'):
|
||||
host = '127.0.0.1'
|
||||
cls.proxy = fuzzer.TCPFuzzingProxy(
|
||||
backend_host=host,
|
||||
backend_port=conn_spec['port'],
|
||||
)
|
||||
cls.proxy.start()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.proxy.stop()
|
||||
super().tearDownClass()
|
||||
|
||||
@classmethod
|
||||
def get_connection_spec(cls, kwargs):
|
||||
conn_spec = super().get_connection_spec(kwargs)
|
||||
conn_spec['host'] = cls.proxy.listening_addr
|
||||
conn_spec['port'] = cls.proxy.listening_port
|
||||
return conn_spec
|
||||
|
||||
def tearDown(self):
|
||||
self.proxy.reset()
|
||||
super().tearDown()
|
||||
|
||||
|
||||
def with_connection_options(**options):
|
||||
if not options:
|
||||
raise ValueError('no connection options were specified')
|
||||
|
||||
def wrap(func):
|
||||
func.__connect_options__ = options
|
||||
return func
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class ConnectedTestCase(ClusterTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# Extract options set up with `with_connection_options`.
|
||||
test_func = getattr(self, self._testMethodName).__func__
|
||||
opts = getattr(test_func, '__connect_options__', {})
|
||||
self.con = self.loop.run_until_complete(self.connect(**opts))
|
||||
self.server_version = self.con.get_server_version()
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
self.loop.run_until_complete(self.con.close())
|
||||
self.con = None
|
||||
finally:
|
||||
super().tearDown()
|
||||
|
||||
|
||||
class HotStandbyTestCase(ClusterTestCase):
|
||||
|
||||
@classmethod
|
||||
def setup_cluster(cls):
|
||||
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
|
||||
cls.start_cluster(
|
||||
cls.master_cluster,
|
||||
server_settings={
|
||||
'max_wal_senders': 10,
|
||||
'wal_level': 'hot_standby'
|
||||
}
|
||||
)
|
||||
|
||||
con = None
|
||||
|
||||
try:
|
||||
con = cls.loop.run_until_complete(
|
||||
cls.master_cluster.connect(
|
||||
database='postgres', user='postgres', loop=cls.loop))
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
con.execute('''
|
||||
CREATE ROLE replication WITH LOGIN REPLICATION
|
||||
'''))
|
||||
|
||||
cls.master_cluster.trust_local_replication_by('replication')
|
||||
|
||||
conn_spec = cls.master_cluster.get_connection_spec()
|
||||
|
||||
cls.standby_cluster = cls.new_cluster(
|
||||
pg_cluster.HotStandbyCluster,
|
||||
cluster_kwargs={
|
||||
'master': conn_spec,
|
||||
'replication_user': 'replication'
|
||||
}
|
||||
)
|
||||
cls.start_cluster(
|
||||
cls.standby_cluster,
|
||||
server_settings={
|
||||
'hot_standby': True
|
||||
}
|
||||
)
|
||||
|
||||
finally:
|
||||
if con is not None:
|
||||
cls.loop.run_until_complete(con.close())
|
||||
|
||||
@classmethod
|
||||
def get_cluster_connection_spec(cls, cluster, kwargs={}):
|
||||
conn_spec = cluster.get_connection_spec()
|
||||
if kwargs.get('dsn'):
|
||||
conn_spec.pop('host')
|
||||
conn_spec.update(kwargs)
|
||||
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
|
||||
if 'database' not in conn_spec:
|
||||
conn_spec['database'] = 'postgres'
|
||||
if 'user' not in conn_spec:
|
||||
conn_spec['user'] = 'postgres'
|
||||
return conn_spec
|
||||
|
||||
@classmethod
|
||||
def get_connection_spec(cls, kwargs={}):
|
||||
primary_spec = cls.get_cluster_connection_spec(
|
||||
cls.master_cluster, kwargs
|
||||
)
|
||||
standby_spec = cls.get_cluster_connection_spec(
|
||||
cls.standby_cluster, kwargs
|
||||
)
|
||||
return {
|
||||
'host': [primary_spec['host'], standby_spec['host']],
|
||||
'port': [primary_spec['port'], standby_spec['port']],
|
||||
'database': primary_spec['database'],
|
||||
'user': primary_spec['user'],
|
||||
**kwargs
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def connect_primary(cls, **kwargs):
|
||||
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
|
||||
return pg_connection.connect(**conn_spec, loop=cls.loop)
|
||||
|
||||
@classmethod
|
||||
def connect_standby(cls, **kwargs):
|
||||
conn_spec = cls.get_cluster_connection_spec(
|
||||
cls.standby_cluster,
|
||||
kwargs
|
||||
)
|
||||
return pg_connection.connect(**conn_spec, loop=cls.loop)
|
||||
306
venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py
Normal file
306
venv/lib/python3.12/site-packages/asyncpg/_testbase/fuzzer.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# Copyright (C) 2016-present the asyncpg authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of asyncpg and is released under
|
||||
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import threading
|
||||
import typing
|
||||
|
||||
from asyncpg import cluster
|
||||
|
||||
|
||||
class StopServer(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TCPFuzzingProxy:
|
||||
def __init__(self, *, listening_addr: str='127.0.0.1',
|
||||
listening_port: typing.Optional[int]=None,
|
||||
backend_host: str, backend_port: int,
|
||||
settings: typing.Optional[dict]=None) -> None:
|
||||
self.listening_addr = listening_addr
|
||||
self.listening_port = listening_port
|
||||
self.backend_host = backend_host
|
||||
self.backend_port = backend_port
|
||||
self.settings = settings or {}
|
||||
self.loop = None
|
||||
self.connectivity = None
|
||||
self.connectivity_loss = None
|
||||
self.stop_event = None
|
||||
self.connections = {}
|
||||
self.sock = None
|
||||
self.listen_task = None
|
||||
|
||||
async def _wait(self, work):
|
||||
work_task = asyncio.ensure_future(work)
|
||||
stop_event_task = asyncio.ensure_future(self.stop_event.wait())
|
||||
|
||||
try:
|
||||
await asyncio.wait(
|
||||
[work_task, stop_event_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if self.stop_event.is_set():
|
||||
raise StopServer()
|
||||
else:
|
||||
return work_task.result()
|
||||
finally:
|
||||
if not work_task.done():
|
||||
work_task.cancel()
|
||||
if not stop_event_task.done():
|
||||
stop_event_task.cancel()
|
||||
|
||||
def start(self):
|
||||
started = threading.Event()
|
||||
self.thread = threading.Thread(
|
||||
target=self._start_thread, args=(started,))
|
||||
self.thread.start()
|
||||
if not started.wait(timeout=2):
|
||||
raise RuntimeError('fuzzer proxy failed to start')
|
||||
|
||||
def stop(self):
|
||||
self.loop.call_soon_threadsafe(self._stop)
|
||||
self.thread.join()
|
||||
|
||||
def _stop(self):
|
||||
self.stop_event.set()
|
||||
|
||||
def _start_thread(self, started_event):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.connectivity = asyncio.Event()
|
||||
self.connectivity.set()
|
||||
self.connectivity_loss = asyncio.Event()
|
||||
self.stop_event = asyncio.Event()
|
||||
|
||||
if self.listening_port is None:
|
||||
self.listening_port = cluster.find_available_port()
|
||||
|
||||
self.sock = socket.socket()
|
||||
self.sock.bind((self.listening_addr, self.listening_port))
|
||||
self.sock.listen(50)
|
||||
self.sock.setblocking(False)
|
||||
|
||||
try:
|
||||
self.loop.run_until_complete(self._main(started_event))
|
||||
finally:
|
||||
self.loop.close()
|
||||
|
||||
async def _main(self, started_event):
|
||||
self.listen_task = asyncio.ensure_future(self.listen())
|
||||
# Notify the main thread that we are ready to go.
|
||||
started_event.set()
|
||||
try:
|
||||
await self.listen_task
|
||||
finally:
|
||||
for c in list(self.connections):
|
||||
c.close()
|
||||
await asyncio.sleep(0.01)
|
||||
if hasattr(self.loop, 'remove_reader'):
|
||||
self.loop.remove_reader(self.sock.fileno())
|
||||
self.sock.close()
|
||||
|
||||
async def listen(self):
|
||||
while True:
|
||||
try:
|
||||
client_sock, _ = await self._wait(
|
||||
self.loop.sock_accept(self.sock))
|
||||
|
||||
backend_sock = socket.socket()
|
||||
backend_sock.setblocking(False)
|
||||
|
||||
await self._wait(self.loop.sock_connect(
|
||||
backend_sock, (self.backend_host, self.backend_port)))
|
||||
except StopServer:
|
||||
break
|
||||
|
||||
conn = Connection(client_sock, backend_sock, self)
|
||||
conn_task = self.loop.create_task(conn.handle())
|
||||
self.connections[conn] = conn_task
|
||||
|
||||
def trigger_connectivity_loss(self):
|
||||
self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)
|
||||
|
||||
def _trigger_connectivity_loss(self):
|
||||
self.connectivity.clear()
|
||||
self.connectivity_loss.set()
|
||||
|
||||
def restore_connectivity(self):
|
||||
self.loop.call_soon_threadsafe(self._restore_connectivity)
|
||||
|
||||
def _restore_connectivity(self):
|
||||
self.connectivity.set()
|
||||
self.connectivity_loss.clear()
|
||||
|
||||
def reset(self):
|
||||
self.restore_connectivity()
|
||||
|
||||
def _close_connection(self, connection):
|
||||
conn_task = self.connections.pop(connection, None)
|
||||
if conn_task is not None:
|
||||
conn_task.cancel()
|
||||
|
||||
def close_all_connections(self):
|
||||
for conn in list(self.connections):
|
||||
self.loop.call_soon_threadsafe(self._close_connection, conn)
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, client_sock, backend_sock, proxy):
|
||||
self.client_sock = client_sock
|
||||
self.backend_sock = backend_sock
|
||||
self.proxy = proxy
|
||||
self.loop = proxy.loop
|
||||
self.connectivity = proxy.connectivity
|
||||
self.connectivity_loss = proxy.connectivity_loss
|
||||
self.proxy_to_backend_task = None
|
||||
self.proxy_from_backend_task = None
|
||||
self.is_closed = False
|
||||
|
||||
def close(self):
|
||||
if self.is_closed:
|
||||
return
|
||||
|
||||
self.is_closed = True
|
||||
|
||||
if self.proxy_to_backend_task is not None:
|
||||
self.proxy_to_backend_task.cancel()
|
||||
self.proxy_to_backend_task = None
|
||||
|
||||
if self.proxy_from_backend_task is not None:
|
||||
self.proxy_from_backend_task.cancel()
|
||||
self.proxy_from_backend_task = None
|
||||
|
||||
self.proxy._close_connection(self)
|
||||
|
||||
async def handle(self):
|
||||
self.proxy_to_backend_task = asyncio.ensure_future(
|
||||
self.proxy_to_backend())
|
||||
|
||||
self.proxy_from_backend_task = asyncio.ensure_future(
|
||||
self.proxy_from_backend())
|
||||
|
||||
try:
|
||||
await asyncio.wait(
|
||||
[self.proxy_to_backend_task, self.proxy_from_backend_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
finally:
|
||||
if self.proxy_to_backend_task is not None:
|
||||
self.proxy_to_backend_task.cancel()
|
||||
|
||||
if self.proxy_from_backend_task is not None:
|
||||
self.proxy_from_backend_task.cancel()
|
||||
|
||||
# Asyncio fails to properly remove the readers and writers
|
||||
# when the task doing recv() or send() is cancelled, so
|
||||
# we must remove the readers and writers manually before
|
||||
# closing the sockets.
|
||||
self.loop.remove_reader(self.client_sock.fileno())
|
||||
self.loop.remove_writer(self.client_sock.fileno())
|
||||
self.loop.remove_reader(self.backend_sock.fileno())
|
||||
self.loop.remove_writer(self.backend_sock.fileno())
|
||||
|
||||
self.client_sock.close()
|
||||
self.backend_sock.close()
|
||||
|
||||
async def _read(self, sock, n):
|
||||
read_task = asyncio.ensure_future(
|
||||
self.loop.sock_recv(sock, n))
|
||||
conn_event_task = asyncio.ensure_future(
|
||||
self.connectivity_loss.wait())
|
||||
|
||||
try:
|
||||
await asyncio.wait(
|
||||
[read_task, conn_event_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if self.connectivity_loss.is_set():
|
||||
return None
|
||||
else:
|
||||
return read_task.result()
|
||||
finally:
|
||||
if not self.loop.is_closed():
|
||||
if not read_task.done():
|
||||
read_task.cancel()
|
||||
if not conn_event_task.done():
|
||||
conn_event_task.cancel()
|
||||
|
||||
async def _write(self, sock, data):
|
||||
write_task = asyncio.ensure_future(
|
||||
self.loop.sock_sendall(sock, data))
|
||||
conn_event_task = asyncio.ensure_future(
|
||||
self.connectivity_loss.wait())
|
||||
|
||||
try:
|
||||
await asyncio.wait(
|
||||
[write_task, conn_event_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if self.connectivity_loss.is_set():
|
||||
return None
|
||||
else:
|
||||
return write_task.result()
|
||||
finally:
|
||||
if not self.loop.is_closed():
|
||||
if not write_task.done():
|
||||
write_task.cancel()
|
||||
if not conn_event_task.done():
|
||||
conn_event_task.cancel()
|
||||
|
||||
async def proxy_to_backend(self):
|
||||
buf = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
await self.connectivity.wait()
|
||||
if buf is not None:
|
||||
data = buf
|
||||
buf = None
|
||||
else:
|
||||
data = await self._read(self.client_sock, 4096)
|
||||
if data == b'':
|
||||
break
|
||||
if self.connectivity_loss.is_set():
|
||||
if data:
|
||||
buf = data
|
||||
continue
|
||||
await self._write(self.backend_sock, data)
|
||||
|
||||
except ConnectionError:
|
||||
pass
|
||||
|
||||
finally:
|
||||
if not self.loop.is_closed():
|
||||
self.loop.call_soon(self.close)
|
||||
|
||||
async def proxy_from_backend(self):
|
||||
buf = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
await self.connectivity.wait()
|
||||
if buf is not None:
|
||||
data = buf
|
||||
buf = None
|
||||
else:
|
||||
data = await self._read(self.backend_sock, 4096)
|
||||
if data == b'':
|
||||
break
|
||||
if self.connectivity_loss.is_set():
|
||||
if data:
|
||||
buf = data
|
||||
continue
|
||||
await self._write(self.client_sock, data)
|
||||
|
||||
except ConnectionError:
|
||||
pass
|
||||
|
||||
finally:
|
||||
if not self.loop.is_closed():
|
||||
self.loop.call_soon(self.close)
|
||||
Reference in New Issue
Block a user