main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -15,11 +15,9 @@ from typing import (
Mapping,
MutableMapping,
Optional,
Protocol,
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
cast,
@@ -39,7 +37,6 @@ from redis.asyncio.connection import (
)
from redis.asyncio.lock import Lock
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialWithJitterBackoff
from redis.client import (
EMPTY_RESPONSE,
NEVER_DECODE,
@@ -52,40 +49,27 @@ from redis.commands import (
AsyncSentinelCommands,
list_or_args,
)
from redis.compat import Protocol, TypedDict
from redis.credentials import CredentialProvider
from redis.event import (
AfterPooledConnectionsInstantiationEvent,
AfterPubSubConnectionInstantiationEvent,
AfterSingleConnectionInstantiationEvent,
ClientType,
EventDispatcher,
)
from redis.exceptions import (
ConnectionError,
ExecAbortError,
PubSubError,
RedisError,
ResponseError,
TimeoutError,
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import (
SSL_AVAILABLE,
HIREDIS_AVAILABLE,
_set_info_logger,
deprecated_args,
deprecated_function,
get_lib_version,
safe_str,
str_if_bytes,
truncate_text,
)
if TYPE_CHECKING and SSL_AVAILABLE:
from ssl import TLSVersion, VerifyMode
else:
TLSVersion = None
VerifyMode = None
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
_KeyT = TypeVar("_KeyT", bound=KeyT)
_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
@@ -96,11 +80,13 @@ if TYPE_CHECKING:
class ResponseCallbackProtocol(Protocol):
def __call__(self, response: Any, **kwargs): ...
def __call__(self, response: Any, **kwargs):
...
class AsyncResponseCallbackProtocol(Protocol):
async def __call__(self, response: Any, **kwargs): ...
async def __call__(self, response: Any, **kwargs):
...
ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
@@ -182,7 +168,7 @@ class Redis(
warnings.warn(
DeprecationWarning(
'"auto_close_connection_pool" is deprecated '
"since version 5.0.1. "
"since version 5.0.0. "
"Please create a ConnectionPool explicitly and "
"provide to the Redis() constructor instead."
)
@@ -208,11 +194,6 @@ class Redis(
client.auto_close_connection_pool = True
return client
@deprecated_args(
args_to_warn=["retry_on_timeout"],
reason="TimeoutError is included by default.",
version="6.0.0",
)
def __init__(
self,
*,
@@ -230,19 +211,14 @@ class Redis(
encoding_errors: str = "strict",
decode_responses: bool = False,
retry_on_timeout: bool = False,
retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
),
retry_on_error: Optional[list] = None,
ssl: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: Union[str, VerifyMode] = "required",
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = True,
ssl_min_version: Optional[TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
ssl_check_hostname: bool = False,
max_connections: Optional[int] = None,
single_connection_client: bool = False,
health_check_interval: int = 0,
@@ -250,38 +226,20 @@ class Redis(
lib_name: Optional[str] = "redis-py",
lib_version: Optional[str] = get_lib_version(),
username: Optional[str] = None,
retry: Optional[Retry] = None,
auto_close_connection_pool: Optional[bool] = None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = None,
):
"""
Initialize a new Redis client.
To specify a retry policy for specific errors, you have two options:
1. Set the `retry_on_error` to a list of the error/s to retry on, and
you can also set `retry` to a valid `Retry` object(in case the default
one is not appropriate) - with this approach the retries will be triggered
on the default errors specified in the Retry object enriched with the
errors specified in `retry_on_error`.
2. Define a `Retry` object with configured 'supported_errors' and set
it to the `retry` parameter - with this approach you completely redefine
the errors on which retries will happen.
`retry_on_timeout` is deprecated - please include the TimeoutError
either in the Retry object or in the `retry_on_error` list.
When 'connection_pool' is provided - the retry configuration of the
provided pool will be used.
To specify a retry policy for specific errors, first set
`retry_on_error` to a list of the error/s to retry on, then set
`retry` to a valid `Retry` object.
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
kwargs: Dict[str, Any]
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
# auto_close_connection_pool only has an effect if connection_pool is
# None. It is assumed that if connection_pool is not None, the user
# wants to manage the connection pool themselves.
@@ -289,7 +247,7 @@ class Redis(
warnings.warn(
DeprecationWarning(
'"auto_close_connection_pool" is deprecated '
"since version 5.0.1. "
"since version 5.0.0. "
"Please create a ConnectionPool explicitly and "
"provide to the Redis() constructor instead."
)
@@ -301,6 +259,8 @@ class Redis(
# Create internal connection pool, expected to be closed by Redis instance
if not retry_on_error:
retry_on_error = []
if retry_on_timeout is True:
retry_on_error.append(TimeoutError)
kwargs = {
"db": db,
"username": username,
@@ -310,6 +270,7 @@ class Redis(
"encoding": encoding,
"encoding_errors": encoding_errors,
"decode_responses": decode_responses,
"retry_on_timeout": retry_on_timeout,
"retry_on_error": retry_on_error,
"retry": copy.deepcopy(retry),
"max_connections": max_connections,
@@ -350,26 +311,14 @@ class Redis(
"ssl_ca_certs": ssl_ca_certs,
"ssl_ca_data": ssl_ca_data,
"ssl_check_hostname": ssl_check_hostname,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)
# This arg only used if no pool is passed in
self.auto_close_connection_pool = auto_close_connection_pool
connection_pool = ConnectionPool(**kwargs)
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.ASYNC, credential_provider
)
)
else:
# If a pool is passed in, do not close it
self.auto_close_connection_pool = False
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.ASYNC, credential_provider
)
)
self.connection_pool = connection_pool
self.single_connection_client = single_connection_client
@@ -388,10 +337,7 @@ class Redis(
self._single_conn_lock = asyncio.Lock()
def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"({self.connection_pool!r})>"
)
return f"{self.__class__.__name__}<{self.connection_pool!r}>"
def __await__(self):
return self.initialize().__await__()
@@ -400,13 +346,7 @@ class Redis(
if self.single_connection_client:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection()
self._event_dispatcher.dispatch(
AfterSingleConnectionInstantiationEvent(
self.connection, ClientType.ASYNC, self._single_conn_lock
)
)
self.connection = await self.connection_pool.get_connection("_")
return self
def set_response_callback(self, command: str, callback: ResponseCallbackT):
@@ -421,10 +361,10 @@ class Redis(
"""Get the connection's key-word arguments"""
return self.connection_pool.connection_kwargs
def get_retry(self) -> Optional[Retry]:
def get_retry(self) -> Optional["Retry"]:
return self.get_connection_kwargs().get("retry")
def set_retry(self, retry: Retry) -> None:
def set_retry(self, retry: "Retry") -> None:
self.get_connection_kwargs().update({"retry": retry})
self.connection_pool.set_retry(retry)
@@ -503,7 +443,6 @@ class Redis(
blocking_timeout: Optional[float] = None,
lock_class: Optional[Type[Lock]] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
) -> Lock:
"""
Return a new Lock object using key ``name`` that mimics
@@ -550,11 +489,6 @@ class Redis(
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
@@ -572,7 +506,6 @@ class Redis(
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
raise_on_release_error=raise_on_release_error,
)
def pubsub(self, **kwargs) -> "PubSub":
@@ -581,9 +514,7 @@ class Redis(
subscribe to channels and listen for messages that get published to
them.
"""
return PubSub(
self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
)
return PubSub(self.connection_pool, **kwargs)
def monitor(self) -> "Monitor":
return Monitor(self.connection_pool)
@@ -615,18 +546,15 @@ class Redis(
_grl().call_exception_handler(context)
except RuntimeError:
pass
self.connection._close()
async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
"""
Closes Redis client connection
Args:
close_connection_pool:
decides whether to close the connection pool used by this Redis client,
overriding Redis.auto_close_connection_pool.
By default, let Redis.auto_close_connection_pool decide
whether to close the connection pool.
:param close_connection_pool: decides whether to close the connection pool used
by this Redis client, overriding Redis.auto_close_connection_pool. By default,
let Redis.auto_close_connection_pool decide whether to close the connection
pool.
"""
conn = self.connection
if conn:
@@ -637,7 +565,7 @@ class Redis(
):
await self.connection_pool.disconnect()
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
"""
Alias for aclose(), for backwards compatibility
@@ -651,17 +579,18 @@ class Redis(
await conn.send_command(*args)
return await self.parse_response(conn, command_name, **options)
async def _close_connection(self, conn: Connection):
async def _disconnect_raise(self, conn: Connection, error: Exception):
"""
Close the connection before retrying.
The supported exceptions are already checked in the
retry object so we don't need to do it here.
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
Close the connection and raise an exception
if retry_on_error is not set or the error
is not one of the specified error types
"""
await conn.disconnect()
if (
conn.retry_on_error is None
or isinstance(error, tuple(conn.retry_on_error)) is False
):
raise error
# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
@@ -669,7 +598,7 @@ class Redis(
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection()
conn = self.connection or await pool.get_connection(command_name, **options)
if self.single_connection_client:
await self._single_conn_lock.acquire()
@@ -678,7 +607,7 @@ class Redis(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda _: self._close_connection(conn),
lambda error: self._disconnect_raise(conn, error),
)
finally:
if self.single_connection_client:
@@ -704,9 +633,6 @@ class Redis(
if EMPTY_RESPONSE in options:
options.pop(EMPTY_RESPONSE)
# Remove keys entry, it needs only for cache.
options.pop("keys", None)
if command_name in self.response_callbacks:
# Mypy bug: https://github.com/python/mypy/issues/10977
command_name = cast(str, command_name)
@@ -743,7 +669,7 @@ class Monitor:
async def connect(self):
if self.connection is None:
self.connection = await self.connection_pool.get_connection()
self.connection = await self.connection_pool.get_connection("MONITOR")
async def __aenter__(self):
await self.connect()
@@ -820,12 +746,7 @@ class PubSub:
ignore_subscribe_messages: bool = False,
encoder=None,
push_handler_func: Optional[Callable] = None,
event_dispatcher: Optional["EventDispatcher"] = None,
):
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.connection_pool = connection_pool
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
@@ -862,7 +783,7 @@ class PubSub:
def __del__(self):
if self.connection:
self.connection.deregister_connect_callback(self.on_connect)
self.connection._deregister_connect_callback(self.on_connect)
async def aclose(self):
# In case a connection property does not yet exist
@@ -873,7 +794,7 @@ class PubSub:
async with self._lock:
if self.connection:
await self.connection.disconnect()
self.connection.deregister_connect_callback(self.on_connect)
self.connection._deregister_connect_callback(self.on_connect)
await self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
@@ -881,12 +802,12 @@ class PubSub:
self.patterns = {}
self.pending_unsubscribe_patterns = set()
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
async def close(self) -> None:
"""Alias for aclose(), for backwards compatibility"""
await self.aclose()
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="reset")
async def reset(self) -> None:
"""Alias for aclose(), for backwards compatibility"""
await self.aclose()
@@ -931,26 +852,26 @@ class PubSub:
Ensure that the PubSub is connected
"""
if self.connection is None:
self.connection = await self.connection_pool.get_connection()
self.connection = await self.connection_pool.get_connection(
"pubsub", self.shard_hint
)
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
self.connection._register_connect_callback(self.on_connect)
else:
await self.connection.connect()
if self.push_handler_func is not None:
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
self._event_dispatcher.dispatch(
AfterPubSubConnectionInstantiationEvent(
self.connection, self.connection_pool, ClientType.ASYNC, self._lock
)
)
async def _reconnect(self, conn):
async def _disconnect_raise_connect(self, conn, error):
"""
Try to reconnect
Close the connection and raise an exception
if retry_on_timeout is not set or the error
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()
async def _execute(self, conn, command, *args, **kwargs):
@@ -963,7 +884,7 @@ class PubSub:
"""
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda _: self._reconnect(conn),
lambda error: self._disconnect_raise_connect(conn, error),
)
async def parse_response(self, block: bool = True, timeout: float = 0):
@@ -1232,11 +1153,13 @@ class PubSub:
class PubsubWorkerExceptionHandler(Protocol):
def __call__(self, e: BaseException, pubsub: PubSub): ...
def __call__(self, e: BaseException, pubsub: PubSub):
...
class AsyncPubsubWorkerExceptionHandler(Protocol):
async def __call__(self, e: BaseException, pubsub: PubSub): ...
async def __call__(self, e: BaseException, pubsub: PubSub):
...
PSWorkerThreadExcHandlerT = Union[
@@ -1254,8 +1177,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
in one transmission. This is convenient for batch processing, such as
saving all the values in a list to Redis.
All commands executed within a pipeline(when running in transactional mode,
which is the default behavior) are wrapped with MULTI and EXEC
All commands executed within a pipeline are wrapped with MULTI and EXEC
calls. This guarantees all commands executed in the pipeline will be
executed atomically.
@@ -1284,7 +1206,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
self.shard_hint = shard_hint
self.watching = False
self.command_stack: CommandStackT = []
self.scripts: Set[Script] = set()
self.scripts: Set["Script"] = set()
self.explicit_transaction = False
async def __aenter__(self: _RedisT) -> _RedisT:
@@ -1356,50 +1278,49 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
return self.immediate_execute_command(*args, **kwargs)
return self.pipeline_execute_command(*args, **kwargs)
async def _disconnect_reset_raise_on_watching(
self,
conn: Connection,
error: Exception,
):
async def _disconnect_reset_raise(self, conn, error):
"""
Close the connection reset watching state and
raise an exception if we were watching.
The supported exceptions are already checked in the
retry object so we don't need to do it here.
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
Close the connection, reset watching state and
raise an exception if we were watching,
retry_on_timeout is not set,
or the error is not a TimeoutError
"""
await conn.disconnect()
# if we were already watching a variable, the watch is no longer
# valid since this connection has died. raise a WatchError, which
# indicates the user should retry this transaction.
if self.watching:
await self.reset()
await self.aclose()
raise WatchError(
f"A {type(error).__name__} occurred while watching one or more keys"
"A ConnectionError occurred on while watching one or more keys"
)
# if retry_on_timeout is not set, or the error is not
# a TimeoutError, raise it
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
await self.aclose()
raise
async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on the supported
errors for retry if we're already WATCHing a variable.
Used when issuing WATCH or subsequent commands retrieving their values but before
Execute a command immediately, but don't auto-retry on a
ConnectionError if we're already WATCHing a variable. Used when
issuing WATCH or subsequent commands retrieving their values but before
MULTI is called.
"""
command_name = args[0]
conn = self.connection
# if this is the first call, we need a connection
if not conn:
conn = await self.connection_pool.get_connection()
conn = await self.connection_pool.get_connection(
command_name, self.shard_hint
)
self.connection = conn
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise_on_watching(conn, error),
lambda error: self._disconnect_reset_raise(conn, error),
)
def pipeline_execute_command(self, *args, **options):
@@ -1484,10 +1405,6 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
if not isinstance(r, Exception):
args, options = cmd
command_name = args[0]
# Remove keys entry, it needs only for cache.
options.pop("keys", None)
if command_name in self.response_callbacks:
r = self.response_callbacks[command_name](r, **options)
if inspect.isawaitable(r):
@@ -1525,10 +1442,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
self, exception: Exception, number: int, command: Iterable[object]
) -> None:
cmd = " ".join(map(safe_str, command))
msg = (
f"Command # {number} ({truncate_text(cmd)}) "
"of pipeline caused error: {exception.args}"
)
msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}"
exception.args = (msg,) + exception.args[1:]
async def parse_response(
@@ -1554,15 +1468,11 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
if not exist:
s.sha = await immediate("SCRIPT LOAD", s.script)
async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception):
async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
"""
Close the connection, raise an exception if we were watching.
The supported exceptions are already checked in the
retry object so we don't need to do it here.
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
Close the connection, raise an exception if we were watching,
and raise an exception if retry_on_timeout is not set,
or the error is not a TimeoutError
"""
await conn.disconnect()
# if we were watching a variable, the watch is no longer valid
@@ -1570,10 +1480,15 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
# indicates the user should retry this transaction.
if self.watching:
raise WatchError(
f"A {type(error).__name__} occurred while watching one or more keys"
"A ConnectionError occurred on while watching one or more keys"
)
# if retry_on_timeout is not set, or the error is not
# a TimeoutError, raise it
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
await self.reset()
raise
async def execute(self, raise_on_error: bool = True) -> List[Any]:
async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
if not stack and not self.watching:
@@ -1587,7 +1502,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
conn = self.connection
if not conn:
conn = await self.connection_pool.get_connection()
conn = await self.connection_pool.get_connection("MULTI", self.shard_hint)
# assign to self.connection so reset() releases the connection
# back to the pool after we're done
self.connection = conn
@@ -1596,7 +1511,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_on_watching(conn, error),
lambda error: self._disconnect_raise_reset(conn, error),
)
finally:
await self.reset()

File diff suppressed because it is too large Load Diff

View File

@@ -3,8 +3,8 @@ import copy
import enum
import inspect
import socket
import ssl
import sys
import warnings
import weakref
from abc import abstractmethod
from itertools import chain
@@ -16,30 +16,14 @@ from typing import (
List,
Mapping,
Optional,
Protocol,
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
)
from urllib.parse import ParseResult, parse_qs, unquote, urlparse
from ..utils import SSL_AVAILABLE
if SSL_AVAILABLE:
import ssl
from ssl import SSLContext, TLSVersion
else:
ssl = None
TLSVersion = None
SSLContext = None
from ..auth.token import TokenInterface
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
from ..utils import deprecated_args, format_error_message
# the functionality is available in 3.11.x but has a major issue before
# 3.11.3. See https://github.com/redis/redis-py/issues/2633
if sys.version_info >= (3, 11, 3):
@@ -49,6 +33,7 @@ else:
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.compat import Protocol, TypedDict
from redis.connection import DEFAULT_RESP_VERSION
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
@@ -93,11 +78,13 @@ else:
class ConnectCallbackProtocol(Protocol):
def __call__(self, connection: "AbstractConnection"): ...
def __call__(self, connection: "AbstractConnection"):
...
class AsyncConnectCallbackProtocol(Protocol):
async def __call__(self, connection: "AbstractConnection"): ...
async def __call__(self, connection: "AbstractConnection"):
...
ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
@@ -159,7 +146,6 @@ class AbstractConnection:
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = None,
):
if (username or password) and credential_provider is not None:
raise DataError(
@@ -168,10 +154,6 @@ class AbstractConnection:
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.db = db
self.client_name = client_name
self.lib_name = lib_name
@@ -211,8 +193,6 @@ class AbstractConnection:
self.set_parser(parser_class)
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
self._buffer_cutoff = 6000
self._re_auth_token: Optional[TokenInterface] = None
try:
p = int(protocol)
except TypeError:
@@ -224,33 +204,9 @@ class AbstractConnection:
raise ConnectionError("protocol must be either 2 or 3")
self.protocol = protocol
def __del__(self, _warnings: Any = warnings):
# For some reason, the individual streams don't get properly garbage
# collected and therefore produce no resource warnings. We add one
# here, in the same style as those from the stdlib.
if getattr(self, "_writer", None):
_warnings.warn(
f"unclosed Connection {self!r}", ResourceWarning, source=self
)
try:
asyncio.get_running_loop()
self._close()
except RuntimeError:
# No actions been taken if pool already closed.
pass
def _close(self):
"""
Internal method to silently close the connection without waiting
"""
if self._writer:
self._writer.close()
self._writer = self._reader = None
def __repr__(self):
repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
return f"{self.__class__.__name__}<{repr_args}>"
@abstractmethod
def repr_pieces(self):
@@ -260,24 +216,12 @@ class AbstractConnection:
def is_connected(self):
return self._reader is not None and self._writer is not None
def register_connect_callback(self, callback):
"""
Register a callback to be called when the connection is established either
initially or reconnected. This allows listeners to issue commands that
are ephemeral to the connection, for example pub/sub subscription or
key tracking. The callback must be a _method_ and will be kept as
a weak reference.
"""
def _register_connect_callback(self, callback):
wm = weakref.WeakMethod(callback)
if wm not in self._connect_callbacks:
self._connect_callbacks.append(wm)
def deregister_connect_callback(self, callback):
"""
De-register a previously registered callback. It will no-longer receive
notifications on connection events. Calling this is not required when the
listener goes away, since the callbacks are kept as weak methods.
"""
def _deregister_connect_callback(self, callback):
try:
self._connect_callbacks.remove(weakref.WeakMethod(callback))
except ValueError:
@@ -293,20 +237,12 @@ class AbstractConnection:
async def connect(self):
"""Connects to the Redis server if not already connected"""
await self.connect_check_health(check_health=True)
async def connect_check_health(
self, check_health: bool = True, retry_socket_connect: bool = True
):
if self.is_connected:
return
try:
if retry_socket_connect:
await self.retry.call_with_retry(
lambda: self._connect(), lambda error: self.disconnect()
)
else:
await self._connect()
await self.retry.call_with_retry(
lambda: self._connect(), lambda error: self.disconnect()
)
except asyncio.CancelledError:
raise # in 3.7 and earlier, this is an Exception, not BaseException
except (socket.timeout, asyncio.TimeoutError):
@@ -319,14 +255,12 @@ class AbstractConnection:
try:
if not self.redis_connect_func:
# Use the default on_connect function
await self.on_connect_check_health(check_health=check_health)
await self.on_connect()
else:
# Use the passed function redis_connect_func
(
await self.redis_connect_func(self)
if asyncio.iscoroutinefunction(self.redis_connect_func)
else self.redis_connect_func(self)
)
await self.redis_connect_func(self) if asyncio.iscoroutinefunction(
self.redis_connect_func
) else self.redis_connect_func(self)
except RedisError:
# clean up after any error in on_connect
await self.disconnect()
@@ -350,17 +284,12 @@ class AbstractConnection:
def _host_error(self) -> str:
pass
@abstractmethod
def _error_message(self, exception: BaseException) -> str:
return format_error_message(self._host_error(), exception)
def get_protocol(self):
return self.protocol
pass
async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
await self.on_connect_check_health(check_health=True)
async def on_connect_check_health(self, check_health: bool = True) -> None:
self._parser.on_connect(self)
parser = self._parser
@@ -371,8 +300,7 @@ class AbstractConnection:
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = await cred_provider.get_credentials_async()
auth_args = cred_provider.get_credentials()
# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol not in [2, "2"]:
@@ -383,11 +311,7 @@ class AbstractConnection:
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
await self.send_command(
"HELLO", self.protocol, "AUTH", *auth_args, check_health=False
)
await self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
response = await self.read_response()
if response.get(b"proto") != int(self.protocol) and response.get(
"proto"
@@ -418,7 +342,7 @@ class AbstractConnection:
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
await self.send_command("HELLO", self.protocol, check_health=check_health)
await self.send_command("HELLO", self.protocol)
response = await self.read_response()
# if response.get(b"proto") != self.protocol and response.get(
# "proto"
@@ -427,35 +351,18 @@ class AbstractConnection:
# if a client_name is given, set it
if self.client_name:
await self.send_command(
"CLIENT",
"SETNAME",
self.client_name,
check_health=check_health,
)
await self.send_command("CLIENT", "SETNAME", self.client_name)
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Error setting client name")
# set the library name and version, pipeline for lower startup latency
if self.lib_name:
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-NAME",
self.lib_name,
check_health=check_health,
)
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
if self.lib_version:
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-VER",
self.lib_version,
check_health=check_health,
)
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db, check_health=check_health)
await self.send_command("SELECT", self.db)
# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -517,8 +424,8 @@ class AbstractConnection:
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
) -> None:
if not self.is_connected:
await self.connect_check_health(check_health=False)
if check_health:
await self.connect()
elif check_health:
await self.check_health()
try:
@@ -581,7 +488,11 @@ class AbstractConnection:
read_timeout = timeout if timeout is not None else self.socket_timeout
host_error = self._host_error()
try:
if read_timeout is not None and self.protocol in ["3", 3]:
if (
read_timeout is not None
and self.protocol in ["3", 3]
and not HIREDIS_AVAILABLE
):
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
@@ -591,7 +502,7 @@ class AbstractConnection:
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
elif self.protocol in ["3", 3]:
elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE:
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
@@ -703,27 +614,6 @@ class AbstractConnection:
output.append(SYM_EMPTY.join(pieces))
return output
def _socket_is_empty(self):
"""Check if the socket is empty"""
return len(self._reader._buffer) == 0
async def process_invalidation_messages(self):
while not self._socket_is_empty():
await self.read_response(push_request=True)
def set_re_auth_token(self, token: TokenInterface):
self._re_auth_token = token
async def re_auth(self):
if self._re_auth_token is not None:
await self.send_command(
"AUTH",
self._re_auth_token.try_get("oid"),
self._re_auth_token.get_value(),
)
await self.read_response()
self._re_auth_token = None
class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"
@@ -781,6 +671,27 @@ class Connection(AbstractConnection):
def _host_error(self) -> str:
return f"{self.host}:{self.port}"
def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if not exception.args:
# asyncio has a bug where on Connection reset by peer, the
# exception is not instanciated, so args is empty. This is the
# workaround.
# See: https://github.com/redis/redis-py/issues/2237
# See: https://github.com/python/cpython/issues/94061
return f"Error connecting to {host_error}. Connection reset by peer"
elif len(exception.args) == 1:
return f"Error connecting to {host_error}. {exception.args[0]}."
else:
return (
f"Error {exception.args[0]} connecting to {host_error}. "
f"{exception.args[0]}."
)
class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
@@ -792,17 +703,12 @@ class SSLConnection(Connection):
self,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = True,
ssl_min_version: Optional[TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
ssl_check_hostname: bool = False,
**kwargs,
):
if not SSL_AVAILABLE:
raise RedisError("Python wasn't built with SSL support")
self.ssl_context: RedisSSLContext = RedisSSLContext(
keyfile=ssl_keyfile,
certfile=ssl_certfile,
@@ -810,8 +716,6 @@ class SSLConnection(Connection):
ca_certs=ssl_ca_certs,
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
min_version=ssl_min_version,
ciphers=ssl_ciphers,
)
super().__init__(**kwargs)
@@ -844,10 +748,6 @@ class SSLConnection(Connection):
def check_hostname(self):
return self.ssl_context.check_hostname
@property
def min_version(self):
return self.ssl_context.min_version
class RedisSSLContext:
__slots__ = (
@@ -858,30 +758,23 @@ class RedisSSLContext:
"ca_data",
"context",
"check_hostname",
"min_version",
"ciphers",
)
def __init__(
self,
keyfile: Optional[str] = None,
certfile: Optional[str] = None,
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
cert_reqs: Optional[str] = None,
ca_certs: Optional[str] = None,
ca_data: Optional[str] = None,
check_hostname: bool = False,
min_version: Optional[TLSVersion] = None,
ciphers: Optional[str] = None,
):
if not SSL_AVAILABLE:
raise RedisError("Python wasn't built with SSL support")
self.keyfile = keyfile
self.certfile = certfile
if cert_reqs is None:
cert_reqs = ssl.CERT_NONE
self.cert_reqs = ssl.CERT_NONE
elif isinstance(cert_reqs, str):
CERT_REQS = { # noqa: N806
CERT_REQS = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
"required": ssl.CERT_REQUIRED,
@@ -890,18 +783,13 @@ class RedisSSLContext:
raise RedisError(
f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
)
cert_reqs = CERT_REQS[cert_reqs]
self.cert_reqs = cert_reqs
self.cert_reqs = CERT_REQS[cert_reqs]
self.ca_certs = ca_certs
self.ca_data = ca_data
self.check_hostname = (
check_hostname if self.cert_reqs != ssl.CERT_NONE else False
)
self.min_version = min_version
self.ciphers = ciphers
self.context: Optional[SSLContext] = None
self.check_hostname = check_hostname
self.context: Optional[ssl.SSLContext] = None
def get(self) -> SSLContext:
def get(self) -> ssl.SSLContext:
if not self.context:
context = ssl.create_default_context()
context.check_hostname = self.check_hostname
@@ -910,10 +798,6 @@ class RedisSSLContext:
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
if self.ca_certs or self.ca_data:
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
if self.min_version is not None:
context.minimum_version = self.min_version
if self.ciphers is not None:
context.set_ciphers(self.ciphers)
self.context = context
return self.context
@@ -941,6 +825,20 @@ class UnixDomainSocketConnection(AbstractConnection):
def _host_error(self) -> str:
return self.path
def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if len(exception.args) == 1:
return (
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
)
else:
return (
f"Error {exception.args[0]} connecting to unix socket: "
f"{host_error}. {exception.args[1]}."
)
FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
@@ -963,7 +861,6 @@ URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyTy
"max_connections": int,
"health_check_interval": int,
"ssl_check_hostname": to_bool,
"timeout": float,
}
)
@@ -990,7 +887,7 @@ def parse_url(url: str) -> ConnectKwargs:
try:
kwargs[name] = parser(value)
except (TypeError, ValueError):
raise ValueError(f"Invalid value for '{name}' in connection URL.")
raise ValueError(f"Invalid value for `{name}` in connection URL.")
else:
kwargs[name] = value
@@ -1042,7 +939,6 @@ class ConnectionPool:
By default, TCP connections are created unless ``connection_class``
is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
unix sockets.
:py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
Any additional keyword arguments are passed to the constructor of
``connection_class``.
@@ -1112,22 +1008,16 @@ class ConnectionPool:
self._available_connections: List[AbstractConnection] = []
self._in_use_connections: Set[AbstractConnection] = set()
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
self._lock = asyncio.Lock()
self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
def __repr__(self):
conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
f"({conn_kwargs})>)>"
f"{self.__class__.__name__}"
f"<{self.connection_class(**self.connection_kwargs)!r}>"
)
def reset(self):
self._available_connections = []
self._in_use_connections = weakref.WeakSet()
self._in_use_connections = set()
def can_get_connection(self) -> bool:
"""Return True if a connection can be retrieved from the pool."""
@@ -1136,25 +1026,8 @@ class ConnectionPool:
or len(self._in_use_connections) < self.max_connections
)
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.3.0",
)
async def get_connection(self, command_name=None, *keys, **options):
async with self._lock:
"""Get a connected connection from the pool"""
connection = self.get_available_connection()
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise
return connection
def get_available_connection(self):
"""Get a connection from the pool, without making sure it is connected"""
async def get_connection(self, command_name, *keys, **options):
"""Get a connection from the pool"""
try:
connection = self._available_connections.pop()
except IndexError:
@@ -1162,6 +1035,13 @@ class ConnectionPool:
raise ConnectionError("Too many connections") from None
connection = self.make_connection()
self._in_use_connections.add(connection)
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise
return connection
def get_encoder(self):
@@ -1187,7 +1067,7 @@ class ConnectionPool:
try:
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except (ConnectionError, TimeoutError, OSError):
except (ConnectionError, OSError):
await connection.disconnect()
await connection.connect()
if await connection.can_read_destructive():
@@ -1199,9 +1079,6 @@ class ConnectionPool:
# not doing so is an error that will cause an exception here.
self._in_use_connections.remove(connection)
self._available_connections.append(connection)
await self._event_dispatcher.dispatch_async(
AsyncAfterConnectionReleasedEvent(connection)
)
async def disconnect(self, inuse_connections: bool = True):
"""
@@ -1235,29 +1112,6 @@ class ConnectionPool:
for conn in self._in_use_connections:
conn.retry = retry
async def re_auth_callback(self, token: TokenInterface):
async with self._lock:
for conn in self._available_connections:
await conn.retry.call_with_retry(
lambda: conn.send_command(
"AUTH", token.try_get("oid"), token.get_value()
),
lambda error: self._mock(error),
)
await conn.retry.call_with_retry(
lambda: conn.read_response(), lambda error: self._mock(error)
)
for conn in self._in_use_connections:
conn.set_re_auth_token(token)
async def _mock(self, error: RedisError):
"""
Dummy functions, needs to be passed as error callback to retry object.
:param error:
:return:
"""
pass
class BlockingConnectionPool(ConnectionPool):
"""
@@ -1275,7 +1129,7 @@ class BlockingConnectionPool(ConnectionPool):
connection from the pool when all of connections are in use, rather than
raising a :py:class:`~redis.ConnectionError` (as the default
:py:class:`~redis.asyncio.ConnectionPool` implementation does), it
blocks the current `Task` for a specified number of seconds until
makes blocks the current `Task` for a specified number of seconds until
a connection becomes available.
Use ``max_connections`` to increase / decrease the pool size::
@@ -1309,29 +1163,16 @@ class BlockingConnectionPool(ConnectionPool):
self._condition = asyncio.Condition()
self.timeout = timeout
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.3.0",
)
async def get_connection(self, command_name=None, *keys, **options):
async def get_connection(self, command_name, *keys, **options):
"""Gets a connection from the pool, blocking until one is available"""
try:
async with self._condition:
async with async_timeout(self.timeout):
async with async_timeout(self.timeout):
async with self._condition:
await self._condition.wait_for(self.can_get_connection)
connection = super().get_available_connection()
return await super().get_connection(command_name, *keys, **options)
except asyncio.TimeoutError as err:
raise ConnectionError("No connection available.") from err
# We now perform the connection check outside of the lock.
try:
await self.ensure_connection(connection)
return connection
except BaseException:
await self.release(connection)
raise
async def release(self, connection: AbstractConnection):
"""Releases the connection back to the pool."""
async with self._condition:

View File

@@ -1,18 +1,14 @@
import asyncio
import logging
import threading
import uuid
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, Optional, Union
from redis.exceptions import LockError, LockNotOwnedError
from redis.typing import Number
if TYPE_CHECKING:
from redis.asyncio import Redis, RedisCluster
logger = logging.getLogger(__name__)
class Lock:
"""
@@ -86,9 +82,8 @@ class Lock:
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[Number] = None,
blocking_timeout: Optional[float] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
@@ -132,11 +127,6 @@ class Lock:
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
@@ -153,7 +143,6 @@ class Lock:
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.local = threading.local() if self.thread_local else SimpleNamespace()
self.raise_on_release_error = raise_on_release_error
self.local.token = None
self.register_scripts()
@@ -173,19 +162,12 @@ class Lock:
raise LockError("Unable to acquire lock within the time specified")
async def __aexit__(self, exc_type, exc_value, traceback):
try:
await self.release()
except LockError:
if self.raise_on_release_error:
raise
logger.warning(
"Lock was unlocked or no longer owned when exiting context manager."
)
await self.release()
async def acquire(
self,
blocking: Optional[bool] = None,
blocking_timeout: Optional[Number] = None,
blocking_timeout: Optional[float] = None,
token: Optional[Union[str, bytes]] = None,
):
"""
@@ -267,10 +249,7 @@ class Lock:
"""Releases the already acquired lock"""
expected_token = self.local.token
if expected_token is None:
raise LockError(
"Cannot release a lock that's not owned or is already unlocked.",
lock_name=self.name,
)
raise LockError("Cannot release an unlocked lock")
self.local.token = None
return self.do_release(expected_token)
@@ -283,7 +262,7 @@ class Lock:
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
def extend(
self, additional_time: Number, replace_ttl: bool = False
self, additional_time: float, replace_ttl: bool = False
) -> Awaitable[bool]:
"""
Adds more time to an already acquired lock.

View File

@@ -2,16 +2,18 @@ from asyncio import sleep
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
from redis.exceptions import ConnectionError, RedisError, TimeoutError
from redis.retry import AbstractRetry
T = TypeVar("T")
if TYPE_CHECKING:
from redis.backoff import AbstractBackoff
class Retry(AbstractRetry[RedisError]):
__hash__ = AbstractRetry.__hash__
T = TypeVar("T")
class Retry:
"""Retry a specific number of times after a failure"""
__slots__ = "_backoff", "_retries", "_supported_errors"
def __init__(
self,
@@ -22,16 +24,23 @@ class Retry(AbstractRetry[RedisError]):
TimeoutError,
),
):
super().__init__(backoff, retries, supported_errors)
"""
Initialize a `Retry` object with a `Backoff` object
that retries a maximum of `retries` times.
`retries` can be negative to retry forever.
You can specify the types of supported errors which trigger
a retry with the `supported_errors` parameter.
"""
self._backoff = backoff
self._retries = retries
self._supported_errors = supported_errors
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Retry):
return NotImplemented
return (
self._backoff == other._backoff
and self._retries == other._retries
and set(self._supported_errors) == set(other._supported_errors)
def update_supported_errors(self, specified_errors: list):
"""
Updates the supported errors with the specified error types
"""
self._supported_errors = tuple(
set(self._supported_errors + tuple(specified_errors))
)
async def call_with_retry(

View File

@@ -11,12 +11,8 @@ from redis.asyncio.connection import (
SSLConnection,
)
from redis.commands import AsyncSentinelCommands
from redis.exceptions import (
ConnectionError,
ReadOnlyError,
ResponseError,
TimeoutError,
)
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
from redis.utils import str_if_bytes
class MasterNotFoundError(ConnectionError):
@@ -33,18 +29,20 @@ class SentinelManagedConnection(Connection):
super().__init__(**kwargs)
def __repr__(self):
s = f"<{self.__class__.__module__}.{self.__class__.__name__}"
pool = self.connection_pool
s = f"{self.__class__.__name__}<service={pool.service_name}"
if self.host:
host_info = f",host={self.host},port={self.port}"
s += host_info
return s + ")>"
return s + ">"
async def connect_to(self, address):
self.host, self.port = address
await self.connect_check_health(
check_health=self.connection_pool.check_connection,
retry_socket_connect=False,
)
await super().connect()
if self.connection_pool.check_connection:
await self.send_command("PING")
if str_if_bytes(await self.read_response()) != "PONG":
raise ConnectionError("PING failed")
async def _connect_retry(self):
if self._reader:
@@ -107,11 +105,9 @@ class SentinelConnectionPool(ConnectionPool):
def __init__(self, service_name, sentinel_manager, **kwargs):
kwargs["connection_class"] = kwargs.get(
"connection_class",
(
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection
),
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection,
)
self.is_master = kwargs.pop("is_master", True)
self.check_connection = kwargs.pop("check_connection", False)
@@ -124,8 +120,8 @@ class SentinelConnectionPool(ConnectionPool):
def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>"
f"{self.__class__.__name__}"
f"<service={self.service_name}({self.is_master and 'master' or 'slave'})>"
)
def reset(self):
@@ -201,7 +197,6 @@ class Sentinel(AsyncSentinelCommands):
sentinels,
min_other_sentinels=0,
sentinel_kwargs=None,
force_master_ip=None,
**connection_kwargs,
):
# if sentinel_kwargs isn't defined, use the socket_* options from
@@ -218,7 +213,6 @@ class Sentinel(AsyncSentinelCommands):
]
self.min_other_sentinels = min_other_sentinels
self.connection_kwargs = connection_kwargs
self._force_master_ip = force_master_ip
async def execute_command(self, *args, **kwargs):
"""
@@ -226,31 +220,19 @@ class Sentinel(AsyncSentinelCommands):
once - If set to True, then execute the resulting command on a single
node at random, rather than across the entire sentinel cluster.
"""
once = bool(kwargs.pop("once", False))
# Check if command is supposed to return the original
# responses instead of boolean value.
return_responses = bool(kwargs.pop("return_responses", False))
once = bool(kwargs.get("once", False))
if "once" in kwargs.keys():
kwargs.pop("once")
if once:
response = await random.choice(self.sentinels).execute_command(
*args, **kwargs
)
if return_responses:
return [response]
else:
return True if response else False
tasks = [
asyncio.Task(sentinel.execute_command(*args, **kwargs))
for sentinel in self.sentinels
]
responses = await asyncio.gather(*tasks)
if return_responses:
return responses
return all(responses)
await random.choice(self.sentinels).execute_command(*args, **kwargs)
else:
tasks = [
asyncio.Task(sentinel.execute_command(*args, **kwargs))
for sentinel in self.sentinels
]
await asyncio.gather(*tasks)
return True
def __repr__(self):
sentinel_addresses = []
@@ -259,10 +241,7 @@ class Sentinel(AsyncSentinelCommands):
f"{sentinel.connection_pool.connection_kwargs['host']}:"
f"{sentinel.connection_pool.connection_kwargs['port']}"
)
return (
f"<{self.__class__}.{self.__class__.__name__}"
f"(sentinels=[{','.join(sentinel_addresses)}])>"
)
return f"{self.__class__.__name__}<sentinels=[{','.join(sentinel_addresses)}]>"
def check_master_state(self, state: dict, service_name: str) -> bool:
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
@@ -294,13 +273,7 @@ class Sentinel(AsyncSentinelCommands):
sentinel,
self.sentinels[0],
)
ip = (
self._force_master_ip
if self._force_master_ip is not None
else state["ip"]
)
return ip, state["port"]
return state["ip"], state["port"]
error_info = ""
if len(collected_errors) > 0:
@@ -341,8 +314,6 @@ class Sentinel(AsyncSentinelCommands):
):
"""
Returns a redis client instance for the ``service_name`` master.
Sentinel client will detect failover and reconnect Redis clients
automatically.
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
used to retrieve the master's address before establishing a new

View File

@@ -16,7 +16,7 @@ def from_url(url, **kwargs):
return Redis.from_url(url, **kwargs)
class pipeline: # noqa: N801
class pipeline:
def __init__(self, redis_obj: "Redis"):
self.p: "Pipeline" = redis_obj.pipeline()