API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -3,8 +3,8 @@ import copy
import enum
import inspect
import socket
import ssl
import sys
import warnings
import weakref
from abc import abstractmethod
from itertools import chain
@@ -16,14 +16,30 @@ from typing import (
List,
Mapping,
Optional,
Protocol,
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
)
from urllib.parse import ParseResult, parse_qs, unquote, urlparse
from ..utils import SSL_AVAILABLE
if SSL_AVAILABLE:
import ssl
from ssl import SSLContext, TLSVersion
else:
ssl = None
TLSVersion = None
SSLContext = None
from ..auth.token import TokenInterface
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
from ..utils import deprecated_args, format_error_message
# the functionality is available in 3.11.x but has a major issue before
# 3.11.3. See https://github.com/redis/redis-py/issues/2633
if sys.version_info >= (3, 11, 3):
@@ -33,7 +49,6 @@ else:
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.compat import Protocol, TypedDict
from redis.connection import DEFAULT_RESP_VERSION
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
@@ -78,13 +93,11 @@ else:
class ConnectCallbackProtocol(Protocol):
def __call__(self, connection: "AbstractConnection"):
...
def __call__(self, connection: "AbstractConnection"): ...
class AsyncConnectCallbackProtocol(Protocol):
async def __call__(self, connection: "AbstractConnection"):
...
async def __call__(self, connection: "AbstractConnection"): ...
ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
@@ -146,6 +159,7 @@ class AbstractConnection:
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = None,
):
if (username or password) and credential_provider is not None:
raise DataError(
@@ -154,6 +168,10 @@ class AbstractConnection:
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.db = db
self.client_name = client_name
self.lib_name = lib_name
@@ -193,6 +211,8 @@ class AbstractConnection:
self.set_parser(parser_class)
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
self._buffer_cutoff = 6000
self._re_auth_token: Optional[TokenInterface] = None
try:
p = int(protocol)
except TypeError:
@@ -204,9 +224,33 @@ class AbstractConnection:
raise ConnectionError("protocol must be either 2 or 3")
self.protocol = protocol
def __del__(self, _warnings: Any = warnings):
# For some reason, the individual streams don't get properly garbage
# collected and therefore produce no resource warnings. We add one
# here, in the same style as those from the stdlib.
if getattr(self, "_writer", None):
_warnings.warn(
f"unclosed Connection {self!r}", ResourceWarning, source=self
)
try:
asyncio.get_running_loop()
self._close()
except RuntimeError:
# No actions been taken if pool already closed.
pass
def _close(self):
"""
Internal method to silently close the connection without waiting
"""
if self._writer:
self._writer.close()
self._writer = self._reader = None
def __repr__(self):
repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
return f"{self.__class__.__name__}<{repr_args}>"
return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
@abstractmethod
def repr_pieces(self):
@@ -216,12 +260,24 @@ class AbstractConnection:
def is_connected(self):
return self._reader is not None and self._writer is not None
def _register_connect_callback(self, callback):
def register_connect_callback(self, callback):
"""
Register a callback to be called when the connection is established either
initially or reconnected. This allows listeners to issue commands that
are ephemeral to the connection, for example pub/sub subscription or
key tracking. The callback must be a _method_ and will be kept as
a weak reference.
"""
wm = weakref.WeakMethod(callback)
if wm not in self._connect_callbacks:
self._connect_callbacks.append(wm)
def _deregister_connect_callback(self, callback):
def deregister_connect_callback(self, callback):
"""
De-register a previously registered callback. It will no-longer receive
notifications on connection events. Calling this is not required when the
listener goes away, since the callbacks are kept as weak methods.
"""
try:
self._connect_callbacks.remove(weakref.WeakMethod(callback))
except ValueError:
@@ -237,12 +293,20 @@ class AbstractConnection:
async def connect(self):
"""Connects to the Redis server if not already connected"""
await self.connect_check_health(check_health=True)
async def connect_check_health(
self, check_health: bool = True, retry_socket_connect: bool = True
):
if self.is_connected:
return
try:
await self.retry.call_with_retry(
lambda: self._connect(), lambda error: self.disconnect()
)
if retry_socket_connect:
await self.retry.call_with_retry(
lambda: self._connect(), lambda error: self.disconnect()
)
else:
await self._connect()
except asyncio.CancelledError:
raise # in 3.7 and earlier, this is an Exception, not BaseException
except (socket.timeout, asyncio.TimeoutError):
@@ -255,12 +319,14 @@ class AbstractConnection:
try:
if not self.redis_connect_func:
# Use the default on_connect function
await self.on_connect()
await self.on_connect_check_health(check_health=check_health)
else:
# Use the passed function redis_connect_func
await self.redis_connect_func(self) if asyncio.iscoroutinefunction(
self.redis_connect_func
) else self.redis_connect_func(self)
(
await self.redis_connect_func(self)
if asyncio.iscoroutinefunction(self.redis_connect_func)
else self.redis_connect_func(self)
)
except RedisError:
# clean up after any error in on_connect
await self.disconnect()
@@ -284,12 +350,17 @@ class AbstractConnection:
def _host_error(self) -> str:
pass
@abstractmethod
def _error_message(self, exception: BaseException) -> str:
pass
return format_error_message(self._host_error(), exception)
def get_protocol(self):
return self.protocol
async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
await self.on_connect_check_health(check_health=True)
async def on_connect_check_health(self, check_health: bool = True) -> None:
self._parser.on_connect(self)
parser = self._parser
@@ -300,7 +371,8 @@ class AbstractConnection:
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = cred_provider.get_credentials()
auth_args = await cred_provider.get_credentials_async()
# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol not in [2, "2"]:
@@ -311,7 +383,11 @@ class AbstractConnection:
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
await self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
await self.send_command(
"HELLO", self.protocol, "AUTH", *auth_args, check_health=False
)
response = await self.read_response()
if response.get(b"proto") != int(self.protocol) and response.get(
"proto"
@@ -342,7 +418,7 @@ class AbstractConnection:
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
await self.send_command("HELLO", self.protocol)
await self.send_command("HELLO", self.protocol, check_health=check_health)
response = await self.read_response()
# if response.get(b"proto") != self.protocol and response.get(
# "proto"
@@ -351,18 +427,35 @@ class AbstractConnection:
# if a client_name is given, set it
if self.client_name:
await self.send_command("CLIENT", "SETNAME", self.client_name)
await self.send_command(
"CLIENT",
"SETNAME",
self.client_name,
check_health=check_health,
)
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Error setting client name")
# set the library name and version, pipeline for lower startup latency
if self.lib_name:
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-NAME",
self.lib_name,
check_health=check_health,
)
if self.lib_version:
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-VER",
self.lib_version,
check_health=check_health,
)
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db)
await self.send_command("SELECT", self.db, check_health=check_health)
# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -424,8 +517,8 @@ class AbstractConnection:
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
) -> None:
if not self.is_connected:
await self.connect()
elif check_health:
await self.connect_check_health(check_health=False)
if check_health:
await self.check_health()
try:
@@ -488,11 +581,7 @@ class AbstractConnection:
read_timeout = timeout if timeout is not None else self.socket_timeout
host_error = self._host_error()
try:
if (
read_timeout is not None
and self.protocol in ["3", 3]
and not HIREDIS_AVAILABLE
):
if read_timeout is not None and self.protocol in ["3", 3]:
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
@@ -502,7 +591,7 @@ class AbstractConnection:
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE:
elif self.protocol in ["3", 3]:
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
@@ -614,6 +703,27 @@ class AbstractConnection:
output.append(SYM_EMPTY.join(pieces))
return output
def _socket_is_empty(self):
"""Check if the socket is empty"""
return len(self._reader._buffer) == 0
async def process_invalidation_messages(self):
while not self._socket_is_empty():
await self.read_response(push_request=True)
def set_re_auth_token(self, token: TokenInterface):
self._re_auth_token = token
async def re_auth(self):
if self._re_auth_token is not None:
await self.send_command(
"AUTH",
self._re_auth_token.try_get("oid"),
self._re_auth_token.get_value(),
)
await self.read_response()
self._re_auth_token = None
class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"
@@ -671,27 +781,6 @@ class Connection(AbstractConnection):
def _host_error(self) -> str:
return f"{self.host}:{self.port}"
def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if not exception.args:
# asyncio has a bug where on Connection reset by peer, the
# exception is not instanciated, so args is empty. This is the
# workaround.
# See: https://github.com/redis/redis-py/issues/2237
# See: https://github.com/python/cpython/issues/94061
return f"Error connecting to {host_error}. Connection reset by peer"
elif len(exception.args) == 1:
return f"Error connecting to {host_error}. {exception.args[0]}."
else:
return (
f"Error {exception.args[0]} connecting to {host_error}. "
f"{exception.args[0]}."
)
class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
@@ -703,12 +792,17 @@ class SSLConnection(Connection):
self,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: str = "required",
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_check_hostname: bool = True,
ssl_min_version: Optional[TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
**kwargs,
):
if not SSL_AVAILABLE:
raise RedisError("Python wasn't built with SSL support")
self.ssl_context: RedisSSLContext = RedisSSLContext(
keyfile=ssl_keyfile,
certfile=ssl_certfile,
@@ -716,6 +810,8 @@ class SSLConnection(Connection):
ca_certs=ssl_ca_certs,
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
min_version=ssl_min_version,
ciphers=ssl_ciphers,
)
super().__init__(**kwargs)
@@ -748,6 +844,10 @@ class SSLConnection(Connection):
def check_hostname(self):
return self.ssl_context.check_hostname
@property
def min_version(self):
return self.ssl_context.min_version
class RedisSSLContext:
__slots__ = (
@@ -758,23 +858,30 @@ class RedisSSLContext:
"ca_data",
"context",
"check_hostname",
"min_version",
"ciphers",
)
def __init__(
self,
keyfile: Optional[str] = None,
certfile: Optional[str] = None,
cert_reqs: Optional[str] = None,
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
ca_certs: Optional[str] = None,
ca_data: Optional[str] = None,
check_hostname: bool = False,
min_version: Optional[TLSVersion] = None,
ciphers: Optional[str] = None,
):
if not SSL_AVAILABLE:
raise RedisError("Python wasn't built with SSL support")
self.keyfile = keyfile
self.certfile = certfile
if cert_reqs is None:
self.cert_reqs = ssl.CERT_NONE
cert_reqs = ssl.CERT_NONE
elif isinstance(cert_reqs, str):
CERT_REQS = {
CERT_REQS = { # noqa: N806
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
"required": ssl.CERT_REQUIRED,
@@ -783,13 +890,18 @@ class RedisSSLContext:
raise RedisError(
f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
)
self.cert_reqs = CERT_REQS[cert_reqs]
cert_reqs = CERT_REQS[cert_reqs]
self.cert_reqs = cert_reqs
self.ca_certs = ca_certs
self.ca_data = ca_data
self.check_hostname = check_hostname
self.context: Optional[ssl.SSLContext] = None
self.check_hostname = (
check_hostname if self.cert_reqs != ssl.CERT_NONE else False
)
self.min_version = min_version
self.ciphers = ciphers
self.context: Optional[SSLContext] = None
def get(self) -> ssl.SSLContext:
def get(self) -> SSLContext:
if not self.context:
context = ssl.create_default_context()
context.check_hostname = self.check_hostname
@@ -798,6 +910,10 @@ class RedisSSLContext:
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
if self.ca_certs or self.ca_data:
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
if self.min_version is not None:
context.minimum_version = self.min_version
if self.ciphers is not None:
context.set_ciphers(self.ciphers)
self.context = context
return self.context
@@ -825,20 +941,6 @@ class UnixDomainSocketConnection(AbstractConnection):
def _host_error(self) -> str:
return self.path
def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if len(exception.args) == 1:
return (
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
)
else:
return (
f"Error {exception.args[0]} connecting to unix socket: "
f"{host_error}. {exception.args[1]}."
)
FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
@@ -861,6 +963,7 @@ URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyTy
"max_connections": int,
"health_check_interval": int,
"ssl_check_hostname": to_bool,
"timeout": float,
}
)
@@ -887,7 +990,7 @@ def parse_url(url: str) -> ConnectKwargs:
try:
kwargs[name] = parser(value)
except (TypeError, ValueError):
raise ValueError(f"Invalid value for `{name}` in connection URL.")
raise ValueError(f"Invalid value for '{name}' in connection URL.")
else:
kwargs[name] = value
@@ -939,6 +1042,7 @@ class ConnectionPool:
By default, TCP connections are created unless ``connection_class``
is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
unix sockets.
:py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
Any additional keyword arguments are passed to the constructor of
``connection_class``.
@@ -1008,16 +1112,22 @@ class ConnectionPool:
self._available_connections: List[AbstractConnection] = []
self._in_use_connections: Set[AbstractConnection] = set()
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
self._lock = asyncio.Lock()
self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
def __repr__(self):
conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
return (
f"{self.__class__.__name__}"
f"<{self.connection_class(**self.connection_kwargs)!r}>"
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
f"({conn_kwargs})>)>"
)
def reset(self):
self._available_connections = []
self._in_use_connections = set()
self._in_use_connections = weakref.WeakSet()
def can_get_connection(self) -> bool:
"""Return True if a connection can be retrieved from the pool."""
@@ -1026,8 +1136,25 @@ class ConnectionPool:
or len(self._in_use_connections) < self.max_connections
)
async def get_connection(self, command_name, *keys, **options):
"""Get a connection from the pool"""
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.3.0",
)
async def get_connection(self, command_name=None, *keys, **options):
async with self._lock:
"""Get a connected connection from the pool"""
connection = self.get_available_connection()
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise
return connection
def get_available_connection(self):
"""Get a connection from the pool, without making sure it is connected"""
try:
connection = self._available_connections.pop()
except IndexError:
@@ -1035,13 +1162,6 @@ class ConnectionPool:
raise ConnectionError("Too many connections") from None
connection = self.make_connection()
self._in_use_connections.add(connection)
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise
return connection
def get_encoder(self):
@@ -1067,7 +1187,7 @@ class ConnectionPool:
try:
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except (ConnectionError, OSError):
except (ConnectionError, TimeoutError, OSError):
await connection.disconnect()
await connection.connect()
if await connection.can_read_destructive():
@@ -1079,6 +1199,9 @@ class ConnectionPool:
# not doing so is an error that will cause an exception here.
self._in_use_connections.remove(connection)
self._available_connections.append(connection)
await self._event_dispatcher.dispatch_async(
AsyncAfterConnectionReleasedEvent(connection)
)
async def disconnect(self, inuse_connections: bool = True):
"""
@@ -1112,6 +1235,29 @@ class ConnectionPool:
for conn in self._in_use_connections:
conn.retry = retry
async def re_auth_callback(self, token: TokenInterface):
async with self._lock:
for conn in self._available_connections:
await conn.retry.call_with_retry(
lambda: conn.send_command(
"AUTH", token.try_get("oid"), token.get_value()
),
lambda error: self._mock(error),
)
await conn.retry.call_with_retry(
lambda: conn.read_response(), lambda error: self._mock(error)
)
for conn in self._in_use_connections:
conn.set_re_auth_token(token)
async def _mock(self, error: RedisError):
"""
Dummy functions, needs to be passed as error callback to retry object.
:param error:
:return:
"""
pass
class BlockingConnectionPool(ConnectionPool):
"""
@@ -1129,7 +1275,7 @@ class BlockingConnectionPool(ConnectionPool):
connection from the pool when all of connections are in use, rather than
raising a :py:class:`~redis.ConnectionError` (as the default
:py:class:`~redis.asyncio.ConnectionPool` implementation does), it
makes blocks the current `Task` for a specified number of seconds until
blocks the current `Task` for a specified number of seconds until
a connection becomes available.
Use ``max_connections`` to increase / decrease the pool size::
@@ -1163,16 +1309,29 @@ class BlockingConnectionPool(ConnectionPool):
self._condition = asyncio.Condition()
self.timeout = timeout
async def get_connection(self, command_name, *keys, **options):
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.3.0",
)
async def get_connection(self, command_name=None, *keys, **options):
"""Gets a connection from the pool, blocking until one is available"""
try:
async with async_timeout(self.timeout):
async with self._condition:
async with self._condition:
async with async_timeout(self.timeout):
await self._condition.wait_for(self.can_get_connection)
return await super().get_connection(command_name, *keys, **options)
connection = super().get_available_connection()
except asyncio.TimeoutError as err:
raise ConnectionError("No connection available.") from err
# We now perform the connection check outside of the lock.
try:
await self.ensure_connection(connection)
return connection
except BaseException:
await self.release(connection)
raise
async def release(self, connection: AbstractConnection):
"""Releases the connection back to the pool."""
async with self._condition: