main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -3,8 +3,8 @@ import copy
import enum
import inspect
import socket
import ssl
import sys
import warnings
import weakref
from abc import abstractmethod
from itertools import chain
@@ -16,30 +16,14 @@ from typing import (
List,
Mapping,
Optional,
Protocol,
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
)
from urllib.parse import ParseResult, parse_qs, unquote, urlparse
from ..utils import SSL_AVAILABLE
if SSL_AVAILABLE:
import ssl
from ssl import SSLContext, TLSVersion
else:
ssl = None
TLSVersion = None
SSLContext = None
from ..auth.token import TokenInterface
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
from ..utils import deprecated_args, format_error_message
# the functionality is available in 3.11.x but has a major issue before
# 3.11.3. See https://github.com/redis/redis-py/issues/2633
if sys.version_info >= (3, 11, 3):
@@ -49,6 +33,7 @@ else:
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.compat import Protocol, TypedDict
from redis.connection import DEFAULT_RESP_VERSION
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
@@ -93,11 +78,13 @@ else:
class ConnectCallbackProtocol(Protocol):
def __call__(self, connection: "AbstractConnection"): ...
def __call__(self, connection: "AbstractConnection"):
...
class AsyncConnectCallbackProtocol(Protocol):
async def __call__(self, connection: "AbstractConnection"): ...
async def __call__(self, connection: "AbstractConnection"):
...
ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
@@ -159,7 +146,6 @@ class AbstractConnection:
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = None,
):
if (username or password) and credential_provider is not None:
raise DataError(
@@ -168,10 +154,6 @@ class AbstractConnection:
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.db = db
self.client_name = client_name
self.lib_name = lib_name
@@ -211,8 +193,6 @@ class AbstractConnection:
self.set_parser(parser_class)
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
self._buffer_cutoff = 6000
self._re_auth_token: Optional[TokenInterface] = None
try:
p = int(protocol)
except TypeError:
@@ -224,33 +204,9 @@ class AbstractConnection:
raise ConnectionError("protocol must be either 2 or 3")
self.protocol = protocol
def __del__(self, _warnings: Any = warnings):
# For some reason, the individual streams don't get properly garbage
# collected and therefore produce no resource warnings. We add one
# here, in the same style as those from the stdlib.
if getattr(self, "_writer", None):
_warnings.warn(
f"unclosed Connection {self!r}", ResourceWarning, source=self
)
try:
asyncio.get_running_loop()
self._close()
except RuntimeError:
# No actions been taken if pool already closed.
pass
def _close(self):
"""
Internal method to silently close the connection without waiting
"""
if self._writer:
self._writer.close()
self._writer = self._reader = None
def __repr__(self):
repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
return f"{self.__class__.__name__}<{repr_args}>"
@abstractmethod
def repr_pieces(self):
@@ -260,24 +216,12 @@ class AbstractConnection:
def is_connected(self):
return self._reader is not None and self._writer is not None
def register_connect_callback(self, callback):
"""
Register a callback to be called when the connection is established either
initially or reconnected. This allows listeners to issue commands that
are ephemeral to the connection, for example pub/sub subscription or
key tracking. The callback must be a _method_ and will be kept as
a weak reference.
"""
def _register_connect_callback(self, callback):
wm = weakref.WeakMethod(callback)
if wm not in self._connect_callbacks:
self._connect_callbacks.append(wm)
def deregister_connect_callback(self, callback):
"""
De-register a previously registered callback. It will no-longer receive
notifications on connection events. Calling this is not required when the
listener goes away, since the callbacks are kept as weak methods.
"""
def _deregister_connect_callback(self, callback):
try:
self._connect_callbacks.remove(weakref.WeakMethod(callback))
except ValueError:
@@ -293,20 +237,12 @@ class AbstractConnection:
async def connect(self):
"""Connects to the Redis server if not already connected"""
await self.connect_check_health(check_health=True)
async def connect_check_health(
self, check_health: bool = True, retry_socket_connect: bool = True
):
if self.is_connected:
return
try:
if retry_socket_connect:
await self.retry.call_with_retry(
lambda: self._connect(), lambda error: self.disconnect()
)
else:
await self._connect()
await self.retry.call_with_retry(
lambda: self._connect(), lambda error: self.disconnect()
)
except asyncio.CancelledError:
raise # in 3.7 and earlier, this is an Exception, not BaseException
except (socket.timeout, asyncio.TimeoutError):
@@ -319,14 +255,12 @@ class AbstractConnection:
try:
if not self.redis_connect_func:
# Use the default on_connect function
await self.on_connect_check_health(check_health=check_health)
await self.on_connect()
else:
# Use the passed function redis_connect_func
(
await self.redis_connect_func(self)
if asyncio.iscoroutinefunction(self.redis_connect_func)
else self.redis_connect_func(self)
)
await self.redis_connect_func(self) if asyncio.iscoroutinefunction(
self.redis_connect_func
) else self.redis_connect_func(self)
except RedisError:
# clean up after any error in on_connect
await self.disconnect()
@@ -350,17 +284,12 @@ class AbstractConnection:
def _host_error(self) -> str:
pass
@abstractmethod
def _error_message(self, exception: BaseException) -> str:
return format_error_message(self._host_error(), exception)
def get_protocol(self):
return self.protocol
pass
async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
await self.on_connect_check_health(check_health=True)
async def on_connect_check_health(self, check_health: bool = True) -> None:
self._parser.on_connect(self)
parser = self._parser
@@ -371,8 +300,7 @@ class AbstractConnection:
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = await cred_provider.get_credentials_async()
auth_args = cred_provider.get_credentials()
# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol not in [2, "2"]:
@@ -383,11 +311,7 @@ class AbstractConnection:
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
await self.send_command(
"HELLO", self.protocol, "AUTH", *auth_args, check_health=False
)
await self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
response = await self.read_response()
if response.get(b"proto") != int(self.protocol) and response.get(
"proto"
@@ -418,7 +342,7 @@ class AbstractConnection:
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
await self.send_command("HELLO", self.protocol, check_health=check_health)
await self.send_command("HELLO", self.protocol)
response = await self.read_response()
# if response.get(b"proto") != self.protocol and response.get(
# "proto"
@@ -427,35 +351,18 @@ class AbstractConnection:
# if a client_name is given, set it
if self.client_name:
await self.send_command(
"CLIENT",
"SETNAME",
self.client_name,
check_health=check_health,
)
await self.send_command("CLIENT", "SETNAME", self.client_name)
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Error setting client name")
# set the library name and version, pipeline for lower startup latency
if self.lib_name:
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-NAME",
self.lib_name,
check_health=check_health,
)
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
if self.lib_version:
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-VER",
self.lib_version,
check_health=check_health,
)
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db, check_health=check_health)
await self.send_command("SELECT", self.db)
# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -517,8 +424,8 @@ class AbstractConnection:
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
) -> None:
if not self.is_connected:
await self.connect_check_health(check_health=False)
if check_health:
await self.connect()
elif check_health:
await self.check_health()
try:
@@ -581,7 +488,11 @@ class AbstractConnection:
read_timeout = timeout if timeout is not None else self.socket_timeout
host_error = self._host_error()
try:
if read_timeout is not None and self.protocol in ["3", 3]:
if (
read_timeout is not None
and self.protocol in ["3", 3]
and not HIREDIS_AVAILABLE
):
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
@@ -591,7 +502,7 @@ class AbstractConnection:
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
elif self.protocol in ["3", 3]:
elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE:
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
@@ -703,27 +614,6 @@ class AbstractConnection:
output.append(SYM_EMPTY.join(pieces))
return output
def _socket_is_empty(self):
"""Check if the socket is empty"""
return len(self._reader._buffer) == 0
async def process_invalidation_messages(self):
while not self._socket_is_empty():
await self.read_response(push_request=True)
def set_re_auth_token(self, token: TokenInterface):
self._re_auth_token = token
async def re_auth(self):
if self._re_auth_token is not None:
await self.send_command(
"AUTH",
self._re_auth_token.try_get("oid"),
self._re_auth_token.get_value(),
)
await self.read_response()
self._re_auth_token = None
class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"
@@ -781,6 +671,27 @@ class Connection(AbstractConnection):
def _host_error(self) -> str:
return f"{self.host}:{self.port}"
def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if not exception.args:
# asyncio has a bug where on Connection reset by peer, the
# exception is not instanciated, so args is empty. This is the
# workaround.
# See: https://github.com/redis/redis-py/issues/2237
# See: https://github.com/python/cpython/issues/94061
return f"Error connecting to {host_error}. Connection reset by peer"
elif len(exception.args) == 1:
return f"Error connecting to {host_error}. {exception.args[0]}."
else:
return (
f"Error {exception.args[0]} connecting to {host_error}. "
f"{exception.args[0]}."
)
class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
@@ -792,17 +703,12 @@ class SSLConnection(Connection):
self,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = True,
ssl_min_version: Optional[TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
ssl_check_hostname: bool = False,
**kwargs,
):
if not SSL_AVAILABLE:
raise RedisError("Python wasn't built with SSL support")
self.ssl_context: RedisSSLContext = RedisSSLContext(
keyfile=ssl_keyfile,
certfile=ssl_certfile,
@@ -810,8 +716,6 @@ class SSLConnection(Connection):
ca_certs=ssl_ca_certs,
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
min_version=ssl_min_version,
ciphers=ssl_ciphers,
)
super().__init__(**kwargs)
@@ -844,10 +748,6 @@ class SSLConnection(Connection):
def check_hostname(self):
return self.ssl_context.check_hostname
@property
def min_version(self):
return self.ssl_context.min_version
class RedisSSLContext:
__slots__ = (
@@ -858,30 +758,23 @@ class RedisSSLContext:
"ca_data",
"context",
"check_hostname",
"min_version",
"ciphers",
)
def __init__(
self,
keyfile: Optional[str] = None,
certfile: Optional[str] = None,
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
cert_reqs: Optional[str] = None,
ca_certs: Optional[str] = None,
ca_data: Optional[str] = None,
check_hostname: bool = False,
min_version: Optional[TLSVersion] = None,
ciphers: Optional[str] = None,
):
if not SSL_AVAILABLE:
raise RedisError("Python wasn't built with SSL support")
self.keyfile = keyfile
self.certfile = certfile
if cert_reqs is None:
cert_reqs = ssl.CERT_NONE
self.cert_reqs = ssl.CERT_NONE
elif isinstance(cert_reqs, str):
CERT_REQS = { # noqa: N806
CERT_REQS = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
"required": ssl.CERT_REQUIRED,
@@ -890,18 +783,13 @@ class RedisSSLContext:
raise RedisError(
f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
)
cert_reqs = CERT_REQS[cert_reqs]
self.cert_reqs = cert_reqs
self.cert_reqs = CERT_REQS[cert_reqs]
self.ca_certs = ca_certs
self.ca_data = ca_data
self.check_hostname = (
check_hostname if self.cert_reqs != ssl.CERT_NONE else False
)
self.min_version = min_version
self.ciphers = ciphers
self.context: Optional[SSLContext] = None
self.check_hostname = check_hostname
self.context: Optional[ssl.SSLContext] = None
def get(self) -> SSLContext:
def get(self) -> ssl.SSLContext:
if not self.context:
context = ssl.create_default_context()
context.check_hostname = self.check_hostname
@@ -910,10 +798,6 @@ class RedisSSLContext:
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
if self.ca_certs or self.ca_data:
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
if self.min_version is not None:
context.minimum_version = self.min_version
if self.ciphers is not None:
context.set_ciphers(self.ciphers)
self.context = context
return self.context
@@ -941,6 +825,20 @@ class UnixDomainSocketConnection(AbstractConnection):
def _host_error(self) -> str:
return self.path
def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if len(exception.args) == 1:
return (
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
)
else:
return (
f"Error {exception.args[0]} connecting to unix socket: "
f"{host_error}. {exception.args[1]}."
)
FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
@@ -963,7 +861,6 @@ URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyTy
"max_connections": int,
"health_check_interval": int,
"ssl_check_hostname": to_bool,
"timeout": float,
}
)
@@ -990,7 +887,7 @@ def parse_url(url: str) -> ConnectKwargs:
try:
kwargs[name] = parser(value)
except (TypeError, ValueError):
raise ValueError(f"Invalid value for '{name}' in connection URL.")
raise ValueError(f"Invalid value for `{name}` in connection URL.")
else:
kwargs[name] = value
@@ -1042,7 +939,6 @@ class ConnectionPool:
By default, TCP connections are created unless ``connection_class``
is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
unix sockets.
:py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
Any additional keyword arguments are passed to the constructor of
``connection_class``.
@@ -1112,22 +1008,16 @@ class ConnectionPool:
self._available_connections: List[AbstractConnection] = []
self._in_use_connections: Set[AbstractConnection] = set()
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
self._lock = asyncio.Lock()
self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
def __repr__(self):
conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
f"({conn_kwargs})>)>"
f"{self.__class__.__name__}"
f"<{self.connection_class(**self.connection_kwargs)!r}>"
)
def reset(self):
self._available_connections = []
self._in_use_connections = weakref.WeakSet()
self._in_use_connections = set()
def can_get_connection(self) -> bool:
"""Return True if a connection can be retrieved from the pool."""
@@ -1136,25 +1026,8 @@ class ConnectionPool:
or len(self._in_use_connections) < self.max_connections
)
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.3.0",
)
async def get_connection(self, command_name=None, *keys, **options):
async with self._lock:
"""Get a connected connection from the pool"""
connection = self.get_available_connection()
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise
return connection
def get_available_connection(self):
"""Get a connection from the pool, without making sure it is connected"""
async def get_connection(self, command_name, *keys, **options):
"""Get a connection from the pool"""
try:
connection = self._available_connections.pop()
except IndexError:
@@ -1162,6 +1035,13 @@ class ConnectionPool:
raise ConnectionError("Too many connections") from None
connection = self.make_connection()
self._in_use_connections.add(connection)
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise
return connection
def get_encoder(self):
@@ -1187,7 +1067,7 @@ class ConnectionPool:
try:
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except (ConnectionError, TimeoutError, OSError):
except (ConnectionError, OSError):
await connection.disconnect()
await connection.connect()
if await connection.can_read_destructive():
@@ -1199,9 +1079,6 @@ class ConnectionPool:
# not doing so is an error that will cause an exception here.
self._in_use_connections.remove(connection)
self._available_connections.append(connection)
await self._event_dispatcher.dispatch_async(
AsyncAfterConnectionReleasedEvent(connection)
)
async def disconnect(self, inuse_connections: bool = True):
"""
@@ -1235,29 +1112,6 @@ class ConnectionPool:
for conn in self._in_use_connections:
conn.retry = retry
async def re_auth_callback(self, token: TokenInterface):
async with self._lock:
for conn in self._available_connections:
await conn.retry.call_with_retry(
lambda: conn.send_command(
"AUTH", token.try_get("oid"), token.get_value()
),
lambda error: self._mock(error),
)
await conn.retry.call_with_retry(
lambda: conn.read_response(), lambda error: self._mock(error)
)
for conn in self._in_use_connections:
conn.set_re_auth_token(token)
async def _mock(self, error: RedisError):
"""
Dummy functions, needs to be passed as error callback to retry object.
:param error:
:return:
"""
pass
class BlockingConnectionPool(ConnectionPool):
"""
@@ -1275,7 +1129,7 @@ class BlockingConnectionPool(ConnectionPool):
connection from the pool when all of connections are in use, rather than
raising a :py:class:`~redis.ConnectionError` (as the default
:py:class:`~redis.asyncio.ConnectionPool` implementation does), it
blocks the current `Task` for a specified number of seconds until
makes blocks the current `Task` for a specified number of seconds until
a connection becomes available.
Use ``max_connections`` to increase / decrease the pool size::
@@ -1309,29 +1163,16 @@ class BlockingConnectionPool(ConnectionPool):
self._condition = asyncio.Condition()
self.timeout = timeout
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.3.0",
)
async def get_connection(self, command_name=None, *keys, **options):
async def get_connection(self, command_name, *keys, **options):
"""Gets a connection from the pool, blocking until one is available"""
try:
async with self._condition:
async with async_timeout(self.timeout):
async with async_timeout(self.timeout):
async with self._condition:
await self._condition.wait_for(self.can_get_connection)
connection = super().get_available_connection()
return await super().get_connection(command_name, *keys, **options)
except asyncio.TimeoutError as err:
raise ConnectionError("No connection available.") from err
# We now perform the connection check outside of the lock.
try:
await self.ensure_connection(connection)
return connection
except BaseException:
await self.release(connection)
raise
async def release(self, connection: AbstractConnection):
"""Releases the connection back to the pool."""
async with self._condition: