main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -1,3 +1,5 @@
import sys
from redis import asyncio # noqa
from redis.backoff import default_backoff
from redis.client import Redis, StrictRedis
@@ -16,15 +18,11 @@ from redis.exceptions import (
BusyLoadingError,
ChildDeadlockedError,
ConnectionError,
CrossSlotTransactionError,
DataError,
InvalidPipelineStack,
InvalidResponse,
MaxConnectionsError,
OutOfMemoryError,
PubSubError,
ReadOnlyError,
RedisClusterException,
RedisError,
ResponseError,
TimeoutError,
@@ -38,6 +36,11 @@ from redis.sentinel import (
)
from redis.utils import from_url
if sys.version_info >= (3, 8):
from importlib import metadata
else:
import importlib_metadata as metadata
def int_or_str(value):
try:
@@ -46,10 +49,17 @@ def int_or_str(value):
return value
__version__ = "6.4.0"
VERSION = tuple(map(int_or_str, __version__.split(".")))
try:
__version__ = metadata.version("redis")
except metadata.PackageNotFoundError:
__version__ = "99.99.99"
try:
VERSION = tuple(map(int_or_str, __version__.split(".")))
except AttributeError:
VERSION = tuple([99, 99, 99])
__all__ = [
"AuthenticationError",
"AuthenticationWrongNumberOfArgsError",
@@ -60,19 +70,15 @@ __all__ = [
"ConnectionError",
"ConnectionPool",
"CredentialProvider",
"CrossSlotTransactionError",
"DataError",
"from_url",
"default_backoff",
"InvalidPipelineStack",
"InvalidResponse",
"MaxConnectionsError",
"OutOfMemoryError",
"PubSubError",
"ReadOnlyError",
"Redis",
"RedisCluster",
"RedisClusterException",
"RedisError",
"ResponseError",
"Sentinel",

View File

@@ -1,9 +1,4 @@
from .base import (
AsyncPushNotificationsParser,
BaseParser,
PushNotificationsParser,
_AsyncRESPBase,
)
from .base import BaseParser, _AsyncRESPBase
from .commands import AsyncCommandsParser, CommandsParser
from .encoders import Encoder
from .hiredis import _AsyncHiredisParser, _HiredisParser
@@ -16,12 +11,10 @@ __all__ = [
"_AsyncRESPBase",
"_AsyncRESP2Parser",
"_AsyncRESP3Parser",
"AsyncPushNotificationsParser",
"CommandsParser",
"Encoder",
"BaseParser",
"_HiredisParser",
"_RESP2Parser",
"_RESP3Parser",
"PushNotificationsParser",
]

View File

@@ -1,7 +1,7 @@
import sys
from abc import ABC
from asyncio import IncompleteReadError, StreamReader, TimeoutError
from typing import Callable, List, Optional, Protocol, Union
from typing import List, Optional, Union
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
@@ -9,32 +9,26 @@ else:
from async_timeout import timeout as async_timeout
from ..exceptions import (
AskError,
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
BusyLoadingError,
ClusterCrossSlotError,
ClusterDownError,
ConnectionError,
ExecAbortError,
MasterDownError,
ModuleError,
MovedError,
NoPermissionError,
NoScriptError,
OutOfMemoryError,
ReadOnlyError,
RedisError,
ResponseError,
TryAgainError,
)
from ..typing import EncodableT
from .encoders import Encoder
from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs."
NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible."
MODULE_EXPORTS_DATA_TYPES_ERROR = (
"Error unloading module: the module "
"exports one or more module-side data "
@@ -78,12 +72,6 @@ class BaseParser(ABC):
"READONLY": ReadOnlyError,
"NOAUTH": AuthenticationError,
"NOPERM": NoPermissionError,
"ASK": AskError,
"TRYAGAIN": TryAgainError,
"MOVED": MovedError,
"CLUSTERDOWN": ClusterDownError,
"CROSSSLOT": ClusterCrossSlotError,
"MASTERDOWN": MasterDownError,
}
@classmethod
@@ -158,58 +146,6 @@ class AsyncBaseParser(BaseParser):
raise NotImplementedError()
_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
class PushNotificationsParser(Protocol):
"""Protocol defining RESP3-specific parsing functionality"""
pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses"""
raise NotImplementedError()
def handle_push_response(self, response, **kwargs):
if response[0] not in _INVALIDATION_MESSAGE:
return self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
return self.invalidation_push_handler_func(response)
def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func
def set_invalidation_push_handler(self, invalidation_push_handler_func):
self.invalidation_push_handler_func = invalidation_push_handler_func
class AsyncPushNotificationsParser(Protocol):
"""Protocol defining async RESP3-specific parsing functionality"""
pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
async def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses asynchronously"""
raise NotImplementedError()
async def handle_push_response(self, response, **kwargs):
"""Handle push responses asynchronously"""
if response[0] not in _INVALIDATION_MESSAGE:
return await self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
return await self.invalidation_push_handler_func(response)
def set_pubsub_push_handler(self, pubsub_push_handler_func):
"""Set the pubsub push handler function"""
self.pubsub_push_handler_func = pubsub_push_handler_func
def set_invalidation_push_handler(self, invalidation_push_handler_func):
"""Set the invalidation push handler function"""
self.invalidation_push_handler_func = invalidation_push_handler_func
class _AsyncRESPBase(AsyncBaseParser):
"""Base class for async resp parsing"""
@@ -246,7 +182,7 @@ class _AsyncRESPBase(AsyncBaseParser):
return True
try:
async with async_timeout(0):
return self._stream.at_eof()
return await self._stream.read(1)
except TimeoutError:
return False

View File

@@ -38,7 +38,7 @@ def parse_info(response):
response = str_if_bytes(response)
def get_value(value):
if "," not in value and "=" not in value:
if "," not in value or "=" not in value:
try:
if "." in value:
return float(value)
@@ -46,18 +46,11 @@ def parse_info(response):
return int(value)
except ValueError:
return value
elif "=" not in value:
return [get_value(v) for v in value.split(",") if v]
else:
sub_dict = {}
for item in value.split(","):
if not item:
continue
if "=" in item:
k, v = item.rsplit("=", 1)
sub_dict[k] = get_value(v)
else:
sub_dict[item] = True
k, v = item.rsplit("=", 1)
sub_dict[k] = get_value(v)
return sub_dict
for line in response.splitlines():
@@ -87,7 +80,7 @@ def parse_memory_stats(response, **kwargs):
"""Parse the results of MEMORY STATS"""
stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True)
for key, value in stats.items():
if key.startswith("db.") and isinstance(value, list):
if key.startswith("db."):
stats[key] = pairs_to_dict(
value, decode_keys=True, decode_string_values=True
)
@@ -275,22 +268,17 @@ def parse_xinfo_stream(response, **options):
data = {str_if_bytes(k): v for k, v in response.items()}
if not options.get("full", False):
first = data.get("first-entry")
if first is not None and first[0] is not None:
if first is not None:
data["first-entry"] = (first[0], pairs_to_dict(first[1]))
last = data["last-entry"]
if last is not None and last[0] is not None:
if last is not None:
data["last-entry"] = (last[0], pairs_to_dict(last[1]))
else:
data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]}
if len(data["groups"]) > 0 and isinstance(data["groups"][0], list):
if isinstance(data["groups"][0], list):
data["groups"] = [
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
]
for g in data["groups"]:
if g["consumers"] and g["consumers"][0] is not None:
g["consumers"] = [
pairs_to_dict(c, decode_keys=True) for c in g["consumers"]
]
else:
data["groups"] = [
{str_if_bytes(k): v for k, v in group.items()}
@@ -334,7 +322,7 @@ def float_or_none(response):
return float(response)
def bool_ok(response, **options):
def bool_ok(response):
return str_if_bytes(response) == "OK"
@@ -366,12 +354,7 @@ def parse_scan(response, **options):
def parse_hscan(response, **options):
cursor, r = response
no_values = options.get("no_values", False)
if no_values:
payload = r or []
else:
payload = r and pairs_to_dict(r) or {}
return int(cursor), payload
return int(cursor), r and pairs_to_dict(r) or {}
def parse_zscan(response, **options):
@@ -396,20 +379,13 @@ def parse_slowlog_get(response, **options):
# an O(N) complexity) instead of the command.
if isinstance(item[3], list):
result["command"] = space.join(item[3])
# These fields are optional, depends on environment.
if len(item) >= 6:
result["client_address"] = item[4]
result["client_name"] = item[5]
result["client_address"] = item[4]
result["client_name"] = item[5]
else:
result["complexity"] = item[3]
result["command"] = space.join(item[4])
# These fields are optional, depends on environment.
if len(item) >= 7:
result["client_address"] = item[5]
result["client_name"] = item[6]
result["client_address"] = item[5]
result["client_name"] = item[6]
return result
return [parse_item(item) for item in response]
@@ -452,11 +428,9 @@ def parse_cluster_info(response, **options):
def _parse_node_line(line):
line_items = line.split(" ")
node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8]
ip = addr.split("@")[0]
hostname = addr.split("@")[1].split(",")[1] if "@" in addr and "," in addr else ""
addr = addr.split("@")[0]
node_dict = {
"node_id": node_id,
"hostname": hostname,
"flags": flags,
"master_id": master_id,
"last_ping_sent": ping,
@@ -469,7 +443,7 @@ def _parse_node_line(line):
if len(line_items) >= 9:
slots, migrations = _parse_slots(line_items[8:])
node_dict["slots"], node_dict["migrations"] = slots, migrations
return ip, node_dict
return addr, node_dict
def _parse_slots(slot_ranges):
@@ -516,7 +490,7 @@ def parse_geosearch_generic(response, **options):
except KeyError: # it means the command was sent via execute_command
return response
if not isinstance(response, list):
if type(response) != list:
response_list = [response]
else:
response_list = response
@@ -676,8 +650,7 @@ def parse_client_info(value):
"omem",
"tot-mem",
}:
if int_key in client_info:
client_info[int_key] = int(client_info[int_key])
client_info[int_key] = int(client_info[int_key])
return client_info
@@ -840,28 +813,24 @@ _RedisCallbacksRESP2 = {
_RedisCallbacksRESP3 = {
**string_keys_to_dict(
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
),
**string_keys_to_dict(
"ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE "
"ZUNION HGETALL XREADGROUP",
lambda r, **kwargs: r,
),
**string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3),
"ACL LOG": lambda r: (
[
{str_if_bytes(key): str_if_bytes(value) for key, value in x.items()}
for x in r
]
if isinstance(r, list)
else bool_ok(r)
),
"ACL LOG": lambda r: [
{str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} for x in r
]
if isinstance(r, list)
else bool_ok(r),
"COMMAND": parse_command_resp3,
"CONFIG GET": lambda r: {
str_if_bytes(key) if key is not None else None: (
str_if_bytes(value) if value is not None else None
)
str_if_bytes(key)
if key is not None
else None: str_if_bytes(value)
if value is not None
else None
for key, value in r.items()
},
"MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()},
@@ -869,11 +838,11 @@ _RedisCallbacksRESP3 = {
"SENTINEL MASTERS": parse_sentinel_masters_resp3,
"SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3,
"SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3,
"STRALGO": lambda r, **options: (
{str_if_bytes(key): str_if_bytes(value) for key, value in r.items()}
if isinstance(r, dict)
else str_if_bytes(r)
),
"STRALGO": lambda r, **options: {
str_if_bytes(key): str_if_bytes(value) for key, value in r.items()
}
if isinstance(r, dict)
else str_if_bytes(r),
"XINFO CONSUMERS": lambda r: [
{str_if_bytes(key): value for key, value in x.items()} for x in r
],

View File

@@ -1,23 +1,19 @@
import asyncio
import socket
import sys
from logging import getLogger
from typing import Callable, List, Optional, TypedDict, Union
from typing import Callable, List, Optional, Union
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
from async_timeout import timeout as async_timeout
from redis.compat import TypedDict
from ..exceptions import ConnectionError, InvalidResponse, RedisError
from ..typing import EncodableT
from ..utils import HIREDIS_AVAILABLE
from .base import (
AsyncBaseParser,
AsyncPushNotificationsParser,
BaseParser,
PushNotificationsParser,
)
from .base import AsyncBaseParser, BaseParser
from .socket import (
NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
NONBLOCKING_EXCEPTIONS,
@@ -25,11 +21,6 @@ from .socket import (
SERVER_CLOSED_CONNECTION_ERROR,
)
# Used to signal that hiredis-py does not have enough data to parse.
# Using `False` or `None` is not reliable, given that the parser can
# return `False` or `None` for legitimate reasons from RESP payloads.
NOT_ENOUGH_DATA = object()
class _HiredisReaderArgs(TypedDict, total=False):
protocolError: Callable[[str], Exception]
@@ -38,7 +29,7 @@ class _HiredisReaderArgs(TypedDict, total=False):
errors: Optional[str]
class _HiredisParser(BaseParser, PushNotificationsParser):
class _HiredisParser(BaseParser):
"Parser class for connections using Hiredis"
def __init__(self, socket_read_size):
@@ -46,9 +37,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
raise RedisError("Hiredis is not installed")
self.socket_read_size = socket_read_size
self._buffer = bytearray(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None
def __del__(self):
try:
@@ -56,11 +44,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
except Exception:
pass
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
def on_connect(self, connection, **kwargs):
import hiredis
@@ -70,32 +53,25 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
"protocolError": InvalidResponse,
"replyError": self.parse_error,
"errors": connection.encoder.encoding_errors,
"notEnoughData": NOT_ENOUGH_DATA,
}
if connection.encoder.decode_responses:
kwargs["encoding"] = connection.encoder.encoding
self._reader = hiredis.Reader(**kwargs)
self._next_response = NOT_ENOUGH_DATA
try:
self._hiredis_PushNotificationType = hiredis.PushNotification
except AttributeError:
# hiredis < 3.2
self._hiredis_PushNotificationType = None
self._next_response = False
def on_disconnect(self):
self._sock = None
self._reader = None
self._next_response = NOT_ENOUGH_DATA
self._next_response = False
def can_read(self, timeout):
if not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._next_response is NOT_ENOUGH_DATA:
if self._next_response is False:
self._next_response = self._reader.gets()
if self._next_response is NOT_ENOUGH_DATA:
if self._next_response is False:
return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
return True
@@ -129,24 +105,14 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
if custom_timeout:
sock.settimeout(self._socket_timeout)
def read_response(self, disable_decoding=False, push_request=False):
def read_response(self, disable_decoding=False):
if not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
# _next_response might be cached from a can_read() call
if self._next_response is not NOT_ENOUGH_DATA:
if self._next_response is not False:
response = self._next_response
self._next_response = NOT_ENOUGH_DATA
if self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if not push_request:
return self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
self._next_response = False
return response
if disable_decoding:
@@ -154,7 +120,7 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
else:
response = self._reader.gets()
while response is NOT_ENOUGH_DATA:
while response is False:
self.read_from_socket()
if disable_decoding:
response = self._reader.gets(False)
@@ -165,16 +131,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
# happened
if isinstance(response, ConnectionError):
raise response
elif self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if not push_request:
return self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
elif (
isinstance(response, list)
and response
@@ -184,7 +140,7 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
return response
class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
class _AsyncHiredisParser(AsyncBaseParser):
"""Async implementation of parser class for connections using Hiredis"""
__slots__ = ("_reader",)
@@ -194,14 +150,6 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
raise RedisError("Hiredis is not available.")
super().__init__(socket_read_size=socket_read_size)
self._reader = None
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None
async def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
return response
def on_connect(self, connection):
import hiredis
@@ -210,7 +158,6 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
kwargs: _HiredisReaderArgs = {
"protocolError": InvalidResponse,
"replyError": self.parse_error,
"notEnoughData": NOT_ENOUGH_DATA,
}
if connection.encoder.decode_responses:
kwargs["encoding"] = connection.encoder.encoding
@@ -219,21 +166,13 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
self._reader = hiredis.Reader(**kwargs)
self._connected = True
try:
self._hiredis_PushNotificationType = getattr(
hiredis, "PushNotification", None
)
except AttributeError:
# hiredis < 3.2
self._hiredis_PushNotificationType = None
def on_disconnect(self):
self._connected = False
async def can_read_destructive(self):
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._reader.gets() is not NOT_ENOUGH_DATA:
if self._reader.gets():
return True
try:
async with async_timeout(0):
@@ -251,7 +190,7 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
return True
async def read_response(
self, disable_decoding: bool = False, push_request: bool = False
self, disable_decoding: bool = False
) -> Union[EncodableT, List[EncodableT]]:
# If `on_disconnect()` has been called, prohibit any more reads
# even if they could happen because data might be present.
@@ -259,33 +198,16 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
if not self._connected:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
while response is NOT_ENOUGH_DATA:
response = self._reader.gets()
while response is False:
await self.read_from_socket()
if disable_decoding:
response = self._reader.gets(False)
else:
response = self._reader.gets()
response = self._reader.gets()
# if the response is a ConnectionError or the response is a list and
# the first item is a ConnectionError, raise it as something bad
# happened
if isinstance(response, ConnectionError):
raise response
elif self._hiredis_PushNotificationType is not None and isinstance(
response, self._hiredis_PushNotificationType
):
response = await self.handle_push_response(response)
if not push_request:
return await self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
elif (
isinstance(response, list)
and response

View File

@@ -3,26 +3,20 @@ from typing import Any, Union
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
from ..typing import EncodableT
from .base import (
AsyncPushNotificationsParser,
PushNotificationsParser,
_AsyncRESPBase,
_RESPBase,
)
from .base import _AsyncRESPBase, _RESPBase
from .socket import SERVER_CLOSED_CONNECTION_ERROR
class _RESP3Parser(_RESPBase, PushNotificationsParser):
class _RESP3Parser(_RESPBase):
"""RESP3 protocol implementation"""
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
self.push_handler_func = self.handle_push_response
def handle_pubsub_push_response(self, response):
def handle_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
logger.info("Push response: " + str(response))
return response
def read_response(self, disable_decoding=False, push_request=False):
@@ -91,16 +85,19 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we return sets as list, all the time, for predictability
# so we need to first convert to a list, and then try to convert it to a set
response = [
self._read_response(disable_decoding=disable_decoding)
for _ in range(int(response))
]
try:
response = set(response)
except TypeError:
pass
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
# Evaluation order of key:val expression in dict comprehension only
# became defined to be left-right in version 3.8
# we use this approach and not dict comprehension here
# because this dict comprehension fails in python 3.7
resp_dict = {}
for _ in range(int(response)):
key = self._read_response(disable_decoding=disable_decoding)
@@ -116,13 +113,13 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
)
for _ in range(int(response))
]
response = self.handle_push_response(response)
res = self.push_handler_func(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
return res
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
@@ -130,16 +127,18 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
response = self.encoder.decode(response)
return response
def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func
class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
class _AsyncRESP3Parser(_AsyncRESPBase):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
self.push_handler_func = self.handle_push_response
async def handle_pubsub_push_response(self, response):
def handle_push_response(self, response):
logger = getLogger("push_response")
logger.debug("Push response: " + str(response))
logger.info("Push response: " + str(response))
return response
async def read_response(
@@ -215,23 +214,23 @@ class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
# set response
elif byte == b"~":
# redis can return unhashable types (like dict) in a set,
# so we always convert to a list, to have predictable return types
# so we need to first convert to a list, and then try to convert it to a set
response = [
(await self._read_response(disable_decoding=disable_decoding))
for _ in range(int(response))
]
try:
response = set(response)
except TypeError:
pass
# map response
elif byte == b"%":
# We cannot use a dict-comprehension to parse stream.
# Evaluation order of key:val expression in dict comprehension only
# became defined to be left-right in version 3.8
resp_dict = {}
for _ in range(int(response)):
key = await self._read_response(disable_decoding=disable_decoding)
resp_dict[key] = await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
response = {
(await self._read_response(disable_decoding=disable_decoding)): (
await self._read_response(disable_decoding=disable_decoding)
)
response = resp_dict
for _ in range(int(response))
}
# push response
elif byte == b">":
response = [
@@ -242,16 +241,19 @@ class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
)
for _ in range(int(response))
]
response = await self.handle_push_response(response)
res = self.push_handler_func(response)
if not push_request:
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return response
return res
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response
def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func

View File

@@ -15,11 +15,9 @@ from typing import (
Mapping,
MutableMapping,
Optional,
Protocol,
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
cast,
@@ -39,7 +37,6 @@ from redis.asyncio.connection import (
)
from redis.asyncio.lock import Lock
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialWithJitterBackoff
from redis.client import (
EMPTY_RESPONSE,
NEVER_DECODE,
@@ -52,40 +49,27 @@ from redis.commands import (
AsyncSentinelCommands,
list_or_args,
)
from redis.compat import Protocol, TypedDict
from redis.credentials import CredentialProvider
from redis.event import (
AfterPooledConnectionsInstantiationEvent,
AfterPubSubConnectionInstantiationEvent,
AfterSingleConnectionInstantiationEvent,
ClientType,
EventDispatcher,
)
from redis.exceptions import (
ConnectionError,
ExecAbortError,
PubSubError,
RedisError,
ResponseError,
TimeoutError,
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeyT
from redis.utils import (
SSL_AVAILABLE,
HIREDIS_AVAILABLE,
_set_info_logger,
deprecated_args,
deprecated_function,
get_lib_version,
safe_str,
str_if_bytes,
truncate_text,
)
if TYPE_CHECKING and SSL_AVAILABLE:
from ssl import TLSVersion, VerifyMode
else:
TLSVersion = None
VerifyMode = None
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
_KeyT = TypeVar("_KeyT", bound=KeyT)
_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
@@ -96,11 +80,13 @@ if TYPE_CHECKING:
class ResponseCallbackProtocol(Protocol):
def __call__(self, response: Any, **kwargs): ...
def __call__(self, response: Any, **kwargs):
...
class AsyncResponseCallbackProtocol(Protocol):
async def __call__(self, response: Any, **kwargs): ...
async def __call__(self, response: Any, **kwargs):
...
ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
@@ -182,7 +168,7 @@ class Redis(
warnings.warn(
DeprecationWarning(
'"auto_close_connection_pool" is deprecated '
"since version 5.0.1. "
"since version 5.0.0. "
"Please create a ConnectionPool explicitly and "
"provide to the Redis() constructor instead."
)
@@ -208,11 +194,6 @@ class Redis(
client.auto_close_connection_pool = True
return client
@deprecated_args(
args_to_warn=["retry_on_timeout"],
reason="TimeoutError is included by default.",
version="6.0.0",
)
def __init__(
self,
*,
@@ -230,19 +211,14 @@ class Redis(
encoding_errors: str = "strict",
decode_responses: bool = False,
retry_on_timeout: bool = False,
retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
),
retry_on_error: Optional[list] = None,
ssl: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: Union[str, VerifyMode] = "required",
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = True,
ssl_min_version: Optional[TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
ssl_check_hostname: bool = False,
max_connections: Optional[int] = None,
single_connection_client: bool = False,
health_check_interval: int = 0,
@@ -250,38 +226,20 @@ class Redis(
lib_name: Optional[str] = "redis-py",
lib_version: Optional[str] = get_lib_version(),
username: Optional[str] = None,
retry: Optional[Retry] = None,
auto_close_connection_pool: Optional[bool] = None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = None,
):
"""
Initialize a new Redis client.
To specify a retry policy for specific errors, you have two options:
1. Set the `retry_on_error` to a list of the error/s to retry on, and
you can also set `retry` to a valid `Retry` object(in case the default
one is not appropriate) - with this approach the retries will be triggered
on the default errors specified in the Retry object enriched with the
errors specified in `retry_on_error`.
2. Define a `Retry` object with configured 'supported_errors' and set
it to the `retry` parameter - with this approach you completely redefine
the errors on which retries will happen.
`retry_on_timeout` is deprecated - please include the TimeoutError
either in the Retry object or in the `retry_on_error` list.
When 'connection_pool' is provided - the retry configuration of the
provided pool will be used.
To specify a retry policy for specific errors, first set
`retry_on_error` to a list of the error/s to retry on, then set
`retry` to a valid `Retry` object.
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
kwargs: Dict[str, Any]
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
# auto_close_connection_pool only has an effect if connection_pool is
# None. It is assumed that if connection_pool is not None, the user
# wants to manage the connection pool themselves.
@@ -289,7 +247,7 @@ class Redis(
warnings.warn(
DeprecationWarning(
'"auto_close_connection_pool" is deprecated '
"since version 5.0.1. "
"since version 5.0.0. "
"Please create a ConnectionPool explicitly and "
"provide to the Redis() constructor instead."
)
@@ -301,6 +259,8 @@ class Redis(
# Create internal connection pool, expected to be closed by Redis instance
if not retry_on_error:
retry_on_error = []
if retry_on_timeout is True:
retry_on_error.append(TimeoutError)
kwargs = {
"db": db,
"username": username,
@@ -310,6 +270,7 @@ class Redis(
"encoding": encoding,
"encoding_errors": encoding_errors,
"decode_responses": decode_responses,
"retry_on_timeout": retry_on_timeout,
"retry_on_error": retry_on_error,
"retry": copy.deepcopy(retry),
"max_connections": max_connections,
@@ -350,26 +311,14 @@ class Redis(
"ssl_ca_certs": ssl_ca_certs,
"ssl_ca_data": ssl_ca_data,
"ssl_check_hostname": ssl_check_hostname,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)
# This arg only used if no pool is passed in
self.auto_close_connection_pool = auto_close_connection_pool
connection_pool = ConnectionPool(**kwargs)
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.ASYNC, credential_provider
)
)
else:
# If a pool is passed in, do not close it
self.auto_close_connection_pool = False
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.ASYNC, credential_provider
)
)
self.connection_pool = connection_pool
self.single_connection_client = single_connection_client
@@ -388,10 +337,7 @@ class Redis(
self._single_conn_lock = asyncio.Lock()
def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"({self.connection_pool!r})>"
)
return f"{self.__class__.__name__}<{self.connection_pool!r}>"
def __await__(self):
return self.initialize().__await__()
@@ -400,13 +346,7 @@ class Redis(
if self.single_connection_client:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection()
self._event_dispatcher.dispatch(
AfterSingleConnectionInstantiationEvent(
self.connection, ClientType.ASYNC, self._single_conn_lock
)
)
self.connection = await self.connection_pool.get_connection("_")
return self
def set_response_callback(self, command: str, callback: ResponseCallbackT):
@@ -421,10 +361,10 @@ class Redis(
"""Get the connection's key-word arguments"""
return self.connection_pool.connection_kwargs
def get_retry(self) -> Optional[Retry]:
def get_retry(self) -> Optional["Retry"]:
return self.get_connection_kwargs().get("retry")
def set_retry(self, retry: Retry) -> None:
def set_retry(self, retry: "Retry") -> None:
self.get_connection_kwargs().update({"retry": retry})
self.connection_pool.set_retry(retry)
@@ -503,7 +443,6 @@ class Redis(
blocking_timeout: Optional[float] = None,
lock_class: Optional[Type[Lock]] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
) -> Lock:
"""
Return a new Lock object using key ``name`` that mimics
@@ -550,11 +489,6 @@ class Redis(
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
@@ -572,7 +506,6 @@ class Redis(
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
raise_on_release_error=raise_on_release_error,
)
def pubsub(self, **kwargs) -> "PubSub":
@@ -581,9 +514,7 @@ class Redis(
subscribe to channels and listen for messages that get published to
them.
"""
return PubSub(
self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
)
return PubSub(self.connection_pool, **kwargs)
def monitor(self) -> "Monitor":
return Monitor(self.connection_pool)
@@ -615,18 +546,15 @@ class Redis(
_grl().call_exception_handler(context)
except RuntimeError:
pass
self.connection._close()
async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
"""
Closes Redis client connection
Args:
close_connection_pool:
decides whether to close the connection pool used by this Redis client,
overriding Redis.auto_close_connection_pool.
By default, let Redis.auto_close_connection_pool decide
whether to close the connection pool.
:param close_connection_pool: decides whether to close the connection pool used
by this Redis client, overriding Redis.auto_close_connection_pool. By default,
let Redis.auto_close_connection_pool decide whether to close the connection
pool.
"""
conn = self.connection
if conn:
@@ -637,7 +565,7 @@ class Redis(
):
await self.connection_pool.disconnect()
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
"""
Alias for aclose(), for backwards compatibility
@@ -651,17 +579,18 @@ class Redis(
await conn.send_command(*args)
return await self.parse_response(conn, command_name, **options)
async def _close_connection(self, conn: Connection):
async def _disconnect_raise(self, conn: Connection, error: Exception):
"""
Close the connection before retrying.
The supported exceptions are already checked in the
retry object so we don't need to do it here.
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
Close the connection and raise an exception
if retry_on_error is not set or the error
is not one of the specified error types
"""
await conn.disconnect()
if (
conn.retry_on_error is None
or isinstance(error, tuple(conn.retry_on_error)) is False
):
raise error
# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
@@ -669,7 +598,7 @@ class Redis(
await self.initialize()
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection()
conn = self.connection or await pool.get_connection(command_name, **options)
if self.single_connection_client:
await self._single_conn_lock.acquire()
@@ -678,7 +607,7 @@ class Redis(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda _: self._close_connection(conn),
lambda error: self._disconnect_raise(conn, error),
)
finally:
if self.single_connection_client:
@@ -704,9 +633,6 @@ class Redis(
if EMPTY_RESPONSE in options:
options.pop(EMPTY_RESPONSE)
# Remove keys entry, it needs only for cache.
options.pop("keys", None)
if command_name in self.response_callbacks:
# Mypy bug: https://github.com/python/mypy/issues/10977
command_name = cast(str, command_name)
@@ -743,7 +669,7 @@ class Monitor:
async def connect(self):
if self.connection is None:
self.connection = await self.connection_pool.get_connection()
self.connection = await self.connection_pool.get_connection("MONITOR")
async def __aenter__(self):
await self.connect()
@@ -820,12 +746,7 @@ class PubSub:
ignore_subscribe_messages: bool = False,
encoder=None,
push_handler_func: Optional[Callable] = None,
event_dispatcher: Optional["EventDispatcher"] = None,
):
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.connection_pool = connection_pool
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
@@ -862,7 +783,7 @@ class PubSub:
def __del__(self):
if self.connection:
self.connection.deregister_connect_callback(self.on_connect)
self.connection._deregister_connect_callback(self.on_connect)
async def aclose(self):
# In case a connection property does not yet exist
@@ -873,7 +794,7 @@ class PubSub:
async with self._lock:
if self.connection:
await self.connection.disconnect()
self.connection.deregister_connect_callback(self.on_connect)
self.connection._deregister_connect_callback(self.on_connect)
await self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
@@ -881,12 +802,12 @@ class PubSub:
self.patterns = {}
self.pending_unsubscribe_patterns = set()
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
async def close(self) -> None:
"""Alias for aclose(), for backwards compatibility"""
await self.aclose()
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="reset")
async def reset(self) -> None:
"""Alias for aclose(), for backwards compatibility"""
await self.aclose()
@@ -931,26 +852,26 @@ class PubSub:
Ensure that the PubSub is connected
"""
if self.connection is None:
self.connection = await self.connection_pool.get_connection()
self.connection = await self.connection_pool.get_connection(
"pubsub", self.shard_hint
)
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
self.connection._register_connect_callback(self.on_connect)
else:
await self.connection.connect()
if self.push_handler_func is not None:
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
self._event_dispatcher.dispatch(
AfterPubSubConnectionInstantiationEvent(
self.connection, self.connection_pool, ClientType.ASYNC, self._lock
)
)
async def _reconnect(self, conn):
async def _disconnect_raise_connect(self, conn, error):
"""
Try to reconnect
Close the connection and raise an exception
if retry_on_timeout is not set or the error
is not a TimeoutError. Otherwise, try to reconnect
"""
await conn.disconnect()
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
await conn.connect()
async def _execute(self, conn, command, *args, **kwargs):
@@ -963,7 +884,7 @@ class PubSub:
"""
return await conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda _: self._reconnect(conn),
lambda error: self._disconnect_raise_connect(conn, error),
)
async def parse_response(self, block: bool = True, timeout: float = 0):
@@ -1232,11 +1153,13 @@ class PubSub:
class PubsubWorkerExceptionHandler(Protocol):
def __call__(self, e: BaseException, pubsub: PubSub): ...
def __call__(self, e: BaseException, pubsub: PubSub):
...
class AsyncPubsubWorkerExceptionHandler(Protocol):
async def __call__(self, e: BaseException, pubsub: PubSub): ...
async def __call__(self, e: BaseException, pubsub: PubSub):
...
PSWorkerThreadExcHandlerT = Union[
@@ -1254,8 +1177,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
in one transmission. This is convenient for batch processing, such as
saving all the values in a list to Redis.
All commands executed within a pipeline(when running in transactional mode,
which is the default behavior) are wrapped with MULTI and EXEC
All commands executed within a pipeline are wrapped with MULTI and EXEC
calls. This guarantees all commands executed in the pipeline will be
executed atomically.
@@ -1284,7 +1206,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
self.shard_hint = shard_hint
self.watching = False
self.command_stack: CommandStackT = []
self.scripts: Set[Script] = set()
self.scripts: Set["Script"] = set()
self.explicit_transaction = False
async def __aenter__(self: _RedisT) -> _RedisT:
@@ -1356,50 +1278,49 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
return self.immediate_execute_command(*args, **kwargs)
return self.pipeline_execute_command(*args, **kwargs)
async def _disconnect_reset_raise_on_watching(
self,
conn: Connection,
error: Exception,
):
async def _disconnect_reset_raise(self, conn, error):
"""
Close the connection reset watching state and
raise an exception if we were watching.
The supported exceptions are already checked in the
retry object so we don't need to do it here.
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
Close the connection, reset watching state and
raise an exception if we were watching,
retry_on_timeout is not set,
or the error is not a TimeoutError
"""
await conn.disconnect()
# if we were already watching a variable, the watch is no longer
# valid since this connection has died. raise a WatchError, which
# indicates the user should retry this transaction.
if self.watching:
await self.reset()
await self.aclose()
raise WatchError(
f"A {type(error).__name__} occurred while watching one or more keys"
"A ConnectionError occurred on while watching one or more keys"
)
# if retry_on_timeout is not set, or the error is not
# a TimeoutError, raise it
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
await self.aclose()
raise
async def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on the supported
errors for retry if we're already WATCHing a variable.
Used when issuing WATCH or subsequent commands retrieving their values but before
Execute a command immediately, but don't auto-retry on a
ConnectionError if we're already WATCHing a variable. Used when
issuing WATCH or subsequent commands retrieving their values but before
MULTI is called.
"""
command_name = args[0]
conn = self.connection
# if this is the first call, we need a connection
if not conn:
conn = await self.connection_pool.get_connection()
conn = await self.connection_pool.get_connection(
command_name, self.shard_hint
)
self.connection = conn
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise_on_watching(conn, error),
lambda error: self._disconnect_reset_raise(conn, error),
)
def pipeline_execute_command(self, *args, **options):
@@ -1484,10 +1405,6 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
if not isinstance(r, Exception):
args, options = cmd
command_name = args[0]
# Remove keys entry, it needs only for cache.
options.pop("keys", None)
if command_name in self.response_callbacks:
r = self.response_callbacks[command_name](r, **options)
if inspect.isawaitable(r):
@@ -1525,10 +1442,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
self, exception: Exception, number: int, command: Iterable[object]
) -> None:
cmd = " ".join(map(safe_str, command))
msg = (
f"Command # {number} ({truncate_text(cmd)}) "
"of pipeline caused error: {exception.args}"
)
msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}"
exception.args = (msg,) + exception.args[1:]
async def parse_response(
@@ -1554,15 +1468,11 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
if not exist:
s.sha = await immediate("SCRIPT LOAD", s.script)
async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception):
async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
"""
Close the connection, raise an exception if we were watching.
The supported exceptions are already checked in the
retry object so we don't need to do it here.
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
Close the connection, raise an exception if we were watching,
and raise an exception if retry_on_timeout is not set,
or the error is not a TimeoutError
"""
await conn.disconnect()
# if we were watching a variable, the watch is no longer valid
@@ -1570,10 +1480,15 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
# indicates the user should retry this transaction.
if self.watching:
raise WatchError(
f"A {type(error).__name__} occurred while watching one or more keys"
"A ConnectionError occurred on while watching one or more keys"
)
# if retry_on_timeout is not set, or the error is not
# a TimeoutError, raise it
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
await self.reset()
raise
async def execute(self, raise_on_error: bool = True) -> List[Any]:
async def execute(self, raise_on_error: bool = True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
if not stack and not self.watching:
@@ -1587,7 +1502,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
conn = self.connection
if not conn:
conn = await self.connection_pool.get_connection()
conn = await self.connection_pool.get_connection("MULTI", self.shard_hint)
# assign to self.connection so reset() releases the connection
# back to the pool after we're done
self.connection = conn
@@ -1596,7 +1511,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
try:
return await conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_on_watching(conn, error),
lambda error: self._disconnect_raise_reset(conn, error),
)
finally:
await self.reset()

File diff suppressed because it is too large Load Diff

View File

@@ -3,8 +3,8 @@ import copy
import enum
import inspect
import socket
import ssl
import sys
import warnings
import weakref
from abc import abstractmethod
from itertools import chain
@@ -16,30 +16,14 @@ from typing import (
List,
Mapping,
Optional,
Protocol,
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
)
from urllib.parse import ParseResult, parse_qs, unquote, urlparse
from ..utils import SSL_AVAILABLE
if SSL_AVAILABLE:
import ssl
from ssl import SSLContext, TLSVersion
else:
ssl = None
TLSVersion = None
SSLContext = None
from ..auth.token import TokenInterface
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
from ..utils import deprecated_args, format_error_message
# the functionality is available in 3.11.x but has a major issue before
# 3.11.3. See https://github.com/redis/redis-py/issues/2633
if sys.version_info >= (3, 11, 3):
@@ -49,6 +33,7 @@ else:
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.compat import Protocol, TypedDict
from redis.connection import DEFAULT_RESP_VERSION
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
from redis.exceptions import (
@@ -93,11 +78,13 @@ else:
class ConnectCallbackProtocol(Protocol):
def __call__(self, connection: "AbstractConnection"): ...
def __call__(self, connection: "AbstractConnection"):
...
class AsyncConnectCallbackProtocol(Protocol):
async def __call__(self, connection: "AbstractConnection"): ...
async def __call__(self, connection: "AbstractConnection"):
...
ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
@@ -159,7 +146,6 @@ class AbstractConnection:
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = None,
):
if (username or password) and credential_provider is not None:
raise DataError(
@@ -168,10 +154,6 @@ class AbstractConnection:
"1. 'password' and (optional) 'username'\n"
"2. 'credential_provider'"
)
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.db = db
self.client_name = client_name
self.lib_name = lib_name
@@ -211,8 +193,6 @@ class AbstractConnection:
self.set_parser(parser_class)
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
self._buffer_cutoff = 6000
self._re_auth_token: Optional[TokenInterface] = None
try:
p = int(protocol)
except TypeError:
@@ -224,33 +204,9 @@ class AbstractConnection:
raise ConnectionError("protocol must be either 2 or 3")
self.protocol = protocol
def __del__(self, _warnings: Any = warnings):
# For some reason, the individual streams don't get properly garbage
# collected and therefore produce no resource warnings. We add one
# here, in the same style as those from the stdlib.
if getattr(self, "_writer", None):
_warnings.warn(
f"unclosed Connection {self!r}", ResourceWarning, source=self
)
try:
asyncio.get_running_loop()
self._close()
except RuntimeError:
# No actions been taken if pool already closed.
pass
def _close(self):
"""
Internal method to silently close the connection without waiting
"""
if self._writer:
self._writer.close()
self._writer = self._reader = None
def __repr__(self):
repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
return f"{self.__class__.__name__}<{repr_args}>"
@abstractmethod
def repr_pieces(self):
@@ -260,24 +216,12 @@ class AbstractConnection:
def is_connected(self):
return self._reader is not None and self._writer is not None
def register_connect_callback(self, callback):
"""
Register a callback to be called when the connection is established either
initially or reconnected. This allows listeners to issue commands that
are ephemeral to the connection, for example pub/sub subscription or
key tracking. The callback must be a _method_ and will be kept as
a weak reference.
"""
def _register_connect_callback(self, callback):
wm = weakref.WeakMethod(callback)
if wm not in self._connect_callbacks:
self._connect_callbacks.append(wm)
def deregister_connect_callback(self, callback):
"""
De-register a previously registered callback. It will no-longer receive
notifications on connection events. Calling this is not required when the
listener goes away, since the callbacks are kept as weak methods.
"""
def _deregister_connect_callback(self, callback):
try:
self._connect_callbacks.remove(weakref.WeakMethod(callback))
except ValueError:
@@ -293,20 +237,12 @@ class AbstractConnection:
async def connect(self):
"""Connects to the Redis server if not already connected"""
await self.connect_check_health(check_health=True)
async def connect_check_health(
self, check_health: bool = True, retry_socket_connect: bool = True
):
if self.is_connected:
return
try:
if retry_socket_connect:
await self.retry.call_with_retry(
lambda: self._connect(), lambda error: self.disconnect()
)
else:
await self._connect()
await self.retry.call_with_retry(
lambda: self._connect(), lambda error: self.disconnect()
)
except asyncio.CancelledError:
raise # in 3.7 and earlier, this is an Exception, not BaseException
except (socket.timeout, asyncio.TimeoutError):
@@ -319,14 +255,12 @@ class AbstractConnection:
try:
if not self.redis_connect_func:
# Use the default on_connect function
await self.on_connect_check_health(check_health=check_health)
await self.on_connect()
else:
# Use the passed function redis_connect_func
(
await self.redis_connect_func(self)
if asyncio.iscoroutinefunction(self.redis_connect_func)
else self.redis_connect_func(self)
)
await self.redis_connect_func(self) if asyncio.iscoroutinefunction(
self.redis_connect_func
) else self.redis_connect_func(self)
except RedisError:
# clean up after any error in on_connect
await self.disconnect()
@@ -350,17 +284,12 @@ class AbstractConnection:
def _host_error(self) -> str:
pass
@abstractmethod
def _error_message(self, exception: BaseException) -> str:
return format_error_message(self._host_error(), exception)
def get_protocol(self):
return self.protocol
pass
async def on_connect(self) -> None:
"""Initialize the connection, authenticate and select a database"""
await self.on_connect_check_health(check_health=True)
async def on_connect_check_health(self, check_health: bool = True) -> None:
self._parser.on_connect(self)
parser = self._parser
@@ -371,8 +300,7 @@ class AbstractConnection:
self.credential_provider
or UsernamePasswordCredentialProvider(self.username, self.password)
)
auth_args = await cred_provider.get_credentials_async()
auth_args = cred_provider.get_credentials()
# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol not in [2, "2"]:
@@ -383,11 +311,7 @@ class AbstractConnection:
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
await self.send_command(
"HELLO", self.protocol, "AUTH", *auth_args, check_health=False
)
await self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
response = await self.read_response()
if response.get(b"proto") != int(self.protocol) and response.get(
"proto"
@@ -418,7 +342,7 @@ class AbstractConnection:
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
await self.send_command("HELLO", self.protocol, check_health=check_health)
await self.send_command("HELLO", self.protocol)
response = await self.read_response()
# if response.get(b"proto") != self.protocol and response.get(
# "proto"
@@ -427,35 +351,18 @@ class AbstractConnection:
# if a client_name is given, set it
if self.client_name:
await self.send_command(
"CLIENT",
"SETNAME",
self.client_name,
check_health=check_health,
)
await self.send_command("CLIENT", "SETNAME", self.client_name)
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Error setting client name")
# set the library name and version, pipeline for lower startup latency
if self.lib_name:
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-NAME",
self.lib_name,
check_health=check_health,
)
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
if self.lib_version:
await self.send_command(
"CLIENT",
"SETINFO",
"LIB-VER",
self.lib_version,
check_health=check_health,
)
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db, check_health=check_health)
await self.send_command("SELECT", self.db)
# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -517,8 +424,8 @@ class AbstractConnection:
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
) -> None:
if not self.is_connected:
await self.connect_check_health(check_health=False)
if check_health:
await self.connect()
elif check_health:
await self.check_health()
try:
@@ -581,7 +488,11 @@ class AbstractConnection:
read_timeout = timeout if timeout is not None else self.socket_timeout
host_error = self._host_error()
try:
if read_timeout is not None and self.protocol in ["3", 3]:
if (
read_timeout is not None
and self.protocol in ["3", 3]
and not HIREDIS_AVAILABLE
):
async with async_timeout(read_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
@@ -591,7 +502,7 @@ class AbstractConnection:
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
elif self.protocol in ["3", 3]:
elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE:
response = await self._parser.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
@@ -703,27 +614,6 @@ class AbstractConnection:
output.append(SYM_EMPTY.join(pieces))
return output
def _socket_is_empty(self):
"""Check if the socket is empty"""
return len(self._reader._buffer) == 0
async def process_invalidation_messages(self):
while not self._socket_is_empty():
await self.read_response(push_request=True)
def set_re_auth_token(self, token: TokenInterface):
self._re_auth_token = token
async def re_auth(self):
if self._re_auth_token is not None:
await self.send_command(
"AUTH",
self._re_auth_token.try_get("oid"),
self._re_auth_token.get_value(),
)
await self.read_response()
self._re_auth_token = None
class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"
@@ -781,6 +671,27 @@ class Connection(AbstractConnection):
def _host_error(self) -> str:
return f"{self.host}:{self.port}"
def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if not exception.args:
# asyncio has a bug where on Connection reset by peer, the
# exception is not instanciated, so args is empty. This is the
# workaround.
# See: https://github.com/redis/redis-py/issues/2237
# See: https://github.com/python/cpython/issues/94061
return f"Error connecting to {host_error}. Connection reset by peer"
elif len(exception.args) == 1:
return f"Error connecting to {host_error}. {exception.args[0]}."
else:
return (
f"Error {exception.args[0]} connecting to {host_error}. "
f"{exception.args[0]}."
)
class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
@@ -792,17 +703,12 @@ class SSLConnection(Connection):
self,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = True,
ssl_min_version: Optional[TLSVersion] = None,
ssl_ciphers: Optional[str] = None,
ssl_check_hostname: bool = False,
**kwargs,
):
if not SSL_AVAILABLE:
raise RedisError("Python wasn't built with SSL support")
self.ssl_context: RedisSSLContext = RedisSSLContext(
keyfile=ssl_keyfile,
certfile=ssl_certfile,
@@ -810,8 +716,6 @@ class SSLConnection(Connection):
ca_certs=ssl_ca_certs,
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
min_version=ssl_min_version,
ciphers=ssl_ciphers,
)
super().__init__(**kwargs)
@@ -844,10 +748,6 @@ class SSLConnection(Connection):
def check_hostname(self):
return self.ssl_context.check_hostname
@property
def min_version(self):
return self.ssl_context.min_version
class RedisSSLContext:
__slots__ = (
@@ -858,30 +758,23 @@ class RedisSSLContext:
"ca_data",
"context",
"check_hostname",
"min_version",
"ciphers",
)
def __init__(
self,
keyfile: Optional[str] = None,
certfile: Optional[str] = None,
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
cert_reqs: Optional[str] = None,
ca_certs: Optional[str] = None,
ca_data: Optional[str] = None,
check_hostname: bool = False,
min_version: Optional[TLSVersion] = None,
ciphers: Optional[str] = None,
):
if not SSL_AVAILABLE:
raise RedisError("Python wasn't built with SSL support")
self.keyfile = keyfile
self.certfile = certfile
if cert_reqs is None:
cert_reqs = ssl.CERT_NONE
self.cert_reqs = ssl.CERT_NONE
elif isinstance(cert_reqs, str):
CERT_REQS = { # noqa: N806
CERT_REQS = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
"required": ssl.CERT_REQUIRED,
@@ -890,18 +783,13 @@ class RedisSSLContext:
raise RedisError(
f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
)
cert_reqs = CERT_REQS[cert_reqs]
self.cert_reqs = cert_reqs
self.cert_reqs = CERT_REQS[cert_reqs]
self.ca_certs = ca_certs
self.ca_data = ca_data
self.check_hostname = (
check_hostname if self.cert_reqs != ssl.CERT_NONE else False
)
self.min_version = min_version
self.ciphers = ciphers
self.context: Optional[SSLContext] = None
self.check_hostname = check_hostname
self.context: Optional[ssl.SSLContext] = None
def get(self) -> SSLContext:
def get(self) -> ssl.SSLContext:
if not self.context:
context = ssl.create_default_context()
context.check_hostname = self.check_hostname
@@ -910,10 +798,6 @@ class RedisSSLContext:
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
if self.ca_certs or self.ca_data:
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
if self.min_version is not None:
context.minimum_version = self.min_version
if self.ciphers is not None:
context.set_ciphers(self.ciphers)
self.context = context
return self.context
@@ -941,6 +825,20 @@ class UnixDomainSocketConnection(AbstractConnection):
def _host_error(self) -> str:
return self.path
def _error_message(self, exception: BaseException) -> str:
# args for socket.error can either be (errno, "message")
# or just "message"
host_error = self._host_error()
if len(exception.args) == 1:
return (
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
)
else:
return (
f"Error {exception.args[0]} connecting to unix socket: "
f"{host_error}. {exception.args[1]}."
)
FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
@@ -963,7 +861,6 @@ URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyTy
"max_connections": int,
"health_check_interval": int,
"ssl_check_hostname": to_bool,
"timeout": float,
}
)
@@ -990,7 +887,7 @@ def parse_url(url: str) -> ConnectKwargs:
try:
kwargs[name] = parser(value)
except (TypeError, ValueError):
raise ValueError(f"Invalid value for '{name}' in connection URL.")
raise ValueError(f"Invalid value for `{name}` in connection URL.")
else:
kwargs[name] = value
@@ -1042,7 +939,6 @@ class ConnectionPool:
By default, TCP connections are created unless ``connection_class``
is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
unix sockets.
:py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
Any additional keyword arguments are passed to the constructor of
``connection_class``.
@@ -1112,22 +1008,16 @@ class ConnectionPool:
self._available_connections: List[AbstractConnection] = []
self._in_use_connections: Set[AbstractConnection] = set()
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
self._lock = asyncio.Lock()
self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
def __repr__(self):
conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
f"({conn_kwargs})>)>"
f"{self.__class__.__name__}"
f"<{self.connection_class(**self.connection_kwargs)!r}>"
)
def reset(self):
self._available_connections = []
self._in_use_connections = weakref.WeakSet()
self._in_use_connections = set()
def can_get_connection(self) -> bool:
"""Return True if a connection can be retrieved from the pool."""
@@ -1136,25 +1026,8 @@ class ConnectionPool:
or len(self._in_use_connections) < self.max_connections
)
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.3.0",
)
async def get_connection(self, command_name=None, *keys, **options):
async with self._lock:
"""Get a connected connection from the pool"""
connection = self.get_available_connection()
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise
return connection
def get_available_connection(self):
"""Get a connection from the pool, without making sure it is connected"""
async def get_connection(self, command_name, *keys, **options):
"""Get a connection from the pool"""
try:
connection = self._available_connections.pop()
except IndexError:
@@ -1162,6 +1035,13 @@ class ConnectionPool:
raise ConnectionError("Too many connections") from None
connection = self.make_connection()
self._in_use_connections.add(connection)
try:
await self.ensure_connection(connection)
except BaseException:
await self.release(connection)
raise
return connection
def get_encoder(self):
@@ -1187,7 +1067,7 @@ class ConnectionPool:
try:
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except (ConnectionError, TimeoutError, OSError):
except (ConnectionError, OSError):
await connection.disconnect()
await connection.connect()
if await connection.can_read_destructive():
@@ -1199,9 +1079,6 @@ class ConnectionPool:
# not doing so is an error that will cause an exception here.
self._in_use_connections.remove(connection)
self._available_connections.append(connection)
await self._event_dispatcher.dispatch_async(
AsyncAfterConnectionReleasedEvent(connection)
)
async def disconnect(self, inuse_connections: bool = True):
"""
@@ -1235,29 +1112,6 @@ class ConnectionPool:
for conn in self._in_use_connections:
conn.retry = retry
async def re_auth_callback(self, token: TokenInterface):
async with self._lock:
for conn in self._available_connections:
await conn.retry.call_with_retry(
lambda: conn.send_command(
"AUTH", token.try_get("oid"), token.get_value()
),
lambda error: self._mock(error),
)
await conn.retry.call_with_retry(
lambda: conn.read_response(), lambda error: self._mock(error)
)
for conn in self._in_use_connections:
conn.set_re_auth_token(token)
async def _mock(self, error: RedisError):
"""
Dummy functions, needs to be passed as error callback to retry object.
:param error:
:return:
"""
pass
class BlockingConnectionPool(ConnectionPool):
"""
@@ -1275,7 +1129,7 @@ class BlockingConnectionPool(ConnectionPool):
connection from the pool when all of connections are in use, rather than
raising a :py:class:`~redis.ConnectionError` (as the default
:py:class:`~redis.asyncio.ConnectionPool` implementation does), it
blocks the current `Task` for a specified number of seconds until
makes blocks the current `Task` for a specified number of seconds until
a connection becomes available.
Use ``max_connections`` to increase / decrease the pool size::
@@ -1309,29 +1163,16 @@ class BlockingConnectionPool(ConnectionPool):
self._condition = asyncio.Condition()
self.timeout = timeout
@deprecated_args(
args_to_warn=["*"],
reason="Use get_connection() without args instead",
version="5.3.0",
)
async def get_connection(self, command_name=None, *keys, **options):
async def get_connection(self, command_name, *keys, **options):
"""Gets a connection from the pool, blocking until one is available"""
try:
async with self._condition:
async with async_timeout(self.timeout):
async with async_timeout(self.timeout):
async with self._condition:
await self._condition.wait_for(self.can_get_connection)
connection = super().get_available_connection()
return await super().get_connection(command_name, *keys, **options)
except asyncio.TimeoutError as err:
raise ConnectionError("No connection available.") from err
# We now perform the connection check outside of the lock.
try:
await self.ensure_connection(connection)
return connection
except BaseException:
await self.release(connection)
raise
async def release(self, connection: AbstractConnection):
"""Releases the connection back to the pool."""
async with self._condition:

View File

@@ -1,18 +1,14 @@
import asyncio
import logging
import threading
import uuid
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, Optional, Union
from redis.exceptions import LockError, LockNotOwnedError
from redis.typing import Number
if TYPE_CHECKING:
from redis.asyncio import Redis, RedisCluster
logger = logging.getLogger(__name__)
class Lock:
"""
@@ -86,9 +82,8 @@ class Lock:
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[Number] = None,
blocking_timeout: Optional[float] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
@@ -132,11 +127,6 @@ class Lock:
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
@@ -153,7 +143,6 @@ class Lock:
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.local = threading.local() if self.thread_local else SimpleNamespace()
self.raise_on_release_error = raise_on_release_error
self.local.token = None
self.register_scripts()
@@ -173,19 +162,12 @@ class Lock:
raise LockError("Unable to acquire lock within the time specified")
async def __aexit__(self, exc_type, exc_value, traceback):
try:
await self.release()
except LockError:
if self.raise_on_release_error:
raise
logger.warning(
"Lock was unlocked or no longer owned when exiting context manager."
)
await self.release()
async def acquire(
self,
blocking: Optional[bool] = None,
blocking_timeout: Optional[Number] = None,
blocking_timeout: Optional[float] = None,
token: Optional[Union[str, bytes]] = None,
):
"""
@@ -267,10 +249,7 @@ class Lock:
"""Releases the already acquired lock"""
expected_token = self.local.token
if expected_token is None:
raise LockError(
"Cannot release a lock that's not owned or is already unlocked.",
lock_name=self.name,
)
raise LockError("Cannot release an unlocked lock")
self.local.token = None
return self.do_release(expected_token)
@@ -283,7 +262,7 @@ class Lock:
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
def extend(
self, additional_time: Number, replace_ttl: bool = False
self, additional_time: float, replace_ttl: bool = False
) -> Awaitable[bool]:
"""
Adds more time to an already acquired lock.

View File

@@ -2,16 +2,18 @@ from asyncio import sleep
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
from redis.exceptions import ConnectionError, RedisError, TimeoutError
from redis.retry import AbstractRetry
T = TypeVar("T")
if TYPE_CHECKING:
from redis.backoff import AbstractBackoff
class Retry(AbstractRetry[RedisError]):
__hash__ = AbstractRetry.__hash__
T = TypeVar("T")
class Retry:
"""Retry a specific number of times after a failure"""
__slots__ = "_backoff", "_retries", "_supported_errors"
def __init__(
self,
@@ -22,16 +24,23 @@ class Retry(AbstractRetry[RedisError]):
TimeoutError,
),
):
super().__init__(backoff, retries, supported_errors)
"""
Initialize a `Retry` object with a `Backoff` object
that retries a maximum of `retries` times.
`retries` can be negative to retry forever.
You can specify the types of supported errors which trigger
a retry with the `supported_errors` parameter.
"""
self._backoff = backoff
self._retries = retries
self._supported_errors = supported_errors
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Retry):
return NotImplemented
return (
self._backoff == other._backoff
and self._retries == other._retries
and set(self._supported_errors) == set(other._supported_errors)
def update_supported_errors(self, specified_errors: list):
"""
Updates the supported errors with the specified error types
"""
self._supported_errors = tuple(
set(self._supported_errors + tuple(specified_errors))
)
async def call_with_retry(

View File

@@ -11,12 +11,8 @@ from redis.asyncio.connection import (
SSLConnection,
)
from redis.commands import AsyncSentinelCommands
from redis.exceptions import (
ConnectionError,
ReadOnlyError,
ResponseError,
TimeoutError,
)
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
from redis.utils import str_if_bytes
class MasterNotFoundError(ConnectionError):
@@ -33,18 +29,20 @@ class SentinelManagedConnection(Connection):
super().__init__(**kwargs)
def __repr__(self):
s = f"<{self.__class__.__module__}.{self.__class__.__name__}"
pool = self.connection_pool
s = f"{self.__class__.__name__}<service={pool.service_name}"
if self.host:
host_info = f",host={self.host},port={self.port}"
s += host_info
return s + ")>"
return s + ">"
async def connect_to(self, address):
self.host, self.port = address
await self.connect_check_health(
check_health=self.connection_pool.check_connection,
retry_socket_connect=False,
)
await super().connect()
if self.connection_pool.check_connection:
await self.send_command("PING")
if str_if_bytes(await self.read_response()) != "PONG":
raise ConnectionError("PING failed")
async def _connect_retry(self):
if self._reader:
@@ -107,11 +105,9 @@ class SentinelConnectionPool(ConnectionPool):
def __init__(self, service_name, sentinel_manager, **kwargs):
kwargs["connection_class"] = kwargs.get(
"connection_class",
(
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection
),
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection,
)
self.is_master = kwargs.pop("is_master", True)
self.check_connection = kwargs.pop("check_connection", False)
@@ -124,8 +120,8 @@ class SentinelConnectionPool(ConnectionPool):
def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>"
f"{self.__class__.__name__}"
f"<service={self.service_name}({self.is_master and 'master' or 'slave'})>"
)
def reset(self):
@@ -201,7 +197,6 @@ class Sentinel(AsyncSentinelCommands):
sentinels,
min_other_sentinels=0,
sentinel_kwargs=None,
force_master_ip=None,
**connection_kwargs,
):
# if sentinel_kwargs isn't defined, use the socket_* options from
@@ -218,7 +213,6 @@ class Sentinel(AsyncSentinelCommands):
]
self.min_other_sentinels = min_other_sentinels
self.connection_kwargs = connection_kwargs
self._force_master_ip = force_master_ip
async def execute_command(self, *args, **kwargs):
"""
@@ -226,31 +220,19 @@ class Sentinel(AsyncSentinelCommands):
once - If set to True, then execute the resulting command on a single
node at random, rather than across the entire sentinel cluster.
"""
once = bool(kwargs.pop("once", False))
# Check if command is supposed to return the original
# responses instead of boolean value.
return_responses = bool(kwargs.pop("return_responses", False))
once = bool(kwargs.get("once", False))
if "once" in kwargs.keys():
kwargs.pop("once")
if once:
response = await random.choice(self.sentinels).execute_command(
*args, **kwargs
)
if return_responses:
return [response]
else:
return True if response else False
tasks = [
asyncio.Task(sentinel.execute_command(*args, **kwargs))
for sentinel in self.sentinels
]
responses = await asyncio.gather(*tasks)
if return_responses:
return responses
return all(responses)
await random.choice(self.sentinels).execute_command(*args, **kwargs)
else:
tasks = [
asyncio.Task(sentinel.execute_command(*args, **kwargs))
for sentinel in self.sentinels
]
await asyncio.gather(*tasks)
return True
def __repr__(self):
sentinel_addresses = []
@@ -259,10 +241,7 @@ class Sentinel(AsyncSentinelCommands):
f"{sentinel.connection_pool.connection_kwargs['host']}:"
f"{sentinel.connection_pool.connection_kwargs['port']}"
)
return (
f"<{self.__class__}.{self.__class__.__name__}"
f"(sentinels=[{','.join(sentinel_addresses)}])>"
)
return f"{self.__class__.__name__}<sentinels=[{','.join(sentinel_addresses)}]>"
def check_master_state(self, state: dict, service_name: str) -> bool:
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
@@ -294,13 +273,7 @@ class Sentinel(AsyncSentinelCommands):
sentinel,
self.sentinels[0],
)
ip = (
self._force_master_ip
if self._force_master_ip is not None
else state["ip"]
)
return ip, state["port"]
return state["ip"], state["port"]
error_info = ""
if len(collected_errors) > 0:
@@ -341,8 +314,6 @@ class Sentinel(AsyncSentinelCommands):
):
"""
Returns a redis client instance for the ``service_name`` master.
Sentinel client will detect failover and reconnect Redis clients
automatically.
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
used to retrieve the master's address before establishing a new

View File

@@ -16,7 +16,7 @@ def from_url(url, **kwargs):
return Redis.from_url(url, **kwargs)
class pipeline: # noqa: N801
class pipeline:
def __init__(self, redis_obj: "Redis"):
self.p: "Pipeline" = redis_obj.pipeline()

View File

@@ -1,31 +0,0 @@
from typing import Iterable
class RequestTokenErr(Exception):
"""
Represents an exception during token request.
"""
def __init__(self, *args):
super().__init__(*args)
class InvalidTokenSchemaErr(Exception):
"""
Represents an exception related to invalid token schema.
"""
def __init__(self, missing_fields: Iterable[str] = []):
super().__init__(
"Unexpected token schema. Following fields are missing: "
+ ", ".join(missing_fields)
)
class TokenRenewalErr(Exception):
"""
Represents an exception during token renewal process.
"""
def __init__(self, *args):
super().__init__(*args)

View File

@@ -1,28 +0,0 @@
from abc import ABC, abstractmethod
from redis.auth.token import TokenInterface
"""
This interface is the facade of an identity provider
"""
class IdentityProviderInterface(ABC):
"""
Receive a token from the identity provider.
Receiving a token only works when being authenticated.
"""
@abstractmethod
def request_token(self, force_refresh=False) -> TokenInterface:
pass
class IdentityProviderConfigInterface(ABC):
"""
Configuration class that provides a configured identity provider.
"""
@abstractmethod
def get_provider(self) -> IdentityProviderInterface:
pass

View File

@@ -1,130 +0,0 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from redis.auth.err import InvalidTokenSchemaErr
class TokenInterface(ABC):
@abstractmethod
def is_expired(self) -> bool:
pass
@abstractmethod
def ttl(self) -> float:
pass
@abstractmethod
def try_get(self, key: str) -> str:
pass
@abstractmethod
def get_value(self) -> str:
pass
@abstractmethod
def get_expires_at_ms(self) -> float:
pass
@abstractmethod
def get_received_at_ms(self) -> float:
pass
class TokenResponse:
def __init__(self, token: TokenInterface):
self._token = token
def get_token(self) -> TokenInterface:
return self._token
def get_ttl_ms(self) -> float:
return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
class SimpleToken(TokenInterface):
def __init__(
self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
) -> None:
self.value = value
self.expires_at = expires_at_ms
self.received_at = received_at_ms
self.claims = claims
def ttl(self) -> float:
if self.expires_at == -1:
return -1
return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
def is_expired(self) -> bool:
if self.expires_at == -1:
return False
return self.ttl() <= 0
def try_get(self, key: str) -> str:
return self.claims.get(key)
def get_value(self) -> str:
return self.value
def get_expires_at_ms(self) -> float:
return self.expires_at
def get_received_at_ms(self) -> float:
return self.received_at
class JWToken(TokenInterface):
REQUIRED_FIELDS = {"exp"}
def __init__(self, token: str):
try:
import jwt
except ImportError as ie:
raise ImportError(
f"The PyJWT library is required for {self.__class__.__name__}.",
) from ie
self._value = token
self._decoded = jwt.decode(
self._value,
options={"verify_signature": False},
algorithms=[jwt.get_unverified_header(self._value).get("alg")],
)
self._validate_token()
def is_expired(self) -> bool:
exp = self._decoded["exp"]
if exp == -1:
return False
return (
self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
)
def ttl(self) -> float:
exp = self._decoded["exp"]
if exp == -1:
return -1
return (
self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
)
def try_get(self, key: str) -> str:
return self._decoded.get(key)
def get_value(self) -> str:
return self._value
def get_expires_at_ms(self) -> float:
return float(self._decoded["exp"] * 1000)
def get_received_at_ms(self) -> float:
return datetime.now(timezone.utc).timestamp() * 1000
def _validate_token(self):
actual_fields = {x for x in self._decoded.keys()}
if len(self.REQUIRED_FIELDS - actual_fields) != 0:
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)

View File

@@ -1,370 +0,0 @@
import asyncio
import logging
import threading
from datetime import datetime, timezone
from time import sleep
from typing import Any, Awaitable, Callable, Union
from redis.auth.err import RequestTokenErr, TokenRenewalErr
from redis.auth.idp import IdentityProviderInterface
from redis.auth.token import TokenResponse
logger = logging.getLogger(__name__)
class CredentialsListener:
"""
Listeners that will be notified on events related to credentials.
Accepts callbacks and awaitable callbacks.
"""
def __init__(self):
self._on_next = None
self._on_error = None
@property
def on_next(self) -> Union[Callable[[Any], None], Awaitable]:
return self._on_next
@on_next.setter
def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None:
self._on_next = callback
@property
def on_error(self) -> Union[Callable[[Exception], None], Awaitable]:
return self._on_error
@on_error.setter
def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None:
self._on_error = callback
class RetryPolicy:
def __init__(self, max_attempts: int, delay_in_ms: float):
self.max_attempts = max_attempts
self.delay_in_ms = delay_in_ms
def get_max_attempts(self) -> int:
"""
Retry attempts before exception will be thrown.
:return: int
"""
return self.max_attempts
def get_delay_in_ms(self) -> float:
"""
Delay between retries in seconds.
:return: int
"""
return self.delay_in_ms
class TokenManagerConfig:
def __init__(
self,
expiration_refresh_ratio: float,
lower_refresh_bound_millis: int,
token_request_execution_timeout_in_ms: int,
retry_policy: RetryPolicy,
):
self._expiration_refresh_ratio = expiration_refresh_ratio
self._lower_refresh_bound_millis = lower_refresh_bound_millis
self._token_request_execution_timeout_in_ms = (
token_request_execution_timeout_in_ms
)
self._retry_policy = retry_policy
def get_expiration_refresh_ratio(self) -> float:
"""
Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501
For example, a value of 0.75 means the token should be refreshed
when 75% of its lifetime has elapsed (or when 25% of its lifetime remains).
:return: float
"""
return self._expiration_refresh_ratio
def get_lower_refresh_bound_millis(self) -> int:
"""
Represents the minimum time in milliseconds before token expiration
to trigger a refresh, in milliseconds.
This value sets a fixed lower bound for when a token refresh should occur,
regardless of the token's total lifetime.
If set to 0 there will be no lower bound and the refresh will be triggered
based on the expirationRefreshRatio only.
:return: int
"""
return self._lower_refresh_bound_millis
def get_token_request_execution_timeout_in_ms(self) -> int:
"""
Represents the maximum time in milliseconds to wait
for a token request to complete.
:return: int
"""
return self._token_request_execution_timeout_in_ms
def get_retry_policy(self) -> RetryPolicy:
"""
Represents the retry policy for token requests.
:return: RetryPolicy
"""
return self._retry_policy
class TokenManager:
def __init__(
self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig
):
self._idp = identity_provider
self._config = config
self._next_timer = None
self._listener = None
self._init_timer = None
self._retries = 0
def __del__(self):
logger.info("Token manager are disposed")
self.stop()
def start(
self,
listener: CredentialsListener,
skip_initial: bool = False,
) -> Callable[[], None]:
self._listener = listener
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# Run loop in a separate thread to unblock main thread.
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=_start_event_loop_in_thread, args=(loop,), daemon=True
)
thread.start()
# Event to block for initial execution.
init_event = asyncio.Event()
self._init_timer = loop.call_later(
0, self._renew_token, skip_initial, init_event
)
logger.info("Token manager started")
# Blocks in thread-safe manner.
asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
return self.stop
async def start_async(
self,
listener: CredentialsListener,
block_for_initial: bool = False,
initial_delay_in_ms: float = 0,
skip_initial: bool = False,
) -> Callable[[], None]:
self._listener = listener
loop = asyncio.get_running_loop()
init_event = asyncio.Event()
# Wraps the async callback with async wrapper to schedule with loop.call_later()
wrapped = _async_to_sync_wrapper(
loop, self._renew_token_async, skip_initial, init_event
)
self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped)
logger.info("Token manager started")
if block_for_initial:
await init_event.wait()
return self.stop
def stop(self):
if self._init_timer is not None:
self._init_timer.cancel()
if self._next_timer is not None:
self._next_timer.cancel()
def acquire_token(self, force_refresh=False) -> TokenResponse:
try:
token = self._idp.request_token(force_refresh)
except RequestTokenErr as e:
if self._retries < self._config.get_retry_policy().get_max_attempts():
self._retries += 1
sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000)
return self.acquire_token(force_refresh)
else:
raise e
self._retries = 0
return TokenResponse(token)
async def acquire_token_async(self, force_refresh=False) -> TokenResponse:
try:
token = self._idp.request_token(force_refresh)
except RequestTokenErr as e:
if self._retries < self._config.get_retry_policy().get_max_attempts():
self._retries += 1
await asyncio.sleep(
self._config.get_retry_policy().get_delay_in_ms() / 1000
)
return await self.acquire_token_async(force_refresh)
else:
raise e
self._retries = 0
return TokenResponse(token)
def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float:
delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date)
delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date)
delay = min(delay_for_ratio_refresh, delay_for_lower_refresh)
return 0 if delay < 0 else delay / 1000
def _delay_for_lower_refresh(self, expire_date: float):
return (
expire_date
- self._config.get_lower_refresh_bound_millis()
- (datetime.now(timezone.utc).timestamp() * 1000)
)
def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
token_ttl = expire_date - issue_date
refresh_before = token_ttl - (
token_ttl * self._config.get_expiration_refresh_ratio()
)
return (
expire_date
- refresh_before
- (datetime.now(timezone.utc).timestamp() * 1000)
)
def _renew_token(
self, skip_initial: bool = False, init_event: asyncio.Event = None
):
"""
Task to renew token from identity provider.
Schedules renewal tasks based on token TTL.
"""
try:
token_res = self.acquire_token(force_refresh=True)
delay = self._calculate_renewal_delay(
token_res.get_token().get_expires_at_ms(),
token_res.get_token().get_received_at_ms(),
)
if token_res.get_token().is_expired():
raise TokenRenewalErr("Requested token is expired")
if self._listener.on_next is None:
logger.warning(
"No registered callback for token renewal task. Renewal cancelled"
)
return
if not skip_initial:
try:
self._listener.on_next(token_res.get_token())
except Exception as e:
raise TokenRenewalErr(e)
if delay <= 0:
return
loop = asyncio.get_running_loop()
self._next_timer = loop.call_later(delay, self._renew_token)
logger.info(f"Next token renewal scheduled in {delay} seconds")
return token_res
except Exception as e:
if self._listener.on_error is None:
raise e
self._listener.on_error(e)
finally:
if init_event:
init_event.set()
async def _renew_token_async(
self, skip_initial: bool = False, init_event: asyncio.Event = None
):
"""
Async task to renew tokens from identity provider.
Schedules renewal tasks based on token TTL.
"""
try:
token_res = await self.acquire_token_async(force_refresh=True)
delay = self._calculate_renewal_delay(
token_res.get_token().get_expires_at_ms(),
token_res.get_token().get_received_at_ms(),
)
if token_res.get_token().is_expired():
raise TokenRenewalErr("Requested token is expired")
if self._listener.on_next is None:
logger.warning(
"No registered callback for token renewal task. Renewal cancelled"
)
return
if not skip_initial:
try:
await self._listener.on_next(token_res.get_token())
except Exception as e:
raise TokenRenewalErr(e)
if delay <= 0:
return
loop = asyncio.get_running_loop()
wrapped = _async_to_sync_wrapper(loop, self._renew_token_async)
logger.info(f"Next token renewal scheduled in {delay} seconds")
loop.call_later(delay, wrapped)
except Exception as e:
if self._listener.on_error is None:
raise e
await self._listener.on_error(e)
finally:
if init_event:
init_event.set()
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
"""
Wraps an asynchronous function so it can be used with loop.call_later.
:param loop: The event loop in which the coroutine will be executed.
:param coro_func: The coroutine function to wrap.
:param args: Positional arguments to pass to the coroutine function.
:param kwargs: Keyword arguments to pass to the coroutine function.
:return: A regular function suitable for loop.call_later.
"""
def wrapped():
# Schedule the coroutine in the event loop
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
return wrapped
def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
"""
Starts event loop in a thread.
Used to be able to schedule tasks using loop.call_later.
:param event_loop:
:return:
"""
asyncio.set_event_loop(event_loop)
event_loop.run_forever()

View File

@@ -19,7 +19,7 @@ class AbstractBackoff(ABC):
pass
@abstractmethod
def compute(self, failures: int) -> float:
def compute(self, failures):
"""Compute backoff in seconds upon failure"""
pass
@@ -27,34 +27,25 @@ class AbstractBackoff(ABC):
class ConstantBackoff(AbstractBackoff):
"""Constant backoff upon failure"""
def __init__(self, backoff: float) -> None:
def __init__(self, backoff):
"""`backoff`: backoff time in seconds"""
self._backoff = backoff
def __hash__(self) -> int:
return hash((self._backoff,))
def __eq__(self, other) -> bool:
if not isinstance(other, ConstantBackoff):
return NotImplemented
return self._backoff == other._backoff
def compute(self, failures: int) -> float:
def compute(self, failures):
return self._backoff
class NoBackoff(ConstantBackoff):
"""No backoff upon failure"""
def __init__(self) -> None:
def __init__(self):
super().__init__(0)
class ExponentialBackoff(AbstractBackoff):
"""Exponential backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE):
def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE):
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
@@ -62,23 +53,14 @@ class ExponentialBackoff(AbstractBackoff):
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, ExponentialBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
def compute(self, failures):
return min(self._cap, self._base * 2**failures)
class FullJitterBackoff(AbstractBackoff):
"""Full jitter backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE):
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
@@ -86,23 +68,14 @@ class FullJitterBackoff(AbstractBackoff):
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, FullJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
def compute(self, failures):
return random.uniform(0, min(self._cap, self._base * 2**failures))
class EqualJitterBackoff(AbstractBackoff):
"""Equal jitter backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE):
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
@@ -110,16 +83,7 @@ class EqualJitterBackoff(AbstractBackoff):
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, EqualJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
def compute(self, failures):
temp = min(self._cap, self._base * 2**failures) / 2
return temp + random.uniform(0, temp)
@@ -127,7 +91,7 @@ class EqualJitterBackoff(AbstractBackoff):
class DecorrelatedJitterBackoff(AbstractBackoff):
"""Decorrelated jitter backoff upon failure"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE):
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
@@ -136,48 +100,15 @@ class DecorrelatedJitterBackoff(AbstractBackoff):
self._base = base
self._previous_backoff = 0
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, DecorrelatedJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def reset(self) -> None:
def reset(self):
self._previous_backoff = 0
def compute(self, failures: int) -> float:
def compute(self, failures):
max_backoff = max(self._base, self._previous_backoff * 3)
temp = random.uniform(self._base, max_backoff)
self._previous_backoff = min(self._cap, temp)
return self._previous_backoff
class ExponentialWithJitterBackoff(AbstractBackoff):
"""Exponential backoff upon failure, with jitter"""
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
"""
`cap`: maximum backoff time in seconds
`base`: base backoff time in seconds
"""
self._cap = cap
self._base = base
def __hash__(self) -> int:
return hash((self._base, self._cap))
def __eq__(self, other) -> bool:
if not isinstance(other, ExponentialWithJitterBackoff):
return NotImplemented
return self._base == other._base and self._cap == other._cap
def compute(self, failures: int) -> float:
return min(self._cap, random.random() * self._base * 2**failures)
def default_backoff():
return EqualJitterBackoff()

View File

@@ -1,401 +0,0 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Union
class CacheEntryStatus(Enum):
VALID = "VALID"
IN_PROGRESS = "IN_PROGRESS"
class EvictionPolicyType(Enum):
time_based = "time_based"
frequency_based = "frequency_based"
@dataclass(frozen=True)
class CacheKey:
command: str
redis_keys: tuple
class CacheEntry:
def __init__(
self,
cache_key: CacheKey,
cache_value: bytes,
status: CacheEntryStatus,
connection_ref,
):
self.cache_key = cache_key
self.cache_value = cache_value
self.status = status
self.connection_ref = connection_ref
def __hash__(self):
return hash(
(self.cache_key, self.cache_value, self.status, self.connection_ref)
)
def __eq__(self, other):
return hash(self) == hash(other)
class EvictionPolicyInterface(ABC):
@property
@abstractmethod
def cache(self):
pass
@cache.setter
def cache(self, value):
pass
@property
@abstractmethod
def type(self) -> EvictionPolicyType:
pass
@abstractmethod
def evict_next(self) -> CacheKey:
pass
@abstractmethod
def evict_many(self, count: int) -> List[CacheKey]:
pass
@abstractmethod
def touch(self, cache_key: CacheKey) -> None:
pass
class CacheConfigurationInterface(ABC):
@abstractmethod
def get_cache_class(self):
pass
@abstractmethod
def get_max_size(self) -> int:
pass
@abstractmethod
def get_eviction_policy(self):
pass
@abstractmethod
def is_exceeds_max_size(self, count: int) -> bool:
pass
@abstractmethod
def is_allowed_to_cache(self, command: str) -> bool:
pass
class CacheInterface(ABC):
@property
@abstractmethod
def collection(self) -> OrderedDict:
pass
@property
@abstractmethod
def config(self) -> CacheConfigurationInterface:
pass
@property
@abstractmethod
def eviction_policy(self) -> EvictionPolicyInterface:
pass
@property
@abstractmethod
def size(self) -> int:
pass
@abstractmethod
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
pass
@abstractmethod
def set(self, entry: CacheEntry) -> bool:
pass
@abstractmethod
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
pass
@abstractmethod
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
pass
@abstractmethod
def flush(self) -> int:
pass
@abstractmethod
def is_cachable(self, key: CacheKey) -> bool:
pass
class DefaultCache(CacheInterface):
def __init__(
self,
cache_config: CacheConfigurationInterface,
) -> None:
self._cache = OrderedDict()
self._cache_config = cache_config
self._eviction_policy = self._cache_config.get_eviction_policy().value()
self._eviction_policy.cache = self
@property
def collection(self) -> OrderedDict:
return self._cache
@property
def config(self) -> CacheConfigurationInterface:
return self._cache_config
@property
def eviction_policy(self) -> EvictionPolicyInterface:
return self._eviction_policy
@property
def size(self) -> int:
return len(self._cache)
def set(self, entry: CacheEntry) -> bool:
if not self.is_cachable(entry.cache_key):
return False
self._cache[entry.cache_key] = entry
self._eviction_policy.touch(entry.cache_key)
if self._cache_config.is_exceeds_max_size(len(self._cache)):
self._eviction_policy.evict_next()
return True
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
entry = self._cache.get(key, None)
if entry is None:
return None
self._eviction_policy.touch(key)
return entry
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
response = []
for key in cache_keys:
if self.get(key) is not None:
self._cache.pop(key)
response.append(True)
else:
response.append(False)
return response
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
response = []
keys_to_delete = []
for redis_key in redis_keys:
if isinstance(redis_key, bytes):
redis_key = redis_key.decode()
for cache_key in self._cache:
if redis_key in cache_key.redis_keys:
keys_to_delete.append(cache_key)
response.append(True)
for key in keys_to_delete:
self._cache.pop(key)
return response
def flush(self) -> int:
elem_count = len(self._cache)
self._cache.clear()
return elem_count
def is_cachable(self, key: CacheKey) -> bool:
return self._cache_config.is_allowed_to_cache(key.command)
class LRUPolicy(EvictionPolicyInterface):
def __init__(self):
self.cache = None
@property
def cache(self):
return self._cache
@cache.setter
def cache(self, cache: CacheInterface):
self._cache = cache
@property
def type(self) -> EvictionPolicyType:
return EvictionPolicyType.time_based
def evict_next(self) -> CacheKey:
self._assert_cache()
popped_entry = self._cache.collection.popitem(last=False)
return popped_entry[0]
def evict_many(self, count: int) -> List[CacheKey]:
self._assert_cache()
if count > len(self._cache.collection):
raise ValueError("Evictions count is above cache size")
popped_keys = []
for _ in range(count):
popped_entry = self._cache.collection.popitem(last=False)
popped_keys.append(popped_entry[0])
return popped_keys
def touch(self, cache_key: CacheKey) -> None:
self._assert_cache()
if self._cache.collection.get(cache_key) is None:
raise ValueError("Given entry does not belong to the cache")
self._cache.collection.move_to_end(cache_key)
def _assert_cache(self):
if self.cache is None or not isinstance(self.cache, CacheInterface):
raise ValueError("Eviction policy should be associated with valid cache.")
class EvictionPolicy(Enum):
LRU = LRUPolicy
class CacheConfig(CacheConfigurationInterface):
DEFAULT_CACHE_CLASS = DefaultCache
DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU
DEFAULT_MAX_SIZE = 10000
DEFAULT_ALLOW_LIST = [
"BITCOUNT",
"BITFIELD_RO",
"BITPOS",
"EXISTS",
"GEODIST",
"GEOHASH",
"GEOPOS",
"GEORADIUSBYMEMBER_RO",
"GEORADIUS_RO",
"GEOSEARCH",
"GET",
"GETBIT",
"GETRANGE",
"HEXISTS",
"HGET",
"HGETALL",
"HKEYS",
"HLEN",
"HMGET",
"HSTRLEN",
"HVALS",
"JSON.ARRINDEX",
"JSON.ARRLEN",
"JSON.GET",
"JSON.MGET",
"JSON.OBJKEYS",
"JSON.OBJLEN",
"JSON.RESP",
"JSON.STRLEN",
"JSON.TYPE",
"LCS",
"LINDEX",
"LLEN",
"LPOS",
"LRANGE",
"MGET",
"SCARD",
"SDIFF",
"SINTER",
"SINTERCARD",
"SISMEMBER",
"SMEMBERS",
"SMISMEMBER",
"SORT_RO",
"STRLEN",
"SUBSTR",
"SUNION",
"TS.GET",
"TS.INFO",
"TS.RANGE",
"TS.REVRANGE",
"TYPE",
"XLEN",
"XPENDING",
"XRANGE",
"XREAD",
"XREVRANGE",
"ZCARD",
"ZCOUNT",
"ZDIFF",
"ZINTER",
"ZINTERCARD",
"ZLEXCOUNT",
"ZMSCORE",
"ZRANGE",
"ZRANGEBYLEX",
"ZRANGEBYSCORE",
"ZRANK",
"ZREVRANGE",
"ZREVRANGEBYLEX",
"ZREVRANGEBYSCORE",
"ZREVRANK",
"ZSCORE",
"ZUNION",
]
def __init__(
self,
max_size: int = DEFAULT_MAX_SIZE,
cache_class: Any = DEFAULT_CACHE_CLASS,
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
):
self._cache_class = cache_class
self._max_size = max_size
self._eviction_policy = eviction_policy
def get_cache_class(self):
return self._cache_class
def get_max_size(self) -> int:
return self._max_size
def get_eviction_policy(self) -> EvictionPolicy:
return self._eviction_policy
def is_exceeds_max_size(self, count: int) -> bool:
return count > self._max_size
def is_allowed_to_cache(self, command: str) -> bool:
return command in self.DEFAULT_ALLOW_LIST
class CacheFactoryInterface(ABC):
@abstractmethod
def get_cache(self) -> CacheInterface:
pass
class CacheFactory(CacheFactoryInterface):
def __init__(self, cache_config: Optional[CacheConfig] = None):
self._config = cache_config
if self._config is None:
self._config = CacheConfig()
def get_cache(self) -> CacheInterface:
cache_class = self._config.get_cache_class()
return cache_class(cache_config=self._config)

421
venv/lib/python3.12/site-packages/redis/client.py Executable file → Normal file
View File

@@ -2,19 +2,9 @@ import copy
import re
import threading
import time
import warnings
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Set,
Type,
Union,
)
from typing import Any, Callable, Dict, List, Optional, Type, Union
from redis._parsers.encoders import Encoder
from redis._parsers.helpers import (
@@ -23,54 +13,33 @@ from redis._parsers.helpers import (
_RedisCallbacksRESP3,
bool_ok,
)
from redis.backoff import ExponentialWithJitterBackoff
from redis.cache import CacheConfig, CacheInterface
from redis.commands import (
CoreCommands,
RedisModuleCommands,
SentinelCommands,
list_or_args,
)
from redis.commands.core import Script
from redis.connection import (
AbstractConnection,
Connection,
ConnectionPool,
SSLConnection,
UnixDomainSocketConnection,
)
from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection
from redis.credentials import CredentialProvider
from redis.event import (
AfterPooledConnectionsInstantiationEvent,
AfterPubSubConnectionInstantiationEvent,
AfterSingleConnectionInstantiationEvent,
ClientType,
EventDispatcher,
)
from redis.exceptions import (
ConnectionError,
ExecAbortError,
PubSubError,
RedisError,
ResponseError,
TimeoutError,
WatchError,
)
from redis.lock import Lock
from redis.retry import Retry
from redis.utils import (
HIREDIS_AVAILABLE,
_set_info_logger,
deprecated_args,
get_lib_version,
safe_str,
str_if_bytes,
truncate_text,
)
if TYPE_CHECKING:
import ssl
import OpenSSL
SYM_EMPTY = b""
EMPTY_RESPONSE = "EMPTY_RESPONSE"
@@ -125,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
"""
@classmethod
def from_url(cls, url: str, **kwargs) -> "Redis":
def from_url(cls, url: str, **kwargs) -> None:
"""
Return a Redis client object configured from the given URL
@@ -191,80 +160,56 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
client.auto_close_connection_pool = True
return client
@deprecated_args(
args_to_warn=["retry_on_timeout"],
reason="TimeoutError is included by default.",
version="6.0.0",
)
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: Optional[str] = None,
socket_timeout: Optional[float] = None,
socket_connect_timeout: Optional[float] = None,
socket_keepalive: Optional[bool] = None,
socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
connection_pool: Optional[ConnectionPool] = None,
unix_socket_path: Optional[str] = None,
encoding: str = "utf-8",
encoding_errors: str = "strict",
decode_responses: bool = False,
retry_on_timeout: bool = False,
retry: Retry = Retry(
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
),
retry_on_error: Optional[List[Type[Exception]]] = None,
ssl: bool = False,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required",
ssl_ca_certs: Optional[str] = None,
ssl_ca_path: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = True,
ssl_password: Optional[str] = None,
ssl_validate_ocsp: bool = False,
ssl_validate_ocsp_stapled: bool = False,
ssl_ocsp_context: Optional["OpenSSL.SSL.Context"] = None,
ssl_ocsp_expected_cert: Optional[str] = None,
ssl_min_version: Optional["ssl.TLSVersion"] = None,
ssl_ciphers: Optional[str] = None,
max_connections: Optional[int] = None,
single_connection_client: bool = False,
health_check_interval: int = 0,
client_name: Optional[str] = None,
lib_name: Optional[str] = "redis-py",
lib_version: Optional[str] = get_lib_version(),
username: Optional[str] = None,
redis_connect_func: Optional[Callable[[], None]] = None,
host="localhost",
port=6379,
db=0,
password=None,
socket_timeout=None,
socket_connect_timeout=None,
socket_keepalive=None,
socket_keepalive_options=None,
connection_pool=None,
unix_socket_path=None,
encoding="utf-8",
encoding_errors="strict",
charset=None,
errors=None,
decode_responses=False,
retry_on_timeout=False,
retry_on_error=None,
ssl=False,
ssl_keyfile=None,
ssl_certfile=None,
ssl_cert_reqs="required",
ssl_ca_certs=None,
ssl_ca_path=None,
ssl_ca_data=None,
ssl_check_hostname=False,
ssl_password=None,
ssl_validate_ocsp=False,
ssl_validate_ocsp_stapled=False,
ssl_ocsp_context=None,
ssl_ocsp_expected_cert=None,
max_connections=None,
single_connection_client=False,
health_check_interval=0,
client_name=None,
lib_name="redis-py",
lib_version=get_lib_version(),
username=None,
retry=None,
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
cache: Optional[CacheInterface] = None,
cache_config: Optional[CacheConfig] = None,
event_dispatcher: Optional[EventDispatcher] = None,
) -> None:
"""
Initialize a new Redis client.
To specify a retry policy for specific errors, you have two options:
1. Set the `retry_on_error` to a list of the error/s to retry on, and
you can also set `retry` to a valid `Retry` object(in case the default
one is not appropriate) - with this approach the retries will be triggered
on the default errors specified in the Retry object enriched with the
errors specified in `retry_on_error`.
2. Define a `Retry` object with configured 'supported_errors' and set
it to the `retry` parameter - with this approach you completely redefine
the errors on which retries will happen.
`retry_on_timeout` is deprecated - please include the TimeoutError
either in the Retry object or in the `retry_on_error` list.
When 'connection_pool' is provided - the retry configuration of the
provided pool will be used.
To specify a retry policy for specific errors, first set
`retry_on_error` to a list of the error/s to retry on, then set
`retry` to a valid `Retry` object.
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
Args:
@@ -272,13 +217,25 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
if `True`, connection pool is not used. In that case `Redis`
instance use is not thread safe.
"""
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
if not connection_pool:
if charset is not None:
warnings.warn(
DeprecationWarning(
'"charset" is deprecated. Use "encoding" instead'
)
)
encoding = charset
if errors is not None:
warnings.warn(
DeprecationWarning(
'"errors" is deprecated. Use "encoding_errors" instead'
)
)
encoding_errors = errors
if not retry_on_error:
retry_on_error = []
if retry_on_timeout is True:
retry_on_error.append(TimeoutError)
kwargs = {
"db": db,
"username": username,
@@ -334,50 +291,17 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
"ssl_validate_ocsp": ssl_validate_ocsp,
"ssl_ocsp_context": ssl_ocsp_context,
"ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
"ssl_min_version": ssl_min_version,
"ssl_ciphers": ssl_ciphers,
}
)
if (cache_config or cache) and protocol in [3, "3"]:
kwargs.update(
{
"cache": cache,
"cache_config": cache_config,
}
)
connection_pool = ConnectionPool(**kwargs)
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.SYNC, credential_provider
)
)
self.auto_close_connection_pool = True
else:
self.auto_close_connection_pool = False
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.SYNC, credential_provider
)
)
self.connection_pool = connection_pool
if (cache_config or cache) and self.connection_pool.get_protocol() not in [
3,
"3",
]:
raise RedisError("Client caching is only supported with RESP version 3")
self.single_connection_lock = threading.RLock()
self.connection = None
self._single_connection_client = single_connection_client
if self._single_connection_client:
self.connection = self.connection_pool.get_connection()
self._event_dispatcher.dispatch(
AfterSingleConnectionInstantiationEvent(
self.connection, ClientType.SYNC, self.single_connection_lock
)
)
if single_connection_client:
self.connection = self.connection_pool.get_connection("_")
self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
@@ -387,10 +311,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
self.response_callbacks.update(_RedisCallbacksRESP2)
def __repr__(self) -> str:
return (
f"<{type(self).__module__}.{type(self).__name__}"
f"({repr(self.connection_pool)})>"
)
return f"{type(self).__name__}<{repr(self.connection_pool)}>"
def get_encoder(self) -> "Encoder":
"""Get the connection pool's encoder"""
@@ -400,10 +321,10 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
"""Get the connection's key-word arguments"""
return self.connection_pool.connection_kwargs
def get_retry(self) -> Optional[Retry]:
def get_retry(self) -> Optional["Retry"]:
return self.get_connection_kwargs().get("retry")
def set_retry(self, retry: Retry) -> None:
def set_retry(self, retry: "Retry") -> None:
self.get_connection_kwargs().update({"retry": retry})
self.connection_pool.set_retry(retry)
@@ -448,7 +369,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
def transaction(
self, func: Callable[["Pipeline"], None], *watches, **kwargs
) -> Union[List[Any], Any, None]:
) -> None:
"""
Convenience method for executing the callable `func` as a transaction
while watching all keys specified in `watches`. The 'func' callable
@@ -479,7 +400,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
blocking_timeout: Optional[float] = None,
lock_class: Union[None, Any] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
):
"""
Return a new Lock object using key ``name`` that mimics
@@ -526,11 +446,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
@@ -548,7 +463,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
raise_on_release_error=raise_on_release_error,
)
def pubsub(self, **kwargs):
@@ -557,9 +471,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
subscribe to channels and listen for messages that get published to
them.
"""
return PubSub(
self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
)
return PubSub(self.connection_pool, **kwargs)
def monitor(self):
return Monitor(self.connection_pool)
@@ -576,12 +488,9 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
self.close()
def __del__(self):
try:
self.close()
except Exception:
pass
self.close()
def close(self) -> None:
def close(self):
# In case a connection property does not yet exist
# (due to a crash earlier in the Redis() constructor), return
# immediately as there is nothing to clean-up.
@@ -600,44 +509,37 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
"""
Send a command and parse the response
"""
conn.send_command(*args, **options)
conn.send_command(*args)
return self.parse_response(conn, command_name, **options)
def _close_connection(self, conn) -> None:
def _disconnect_raise(self, conn, error):
"""
Close the connection before retrying.
The supported exceptions are already checked in the
retry object so we don't need to do it here.
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
Close the connection and raise an exception
if retry_on_error is not set or the error
is not one of the specified error types
"""
conn.disconnect()
if (
conn.retry_on_error is None
or isinstance(error, tuple(conn.retry_on_error)) is False
):
raise error
# COMMAND EXECUTION AND PROTOCOL PARSING
def execute_command(self, *args, **options):
return self._execute_command(*args, **options)
def _execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
pool = self.connection_pool
command_name = args[0]
conn = self.connection or pool.get_connection()
conn = self.connection or pool.get_connection(command_name, **options)
if self._single_connection_client:
self.single_connection_lock.acquire()
try:
return conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda _: self._close_connection(conn),
lambda error: self._disconnect_raise(conn, error),
)
finally:
if self._single_connection_client:
self.single_connection_lock.release()
if not self.connection:
pool.release(conn)
@@ -657,16 +559,10 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
if EMPTY_RESPONSE in options:
options.pop(EMPTY_RESPONSE)
# Remove keys entry, it needs only for cache.
options.pop("keys", None)
if command_name in self.response_callbacks:
return self.response_callbacks[command_name](response, **options)
return response
def get_cache(self) -> Optional[CacheInterface]:
return self.connection_pool.cache
StrictRedis = Redis
@@ -683,7 +579,7 @@ class Monitor:
def __init__(self, connection_pool):
self.connection_pool = connection_pool
self.connection = self.connection_pool.get_connection()
self.connection = self.connection_pool.get_connection("MONITOR")
def __enter__(self):
self.connection.send_command("MONITOR")
@@ -758,7 +654,6 @@ class PubSub:
ignore_subscribe_messages: bool = False,
encoder: Optional["Encoder"] = None,
push_handler_func: Union[None, Callable[[str], None]] = None,
event_dispatcher: Optional["EventDispatcher"] = None,
):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
@@ -769,12 +664,6 @@ class PubSub:
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
self.push_handler_func = push_handler_func
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self._lock = threading.RLock()
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
@@ -804,7 +693,7 @@ class PubSub:
def reset(self) -> None:
if self.connection:
self.connection.disconnect()
self.connection.deregister_connect_callback(self.on_connect)
self.connection._deregister_connect_callback(self.on_connect)
self.connection_pool.release(self.connection)
self.connection = None
self.health_check_response_counter = 0
@@ -857,23 +746,19 @@ class PubSub:
# subscribed to one or more channels
if self.connection is None:
self.connection = self.connection_pool.get_connection()
self.connection = self.connection_pool.get_connection(
"pubsub", self.shard_hint
)
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
if self.push_handler_func is not None:
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
self._event_dispatcher.dispatch(
AfterPubSubConnectionInstantiationEvent(
self.connection, self.connection_pool, ClientType.SYNC, self._lock
)
)
self.connection._register_connect_callback(self.on_connect)
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
self.clean_health_check_responses()
with self._lock:
self._execute(connection, connection.send_command, *args, **kwargs)
self._execute(connection, connection.send_command, *args, **kwargs)
def clean_health_check_responses(self) -> None:
"""
@@ -889,18 +774,19 @@ class PubSub:
else:
raise PubSubError(
"A non health check response was cleaned by "
"execute_command: {}".format(response)
"execute_command: {0}".format(response)
)
ttl -= 1
def _reconnect(self, conn) -> None:
def _disconnect_raise_connect(self, conn, error) -> None:
"""
The supported exceptions are already checked in the
retry object so we don't need to do it here.
In this error handler we are trying to reconnect to the server.
Close the connection and raise an exception
if retry_on_timeout is not set or the error
is not a TimeoutError. Otherwise, try to reconnect
"""
conn.disconnect()
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
raise error
conn.connect()
def _execute(self, conn, command, *args, **kwargs):
@@ -913,7 +799,7 @@ class PubSub:
"""
return conn.retry.call_with_retry(
lambda: command(*args, **kwargs),
lambda _: self._reconnect(conn),
lambda error: self._disconnect_raise_connect(conn, error),
)
def parse_response(self, block=True, timeout=0):
@@ -962,7 +848,7 @@ class PubSub:
"did you forget to call subscribe() or psubscribe()?"
)
if conn.health_check_interval and time.monotonic() > conn.next_health_check:
if conn.health_check_interval and time.time() > conn.next_health_check:
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
self.health_check_response_counter += 1
@@ -1112,12 +998,12 @@ class PubSub:
"""
if not self.subscribed:
# Wait for subscription
start_time = time.monotonic()
start_time = time.time()
if self.subscribed_event.wait(timeout) is True:
# The connection was subscribed during the timeout time frame.
# The timeout should be adjusted based on the time spent
# waiting for the subscription
time_spent = time.monotonic() - start_time
time_spent = time.time() - start_time
timeout = max(0.0, timeout - time_spent)
else:
# The connection isn't subscribed to any channels or patterns,
@@ -1214,7 +1100,7 @@ class PubSub:
def run_in_thread(
self,
sleep_time: float = 0.0,
sleep_time: int = 0,
daemon: bool = False,
exception_handler: Optional[Callable] = None,
) -> "PubSubWorkerThread":
@@ -1282,8 +1168,7 @@ class Pipeline(Redis):
in one transmission. This is convenient for batch processing, such as
saving all the values in a list to Redis.
All commands executed within a pipeline(when running in transactional mode,
which is the default behavior) are wrapped with MULTI and EXEC
All commands executed within a pipeline are wrapped with MULTI and EXEC
calls. This guarantees all commands executed in the pipeline will be
executed atomically.
@@ -1298,22 +1183,15 @@ class Pipeline(Redis):
UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
def __init__(
self,
connection_pool: ConnectionPool,
response_callbacks,
transaction,
shard_hint,
):
def __init__(self, connection_pool, response_callbacks, transaction, shard_hint):
self.connection_pool = connection_pool
self.connection: Optional[Connection] = None
self.connection = None
self.response_callbacks = response_callbacks
self.transaction = transaction
self.shard_hint = shard_hint
self.watching = False
self.command_stack = []
self.scripts: Set[Script] = set()
self.explicit_transaction = False
self.reset()
def __enter__(self) -> "Pipeline":
return self
@@ -1379,51 +1257,47 @@ class Pipeline(Redis):
return self.immediate_execute_command(*args, **kwargs)
return self.pipeline_execute_command(*args, **kwargs)
def _disconnect_reset_raise_on_watching(
self,
conn: AbstractConnection,
error: Exception,
) -> None:
def _disconnect_reset_raise(self, conn, error) -> None:
"""
Close the connection reset watching state and
raise an exception if we were watching.
The supported exceptions are already checked in the
retry object so we don't need to do it here.
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
Close the connection, reset watching state and
raise an exception if we were watching,
retry_on_timeout is not set,
or the error is not a TimeoutError
"""
conn.disconnect()
# if we were already watching a variable, the watch is no longer
# valid since this connection has died. raise a WatchError, which
# indicates the user should retry this transaction.
if self.watching:
self.reset()
raise WatchError(
f"A {type(error).__name__} occurred while watching one or more keys"
"A ConnectionError occurred on while watching one or more keys"
)
# if retry_on_timeout is not set, or the error is not
# a TimeoutError, raise it
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
self.reset()
raise
def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on the supported
errors for retry if we're already WATCHing a variable.
Used when issuing WATCH or subsequent commands retrieving their values but before
Execute a command immediately, but don't auto-retry on a
ConnectionError if we're already WATCHing a variable. Used when
issuing WATCH or subsequent commands retrieving their values but before
MULTI is called.
"""
command_name = args[0]
conn = self.connection
# if this is the first call, we need a connection
if not conn:
conn = self.connection_pool.get_connection()
conn = self.connection_pool.get_connection(command_name, self.shard_hint)
self.connection = conn
return conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_reset_raise_on_watching(conn, error),
lambda error: self._disconnect_reset_raise(conn, error),
)
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
@@ -1441,9 +1315,7 @@ class Pipeline(Redis):
self.command_stack.append((args, options))
return self
def _execute_transaction(
self, connection: Connection, commands, raise_on_error
) -> List:
def _execute_transaction(self, connection, commands, raise_on_error) -> List:
cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})])
all_cmds = connection.pack_commands(
[args for args, options in cmds if EMPTY_RESPONSE not in options]
@@ -1504,8 +1376,6 @@ class Pipeline(Redis):
for r, cmd in zip(response, commands):
if not isinstance(r, Exception):
args, options = cmd
# Remove keys entry, it needs only for cache.
options.pop("keys", None)
command_name = args[0]
if command_name in self.response_callbacks:
r = self.response_callbacks[command_name](r, **options)
@@ -1537,7 +1407,7 @@ class Pipeline(Redis):
def annotate_exception(self, exception, number, command):
cmd = " ".join(map(safe_str, command))
msg = (
f"Command # {number} ({truncate_text(cmd)}) of pipeline "
f"Command # {number} ({cmd}) of pipeline "
f"caused error: {exception.args[0]}"
)
exception.args = (msg,) + exception.args[1:]
@@ -1563,19 +1433,11 @@ class Pipeline(Redis):
if not exist:
s.sha = immediate("SCRIPT LOAD", s.script)
def _disconnect_raise_on_watching(
self,
conn: AbstractConnection,
error: Exception,
) -> None:
def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None:
"""
Close the connection, raise an exception if we were watching.
The supported exceptions are already checked in the
retry object so we don't need to do it here.
After we disconnect the connection, it will try to reconnect and
do a health check as part of the send_command logic(on connection level).
Close the connection, raise an exception if we were watching,
and raise an exception if TimeoutError is not part of retry_on_error,
or the error is not a TimeoutError
"""
conn.disconnect()
# if we were watching a variable, the watch is no longer valid
@@ -1583,10 +1445,17 @@ class Pipeline(Redis):
# indicates the user should retry this transaction.
if self.watching:
raise WatchError(
f"A {type(error).__name__} occurred while watching one or more keys"
"A ConnectionError occurred on while watching one or more keys"
)
# if TimeoutError is not part of retry_on_error, or the error
# is not a TimeoutError, raise it
if not (
TimeoutError in conn.retry_on_error and isinstance(error, TimeoutError)
):
self.reset()
raise error
def execute(self, raise_on_error: bool = True) -> List[Any]:
def execute(self, raise_on_error=True):
"""Execute all the commands in the current pipeline"""
stack = self.command_stack
if not stack and not self.watching:
@@ -1600,7 +1469,7 @@ class Pipeline(Redis):
conn = self.connection
if not conn:
conn = self.connection_pool.get_connection()
conn = self.connection_pool.get_connection("MULTI", self.shard_hint)
# assign to self.connection so reset() releases the connection
# back to the pool after we're done
self.connection = conn
@@ -1608,7 +1477,7 @@ class Pipeline(Redis):
try:
return conn.retry.call_with_retry(
lambda: execute(conn, stack, raise_on_error),
lambda error: self._disconnect_raise_on_watching(conn, error),
lambda error: self._disconnect_raise_reset(conn, error),
)
finally:
self.reset()

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@ from .commands import * # noqa
from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo
class AbstractBloom:
class AbstractBloom(object):
"""
The client allows to interact with RedisBloom and use all of
it's functionality.

View File

@@ -1,5 +1,6 @@
from redis.client import NEVER_DECODE
from redis.utils import deprecated_function
from redis.exceptions import ModuleError
from redis.utils import HIREDIS_AVAILABLE, deprecated_function
BF_RESERVE = "BF.RESERVE"
BF_ADD = "BF.ADD"
@@ -138,6 +139,9 @@ class BFCommands:
This command will return successive (iter, data) pairs until (0, NULL) to indicate completion.
For more information see `BF.SCANDUMP <https://redis.io/commands/bf.scandump>`_.
""" # noqa
if HIREDIS_AVAILABLE:
raise ModuleError("This command cannot be used when hiredis is available.")
params = [key, iter]
options = {}
options[NEVER_DECODE] = []

View File

@@ -1,7 +1,7 @@
from ..helpers import nativestr
class BFInfo:
class BFInfo(object):
capacity = None
size = None
filterNum = None
@@ -26,7 +26,7 @@ class BFInfo:
return getattr(self, item)
class CFInfo:
class CFInfo(object):
size = None
bucketNum = None
filterNum = None
@@ -57,7 +57,7 @@ class CFInfo:
return getattr(self, item)
class CMSInfo:
class CMSInfo(object):
width = None
depth = None
count = None
@@ -72,7 +72,7 @@ class CMSInfo:
return getattr(self, item)
class TopKInfo:
class TopKInfo(object):
k = None
width = None
depth = None
@@ -89,7 +89,7 @@ class TopKInfo:
return getattr(self, item)
class TDigestInfo:
class TDigestInfo(object):
compression = None
capacity = None
merged_nodes = None

View File

@@ -7,13 +7,13 @@ from typing import (
Iterable,
Iterator,
List,
Literal,
Mapping,
NoReturn,
Optional,
Union,
)
from redis.compat import Literal
from redis.crc import key_slot
from redis.exceptions import RedisClusterException, RedisError
from redis.typing import (
@@ -23,7 +23,6 @@ from redis.typing import (
KeysT,
KeyT,
PatternT,
ResponseT,
)
from .core import (
@@ -31,18 +30,21 @@ from .core import (
AsyncACLCommands,
AsyncDataAccessCommands,
AsyncFunctionCommands,
AsyncGearsCommands,
AsyncManagementCommands,
AsyncModuleCommands,
AsyncScriptCommands,
DataAccessCommands,
FunctionCommands,
GearsCommands,
ManagementCommands,
ModuleCommands,
PubSubCommands,
ResponseT,
ScriptCommands,
)
from .helpers import list_or_args
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
from .redismodules import RedisModuleCommands
if TYPE_CHECKING:
from redis.asyncio.cluster import TargetNodesT
@@ -223,7 +225,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
The keys are first split up into slots
and then an DEL command is sent for every slot
Non-existent keys are ignored.
Non-existant keys are ignored.
Returns the number of keys that were deleted.
For more information see https://redis.io/commands/del
@@ -238,7 +240,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
The keys are first split up into slots
and then an TOUCH command is sent for every slot
Non-existent keys are ignored.
Non-existant keys are ignored.
Returns the number of keys that were touched.
For more information see https://redis.io/commands/touch
@@ -252,7 +254,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
The keys are first split up into slots
and then an TOUCH command is sent for every slot
Non-existent keys are ignored.
Non-existant keys are ignored.
Returns the number of keys that were unlinked.
For more information see https://redis.io/commands/unlink
@@ -593,7 +595,7 @@ class ClusterManagementCommands(ManagementCommands):
"CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node
)
elif state.upper() == "STABLE":
raise RedisError('For "stable" state please use cluster_setslot_stable')
raise RedisError('For "stable" state please use ' "cluster_setslot_stable")
else:
raise RedisError(f"Invalid slot state: {state}")
@@ -691,6 +693,12 @@ class ClusterManagementCommands(ManagementCommands):
self.read_from_replicas = False
return self.execute_command("READWRITE", target_nodes=target_nodes)
def gears_refresh_cluster(self, **kwargs) -> ResponseT:
"""
On an OSS cluster, before executing any gears function, you must call this command. # noqa
"""
return self.execute_command("REDISGEARS_2.REFRESHCLUSTER", **kwargs)
class AsyncClusterManagementCommands(
ClusterManagementCommands, AsyncManagementCommands
@@ -866,6 +874,7 @@ class RedisClusterCommands(
ClusterDataAccessCommands,
ScriptCommands,
FunctionCommands,
GearsCommands,
ModuleCommands,
RedisModuleCommands,
):
@@ -896,8 +905,8 @@ class AsyncRedisClusterCommands(
AsyncClusterDataAccessCommands,
AsyncScriptCommands,
AsyncFunctionCommands,
AsyncGearsCommands,
AsyncModuleCommands,
AsyncRedisModuleCommands,
):
"""
A class for all Redis Cluster commands

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,263 @@
import warnings
from ..helpers import quote_string, random_string, stringify_param_value
from .commands import AsyncGraphCommands, GraphCommands
from .edge import Edge # noqa
from .node import Node # noqa
from .path import Path # noqa
DB_LABELS = "DB.LABELS"
DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES"
DB_PROPERTYKEYS = "DB.PROPERTYKEYS"
class Graph(GraphCommands):
"""
Graph, collection of nodes and edges.
"""
def __init__(self, client, name=random_string()):
"""
Create a new graph.
"""
warnings.warn(
DeprecationWarning(
"RedisGraph support is deprecated as of Redis Stack 7.2 \
(https://redis.com/blog/redisgraph-eol/)"
)
)
self.NAME = name # Graph key
self.client = client
self.execute_command = client.execute_command
self.nodes = {}
self.edges = []
self._labels = [] # List of node labels.
self._properties = [] # List of properties.
self._relationship_types = [] # List of relation types.
self.version = 0 # Graph version
@property
def name(self):
return self.NAME
def _clear_schema(self):
self._labels = []
self._properties = []
self._relationship_types = []
def _refresh_schema(self):
self._clear_schema()
self._refresh_labels()
self._refresh_relations()
self._refresh_attributes()
def _refresh_labels(self):
lbls = self.labels()
# Unpack data.
self._labels = [l[0] for _, l in enumerate(lbls)]
def _refresh_relations(self):
rels = self.relationship_types()
# Unpack data.
self._relationship_types = [r[0] for _, r in enumerate(rels)]
def _refresh_attributes(self):
props = self.property_keys()
# Unpack data.
self._properties = [p[0] for _, p in enumerate(props)]
def get_label(self, idx):
"""
Returns a label by it's index
Args:
idx:
The index of the label
"""
try:
label = self._labels[idx]
except IndexError:
# Refresh labels.
self._refresh_labels()
label = self._labels[idx]
return label
def get_relation(self, idx):
"""
Returns a relationship type by it's index
Args:
idx:
The index of the relation
"""
try:
relationship_type = self._relationship_types[idx]
except IndexError:
# Refresh relationship types.
self._refresh_relations()
relationship_type = self._relationship_types[idx]
return relationship_type
def get_property(self, idx):
"""
Returns a property by it's index
Args:
idx:
The index of the property
"""
try:
p = self._properties[idx]
except IndexError:
# Refresh properties.
self._refresh_attributes()
p = self._properties[idx]
return p
def add_node(self, node):
"""
Adds a node to the graph.
"""
if node.alias is None:
node.alias = random_string()
self.nodes[node.alias] = node
def add_edge(self, edge):
"""
Adds an edge to the graph.
"""
if not (self.nodes[edge.src_node.alias] and self.nodes[edge.dest_node.alias]):
raise AssertionError("Both edge's end must be in the graph")
self.edges.append(edge)
def _build_params_header(self, params):
if params is None:
return ""
if not isinstance(params, dict):
raise TypeError("'params' must be a dict")
# Header starts with "CYPHER"
params_header = "CYPHER "
for key, value in params.items():
params_header += str(key) + "=" + stringify_param_value(value) + " "
return params_header
# Procedures.
def call_procedure(self, procedure, *args, read_only=False, **kwagrs):
args = [quote_string(arg) for arg in args]
q = f"CALL {procedure}({','.join(args)})"
y = kwagrs.get("y", None)
if y is not None:
q += f"YIELD {','.join(y)}"
return self.query(q, read_only=read_only)
def labels(self):
return self.call_procedure(DB_LABELS, read_only=True).result_set
def relationship_types(self):
return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set
def property_keys(self):
return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set
class AsyncGraph(Graph, AsyncGraphCommands):
"""Async version for Graph"""
async def _refresh_labels(self):
lbls = await self.labels()
# Unpack data.
self._labels = [l[0] for _, l in enumerate(lbls)]
async def _refresh_attributes(self):
props = await self.property_keys()
# Unpack data.
self._properties = [p[0] for _, p in enumerate(props)]
async def _refresh_relations(self):
rels = await self.relationship_types()
# Unpack data.
self._relationship_types = [r[0] for _, r in enumerate(rels)]
async def get_label(self, idx):
"""
Returns a label by it's index
Args:
idx:
The index of the label
"""
try:
label = self._labels[idx]
except IndexError:
# Refresh labels.
await self._refresh_labels()
label = self._labels[idx]
return label
async def get_property(self, idx):
"""
Returns a property by it's index
Args:
idx:
The index of the property
"""
try:
p = self._properties[idx]
except IndexError:
# Refresh properties.
await self._refresh_attributes()
p = self._properties[idx]
return p
async def get_relation(self, idx):
"""
Returns a relationship type by it's index
Args:
idx:
The index of the relation
"""
try:
relationship_type = self._relationship_types[idx]
except IndexError:
# Refresh relationship types.
await self._refresh_relations()
relationship_type = self._relationship_types[idx]
return relationship_type
async def call_procedure(self, procedure, *args, read_only=False, **kwagrs):
args = [quote_string(arg) for arg in args]
q = f"CALL {procedure}({','.join(args)})"
y = kwagrs.get("y", None)
if y is not None:
f"YIELD {','.join(y)}"
return await self.query(q, read_only=read_only)
async def labels(self):
return ((await self.call_procedure(DB_LABELS, read_only=True))).result_set
async def property_keys(self):
return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set
async def relationship_types(self):
return (
await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True)
).result_set

View File

@@ -0,0 +1,313 @@
from redis import DataError
from redis.exceptions import ResponseError
from .exceptions import VersionMismatchException
from .execution_plan import ExecutionPlan
from .query_result import AsyncQueryResult, QueryResult
PROFILE_CMD = "GRAPH.PROFILE"
RO_QUERY_CMD = "GRAPH.RO_QUERY"
QUERY_CMD = "GRAPH.QUERY"
DELETE_CMD = "GRAPH.DELETE"
SLOWLOG_CMD = "GRAPH.SLOWLOG"
CONFIG_CMD = "GRAPH.CONFIG"
LIST_CMD = "GRAPH.LIST"
EXPLAIN_CMD = "GRAPH.EXPLAIN"
class GraphCommands:
"""RedisGraph Commands"""
def commit(self):
"""
Create entire graph.
"""
if len(self.nodes) == 0 and len(self.edges) == 0:
return None
query = "CREATE "
for _, node in self.nodes.items():
query += str(node) + ","
query += ",".join([str(edge) for edge in self.edges])
# Discard leading comma.
if query[-1] == ",":
query = query[:-1]
return self.query(query)
def query(self, q, params=None, timeout=None, read_only=False, profile=False):
"""
Executes a query against the graph.
For more information see `GRAPH.QUERY <https://redis.io/commands/graph.query>`_. # noqa
Args:
q : str
The query.
params : dict
Query parameters.
timeout : int
Maximum runtime for read queries in milliseconds.
read_only : bool
Executes a readonly query if set to True.
profile : bool
Return details on results produced by and time
spent in each operation.
"""
# maintain original 'q'
query = q
# handle query parameters
query = self._build_params_header(params) + query
# construct query command
# ask for compact result-set format
# specify known graph version
if profile:
cmd = PROFILE_CMD
else:
cmd = RO_QUERY_CMD if read_only else QUERY_CMD
command = [cmd, self.name, query, "--compact"]
# include timeout is specified
if isinstance(timeout, int):
command.extend(["timeout", timeout])
elif timeout is not None:
raise Exception("Timeout argument must be a positive integer")
# issue query
try:
response = self.execute_command(*command)
return QueryResult(self, response, profile)
except ResponseError as e:
if "unknown command" in str(e) and read_only:
# `GRAPH.RO_QUERY` is unavailable in older versions.
return self.query(q, params, timeout, read_only=False)
raise e
except VersionMismatchException as e:
# client view over the graph schema is out of sync
# set client version and refresh local schema
self.version = e.version
self._refresh_schema()
# re-issue query
return self.query(q, params, timeout, read_only)
def merge(self, pattern):
"""
Merge pattern.
"""
query = "MERGE "
query += str(pattern)
return self.query(query)
def delete(self):
"""
Deletes graph.
For more information see `DELETE <https://redis.io/commands/graph.delete>`_. # noqa
"""
self._clear_schema()
return self.execute_command(DELETE_CMD, self.name)
# declared here, to override the built in redis.db.flush()
def flush(self):
"""
Commit the graph and reset the edges and the nodes to zero length.
"""
self.commit()
self.nodes = {}
self.edges = []
def bulk(self, **kwargs):
"""Internal only. Not supported."""
raise NotImplementedError(
"GRAPH.BULK is internal only. "
"Use https://github.com/redisgraph/redisgraph-bulk-loader."
)
def profile(self, query):
"""
Execute a query and produce an execution plan augmented with metrics
for each operation's execution. Return a string representation of a
query execution plan, with details on results produced by and time
spent in each operation.
For more information see `GRAPH.PROFILE <https://redis.io/commands/graph.profile>`_. # noqa
"""
return self.query(query, profile=True)
def slowlog(self):
"""
Get a list containing up to 10 of the slowest queries issued
against the given graph ID.
For more information see `GRAPH.SLOWLOG <https://redis.io/commands/graph.slowlog>`_. # noqa
Each item in the list has the following structure:
1. A unix timestamp at which the log entry was processed.
2. The issued command.
3. The issued query.
4. The amount of time needed for its execution, in milliseconds.
"""
return self.execute_command(SLOWLOG_CMD, self.name)
def config(self, name, value=None, set=False):
"""
Retrieve or update a RedisGraph configuration.
For more information see `https://redis.io/commands/graph.config-get/>`_. # noqa
Args:
name : str
The name of the configuration
value :
The value we want to set (can be used only when `set` is on)
set : bool
Turn on to set a configuration. Default behavior is get.
"""
params = ["SET" if set else "GET", name]
if value is not None:
if set:
params.append(value)
else:
raise DataError(
"``value`` can be provided only when ``set`` is True"
) # noqa
return self.execute_command(CONFIG_CMD, *params)
def list_keys(self):
"""
Lists all graph keys in the keyspace.
For more information see `GRAPH.LIST <https://redis.io/commands/graph.list>`_. # noqa
"""
return self.execute_command(LIST_CMD)
def execution_plan(self, query, params=None):
"""
Get the execution plan for given query,
GRAPH.EXPLAIN returns an array of operations.
Args:
query: the query that will be executed
params: query parameters
"""
query = self._build_params_header(params) + query
plan = self.execute_command(EXPLAIN_CMD, self.name, query)
if isinstance(plan[0], bytes):
plan = [b.decode() for b in plan]
return "\n".join(plan)
def explain(self, query, params=None):
"""
Get the execution plan for given query,
GRAPH.EXPLAIN returns ExecutionPlan object.
For more information see `GRAPH.EXPLAIN <https://redis.io/commands/graph.explain>`_. # noqa
Args:
query: the query that will be executed
params: query parameters
"""
query = self._build_params_header(params) + query
plan = self.execute_command(EXPLAIN_CMD, self.name, query)
return ExecutionPlan(plan)
class AsyncGraphCommands(GraphCommands):
async def query(self, q, params=None, timeout=None, read_only=False, profile=False):
"""
Executes a query against the graph.
For more information see `GRAPH.QUERY <https://oss.redis.com/redisgraph/master/commands/#graphquery>`_. # noqa
Args:
q : str
The query.
params : dict
Query parameters.
timeout : int
Maximum runtime for read queries in milliseconds.
read_only : bool
Executes a readonly query if set to True.
profile : bool
Return details on results produced by and time
spent in each operation.
"""
# maintain original 'q'
query = q
# handle query parameters
query = self._build_params_header(params) + query
# construct query command
# ask for compact result-set format
# specify known graph version
if profile:
cmd = PROFILE_CMD
else:
cmd = RO_QUERY_CMD if read_only else QUERY_CMD
command = [cmd, self.name, query, "--compact"]
# include timeout is specified
if isinstance(timeout, int):
command.extend(["timeout", timeout])
elif timeout is not None:
raise Exception("Timeout argument must be a positive integer")
# issue query
try:
response = await self.execute_command(*command)
return await AsyncQueryResult().initialize(self, response, profile)
except ResponseError as e:
if "unknown command" in str(e) and read_only:
# `GRAPH.RO_QUERY` is unavailable in older versions.
return await self.query(q, params, timeout, read_only=False)
raise e
except VersionMismatchException as e:
# client view over the graph schema is out of sync
# set client version and refresh local schema
self.version = e.version
self._refresh_schema()
# re-issue query
return await self.query(q, params, timeout, read_only)
async def execution_plan(self, query, params=None):
"""
Get the execution plan for given query,
GRAPH.EXPLAIN returns an array of operations.
Args:
query: the query that will be executed
params: query parameters
"""
query = self._build_params_header(params) + query
plan = await self.execute_command(EXPLAIN_CMD, self.name, query)
if isinstance(plan[0], bytes):
plan = [b.decode() for b in plan]
return "\n".join(plan)
async def explain(self, query, params=None):
"""
Get the execution plan for given query,
GRAPH.EXPLAIN returns ExecutionPlan object.
Args:
query: the query that will be executed
params: query parameters
"""
query = self._build_params_header(params) + query
plan = await self.execute_command(EXPLAIN_CMD, self.name, query)
return ExecutionPlan(plan)
async def flush(self):
"""
Commit the graph and reset the edges and the nodes to zero length.
"""
await self.commit()
self.nodes = {}
self.edges = []

View File

@@ -0,0 +1,91 @@
from ..helpers import quote_string
from .node import Node
class Edge:
"""
An edge connecting two nodes.
"""
def __init__(self, src_node, relation, dest_node, edge_id=None, properties=None):
"""
Create a new edge.
"""
if src_node is None or dest_node is None:
# NOTE(bors-42): It makes sense to change AssertionError to
# ValueError here
raise AssertionError("Both src_node & dest_node must be provided")
self.id = edge_id
self.relation = relation or ""
self.properties = properties or {}
self.src_node = src_node
self.dest_node = dest_node
def to_string(self):
res = ""
if self.properties:
props = ",".join(
key + ":" + str(quote_string(val))
for key, val in sorted(self.properties.items())
)
res += "{" + props + "}"
return res
def __str__(self):
# Source node.
if isinstance(self.src_node, Node):
res = str(self.src_node)
else:
res = "()"
# Edge
res += "-["
if self.relation:
res += ":" + self.relation
if self.properties:
props = ",".join(
key + ":" + str(quote_string(val))
for key, val in sorted(self.properties.items())
)
res += "{" + props + "}"
res += "]->"
# Dest node.
if isinstance(self.dest_node, Node):
res += str(self.dest_node)
else:
res += "()"
return res
def __eq__(self, rhs):
# Type checking
if not isinstance(rhs, Edge):
return False
# Quick positive check, if both IDs are set.
if self.id is not None and rhs.id is not None and self.id == rhs.id:
return True
# Source and destination nodes should match.
if self.src_node != rhs.src_node:
return False
if self.dest_node != rhs.dest_node:
return False
# Relation should match.
if self.relation != rhs.relation:
return False
# Quick check for number of properties.
if len(self.properties) != len(rhs.properties):
return False
# Compare properties.
if self.properties != rhs.properties:
return False
return True

View File

@@ -0,0 +1,3 @@
class VersionMismatchException(Exception):
def __init__(self, version):
self.version = version

View File

@@ -0,0 +1,211 @@
import re
class ProfileStats:
"""
ProfileStats, runtime execution statistics of operation.
"""
def __init__(self, records_produced, execution_time):
self.records_produced = records_produced
self.execution_time = execution_time
class Operation:
"""
Operation, single operation within execution plan.
"""
def __init__(self, name, args=None, profile_stats=None):
"""
Create a new operation.
Args:
name: string that represents the name of the operation
args: operation arguments
profile_stats: profile statistics
"""
self.name = name
self.args = args
self.profile_stats = profile_stats
self.children = []
def append_child(self, child):
if not isinstance(child, Operation) or self is child:
raise Exception("child must be Operation")
self.children.append(child)
return self
def child_count(self):
return len(self.children)
def __eq__(self, o: object) -> bool:
if not isinstance(o, Operation):
return False
return self.name == o.name and self.args == o.args
def __str__(self) -> str:
args_str = "" if self.args is None else " | " + self.args
return f"{self.name}{args_str}"
class ExecutionPlan:
"""
ExecutionPlan, collection of operations.
"""
def __init__(self, plan):
"""
Create a new execution plan.
Args:
plan: array of strings that represents the collection operations
the output from GRAPH.EXPLAIN
"""
if not isinstance(plan, list):
raise Exception("plan must be an array")
if isinstance(plan[0], bytes):
plan = [b.decode() for b in plan]
self.plan = plan
self.structured_plan = self._operation_tree()
def _compare_operations(self, root_a, root_b):
"""
Compare execution plan operation tree
Return: True if operation trees are equal, False otherwise
"""
# compare current root
if root_a != root_b:
return False
# make sure root have the same number of children
if root_a.child_count() != root_b.child_count():
return False
# recursively compare children
for i in range(root_a.child_count()):
if not self._compare_operations(root_a.children[i], root_b.children[i]):
return False
return True
def __str__(self) -> str:
def aggraget_str(str_children):
return "\n".join(
[
" " + line
for str_child in str_children
for line in str_child.splitlines()
]
)
def combine_str(x, y):
return f"{x}\n{y}"
return self._operation_traverse(
self.structured_plan, str, aggraget_str, combine_str
)
def __eq__(self, o: object) -> bool:
"""Compares two execution plans
Return: True if the two plans are equal False otherwise
"""
# make sure 'o' is an execution-plan
if not isinstance(o, ExecutionPlan):
return False
# get root for both plans
root_a = self.structured_plan
root_b = o.structured_plan
# compare execution trees
return self._compare_operations(root_a, root_b)
def _operation_traverse(self, op, op_f, aggregate_f, combine_f):
"""
Traverse operation tree recursively applying functions
Args:
op: operation to traverse
op_f: function applied for each operation
aggregate_f: aggregation function applied for all children of a single operation
combine_f: combine function applied for the operation result and the children result
""" # noqa
# apply op_f for each operation
op_res = op_f(op)
if len(op.children) == 0:
return op_res # no children return
else:
# apply _operation_traverse recursively
children = [
self._operation_traverse(child, op_f, aggregate_f, combine_f)
for child in op.children
]
# combine the operation result with the children aggregated result
return combine_f(op_res, aggregate_f(children))
def _operation_tree(self):
"""Build the operation tree from the string representation"""
# initial state
i = 0
level = 0
stack = []
current = None
def _create_operation(args):
profile_stats = None
name = args[0].strip()
args.pop(0)
if len(args) > 0 and "Records produced" in args[-1]:
records_produced = int(
re.search("Records produced: (\\d+)", args[-1]).group(1)
)
execution_time = float(
re.search("Execution time: (\\d+.\\d+) ms", args[-1]).group(1)
)
profile_stats = ProfileStats(records_produced, execution_time)
args.pop(-1)
return Operation(
name, None if len(args) == 0 else args[0].strip(), profile_stats
)
# iterate plan operations
while i < len(self.plan):
current_op = self.plan[i]
op_level = current_op.count(" ")
if op_level == level:
# if the operation level equal to the current level
# set the current operation and move next
child = _create_operation(current_op.split("|"))
if current:
current = stack.pop()
current.append_child(child)
current = child
i += 1
elif op_level == level + 1:
# if the operation is child of the current operation
# add it as child and set as current operation
child = _create_operation(current_op.split("|"))
current.append_child(child)
stack.append(current)
current = child
level += 1
i += 1
elif op_level < level:
# if the operation is not child of current operation
# go back to it's parent operation
levels_back = level - op_level + 1
for _ in range(levels_back):
current = stack.pop()
level -= levels_back
else:
raise Exception("corrupted plan")
return stack[0]

View File

@@ -0,0 +1,88 @@
from ..helpers import quote_string
class Node:
"""
A node within the graph.
"""
def __init__(self, node_id=None, alias=None, label=None, properties=None):
"""
Create a new node.
"""
self.id = node_id
self.alias = alias
if isinstance(label, list):
label = [inner_label for inner_label in label if inner_label != ""]
if (
label is None
or label == ""
or (isinstance(label, list) and len(label) == 0)
):
self.label = None
self.labels = None
elif isinstance(label, str):
self.label = label
self.labels = [label]
elif isinstance(label, list) and all(
[isinstance(inner_label, str) for inner_label in label]
):
self.label = label[0]
self.labels = label
else:
raise AssertionError(
"label should be either None, string or a list of strings"
)
self.properties = properties or {}
def to_string(self):
res = ""
if self.properties:
props = ",".join(
key + ":" + str(quote_string(val))
for key, val in sorted(self.properties.items())
)
res += "{" + props + "}"
return res
def __str__(self):
res = "("
if self.alias:
res += self.alias
if self.labels:
res += ":" + ":".join(self.labels)
if self.properties:
props = ",".join(
key + ":" + str(quote_string(val))
for key, val in sorted(self.properties.items())
)
res += "{" + props + "}"
res += ")"
return res
def __eq__(self, rhs):
# Type checking
if not isinstance(rhs, Node):
return False
# Quick positive check, if both IDs are set.
if self.id is not None and rhs.id is not None and self.id != rhs.id:
return False
# Label should match.
if self.label != rhs.label:
return False
# Quick check for number of properties.
if len(self.properties) != len(rhs.properties):
return False
# Compare properties.
if self.properties != rhs.properties:
return False
return True

View File

@@ -0,0 +1,78 @@
from .edge import Edge
from .node import Node
class Path:
def __init__(self, nodes, edges):
if not (isinstance(nodes, list) and isinstance(edges, list)):
raise TypeError("nodes and edges must be list")
self._nodes = nodes
self._edges = edges
self.append_type = Node
@classmethod
def new_empty_path(cls):
return cls([], [])
def nodes(self):
return self._nodes
def edges(self):
return self._edges
def get_node(self, index):
return self._nodes[index]
def get_relationship(self, index):
return self._edges[index]
def first_node(self):
return self._nodes[0]
def last_node(self):
return self._nodes[-1]
def edge_count(self):
return len(self._edges)
def nodes_count(self):
return len(self._nodes)
def add_node(self, node):
if not isinstance(node, self.append_type):
raise AssertionError("Add Edge before adding Node")
self._nodes.append(node)
self.append_type = Edge
return self
def add_edge(self, edge):
if not isinstance(edge, self.append_type):
raise AssertionError("Add Node before adding Edge")
self._edges.append(edge)
self.append_type = Node
return self
def __eq__(self, other):
# Type checking
if not isinstance(other, Path):
return False
return self.nodes() == other.nodes() and self.edges() == other.edges()
def __str__(self):
res = "<"
edge_count = self.edge_count()
for i in range(0, edge_count):
node_id = self.get_node(i).id
res += "(" + str(node_id) + ")"
edge = self.get_relationship(i)
res += (
"-[" + str(int(edge.id)) + "]->"
if edge.src_node == node_id
else "<-[" + str(int(edge.id)) + "]-"
)
node_id = self.get_node(edge_count).id
res += "(" + str(node_id) + ")"
res += ">"
return res

View File

@@ -0,0 +1,573 @@
import sys
from collections import OrderedDict
from distutils.util import strtobool
# from prettytable import PrettyTable
from redis import ResponseError
from .edge import Edge
from .exceptions import VersionMismatchException
from .node import Node
from .path import Path
LABELS_ADDED = "Labels added"
LABELS_REMOVED = "Labels removed"
NODES_CREATED = "Nodes created"
NODES_DELETED = "Nodes deleted"
RELATIONSHIPS_DELETED = "Relationships deleted"
PROPERTIES_SET = "Properties set"
PROPERTIES_REMOVED = "Properties removed"
RELATIONSHIPS_CREATED = "Relationships created"
INDICES_CREATED = "Indices created"
INDICES_DELETED = "Indices deleted"
CACHED_EXECUTION = "Cached execution"
INTERNAL_EXECUTION_TIME = "internal execution time"
STATS = [
LABELS_ADDED,
LABELS_REMOVED,
NODES_CREATED,
PROPERTIES_SET,
PROPERTIES_REMOVED,
RELATIONSHIPS_CREATED,
NODES_DELETED,
RELATIONSHIPS_DELETED,
INDICES_CREATED,
INDICES_DELETED,
CACHED_EXECUTION,
INTERNAL_EXECUTION_TIME,
]
class ResultSetColumnTypes:
COLUMN_UNKNOWN = 0
COLUMN_SCALAR = 1
COLUMN_NODE = 2 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
COLUMN_RELATION = 3 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
class ResultSetScalarTypes:
VALUE_UNKNOWN = 0
VALUE_NULL = 1
VALUE_STRING = 2
VALUE_INTEGER = 3
VALUE_BOOLEAN = 4
VALUE_DOUBLE = 5
VALUE_ARRAY = 6
VALUE_EDGE = 7
VALUE_NODE = 8
VALUE_PATH = 9
VALUE_MAP = 10
VALUE_POINT = 11
class QueryResult:
def __init__(self, graph, response, profile=False):
"""
A class that represents a result of the query operation.
Args:
graph:
The graph on which the query was executed.
response:
The response from the server.
profile:
A boolean indicating if the query command was "GRAPH.PROFILE"
"""
self.graph = graph
self.header = []
self.result_set = []
# in case of an error an exception will be raised
self._check_for_errors(response)
if len(response) == 1:
self.parse_statistics(response[0])
elif profile:
self.parse_profile(response)
else:
# start by parsing statistics, matches the one we have
self.parse_statistics(response[-1]) # Last element.
self.parse_results(response)
def _check_for_errors(self, response):
"""
Check if the response contains an error.
"""
if isinstance(response[0], ResponseError):
error = response[0]
if str(error) == "version mismatch":
version = response[1]
error = VersionMismatchException(version)
raise error
# If we encountered a run-time error, the last response
# element will be an exception
if isinstance(response[-1], ResponseError):
raise response[-1]
def parse_results(self, raw_result_set):
"""
Parse the query execution result returned from the server.
"""
self.header = self.parse_header(raw_result_set)
# Empty header.
if len(self.header) == 0:
return
self.result_set = self.parse_records(raw_result_set)
def parse_statistics(self, raw_statistics):
"""
Parse the statistics returned in the response.
"""
self.statistics = {}
# decode statistics
for idx, stat in enumerate(raw_statistics):
if isinstance(stat, bytes):
raw_statistics[idx] = stat.decode()
for s in STATS:
v = self._get_value(s, raw_statistics)
if v is not None:
self.statistics[s] = v
def parse_header(self, raw_result_set):
"""
Parse the header of the result.
"""
# An array of column name/column type pairs.
header = raw_result_set[0]
return header
def parse_records(self, raw_result_set):
"""
Parses the result set and returns a list of records.
"""
records = [
[
self.parse_record_types[self.header[idx][0]](cell)
for idx, cell in enumerate(row)
]
for row in raw_result_set[1]
]
return records
def parse_entity_properties(self, props):
"""
Parse node / edge properties.
"""
# [[name, value type, value] X N]
properties = {}
for prop in props:
prop_name = self.graph.get_property(prop[0])
prop_value = self.parse_scalar(prop[1:])
properties[prop_name] = prop_value
return properties
def parse_string(self, cell):
"""
Parse the cell as a string.
"""
if isinstance(cell, bytes):
return cell.decode()
elif not isinstance(cell, str):
return str(cell)
else:
return cell
def parse_node(self, cell):
"""
Parse the cell to a node.
"""
# Node ID (integer),
# [label string offset (integer)],
# [[name, value type, value] X N]
node_id = int(cell[0])
labels = None
if len(cell[1]) > 0:
labels = []
for inner_label in cell[1]:
labels.append(self.graph.get_label(inner_label))
properties = self.parse_entity_properties(cell[2])
return Node(node_id=node_id, label=labels, properties=properties)
def parse_edge(self, cell):
"""
Parse the cell to an edge.
"""
# Edge ID (integer),
# reltype string offset (integer),
# src node ID offset (integer),
# dest node ID offset (integer),
# [[name, value, value type] X N]
edge_id = int(cell[0])
relation = self.graph.get_relation(cell[1])
src_node_id = int(cell[2])
dest_node_id = int(cell[3])
properties = self.parse_entity_properties(cell[4])
return Edge(
src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties
)
def parse_path(self, cell):
"""
Parse the cell to a path.
"""
nodes = self.parse_scalar(cell[0])
edges = self.parse_scalar(cell[1])
return Path(nodes, edges)
def parse_map(self, cell):
"""
Parse the cell as a map.
"""
m = OrderedDict()
n_entries = len(cell)
# A map is an array of key value pairs.
# 1. key (string)
# 2. array: (value type, value)
for i in range(0, n_entries, 2):
key = self.parse_string(cell[i])
m[key] = self.parse_scalar(cell[i + 1])
return m
def parse_point(self, cell):
"""
Parse the cell to point.
"""
p = {}
# A point is received an array of the form: [latitude, longitude]
# It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa
p["latitude"] = float(cell[0])
p["longitude"] = float(cell[1])
return p
def parse_null(self, cell):
"""
Parse a null value.
"""
return None
def parse_integer(self, cell):
"""
Parse the integer value from the cell.
"""
return int(cell)
def parse_boolean(self, value):
"""
Parse the cell value as a boolean.
"""
value = value.decode() if isinstance(value, bytes) else value
try:
scalar = True if strtobool(value) else False
except ValueError:
sys.stderr.write("unknown boolean type\n")
scalar = None
return scalar
def parse_double(self, cell):
"""
Parse the cell as a double.
"""
return float(cell)
def parse_array(self, value):
"""
Parse an array of values.
"""
scalar = [self.parse_scalar(value[i]) for i in range(len(value))]
return scalar
def parse_unknown(self, cell):
"""
Parse a cell of unknown type.
"""
sys.stderr.write("Unknown type\n")
return None
def parse_scalar(self, cell):
"""
Parse a scalar value from a cell in the result set.
"""
scalar_type = int(cell[0])
value = cell[1]
scalar = self.parse_scalar_types[scalar_type](value)
return scalar
def parse_profile(self, response):
self.result_set = [x[0 : x.index(",")].strip() for x in response]
def is_empty(self):
return len(self.result_set) == 0
@staticmethod
def _get_value(prop, statistics):
for stat in statistics:
if prop in stat:
return float(stat.split(": ")[1].split(" ")[0])
return None
def _get_stat(self, stat):
return self.statistics[stat] if stat in self.statistics else 0
@property
def labels_added(self):
"""Returns the number of labels added in the query"""
return self._get_stat(LABELS_ADDED)
@property
def labels_removed(self):
"""Returns the number of labels removed in the query"""
return self._get_stat(LABELS_REMOVED)
@property
def nodes_created(self):
"""Returns the number of nodes created in the query"""
return self._get_stat(NODES_CREATED)
@property
def nodes_deleted(self):
"""Returns the number of nodes deleted in the query"""
return self._get_stat(NODES_DELETED)
@property
def properties_set(self):
"""Returns the number of properties set in the query"""
return self._get_stat(PROPERTIES_SET)
@property
def properties_removed(self):
"""Returns the number of properties removed in the query"""
return self._get_stat(PROPERTIES_REMOVED)
@property
def relationships_created(self):
"""Returns the number of relationships created in the query"""
return self._get_stat(RELATIONSHIPS_CREATED)
@property
def relationships_deleted(self):
"""Returns the number of relationships deleted in the query"""
return self._get_stat(RELATIONSHIPS_DELETED)
@property
def indices_created(self):
"""Returns the number of indices created in the query"""
return self._get_stat(INDICES_CREATED)
@property
def indices_deleted(self):
"""Returns the number of indices deleted in the query"""
return self._get_stat(INDICES_DELETED)
@property
def cached_execution(self):
"""Returns whether or not the query execution plan was cached"""
return self._get_stat(CACHED_EXECUTION) == 1
@property
def run_time_ms(self):
"""Returns the server execution time of the query"""
return self._get_stat(INTERNAL_EXECUTION_TIME)
@property
def parse_scalar_types(self):
return {
ResultSetScalarTypes.VALUE_NULL: self.parse_null,
ResultSetScalarTypes.VALUE_STRING: self.parse_string,
ResultSetScalarTypes.VALUE_INTEGER: self.parse_integer,
ResultSetScalarTypes.VALUE_BOOLEAN: self.parse_boolean,
ResultSetScalarTypes.VALUE_DOUBLE: self.parse_double,
ResultSetScalarTypes.VALUE_ARRAY: self.parse_array,
ResultSetScalarTypes.VALUE_NODE: self.parse_node,
ResultSetScalarTypes.VALUE_EDGE: self.parse_edge,
ResultSetScalarTypes.VALUE_PATH: self.parse_path,
ResultSetScalarTypes.VALUE_MAP: self.parse_map,
ResultSetScalarTypes.VALUE_POINT: self.parse_point,
ResultSetScalarTypes.VALUE_UNKNOWN: self.parse_unknown,
}
@property
def parse_record_types(self):
return {
ResultSetColumnTypes.COLUMN_SCALAR: self.parse_scalar,
ResultSetColumnTypes.COLUMN_NODE: self.parse_node,
ResultSetColumnTypes.COLUMN_RELATION: self.parse_edge,
ResultSetColumnTypes.COLUMN_UNKNOWN: self.parse_unknown,
}
class AsyncQueryResult(QueryResult):
"""
Async version for the QueryResult class - a class that
represents a result of the query operation.
"""
def __init__(self):
"""
To init the class you must call self.initialize()
"""
pass
async def initialize(self, graph, response, profile=False):
"""
Initializes the class.
Args:
graph:
The graph on which the query was executed.
response:
The response from the server.
profile:
A boolean indicating if the query command was "GRAPH.PROFILE"
"""
self.graph = graph
self.header = []
self.result_set = []
# in case of an error an exception will be raised
self._check_for_errors(response)
if len(response) == 1:
self.parse_statistics(response[0])
elif profile:
self.parse_profile(response)
else:
# start by parsing statistics, matches the one we have
self.parse_statistics(response[-1]) # Last element.
await self.parse_results(response)
return self
async def parse_node(self, cell):
"""
Parses a node from the cell.
"""
# Node ID (integer),
# [label string offset (integer)],
# [[name, value type, value] X N]
labels = None
if len(cell[1]) > 0:
labels = []
for inner_label in cell[1]:
labels.append(await self.graph.get_label(inner_label))
properties = await self.parse_entity_properties(cell[2])
node_id = int(cell[0])
return Node(node_id=node_id, label=labels, properties=properties)
async def parse_scalar(self, cell):
"""
Parses a scalar value from the server response.
"""
scalar_type = int(cell[0])
value = cell[1]
try:
scalar = await self.parse_scalar_types[scalar_type](value)
except TypeError:
# Not all of the functions are async
scalar = self.parse_scalar_types[scalar_type](value)
return scalar
async def parse_records(self, raw_result_set):
"""
Parses the result set and returns a list of records.
"""
records = []
for row in raw_result_set[1]:
record = [
await self.parse_record_types[self.header[idx][0]](cell)
for idx, cell in enumerate(row)
]
records.append(record)
return records
async def parse_results(self, raw_result_set):
"""
Parse the query execution result returned from the server.
"""
self.header = self.parse_header(raw_result_set)
# Empty header.
if len(self.header) == 0:
return
self.result_set = await self.parse_records(raw_result_set)
async def parse_entity_properties(self, props):
"""
Parse node / edge properties.
"""
# [[name, value type, value] X N]
properties = {}
for prop in props:
prop_name = await self.graph.get_property(prop[0])
prop_value = await self.parse_scalar(prop[1:])
properties[prop_name] = prop_value
return properties
async def parse_edge(self, cell):
"""
Parse the cell to an edge.
"""
# Edge ID (integer),
# reltype string offset (integer),
# src node ID offset (integer),
# dest node ID offset (integer),
# [[name, value, value type] X N]
edge_id = int(cell[0])
relation = await self.graph.get_relation(cell[1])
src_node_id = int(cell[2])
dest_node_id = int(cell[3])
properties = await self.parse_entity_properties(cell[4])
return Edge(
src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties
)
async def parse_path(self, cell):
"""
Parse the cell to a path.
"""
nodes = await self.parse_scalar(cell[0])
edges = await self.parse_scalar(cell[1])
return Path(nodes, edges)
async def parse_map(self, cell):
"""
Parse the cell to a map.
"""
m = OrderedDict()
n_entries = len(cell)
# A map is an array of key value pairs.
# 1. key (string)
# 2. array: (value type, value)
for i in range(0, n_entries, 2):
key = self.parse_string(cell[i])
m[key] = await self.parse_scalar(cell[i + 1])
return m
async def parse_array(self, value):
"""
Parse array value.
"""
scalar = [await self.parse_scalar(value[i]) for i in range(len(value))]
return scalar

View File

@@ -43,32 +43,19 @@ def parse_to_list(response):
"""Optimistically parse the response to a list."""
res = []
special_values = {"infinity", "nan", "-infinity"}
if response is None:
return res
for item in response:
if item is None:
res.append(None)
continue
try:
item_str = nativestr(item)
res.append(int(item))
except ValueError:
try:
res.append(float(item))
except ValueError:
res.append(nativestr(item))
except TypeError:
res.append(None)
continue
if isinstance(item_str, str) and item_str.lower() in special_values:
res.append(item_str) # Keep as string
else:
try:
res.append(int(item))
except ValueError:
try:
res.append(float(item))
except ValueError:
res.append(item_str)
return res
@@ -77,11 +64,6 @@ def parse_list_to_dict(response):
for i in range(0, len(response), 2):
if isinstance(response[i], list):
res["Child iterators"].append(parse_list_to_dict(response[i]))
try:
if isinstance(response[i + 1], list):
res["Child iterators"].append(parse_list_to_dict(response[i + 1]))
except IndexError:
pass
elif isinstance(response[i + 1], list):
res["Child iterators"] = [parse_list_to_dict(response[i + 1])]
else:
@@ -92,6 +74,25 @@ def parse_list_to_dict(response):
return res
def parse_to_dict(response):
if response is None:
return {}
res = {}
for det in response:
if isinstance(det[1], list):
res[det[0]] = parse_list_to_dict(det[1])
else:
try: # try to set the attribute. may be provided without value
try: # try to convert the value to float
res[det[0]] = float(det[1])
except (TypeError, ValueError):
res[det[0]] = det[1]
except IndexError:
pass
return res
def random_string(length=10):
"""
Returns a random N character long string.
@@ -101,6 +102,26 @@ def random_string(length=10):
)
def quote_string(v):
"""
RedisGraph strings must be quoted,
quote_string wraps given v with quotes incase
v is a string.
"""
if isinstance(v, bytes):
v = v.decode()
elif not isinstance(v, str):
return v
if len(v) == 0:
return '""'
v = v.replace("\\", "\\\\")
v = v.replace('"', '\\"')
return f'"{v}"'
def decode_dict_keys(obj):
"""Decode the keys of the given dictionary with utf-8."""
newobj = copy.copy(obj)
@@ -111,6 +132,33 @@ def decode_dict_keys(obj):
return newobj
def stringify_param_value(value):
"""
Turn a parameter value into a string suitable for the params header of
a Cypher command.
You may pass any value that would be accepted by `json.dumps()`.
Ways in which output differs from that of `str()`:
* Strings are quoted.
* None --> "null".
* In dictionaries, keys are _not_ quoted.
:param value: The parameter value to be turned into a string.
:return: string
"""
if isinstance(value, str):
return quote_string(value)
elif value is None:
return "null"
elif isinstance(value, (list, tuple)):
return f'[{",".join(map(stringify_param_value, value))}]'
elif isinstance(value, dict):
return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' # noqa
else:
return str(value)
def get_protocol_version(client):
if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis):
return client.connection_pool.connection_kwargs.get("protocol")

View File

@@ -120,7 +120,7 @@ class JSON(JSONCommands):
startup_nodes=self.client.nodes_manager.startup_nodes,
result_callbacks=self.client.result_callbacks,
cluster_response_callbacks=self.client.cluster_response_callbacks,
cluster_error_retry_attempts=self.client.retry.get_retries(),
cluster_error_retry_attempts=self.client.cluster_error_retry_attempts,
read_from_replicas=self.client.read_from_replicas,
reinitialize_steps=self.client.reinitialize_steps,
lock=self.client._lock,

View File

@@ -1,5 +1,3 @@
from typing import List, Mapping, Union
from typing import Any, Dict, List, Union
JsonType = Union[
str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"]
]
JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]

View File

@@ -15,7 +15,7 @@ class JSONCommands:
def arrappend(
self, name: str, path: Optional[str] = Path.root_path(), *args: List[JsonType]
) -> List[Optional[int]]:
) -> List[Union[int, None]]:
"""Append the objects ``args`` to the array under the
``path` in key ``name``.
@@ -33,7 +33,7 @@ class JSONCommands:
scalar: int,
start: Optional[int] = None,
stop: Optional[int] = None,
) -> List[Optional[int]]:
) -> List[Union[int, None]]:
"""
Return the index of ``scalar`` in the JSON array under ``path`` at key
``name``.
@@ -49,11 +49,11 @@ class JSONCommands:
if stop is not None:
pieces.append(stop)
return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name])
return self.execute_command("JSON.ARRINDEX", *pieces)
def arrinsert(
self, name: str, path: str, index: int, *args: List[JsonType]
) -> List[Optional[int]]:
) -> List[Union[int, None]]:
"""Insert the objects ``args`` to the array at index ``index``
under the ``path` in key ``name``.
@@ -66,20 +66,20 @@ class JSONCommands:
def arrlen(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[int]]:
) -> List[Union[int, None]]:
"""Return the length of the array JSON value under ``path``
at key``name``.
For more information see `JSON.ARRLEN <https://redis.io/commands/json.arrlen>`_.
""" # noqa
return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name])
return self.execute_command("JSON.ARRLEN", name, str(path))
def arrpop(
self,
name: str,
path: Optional[str] = Path.root_path(),
index: Optional[int] = -1,
) -> List[Optional[str]]:
) -> List[Union[str, None]]:
"""Pop the element at ``index`` in the array JSON value under
``path`` at key ``name``.
@@ -89,7 +89,7 @@ class JSONCommands:
def arrtrim(
self, name: str, path: str, start: int, stop: int
) -> List[Optional[int]]:
) -> List[Union[int, None]]:
"""Trim the array JSON value under ``path`` at key ``name`` to the
inclusive range given by ``start`` and ``stop``.
@@ -102,34 +102,32 @@ class JSONCommands:
For more information see `JSON.TYPE <https://redis.io/commands/json.type>`_.
""" # noqa
return self.execute_command("JSON.TYPE", name, str(path), keys=[name])
return self.execute_command("JSON.TYPE", name, str(path))
def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List:
"""Return the JSON value under ``path`` at key ``name``.
For more information see `JSON.RESP <https://redis.io/commands/json.resp>`_.
""" # noqa
return self.execute_command("JSON.RESP", name, str(path), keys=[name])
return self.execute_command("JSON.RESP", name, str(path))
def objkeys(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[List[str]]]:
) -> List[Union[List[str], None]]:
"""Return the key names in the dictionary JSON value under ``path`` at
key ``name``.
For more information see `JSON.OBJKEYS <https://redis.io/commands/json.objkeys>`_.
""" # noqa
return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name])
return self.execute_command("JSON.OBJKEYS", name, str(path))
def objlen(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[int]]:
def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int:
"""Return the length of the dictionary JSON value under ``path`` at key
``name``.
For more information see `JSON.OBJLEN <https://redis.io/commands/json.objlen>`_.
""" # noqa
return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name])
return self.execute_command("JSON.OBJLEN", name, str(path))
def numincrby(self, name: str, path: str, number: int) -> str:
"""Increment the numeric (integer or floating point) JSON value under
@@ -175,7 +173,7 @@ class JSONCommands:
def get(
self, name: str, *args, no_escape: Optional[bool] = False
) -> Optional[List[JsonType]]:
) -> List[JsonType]:
"""
Get the object stored as a JSON value at key ``name``.
@@ -199,7 +197,7 @@ class JSONCommands:
# Handle case where key doesn't exist. The JSONDecoder would raise a
# TypeError exception since it can't decode None
try:
return self.execute_command("JSON.GET", *pieces, keys=[name])
return self.execute_command("JSON.GET", *pieces)
except TypeError:
return None
@@ -213,7 +211,7 @@ class JSONCommands:
pieces = []
pieces += keys
pieces.append(str(path))
return self.execute_command("JSON.MGET", *pieces, keys=keys)
return self.execute_command("JSON.MGET", *pieces)
def set(
self,
@@ -314,7 +312,7 @@ class JSONCommands:
"""
with open(file_name) as fp:
with open(file_name, "r") as fp:
file_content = loads(fp.read())
return self.set(name, path, file_content, nx=nx, xx=xx, decode_keys=decode_keys)
@@ -326,7 +324,7 @@ class JSONCommands:
nx: Optional[bool] = False,
xx: Optional[bool] = False,
decode_keys: Optional[bool] = False,
) -> Dict[str, bool]:
) -> List[Dict[str, bool]]:
"""
Iterate over ``root_folder`` and set each JSON file to a value
under ``json_path`` with the file name as the key.
@@ -357,7 +355,7 @@ class JSONCommands:
return set_files_result
def strlen(self, name: str, path: Optional[str] = None) -> List[Optional[int]]:
def strlen(self, name: str, path: Optional[str] = None) -> List[Union[int, None]]:
"""Return the length of the string JSON value under ``path`` at key
``name``.
@@ -366,7 +364,7 @@ class JSONCommands:
pieces = [name]
if path is not None:
pieces.append(str(path))
return self.execute_command("JSON.STRLEN", *pieces, keys=[name])
return self.execute_command("JSON.STRLEN", *pieces)
def toggle(
self, name: str, path: Optional[str] = Path.root_path()
@@ -379,7 +377,7 @@ class JSONCommands:
return self.execute_command("JSON.TOGGLE", name, str(path))
def strappend(
self, name: str, value: str, path: Optional[str] = Path.root_path()
self, name: str, value: str, path: Optional[int] = Path.root_path()
) -> Union[int, List[Optional[int]]]:
"""Append to the string JSON value. If two options are specified after
the key name, the path is determined to be the first. If a single

View File

@@ -1,14 +1,4 @@
from __future__ import annotations
from json import JSONDecoder, JSONEncoder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .bf import BFBloom, CFBloom, CMSBloom, TDigestBloom, TOPKBloom
from .json import JSON
from .search import AsyncSearch, Search
from .timeseries import TimeSeries
from .vectorset import VectorSet
class RedisModuleCommands:
@@ -16,7 +6,7 @@ class RedisModuleCommands:
modules into the command namespace.
"""
def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()) -> JSON:
def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()):
"""Access the json namespace, providing support for redis json."""
from .json import JSON
@@ -24,7 +14,7 @@ class RedisModuleCommands:
jj = JSON(client=self, encoder=encoder, decoder=decoder)
return jj
def ft(self, index_name="idx") -> Search:
def ft(self, index_name="idx"):
"""Access the search namespace, providing support for redis search."""
from .search import Search
@@ -32,7 +22,7 @@ class RedisModuleCommands:
s = Search(client=self, index_name=index_name)
return s
def ts(self) -> TimeSeries:
def ts(self):
"""Access the timeseries namespace, providing support for
redis timeseries data.
"""
@@ -42,7 +32,7 @@ class RedisModuleCommands:
s = TimeSeries(client=self)
return s
def bf(self) -> BFBloom:
def bf(self):
"""Access the bloom namespace."""
from .bf import BFBloom
@@ -50,7 +40,7 @@ class RedisModuleCommands:
bf = BFBloom(client=self)
return bf
def cf(self) -> CFBloom:
def cf(self):
"""Access the bloom namespace."""
from .bf import CFBloom
@@ -58,7 +48,7 @@ class RedisModuleCommands:
cf = CFBloom(client=self)
return cf
def cms(self) -> CMSBloom:
def cms(self):
"""Access the bloom namespace."""
from .bf import CMSBloom
@@ -66,7 +56,7 @@ class RedisModuleCommands:
cms = CMSBloom(client=self)
return cms
def topk(self) -> TOPKBloom:
def topk(self):
"""Access the bloom namespace."""
from .bf import TOPKBloom
@@ -74,7 +64,7 @@ class RedisModuleCommands:
topk = TOPKBloom(client=self)
return topk
def tdigest(self) -> TDigestBloom:
def tdigest(self):
"""Access the bloom namespace."""
from .bf import TDigestBloom
@@ -82,20 +72,32 @@ class RedisModuleCommands:
tdigest = TDigestBloom(client=self)
return tdigest
def vset(self) -> VectorSet:
"""Access the VectorSet commands namespace."""
def graph(self, index_name="idx"):
"""Access the graph namespace, providing support for
redis graph data.
"""
from .vectorset import VectorSet
from .graph import Graph
vset = VectorSet(client=self)
return vset
g = Graph(client=self, name=index_name)
return g
class AsyncRedisModuleCommands(RedisModuleCommands):
def ft(self, index_name="idx") -> AsyncSearch:
def ft(self, index_name="idx"):
"""Access the search namespace, providing support for redis search."""
from .search import AsyncSearch
s = AsyncSearch(client=self, index_name=index_name)
return s
def graph(self, index_name="idx"):
"""Access the graph namespace, providing support for
redis graph data.
"""
from .graph import AsyncGraph
g = AsyncGraph(client=self, name=index_name)
return g

View File

@@ -1,7 +1,7 @@
def to_string(s, encoding: str = "utf-8"):
def to_string(s):
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode(encoding, "ignore")
return s.decode("utf-8", "ignore")
else:
return s # Not a string we care about

View File

@@ -1,7 +1,5 @@
from typing import List, Union
from redis.commands.search.dialect import DEFAULT_DIALECT
FIELDNAME = object()
@@ -26,7 +24,7 @@ class Reducer:
NAME = None
def __init__(self, *args: str) -> None:
def __init__(self, *args: List[str]) -> None:
self._args = args
self._field = None
self._alias = None
@@ -112,11 +110,9 @@ class AggregateRequest:
self._with_schema = False
self._verbatim = False
self._cursor = []
self._dialect = DEFAULT_DIALECT
self._add_scores = False
self._scorer = "TFIDF"
self._dialect = None
def load(self, *fields: str) -> "AggregateRequest":
def load(self, *fields: List[str]) -> "AggregateRequest":
"""
Indicate the fields to be returned in the response. These fields are
returned in addition to any others implicitly specified.
@@ -223,7 +219,7 @@ class AggregateRequest:
self._aggregateplan.extend(_limit.build_args())
return self
def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
"""
Indicate how the results should be sorted. This can also be used for
*top-N* style queries
@@ -296,24 +292,6 @@ class AggregateRequest:
self._with_schema = True
return self
def add_scores(self) -> "AggregateRequest":
"""
If set, includes the score as an ordinary field of the row.
"""
self._add_scores = True
return self
def scorer(self, scorer: str) -> "AggregateRequest":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self
def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
@@ -337,19 +315,12 @@ class AggregateRequest:
if self._verbatim:
ret.append("VERBATIM")
if self._scorer:
ret.extend(["SCORER", self._scorer])
if self._add_scores:
ret.append("ADDSCORES")
if self._cursor:
ret += self._cursor
if self._loadall:
ret.append("LOAD")
ret.append("*")
elif self._loadfields:
ret.append("LOAD")
ret.append(str(len(self._loadfields)))

View File

@@ -2,16 +2,13 @@ import itertools
import time
from typing import Dict, List, Optional, Union
from redis.client import NEVER_DECODE, Pipeline
from redis.client import Pipeline
from redis.utils import deprecated_function
from ..helpers import get_protocol_version
from ..helpers import get_protocol_version, parse_to_dict
from ._util import to_string
from .aggregation import AggregateRequest, AggregateResult, Cursor
from .document import Document
from .field import Field
from .index_definition import IndexDefinition
from .profile_information import ProfileInformation
from .query import Query
from .result import Result
from .suggestion import SuggestionParser
@@ -23,6 +20,7 @@ ALTER_CMD = "FT.ALTER"
SEARCH_CMD = "FT.SEARCH"
ADD_CMD = "FT.ADD"
ADDHASH_CMD = "FT.ADDHASH"
DROP_CMD = "FT.DROP"
DROPINDEX_CMD = "FT.DROPINDEX"
EXPLAIN_CMD = "FT.EXPLAIN"
EXPLAINCLI_CMD = "FT.EXPLAINCLI"
@@ -34,6 +32,7 @@ SPELLCHECK_CMD = "FT.SPELLCHECK"
DICT_ADD_CMD = "FT.DICTADD"
DICT_DEL_CMD = "FT.DICTDEL"
DICT_DUMP_CMD = "FT.DICTDUMP"
GET_CMD = "FT.GET"
MGET_CMD = "FT.MGET"
CONFIG_CMD = "FT.CONFIG"
TAGVALS_CMD = "FT.TAGVALS"
@@ -66,7 +65,7 @@ class SearchCommands:
def _parse_results(self, cmd, res, **kwargs):
if get_protocol_version(self.client) in ["3", 3]:
return ProfileInformation(res) if cmd == "FT.PROFILE" else res
return res
else:
return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs)
@@ -81,7 +80,6 @@ class SearchCommands:
duration=kwargs["duration"],
has_payload=kwargs["query"]._with_payloads,
with_scores=kwargs["query"]._with_scores,
field_encodings=kwargs["query"]._return_fields_decode_as,
)
def _parse_aggregate(self, res, **kwargs):
@@ -100,7 +98,7 @@ class SearchCommands:
with_scores=query._with_scores,
)
return result, ProfileInformation(res[1])
return result, parse_to_dict(res[1])
def _parse_spellcheck(self, res, **kwargs):
corrections = {}
@@ -153,43 +151,44 @@ class SearchCommands:
def create_index(
self,
fields: List[Field],
no_term_offsets: bool = False,
no_field_flags: bool = False,
stopwords: Optional[List[str]] = None,
definition: Optional[IndexDefinition] = None,
fields,
no_term_offsets=False,
no_field_flags=False,
stopwords=None,
definition=None,
max_text_fields=False,
temporary=None,
no_highlight: bool = False,
no_term_frequencies: bool = False,
skip_initial_scan: bool = False,
no_highlight=False,
no_term_frequencies=False,
skip_initial_scan=False,
):
"""
Creates the search index. The index must not already exist.
Create the search index. The index must not already exist.
For more information, see https://redis.io/commands/ft.create/
### Parameters:
Args:
fields: A list of Field objects.
no_term_offsets: If `true`, term offsets will not be saved in the index.
no_field_flags: If true, field flags that allow searching in specific fields
will not be saved.
stopwords: If provided, the index will be created with this custom stopword
list. The list can be empty.
definition: If provided, the index will be created with this custom index
definition.
max_text_fields: If true, indexes will be encoded as if there were more than
32 text fields, allowing for additional fields beyond 32.
temporary: Creates a lightweight temporary index which will expire after the
specified period of inactivity. The internal idle timer is reset
whenever the index is searched or added to.
no_highlight: If true, disables highlighting support. Also implied by
`no_term_offsets`.
no_term_frequencies: If true, term frequencies will not be saved in the
index.
skip_initial_scan: If true, the initial scan and indexing will be skipped.
- **fields**: a list of TextField or NumericField objects
- **no_term_offsets**: If true, we will not save term offsets in
the index
- **no_field_flags**: If true, we will not save field flags that
allow searching in specific fields
- **stopwords**: If not None, we create the index with this custom
stopword list. The list can be empty
- **max_text_fields**: If true, we will encode indexes as if there
were more than 32 text fields which allows you to add additional
fields (beyond 32).
- **temporary**: Create a lightweight temporary index which will
expire after the specified period of inactivity (in seconds). The
internal idle timer is reset whenever the index is searched or added to.
- **no_highlight**: If true, disabling highlighting support.
Also implied by no_term_offsets.
- **no_term_frequencies**: If true, we avoid saving the term frequencies
in the index.
- **skip_initial_scan**: If true, we do not scan and index.
For more information see `FT.CREATE <https://redis.io/commands/ft.create>`_.
""" # noqa
"""
args = [CREATE_CMD, self.index_name]
if definition is not None:
args += definition.args
@@ -253,18 +252,8 @@ class SearchCommands:
For more information see `FT.DROPINDEX <https://redis.io/commands/ft.dropindex>`_.
""" # noqa
args = [DROPINDEX_CMD, self.index_name]
delete_str = (
"DD"
if isinstance(delete_documents, bool) and delete_documents is True
else ""
)
if delete_str:
args.append(delete_str)
return self.execute_command(*args)
delete_str = "DD" if delete_documents else ""
return self.execute_command(DROPINDEX_CMD, self.index_name, delete_str)
def _add_document(
self,
@@ -346,30 +335,30 @@ class SearchCommands:
"""
Add a single document to the index.
Args:
### Parameters
doc_id: the id of the saved document.
nosave: if set to true, we just index the document, and don't
- **doc_id**: the id of the saved document.
- **nosave**: if set to true, we just index the document, and don't
save a copy of it. This means that searches will just
return ids.
score: the document ranking, between 0.0 and 1.0
payload: optional inner-index payload we can save for fast
access in scoring functions
replace: if True, and the document already is in the index,
we perform an update and reindex the document
partial: if True, the fields specified will be added to the
- **score**: the document ranking, between 0.0 and 1.0
- **payload**: optional inner-index payload we can save for fast
i access in scoring functions
- **replace**: if True, and the document already is in the index,
we perform an update and reindex the document
- **partial**: if True, the fields specified will be added to the
existing document.
This has the added benefit that any fields specified
with `no_index`
will not be reindexed again. Implies `replace`
language: Specify the language used for document tokenization.
no_create: if True, the document is only updated and reindexed
- **language**: Specify the language used for document tokenization.
- **no_create**: if True, the document is only updated and reindexed
if it already exists.
If the document does not exist, an error will be
returned. Implies `replace`
fields: kwargs dictionary of the document fields to be saved
and/or indexed.
NOTE: Geo points shoule be encoded as strings of "lon,lat"
- **fields** kwargs dictionary of the document fields to be saved
and/or indexed.
NOTE: Geo points shoule be encoded as strings of "lon,lat"
""" # noqa
return self._add_document(
doc_id,
@@ -404,7 +393,6 @@ class SearchCommands:
doc_id, conn=None, score=score, language=language, replace=replace
)
@deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0")
def delete_document(self, doc_id, conn=None, delete_actual_document=False):
"""
Delete a document from index
@@ -439,7 +427,6 @@ class SearchCommands:
return Document(id=id, **fields)
@deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0")
def get(self, *ids):
"""
Returns the full contents of multiple documents.
@@ -510,19 +497,14 @@ class SearchCommands:
For more information see `FT.SEARCH <https://redis.io/commands/ft.search>`_.
""" # noqa
args, query = self._mk_query_args(query, query_params=query_params)
st = time.monotonic()
options = {}
if get_protocol_version(self.client) not in ["3", 3]:
options[NEVER_DECODE] = True
res = self.execute_command(SEARCH_CMD, *args, **options)
st = time.time()
res = self.execute_command(SEARCH_CMD, *args)
if isinstance(res, Pipeline):
return res
return self._parse_results(
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0
)
def explain(
@@ -542,7 +524,7 @@ class SearchCommands:
def aggregate(
self,
query: Union[AggregateRequest, Cursor],
query: Union[str, Query],
query_params: Dict[str, Union[str, int, float]] = None,
):
"""
@@ -573,7 +555,7 @@ class SearchCommands:
)
def _get_aggregate_result(
self, raw: List, query: Union[AggregateRequest, Cursor], has_cursor: bool
self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool
):
if has_cursor:
if isinstance(query, Cursor):
@@ -596,7 +578,7 @@ class SearchCommands:
def profile(
self,
query: Union[Query, AggregateRequest],
query: Union[str, Query, AggregateRequest],
limited: bool = False,
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
):
@@ -606,13 +588,13 @@ class SearchCommands:
### Parameters
**query**: This can be either an `AggregateRequest` or `Query`.
**query**: This can be either an `AggregateRequest`, `Query` or string.
**limited**: If set to True, removes details of reader iterator.
**query_params**: Define one or more value parameters.
Each parameter has a name and a value.
"""
st = time.monotonic()
st = time.time()
cmd = [PROFILE_CMD, self.index_name, ""]
if limited:
cmd.append("LIMITED")
@@ -631,20 +613,20 @@ class SearchCommands:
res = self.execute_command(*cmd)
return self._parse_results(
PROFILE_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0
)
def spellcheck(self, query, distance=None, include=None, exclude=None):
"""
Issue a spellcheck query
Args:
### Parameters
query: search query.
distance: the maximal Levenshtein distance for spelling
**query**: search query.
**distance***: the maximal Levenshtein distance for spelling
suggestions (default: 1, max: 4).
include: specifies an inclusion custom dictionary.
exclude: specifies an exclusion custom dictionary.
**include**: specifies an inclusion custom dictionary.
**exclude**: specifies an exclusion custom dictionary.
For more information see `FT.SPELLCHECK <https://redis.io/commands/ft.spellcheck>`_.
""" # noqa
@@ -702,10 +684,6 @@ class SearchCommands:
cmd = [DICT_DUMP_CMD, name]
return self.execute_command(*cmd)
@deprecated_function(
version="8.0.0",
reason="deprecated since Redis 8.0, call config_set from core module instead",
)
def config_set(self, option: str, value: str) -> bool:
"""Set runtime configuration option.
@@ -720,10 +698,6 @@ class SearchCommands:
raw = self.execute_command(*cmd)
return raw == "OK"
@deprecated_function(
version="8.0.0",
reason="deprecated since Redis 8.0, call config_get from core module instead",
)
def config_get(self, option: str) -> str:
"""Get runtime configuration option value.
@@ -950,24 +924,19 @@ class AsyncSearchCommands(SearchCommands):
For more information see `FT.SEARCH <https://redis.io/commands/ft.search>`_.
""" # noqa
args, query = self._mk_query_args(query, query_params=query_params)
st = time.monotonic()
options = {}
if get_protocol_version(self.client) not in ["3", 3]:
options[NEVER_DECODE] = True
res = await self.execute_command(SEARCH_CMD, *args, **options)
st = time.time()
res = await self.execute_command(SEARCH_CMD, *args)
if isinstance(res, Pipeline):
return res
return self._parse_results(
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0
)
async def aggregate(
self,
query: Union[AggregateResult, Cursor],
query: Union[str, Query],
query_params: Dict[str, Union[str, int, float]] = None,
):
"""
@@ -1025,10 +994,6 @@ class AsyncSearchCommands(SearchCommands):
return self._parse_results(SPELLCHECK_CMD, res)
@deprecated_function(
version="8.0.0",
reason="deprecated since Redis 8.0, call config_set from core module instead",
)
async def config_set(self, option: str, value: str) -> bool:
"""Set runtime configuration option.
@@ -1043,10 +1008,6 @@ class AsyncSearchCommands(SearchCommands):
raw = await self.execute_command(*cmd)
return raw == "OK"
@deprecated_function(
version="8.0.0",
reason="deprecated since Redis 8.0, call config_get from core module instead",
)
async def config_get(self, option: str) -> str:
"""Get runtime configuration option value.

View File

@@ -1,3 +0,0 @@
# Value for the default dialect to be used as a part of
# Search or Aggregate query.
DEFAULT_DIALECT = 2

View File

@@ -4,10 +4,6 @@ from redis import DataError
class Field:
"""
A class representing a field in a document.
"""
NUMERIC = "NUMERIC"
TEXT = "TEXT"
WEIGHT = "WEIGHT"
@@ -17,9 +13,6 @@ class Field:
SORTABLE = "SORTABLE"
NOINDEX = "NOINDEX"
AS = "AS"
GEOSHAPE = "GEOSHAPE"
INDEX_MISSING = "INDEXMISSING"
INDEX_EMPTY = "INDEXEMPTY"
def __init__(
self,
@@ -27,24 +20,8 @@ class Field:
args: List[str] = None,
sortable: bool = False,
no_index: bool = False,
index_missing: bool = False,
index_empty: bool = False,
as_name: str = None,
):
"""
Create a new field object.
Args:
name: The name of the field.
args:
sortable: If `True`, the field will be sortable.
no_index: If `True`, the field will not be indexed.
index_missing: If `True`, it will be possible to search for documents that
have this field missing.
index_empty: If `True`, it will be possible to search for documents that
have this field empty.
as_name: If provided, this alias will be used for the field.
"""
if args is None:
args = []
self.name = name
@@ -56,10 +33,6 @@ class Field:
self.args_suffix.append(Field.SORTABLE)
if no_index:
self.args_suffix.append(Field.NOINDEX)
if index_missing:
self.args_suffix.append(Field.INDEX_MISSING)
if index_empty:
self.args_suffix.append(Field.INDEX_EMPTY)
if no_index and not sortable:
raise ValueError("Non-Sortable non-Indexable fields are ignored")
@@ -118,21 +91,6 @@ class NumericField(Field):
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
class GeoShapeField(Field):
"""
GeoShapeField is used to enable within/contain indexing/searching
"""
SPHERICAL = "SPHERICAL"
FLAT = "FLAT"
def __init__(self, name: str, coord_system=None, **kwargs):
args = [Field.GEOSHAPE]
if coord_system:
args.append(coord_system)
Field.__init__(self, name, args=args, **kwargs)
class GeoField(Field):
"""
GeoField is used to define a geo-indexing field in a schema definition
@@ -181,7 +139,7 @@ class VectorField(Field):
``name`` is the name of the field.
``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA".
``algorithm`` can be "FLAT" or "HNSW".
``attributes`` each algorithm can have specific attributes. Some of them
are mandatory and some of them are optional. See
@@ -194,10 +152,10 @@ class VectorField(Field):
if sort or noindex:
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]:
if algorithm.upper() not in ["FLAT", "HNSW"]:
raise DataError(
"Realtime vector indexing supporting 3 Indexing Methods:"
"'FLAT', 'HNSW', and 'SVS-VAMANA'."
"Realtime vector indexing supporting 2 Indexing Methods:"
"'FLAT' and 'HNSW'."
)
attr_li = []

View File

@@ -1,14 +0,0 @@
from typing import Any
class ProfileInformation:
"""
Wrapper around FT.PROFILE response
"""
def __init__(self, info: Any) -> None:
self._info: Any = info
@property
def info(self) -> Any:
return self._info

View File

@@ -1,7 +1,5 @@
from typing import List, Optional, Union
from redis.commands.search.dialect import DEFAULT_DIALECT
class Query:
"""
@@ -37,12 +35,11 @@ class Query:
self._in_order: bool = False
self._sortby: Optional[SortbyField] = None
self._return_fields: List = []
self._return_fields_decode_as: dict = {}
self._summarize_fields: List = []
self._highlight_fields: List = []
self._language: Optional[str] = None
self._expander: Optional[str] = None
self._dialect: int = DEFAULT_DIALECT
self._dialect: Optional[int] = None
def query_string(self) -> str:
"""Return the query string of this query only."""
@@ -56,27 +53,13 @@ class Query:
def return_fields(self, *fields) -> "Query":
"""Add fields to return fields."""
for field in fields:
self.return_field(field)
self._return_fields += fields
return self
def return_field(
self,
field: str,
as_field: Optional[str] = None,
decode_field: Optional[bool] = True,
encoding: Optional[str] = "utf8",
) -> "Query":
"""
Add a field to the list of fields to return.
- **field**: The field to include in query results
- **as_field**: The alias for the field
- **decode_field**: Whether to decode the field from bytes to string
- **encoding**: The encoding to use when decoding the field
"""
def return_field(self, field: str, as_field: Optional[str] = None) -> "Query":
"""Add field to return fields (Optional: add 'AS' name
to the field)."""
self._return_fields.append(field)
self._return_fields_decode_as[field] = encoding if decode_field else None
if as_field is not None:
self._return_fields += ("AS", as_field)
return self
@@ -179,8 +162,6 @@ class Query:
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
Since Redis 8.0 default was changed to BM25STD.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""

View File

@@ -1,5 +1,3 @@
from typing import Optional
from ._util import to_string
from .document import Document
@@ -11,19 +9,11 @@ class Result:
"""
def __init__(
self,
res,
hascontent,
duration=0,
has_payload=False,
with_scores=False,
field_encodings: Optional[dict] = None,
self, res, hascontent, duration=0, has_payload=False, with_scores=False
):
"""
- duration: the execution time of the query
- has_payload: whether the query has payloads
- with_scores: whether the query has scores
- field_encodings: a dictionary of field encodings if any is provided
- **snippets**: An optional dictionary of the form
{field: snippet_size} for snippet formatting
"""
self.total = res[0]
@@ -49,22 +39,18 @@ class Result:
fields = {}
if hascontent and res[i + fields_offset] is not None:
keys = map(to_string, res[i + fields_offset][::2])
values = res[i + fields_offset][1::2]
for key, value in zip(keys, values):
if field_encodings is None or key not in field_encodings:
fields[key] = to_string(value)
continue
encoding = field_encodings[key]
# If the encoding is None, we don't need to decode the value
if encoding is None:
fields[key] = value
else:
fields[key] = to_string(value, encoding=encoding)
fields = (
dict(
dict(
zip(
map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]),
)
)
)
if hascontent
else {}
)
try:
del fields["id"]
except KeyError:

View File

@@ -11,35 +11,16 @@ class SentinelCommands:
"""Redis Sentinel's SENTINEL command."""
warnings.warn(DeprecationWarning("Use the individual sentinel_* methods"))
def sentinel_get_master_addr_by_name(self, service_name, return_responses=False):
"""
Returns a (host, port) pair for the given ``service_name`` when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL GET-MASTER-ADDR-BY-NAME",
service_name,
once=True,
return_responses=return_responses,
)
def sentinel_get_master_addr_by_name(self, service_name):
"""Returns a (host, port) pair for the given ``service_name``"""
return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name)
def sentinel_master(self, service_name, return_responses=False):
"""
Returns a dictionary containing the specified masters state, when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL MASTER", service_name, return_responses=return_responses
)
def sentinel_master(self, service_name):
"""Returns a dictionary containing the specified masters state."""
return self.execute_command("SENTINEL MASTER", service_name)
def sentinel_masters(self):
"""
Returns a list of dictionaries containing each master's state.
Important: This function is called by the Sentinel implementation and is
called directly on the Redis standalone client for sentinels,
so it doesn't support the "once" and "return_responses" options.
"""
"""Returns a list of dictionaries containing each master's state."""
return self.execute_command("SENTINEL MASTERS")
def sentinel_monitor(self, name, ip, port, quorum):
@@ -50,27 +31,16 @@ class SentinelCommands:
"""Remove a master from Sentinel's monitoring"""
return self.execute_command("SENTINEL REMOVE", name)
def sentinel_sentinels(self, service_name, return_responses=False):
"""
Returns a list of sentinels for ``service_name``, when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL SENTINELS", service_name, return_responses=return_responses
)
def sentinel_sentinels(self, service_name):
"""Returns a list of sentinels for ``service_name``"""
return self.execute_command("SENTINEL SENTINELS", service_name)
def sentinel_set(self, name, option, value):
"""Set Sentinel monitoring parameters for a given master"""
return self.execute_command("SENTINEL SET", name, option, value)
def sentinel_slaves(self, service_name):
"""
Returns a list of slaves for ``service_name``
Important: This function is called by the Sentinel implementation and is
called directly on the Redis standalone client for sentinels,
so it doesn't support the "once" and "return_responses" options.
"""
"""Returns a list of slaves for ``service_name``"""
return self.execute_command("SENTINEL SLAVES", service_name)
def sentinel_reset(self, pattern):

View File

@@ -84,7 +84,7 @@ class TimeSeries(TimeSeriesCommands):
startup_nodes=self.client.nodes_manager.startup_nodes,
result_callbacks=self.client.result_callbacks,
cluster_response_callbacks=self.client.cluster_response_callbacks,
cluster_error_retry_attempts=self.client.retry.get_retries(),
cluster_error_retry_attempts=self.client.cluster_error_retry_attempts,
read_from_replicas=self.client.read_from_replicas,
reinitialize_steps=self.client.reinitialize_steps,
lock=self.client._lock,

View File

@@ -6,7 +6,7 @@ class TSInfo:
"""
Hold information and statistics on the time-series.
Can be created using ``tsinfo`` command
https://redis.io/docs/latest/commands/ts.info/
https://oss.redis.com/redistimeseries/commands/#tsinfo.
"""
rules = []
@@ -57,7 +57,7 @@ class TSInfo:
Policy that will define handling of duplicate samples.
Can read more about on
https://redis.io/docs/latest/develop/data-types/timeseries/configuration/#duplicate_policy
https://oss.redis.com/redistimeseries/configuration/#duplicate_policy
"""
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.rules = response.get("rules")
@@ -78,7 +78,7 @@ class TSInfo:
self.chunk_size = response["chunkSize"]
if "duplicatePolicy" in response:
self.duplicate_policy = response["duplicatePolicy"]
if isinstance(self.duplicate_policy, bytes):
if type(self.duplicate_policy) == bytes:
self.duplicate_policy = self.duplicate_policy.decode()
def get(self, item):

View File

@@ -5,7 +5,7 @@ def list_to_dict(aList):
return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))}
def parse_range(response, **kwargs):
def parse_range(response):
"""Parse range response. Used by TS.RANGE and TS.REVRANGE."""
return [tuple((r[0], float(r[1]))) for r in response]

View File

@@ -1,46 +0,0 @@
import json
from redis._parsers.helpers import pairs_to_dict
from redis.commands.vectorset.utils import (
parse_vemb_result,
parse_vlinks_result,
parse_vsim_result,
)
from ..helpers import get_protocol_version
from .commands import (
VEMB_CMD,
VGETATTR_CMD,
VINFO_CMD,
VLINKS_CMD,
VSIM_CMD,
VectorSetCommands,
)
class VectorSet(VectorSetCommands):
def __init__(self, client, **kwargs):
"""Create a new VectorSet client."""
# Set the module commands' callbacks
self._MODULE_CALLBACKS = {
VEMB_CMD: parse_vemb_result,
VGETATTR_CMD: lambda r: r and json.loads(r) or None,
}
self._RESP2_MODULE_CALLBACKS = {
VINFO_CMD: lambda r: r and pairs_to_dict(r) or None,
VSIM_CMD: parse_vsim_result,
VLINKS_CMD: parse_vlinks_result,
}
self._RESP3_MODULE_CALLBACKS = {}
self.client = client
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
self._MODULE_CALLBACKS.update(self._RESP3_MODULE_CALLBACKS)
else:
self._MODULE_CALLBACKS.update(self._RESP2_MODULE_CALLBACKS)
for k, v in self._MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)

View File

@@ -1,374 +0,0 @@
import json
from enum import Enum
from typing import Awaitable, Dict, List, Optional, Union
from redis.client import NEVER_DECODE
from redis.commands.helpers import get_protocol_version
from redis.exceptions import DataError
from redis.typing import CommandsProtocol, EncodableT, KeyT, Number
VADD_CMD = "VADD"
VSIM_CMD = "VSIM"
VREM_CMD = "VREM"
VDIM_CMD = "VDIM"
VCARD_CMD = "VCARD"
VEMB_CMD = "VEMB"
VLINKS_CMD = "VLINKS"
VINFO_CMD = "VINFO"
VSETATTR_CMD = "VSETATTR"
VGETATTR_CMD = "VGETATTR"
VRANDMEMBER_CMD = "VRANDMEMBER"
class QuantizationOptions(Enum):
"""Quantization options for the VADD command."""
NOQUANT = "NOQUANT"
BIN = "BIN"
Q8 = "Q8"
class CallbacksOptions(Enum):
"""Options that can be set for the commands callbacks"""
RAW = "RAW"
WITHSCORES = "WITHSCORES"
ALLOW_DECODING = "ALLOW_DECODING"
RESP3 = "RESP3"
class VectorSetCommands(CommandsProtocol):
"""Redis VectorSet commands"""
def vadd(
self,
key: KeyT,
vector: Union[List[float], bytes],
element: str,
reduce_dim: Optional[int] = None,
cas: Optional[bool] = False,
quantization: Optional[QuantizationOptions] = None,
ef: Optional[Number] = None,
attributes: Optional[Union[dict, str]] = None,
numlinks: Optional[int] = None,
) -> Union[Awaitable[int], int]:
"""
Add vector ``vector`` for element ``element`` to a vector set ``key``.
``reduce_dim`` sets the dimensions to reduce the vector to.
If not provided, the vector is not reduced.
``cas`` is a boolean flag that indicates whether to use CAS (check-and-set style)
when adding the vector. If not provided, CAS is not used.
``quantization`` sets the quantization type to use.
If not provided, int8 quantization is used.
The options are:
- NOQUANT: No quantization
- BIN: Binary quantization
- Q8: Signed 8-bit quantization
``ef`` sets the exploration factor to use.
If not provided, the default exploration factor is used.
``attributes`` is a dictionary or json string that contains the attributes to set for the vector.
If not provided, no attributes are set.
``numlinks`` sets the number of links to create for the vector.
If not provided, the default number of links is used.
For more information see https://redis.io/commands/vadd
"""
if not vector or not element:
raise DataError("Both vector and element must be provided")
pieces = []
if reduce_dim:
pieces.extend(["REDUCE", reduce_dim])
values_pieces = []
if isinstance(vector, bytes):
values_pieces.extend(["FP32", vector])
else:
values_pieces.extend(["VALUES", len(vector)])
values_pieces.extend(vector)
pieces.extend(values_pieces)
pieces.append(element)
if cas:
pieces.append("CAS")
if quantization:
pieces.append(quantization.value)
if ef:
pieces.extend(["EF", ef])
if attributes:
if isinstance(attributes, dict):
# transform attributes to json string
attributes_json = json.dumps(attributes)
else:
attributes_json = attributes
pieces.extend(["SETATTR", attributes_json])
if numlinks:
pieces.extend(["M", numlinks])
return self.execute_command(VADD_CMD, key, *pieces)
def vsim(
self,
key: KeyT,
input: Union[List[float], bytes, str],
with_scores: Optional[bool] = False,
count: Optional[int] = None,
ef: Optional[Number] = None,
filter: Optional[str] = None,
filter_ef: Optional[str] = None,
truth: Optional[bool] = False,
no_thread: Optional[bool] = False,
epsilon: Optional[Number] = None,
) -> Union[
Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]],
Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]],
]:
"""
Compare a vector or element ``input`` with the other vectors in a vector set ``key``.
``with_scores`` sets if the results should be returned with the
similarity scores of the elements in the result.
``count`` sets the number of results to return.
``ef`` sets the exploration factor.
``filter`` sets filter that should be applied for the search.
``filter_ef`` sets the max filtering effort.
``truth`` when enabled forces the command to perform linear scan.
``no_thread`` when enabled forces the command to execute the search
on the data structure in the main thread.
``epsilon`` floating point between 0 and 1, if specified will return
only elements with distance no further than the specified one.
For more information see https://redis.io/commands/vsim
"""
if not input:
raise DataError("'input' should be provided")
pieces = []
options = {}
if isinstance(input, bytes):
pieces.extend(["FP32", input])
elif isinstance(input, list):
pieces.extend(["VALUES", len(input)])
pieces.extend(input)
else:
pieces.extend(["ELE", input])
if with_scores:
pieces.append("WITHSCORES")
options[CallbacksOptions.WITHSCORES.value] = True
if count:
pieces.extend(["COUNT", count])
if epsilon:
pieces.extend(["EPSILON", epsilon])
if ef:
pieces.extend(["EF", ef])
if filter:
pieces.extend(["FILTER", filter])
if filter_ef:
pieces.extend(["FILTER-EF", filter_ef])
if truth:
pieces.append("TRUTH")
if no_thread:
pieces.append("NOTHREAD")
return self.execute_command(VSIM_CMD, key, *pieces, **options)
def vdim(self, key: KeyT) -> Union[Awaitable[int], int]:
"""
Get the dimension of a vector set.
In the case of vectors that were populated using the `REDUCE`
option, for random projection, the vector set will report the size of
the projected (reduced) dimension.
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
For more information see https://redis.io/commands/vdim
"""
return self.execute_command(VDIM_CMD, key)
def vcard(self, key: KeyT) -> Union[Awaitable[int], int]:
"""
Get the cardinality(the number of elements) of a vector set with key ``key``.
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
For more information see https://redis.io/commands/vcard
"""
return self.execute_command(VCARD_CMD, key)
def vrem(self, key: KeyT, element: str) -> Union[Awaitable[int], int]:
"""
Remove an element from a vector set.
For more information see https://redis.io/commands/vrem
"""
return self.execute_command(VREM_CMD, key, element)
def vemb(
self, key: KeyT, element: str, raw: Optional[bool] = False
) -> Union[
Awaitable[Optional[Union[List[EncodableT], Dict[str, EncodableT]]]],
Optional[Union[List[EncodableT], Dict[str, EncodableT]]],
]:
"""
Get the approximated vector of an element ``element`` from vector set ``key``.
``raw`` is a boolean flag that indicates whether to return the
interal representation used by the vector.
For more information see https://redis.io/commands/vembed
"""
options = {}
pieces = []
pieces.extend([key, element])
if get_protocol_version(self.client) in ["3", 3]:
options[CallbacksOptions.RESP3.value] = True
if raw:
pieces.append("RAW")
options[NEVER_DECODE] = True
if (
hasattr(self.client, "connection_pool")
and self.client.connection_pool.connection_kwargs["decode_responses"]
) or (
hasattr(self.client, "nodes_manager")
and self.client.nodes_manager.connection_kwargs["decode_responses"]
):
# allow decoding in the postprocessing callback
# if the user set decode_responses=True
# in the connection pool
options[CallbacksOptions.ALLOW_DECODING.value] = True
options[CallbacksOptions.RAW.value] = True
return self.execute_command(VEMB_CMD, *pieces, **options)
def vlinks(
self, key: KeyT, element: str, with_scores: Optional[bool] = False
) -> Union[
Awaitable[
Optional[
List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]
]
],
Optional[List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]],
]:
"""
Returns the neighbors for each level the element ``element`` exists in the vector set ``key``.
The result is a list of lists, where each list contains the neighbors for one level.
If the element does not exist, or if the vector set does not exist, None is returned.
If the ``WITHSCORES`` option is provided, the result is a list of dicts,
where each dict contains the neighbors for one level, with the scores as values.
For more information see https://redis.io/commands/vlinks
"""
options = {}
pieces = []
pieces.extend([key, element])
if with_scores:
pieces.append("WITHSCORES")
options[CallbacksOptions.WITHSCORES.value] = True
return self.execute_command(VLINKS_CMD, *pieces, **options)
def vinfo(self, key: KeyT) -> Union[Awaitable[dict], dict]:
"""
Get information about a vector set.
For more information see https://redis.io/commands/vinfo
"""
return self.execute_command(VINFO_CMD, key)
def vsetattr(
self, key: KeyT, element: str, attributes: Optional[Union[dict, str]] = None
) -> Union[Awaitable[int], int]:
"""
Associate or remove JSON attributes ``attributes`` of element ``element``
for vector set ``key``.
For more information see https://redis.io/commands/vsetattr
"""
if attributes is None:
attributes_json = "{}"
elif isinstance(attributes, dict):
# transform attributes to json string
attributes_json = json.dumps(attributes)
else:
attributes_json = attributes
return self.execute_command(VSETATTR_CMD, key, element, attributes_json)
def vgetattr(
self, key: KeyT, element: str
) -> Union[Optional[Awaitable[dict]], Optional[dict]]:
"""
Retrieve the JSON attributes of an element ``elemet`` for vector set ``key``.
If the element does not exist, or if the vector set does not exist, None is
returned.
For more information see https://redis.io/commands/vgetattr
"""
return self.execute_command(VGETATTR_CMD, key, element)
def vrandmember(
self, key: KeyT, count: Optional[int] = None
) -> Union[
Awaitable[Optional[Union[List[str], str]]], Optional[Union[List[str], str]]
]:
"""
Returns random elements from a vector set ``key``.
``count`` is the number of elements to return.
If ``count`` is not provided, a single element is returned as a single string.
If ``count`` is positive(smaller than the number of elements
in the vector set), the command returns a list with up to ``count``
distinct elements from the vector set
If ``count`` is negative, the command returns a list with ``count`` random elements,
potentially with duplicates.
If ``count`` is greater than the number of elements in the vector set,
only the entire set is returned as a list.
If the vector set does not exist, ``None`` is returned.
For more information see https://redis.io/commands/vrandmember
"""
pieces = []
pieces.append(key)
if count is not None:
pieces.append(count)
return self.execute_command(VRANDMEMBER_CMD, *pieces)

View File

@@ -1,94 +0,0 @@
from redis._parsers.helpers import pairs_to_dict
from redis.commands.vectorset.commands import CallbacksOptions
def parse_vemb_result(response, **options):
"""
Handle VEMB result since the command can returning different result
structures depending on input options and on quantization type of the vector set.
Parsing VEMB result into:
- List[Union[bytes, Union[int, float]]]
- Dict[str, Union[bytes, str, float]]
"""
if response is None:
return response
if options.get(CallbacksOptions.RAW.value):
result = {}
result["quantization"] = (
response[0].decode("utf-8")
if options.get(CallbacksOptions.ALLOW_DECODING.value)
else response[0]
)
result["raw"] = response[1]
result["l2"] = float(response[2])
if len(response) > 3:
result["range"] = float(response[3])
return result
else:
if options.get(CallbacksOptions.RESP3.value):
return response
result = []
for i in range(len(response)):
try:
result.append(int(response[i]))
except ValueError:
# if the value is not an integer, it should be a float
result.append(float(response[i]))
return result
def parse_vlinks_result(response, **options):
"""
Handle VLINKS result since the command can be returning different result
structures depending on input options.
Parsing VLINKS result into:
- List[List[str]]
- List[Dict[str, Number]]
"""
if response is None:
return response
if options.get(CallbacksOptions.WITHSCORES.value):
result = []
# Redis will return a list of list of strings.
# This list have to be transformed to list of dicts
for level_item in response:
level_data_dict = {}
for key, value in pairs_to_dict(level_item).items():
value = float(value)
level_data_dict[key] = value
result.append(level_data_dict)
return result
else:
# return the list of elements for each level
# list of lists
return response
def parse_vsim_result(response, **options):
"""
Handle VSIM result since the command can be returning different result
structures depending on input options.
Parsing VSIM result into:
- List[List[str]]
- List[Dict[str, Number]]
"""
if response is None:
return response
if options.get(CallbacksOptions.WITHSCORES.value):
# Redis will return a list of list of pairs.
# This list have to be transformed to dict
result_dict = {}
for key, value in pairs_to_dict(response).items():
value = float(value)
result_dict[key] = value
return result_dict
else:
# return the list of elements for each level
# list of lists
return response

View File

@@ -0,0 +1,6 @@
# flake8: noqa
try:
from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import]
except ImportError:
from typing_extensions import Literal # lgtm [py/unused-import]
from typing_extensions import Protocol, TypedDict

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,4 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Tuple, Union
logger = logging.getLogger(__name__)
from typing import Optional, Tuple, Union
class CredentialProvider:
@@ -13,38 +9,6 @@ class CredentialProvider:
def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
raise NotImplementedError("get_credentials must be implemented")
async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
logger.warning(
"This method is added for backward compatability. "
"Please override it in your implementation."
)
return self.get_credentials()
class StreamingCredentialProvider(CredentialProvider, ABC):
"""
Credential provider that streams credentials in the background.
"""
@abstractmethod
def on_next(self, callback: Callable[[Any], None]):
"""
Specifies the callback that should be invoked
when the next credentials will be retrieved.
:param callback: Callback with
:return:
"""
pass
@abstractmethod
def on_error(self, callback: Callable[[Exception], None]):
pass
@abstractmethod
def is_streaming(self) -> bool:
pass
class UsernamePasswordCredentialProvider(CredentialProvider):
"""
@@ -60,6 +24,3 @@ class UsernamePasswordCredentialProvider(CredentialProvider):
if self.username:
return self.username, self.password
return (self.password,)
async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
return self.get_credentials()

View File

@@ -1,394 +0,0 @@
import asyncio
import threading
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Optional, Union
from redis.auth.token import TokenInterface
from redis.credentials import CredentialProvider, StreamingCredentialProvider
class EventListenerInterface(ABC):
"""
Represents a listener for given event object.
"""
@abstractmethod
def listen(self, event: object):
pass
class AsyncEventListenerInterface(ABC):
"""
Represents an async listener for given event object.
"""
@abstractmethod
async def listen(self, event: object):
pass
class EventDispatcherInterface(ABC):
"""
Represents a dispatcher that dispatches events to listeners
associated with given event.
"""
@abstractmethod
def dispatch(self, event: object):
pass
@abstractmethod
async def dispatch_async(self, event: object):
pass
class EventException(Exception):
"""
Exception wrapper that adds an event object into exception context.
"""
def __init__(self, exception: Exception, event: object):
self.exception = exception
self.event = event
super().__init__(exception)
class EventDispatcher(EventDispatcherInterface):
# TODO: Make dispatcher to accept external mappings.
def __init__(self):
"""
Mapping should be extended for any new events or listeners to be added.
"""
self._event_listeners_mapping = {
AfterConnectionReleasedEvent: [
ReAuthConnectionListener(),
],
AfterPooledConnectionsInstantiationEvent: [
RegisterReAuthForPooledConnections()
],
AfterSingleConnectionInstantiationEvent: [
RegisterReAuthForSingleConnection()
],
AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()],
AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()],
AsyncAfterConnectionReleasedEvent: [
AsyncReAuthConnectionListener(),
],
}
def dispatch(self, event: object):
listeners = self._event_listeners_mapping.get(type(event))
for listener in listeners:
listener.listen(event)
async def dispatch_async(self, event: object):
listeners = self._event_listeners_mapping.get(type(event))
for listener in listeners:
await listener.listen(event)
class AfterConnectionReleasedEvent:
"""
Event that will be fired before each command execution.
"""
def __init__(self, connection):
self._connection = connection
@property
def connection(self):
return self._connection
class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent):
pass
class ClientType(Enum):
SYNC = ("sync",)
ASYNC = ("async",)
class AfterPooledConnectionsInstantiationEvent:
"""
Event that will be fired after pooled connection instances was created.
"""
def __init__(
self,
connection_pools: List,
client_type: ClientType,
credential_provider: Optional[CredentialProvider] = None,
):
self._connection_pools = connection_pools
self._client_type = client_type
self._credential_provider = credential_provider
@property
def connection_pools(self):
return self._connection_pools
@property
def client_type(self) -> ClientType:
return self._client_type
@property
def credential_provider(self) -> Union[CredentialProvider, None]:
return self._credential_provider
class AfterSingleConnectionInstantiationEvent:
"""
Event that will be fired after single connection instances was created.
:param connection_lock: For sync client thread-lock should be provided,
for async asyncio.Lock
"""
def __init__(
self,
connection,
client_type: ClientType,
connection_lock: Union[threading.RLock, asyncio.Lock],
):
self._connection = connection
self._client_type = client_type
self._connection_lock = connection_lock
@property
def connection(self):
return self._connection
@property
def client_type(self) -> ClientType:
return self._client_type
@property
def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
return self._connection_lock
class AfterPubSubConnectionInstantiationEvent:
def __init__(
self,
pubsub_connection,
connection_pool,
client_type: ClientType,
connection_lock: Union[threading.RLock, asyncio.Lock],
):
self._pubsub_connection = pubsub_connection
self._connection_pool = connection_pool
self._client_type = client_type
self._connection_lock = connection_lock
@property
def pubsub_connection(self):
return self._pubsub_connection
@property
def connection_pool(self):
return self._connection_pool
@property
def client_type(self) -> ClientType:
return self._client_type
@property
def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
return self._connection_lock
class AfterAsyncClusterInstantiationEvent:
"""
Event that will be fired after async cluster instance was created.
Async cluster doesn't use connection pools,
instead ClusterNode object manages connections.
"""
def __init__(
self,
nodes: dict,
credential_provider: Optional[CredentialProvider] = None,
):
self._nodes = nodes
self._credential_provider = credential_provider
@property
def nodes(self) -> dict:
return self._nodes
@property
def credential_provider(self) -> Union[CredentialProvider, None]:
return self._credential_provider
class ReAuthConnectionListener(EventListenerInterface):
"""
Listener that performs re-authentication of given connection.
"""
def listen(self, event: AfterConnectionReleasedEvent):
event.connection.re_auth()
class AsyncReAuthConnectionListener(AsyncEventListenerInterface):
"""
Async listener that performs re-authentication of given connection.
"""
async def listen(self, event: AsyncAfterConnectionReleasedEvent):
await event.connection.re_auth()
class RegisterReAuthForPooledConnections(EventListenerInterface):
"""
Listener that registers a re-authentication callback for pooled connections.
Required by :class:`StreamingCredentialProvider`.
"""
def __init__(self):
self._event = None
def listen(self, event: AfterPooledConnectionsInstantiationEvent):
if isinstance(event.credential_provider, StreamingCredentialProvider):
self._event = event
if event.client_type == ClientType.SYNC:
event.credential_provider.on_next(self._re_auth)
event.credential_provider.on_error(self._raise_on_error)
else:
event.credential_provider.on_next(self._re_auth_async)
event.credential_provider.on_error(self._raise_on_error_async)
def _re_auth(self, token):
for pool in self._event.connection_pools:
pool.re_auth_callback(token)
async def _re_auth_async(self, token):
for pool in self._event.connection_pools:
await pool.re_auth_callback(token)
def _raise_on_error(self, error: Exception):
raise EventException(error, self._event)
async def _raise_on_error_async(self, error: Exception):
raise EventException(error, self._event)
class RegisterReAuthForSingleConnection(EventListenerInterface):
"""
Listener that registers a re-authentication callback for single connection.
Required by :class:`StreamingCredentialProvider`.
"""
def __init__(self):
self._event = None
def listen(self, event: AfterSingleConnectionInstantiationEvent):
if isinstance(
event.connection.credential_provider, StreamingCredentialProvider
):
self._event = event
if event.client_type == ClientType.SYNC:
event.connection.credential_provider.on_next(self._re_auth)
event.connection.credential_provider.on_error(self._raise_on_error)
else:
event.connection.credential_provider.on_next(self._re_auth_async)
event.connection.credential_provider.on_error(
self._raise_on_error_async
)
def _re_auth(self, token):
with self._event.connection_lock:
self._event.connection.send_command(
"AUTH", token.try_get("oid"), token.get_value()
)
self._event.connection.read_response()
async def _re_auth_async(self, token):
async with self._event.connection_lock:
await self._event.connection.send_command(
"AUTH", token.try_get("oid"), token.get_value()
)
await self._event.connection.read_response()
def _raise_on_error(self, error: Exception):
raise EventException(error, self._event)
async def _raise_on_error_async(self, error: Exception):
raise EventException(error, self._event)
class RegisterReAuthForAsyncClusterNodes(EventListenerInterface):
def __init__(self):
self._event = None
def listen(self, event: AfterAsyncClusterInstantiationEvent):
if isinstance(event.credential_provider, StreamingCredentialProvider):
self._event = event
event.credential_provider.on_next(self._re_auth)
event.credential_provider.on_error(self._raise_on_error)
async def _re_auth(self, token: TokenInterface):
for key in self._event.nodes:
await self._event.nodes[key].re_auth_callback(token)
async def _raise_on_error(self, error: Exception):
raise EventException(error, self._event)
class RegisterReAuthForPubSub(EventListenerInterface):
def __init__(self):
self._connection = None
self._connection_pool = None
self._client_type = None
self._connection_lock = None
self._event = None
def listen(self, event: AfterPubSubConnectionInstantiationEvent):
if isinstance(
event.pubsub_connection.credential_provider, StreamingCredentialProvider
) and event.pubsub_connection.get_protocol() in [3, "3"]:
self._event = event
self._connection = event.pubsub_connection
self._connection_pool = event.connection_pool
self._client_type = event.client_type
self._connection_lock = event.connection_lock
if self._client_type == ClientType.SYNC:
self._connection.credential_provider.on_next(self._re_auth)
self._connection.credential_provider.on_error(self._raise_on_error)
else:
self._connection.credential_provider.on_next(self._re_auth_async)
self._connection.credential_provider.on_error(
self._raise_on_error_async
)
def _re_auth(self, token: TokenInterface):
with self._connection_lock:
self._connection.send_command(
"AUTH", token.try_get("oid"), token.get_value()
)
self._connection.read_response()
self._connection_pool.re_auth_callback(token)
async def _re_auth_async(self, token: TokenInterface):
async with self._connection_lock:
await self._connection.send_command(
"AUTH", token.try_get("oid"), token.get_value()
)
await self._connection.read_response()
await self._connection_pool.re_auth_callback(token)
def _raise_on_error(self, error: Exception):
raise EventException(error, self._event)
async def _raise_on_error_async(self, error: Exception):
raise EventException(error, self._event)

View File

@@ -79,24 +79,18 @@ class ModuleError(ResponseError):
class LockError(RedisError, ValueError):
"Errors acquiring or releasing a lock"
# NOTE: For backwards compatibility, this class derives from ValueError.
# This was originally chosen to behave like threading.Lock.
def __init__(self, message=None, lock_name=None):
self.message = message
self.lock_name = lock_name
pass
class LockNotOwnedError(LockError):
"Error trying to extend or release a lock that is not owned (anymore)"
"Error trying to extend or release a lock that is (no longer) owned"
pass
class ChildDeadlockedError(Exception):
"Error indicating that a child process is deadlocked after a fork()"
pass
@@ -221,27 +215,4 @@ class SlotNotCoveredError(RedisClusterException):
class MaxConnectionsError(ConnectionError):
"""
Raised when a connection pool has reached its max_connections limit.
This indicates pool exhaustion rather than an actual connection failure.
"""
pass
class CrossSlotTransactionError(RedisClusterException):
"""
Raised when a transaction or watch is triggered in a pipeline
and not all keys or all commands belong to the same slot.
"""
pass
class InvalidPipelineStack(RedisClusterException):
"""
Raised on unexpected response length on pipelines. This is
most likely a handling error on the stack.
"""
pass
...

View File

@@ -1,4 +1,3 @@
import logging
import threading
import time as mod_time
import uuid
@@ -8,8 +7,6 @@ from typing import Optional, Type
from redis.exceptions import LockError, LockNotOwnedError
from redis.typing import Number
logger = logging.getLogger(__name__)
class Lock:
"""
@@ -85,7 +82,6 @@ class Lock:
blocking: bool = True,
blocking_timeout: Optional[Number] = None,
thread_local: bool = True,
raise_on_release_error: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
@@ -129,11 +125,6 @@ class Lock:
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
``raise_on_release_error`` indicates whether to raise an exception when
the lock is no longer owned when exiting the context manager. By default,
this is True, meaning an exception will be raised. If False, the warning
will be logged and the exception will be suppressed.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
@@ -149,7 +140,6 @@ class Lock:
self.blocking = blocking
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.raise_on_release_error = raise_on_release_error
self.local = threading.local() if self.thread_local else SimpleNamespace()
self.local.token = None
self.register_scripts()
@@ -167,10 +157,7 @@ class Lock:
def __enter__(self) -> "Lock":
if self.acquire():
return self
raise LockError(
"Unable to acquire lock within the time specified",
lock_name=self.name,
)
raise LockError("Unable to acquire lock within the time specified")
def __exit__(
self,
@@ -178,14 +165,7 @@ class Lock:
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
try:
self.release()
except LockError:
if self.raise_on_release_error:
raise
logger.warning(
"Lock was unlocked or no longer owned when exiting context manager."
)
self.release()
def acquire(
self,
@@ -268,10 +248,7 @@ class Lock:
"""
expected_token = self.local.token
if expected_token is None:
raise LockError(
"Cannot release a lock that's not owned or is already unlocked.",
lock_name=self.name,
)
raise LockError("Cannot release an unlocked lock")
self.local.token = None
self.do_release(expected_token)
@@ -279,12 +256,9 @@ class Lock:
if not bool(
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
):
raise LockNotOwnedError(
"Cannot release a lock that's no longer owned",
lock_name=self.name,
)
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
def extend(self, additional_time: Number, replace_ttl: bool = False) -> bool:
def extend(self, additional_time: int, replace_ttl: bool = False) -> bool:
"""
Adds more time to an already acquired lock.
@@ -296,12 +270,12 @@ class Lock:
`additional_time`.
"""
if self.local.token is None:
raise LockError("Cannot extend an unlocked lock", lock_name=self.name)
raise LockError("Cannot extend an unlocked lock")
if self.timeout is None:
raise LockError("Cannot extend a lock with no timeout", lock_name=self.name)
raise LockError("Cannot extend a lock with no timeout")
return self.do_extend(additional_time, replace_ttl)
def do_extend(self, additional_time: Number, replace_ttl: bool) -> bool:
def do_extend(self, additional_time: int, replace_ttl: bool) -> bool:
additional_time = int(additional_time * 1000)
if not bool(
self.lua_extend(
@@ -310,10 +284,7 @@ class Lock:
client=self.redis,
)
):
raise LockNotOwnedError(
"Cannot extend a lock that's no longer owned",
lock_name=self.name,
)
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
return True
def reacquire(self) -> bool:
@@ -321,12 +292,9 @@ class Lock:
Resets a TTL of an already acquired lock back to a timeout value.
"""
if self.local.token is None:
raise LockError("Cannot reacquire an unlocked lock", lock_name=self.name)
raise LockError("Cannot reacquire an unlocked lock")
if self.timeout is None:
raise LockError(
"Cannot reacquire a lock with no timeout",
lock_name=self.name,
)
raise LockError("Cannot reacquire a lock with no timeout")
return self.do_reacquire()
def do_reacquire(self) -> bool:
@@ -336,8 +304,5 @@ class Lock:
keys=[self.name], args=[self.local.token, timeout], client=self.redis
)
):
raise LockNotOwnedError(
"Cannot reacquire a lock that's no longer owned",
lock_name=self.name,
)
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
return True

View File

@@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from cryptography.hazmat.primitives.hashes import SHA1, Hash
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from cryptography.x509 import ocsp
from redis.exceptions import AuthorizationError, ConnectionError
@@ -57,12 +56,12 @@ def _check_certificate(issuer_cert, ocsp_bytes, validate=True):
if ocsp_response.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL:
if ocsp_response.certificate_status != ocsp.OCSPCertStatus.GOOD:
raise ConnectionError(
f"Received an {str(ocsp_response.certificate_status).split('.')[1]} "
f'Received an {str(ocsp_response.certificate_status).split(".")[1]} '
"ocsp certificate status"
)
else:
raise ConnectionError(
"failed to retrieve a successful response from the ocsp responder"
"failed to retrieve a sucessful response from the ocsp responder"
)
if ocsp_response.this_update >= datetime.datetime.now():
@@ -140,7 +139,7 @@ def _get_pubkey_hash(certificate):
def ocsp_staple_verifier(con, ocsp_bytes, expected=None):
"""An implementation of a function for set_ocsp_client_callback in PyOpenSSL.
"""An implemention of a function for set_ocsp_client_callback in PyOpenSSL.
This function validates that the provide ocsp_bytes response is valid,
and matches the expected, stapled responses.
@@ -267,7 +266,7 @@ class OCSPVerifier:
return url
def check_certificate(self, server, cert, issuer_url):
"""Checks the validity of an ocsp server for an issuer"""
"""Checks the validitity of an ocsp server for an issuer"""
r = requests.get(issuer_url)
if not r.ok:

View File

@@ -1,27 +1,17 @@
import abc
import socket
from time import sleep
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Tuple, Type, TypeVar
from redis.exceptions import ConnectionError, TimeoutError
T = TypeVar("T")
E = TypeVar("E", bound=Exception, covariant=True)
if TYPE_CHECKING:
from redis.backoff import AbstractBackoff
class AbstractRetry(Generic[E], abc.ABC):
class Retry:
"""Retry a specific number of times after a failure"""
_supported_errors: Tuple[Type[E], ...]
def __init__(
self,
backoff: "AbstractBackoff",
retries: int,
supported_errors: Tuple[Type[E], ...],
backoff,
retries,
supported_errors=(ConnectionError, TimeoutError, socket.timeout),
):
"""
Initialize a `Retry` object with a `Backoff` object
@@ -34,14 +24,7 @@ class AbstractRetry(Generic[E], abc.ABC):
self._retries = retries
self._supported_errors = supported_errors
@abc.abstractmethod
def __eq__(self, other: Any) -> bool:
return NotImplemented
def __hash__(self) -> int:
return hash((self._backoff, self._retries, frozenset(self._supported_errors)))
def update_supported_errors(self, specified_errors: Iterable[Type[E]]) -> None:
def update_supported_errors(self, specified_errors: list):
"""
Updates the supported errors with the specified error types
"""
@@ -49,49 +32,7 @@ class AbstractRetry(Generic[E], abc.ABC):
set(self._supported_errors + tuple(specified_errors))
)
def get_retries(self) -> int:
"""
Get the number of retries.
"""
return self._retries
def update_retries(self, value: int) -> None:
"""
Set the number of retries.
"""
self._retries = value
class Retry(AbstractRetry[Exception]):
__hash__ = AbstractRetry.__hash__
def __init__(
self,
backoff: "AbstractBackoff",
retries: int,
supported_errors: Tuple[Type[Exception], ...] = (
ConnectionError,
TimeoutError,
socket.timeout,
),
):
super().__init__(backoff, retries, supported_errors)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Retry):
return NotImplemented
return (
self._backoff == other._backoff
and self._retries == other._retries
and set(self._supported_errors) == set(other._supported_errors)
)
def call_with_retry(
self,
do: Callable[[], T],
fail: Callable[[Exception], Any],
) -> T:
def call_with_retry(self, do, fail):
"""
Execute an operation that might fail and returns its result, or
raise the exception that was thrown depending on the `Backoff` object.

View File

@@ -5,12 +5,8 @@ from typing import Optional
from redis.client import Redis
from redis.commands import SentinelCommands
from redis.connection import Connection, ConnectionPool, SSLConnection
from redis.exceptions import (
ConnectionError,
ReadOnlyError,
ResponseError,
TimeoutError,
)
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
from redis.utils import str_if_bytes
class MasterNotFoundError(ConnectionError):
@@ -28,10 +24,7 @@ class SentinelManagedConnection(Connection):
def __repr__(self):
pool = self.connection_pool
s = (
f"<{type(self).__module__}.{type(self).__name__}"
f"(service={pool.service_name}%s)>"
)
s = f"{type(self).__name__}<service={pool.service_name}%s>"
if self.host:
host_info = f",host={self.host},port={self.port}"
s = s % host_info
@@ -39,11 +32,11 @@ class SentinelManagedConnection(Connection):
def connect_to(self, address):
self.host, self.port = address
self.connect_check_health(
check_health=self.connection_pool.check_connection,
retry_socket_connect=False,
)
super().connect()
if self.connection_pool.check_connection:
self.send_command("PING")
if str_if_bytes(self.read_response()) != "PONG":
raise ConnectionError("PING failed")
def _connect_retry(self):
if self._sock:
@@ -149,11 +142,9 @@ class SentinelConnectionPool(ConnectionPool):
def __init__(self, service_name, sentinel_manager, **kwargs):
kwargs["connection_class"] = kwargs.get(
"connection_class",
(
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection
),
SentinelManagedSSLConnection
if kwargs.pop("ssl", False)
else SentinelManagedConnection,
)
self.is_master = kwargs.pop("is_master", True)
self.check_connection = kwargs.pop("check_connection", False)
@@ -171,10 +162,7 @@ class SentinelConnectionPool(ConnectionPool):
def __repr__(self):
role = "master" if self.is_master else "slave"
return (
f"<{type(self).__module__}.{type(self).__name__}"
f"(service={self.service_name}({role}))>"
)
return f"{type(self).__name__}<service={self.service_name}({role})"
def reset(self):
super().reset()
@@ -233,7 +221,6 @@ class Sentinel(SentinelCommands):
sentinels,
min_other_sentinels=0,
sentinel_kwargs=None,
force_master_ip=None,
**connection_kwargs,
):
# if sentinel_kwargs isn't defined, use the socket_* options from
@@ -250,7 +237,6 @@ class Sentinel(SentinelCommands):
]
self.min_other_sentinels = min_other_sentinels
self.connection_kwargs = connection_kwargs
self._force_master_ip = force_master_ip
def execute_command(self, *args, **kwargs):
"""
@@ -258,27 +244,16 @@ class Sentinel(SentinelCommands):
once - If set to True, then execute the resulting command on a single
node at random, rather than across the entire sentinel cluster.
"""
once = bool(kwargs.pop("once", False))
# Check if command is supposed to return the original
# responses instead of boolean value.
return_responses = bool(kwargs.pop("return_responses", False))
once = bool(kwargs.get("once", False))
if "once" in kwargs.keys():
kwargs.pop("once")
if once:
response = random.choice(self.sentinels).execute_command(*args, **kwargs)
if return_responses:
return [response]
else:
return True if response else False
responses = []
for sentinel in self.sentinels:
responses.append(sentinel.execute_command(*args, **kwargs))
if return_responses:
return responses
return all(responses)
random.choice(self.sentinels).execute_command(*args, **kwargs)
else:
for sentinel in self.sentinels:
sentinel.execute_command(*args, **kwargs)
return True
def __repr__(self):
sentinel_addresses = []
@@ -286,10 +261,7 @@ class Sentinel(SentinelCommands):
sentinel_addresses.append(
"{host}:{port}".format_map(sentinel.connection_pool.connection_kwargs)
)
return (
f"<{type(self).__module__}.{type(self).__name__}"
f"(sentinels=[{','.join(sentinel_addresses)}])>"
)
return f'{type(self).__name__}<sentinels=[{",".join(sentinel_addresses)}]>'
def check_master_state(self, state, service_name):
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
@@ -321,13 +293,7 @@ class Sentinel(SentinelCommands):
sentinel,
self.sentinels[0],
)
ip = (
self._force_master_ip
if self._force_master_ip is not None
else state["ip"]
)
return ip, state["port"]
return state["ip"], state["port"]
error_info = ""
if len(collected_errors) > 0:
@@ -364,8 +330,6 @@ class Sentinel(SentinelCommands):
):
"""
Returns a redis client instance for the ``service_name`` master.
Sentinel client will detect failover and reconnect Redis clients
automatically.
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
used to retrieve the master's address before establishing a new

View File

@@ -7,18 +7,21 @@ from typing import (
Awaitable,
Iterable,
Mapping,
Protocol,
Type,
TypeVar,
Union,
)
from redis.compat import Protocol
if TYPE_CHECKING:
from redis._parsers import Encoder
from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool
from redis.connection import ConnectionPool
Number = Union[int, float]
EncodedT = Union[bytes, bytearray, memoryview]
EncodedT = Union[bytes, memoryview]
DecodedT = Union[str, int, float]
EncodableT = Union[EncodedT, DecodedT]
AbsExpiryT = Union[int, datetime]
@@ -30,7 +33,6 @@ KeyT = _StringLikeT # Main redis key space
PatternT = _StringLikeT # Patterns matched against keys, fields etc
FieldT = EncodableT # Fields within hash tables, streams and geo commands
KeysT = Union[KeyT, Iterable[KeyT]]
ResponseT = Union[Awaitable[Any], Any]
ChannelT = _StringLikeT
GroupT = _StringLikeT # Consumer group
ConsumerT = _StringLikeT # Consumer name
@@ -50,8 +52,14 @@ ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Except
class CommandsProtocol(Protocol):
def execute_command(self, *args, **options) -> ResponseT: ...
connection_pool: Union["AsyncConnectionPool", "ConnectionPool"]
def execute_command(self, *args, **options):
...
class ClusterCommandsProtocol(CommandsProtocol):
class ClusterCommandsProtocol(CommandsProtocol, Protocol):
encoder: "Encoder"
def execute_command(self, *args, **options) -> Union[Any, Awaitable]:
...

View File

@@ -1,26 +1,18 @@
import datetime
import logging
import textwrap
from collections.abc import Callable
import sys
from contextlib import contextmanager
from functools import wraps
from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union
from redis.exceptions import DataError
from redis.typing import AbsExpiryT, EncodableT, ExpiryT
from typing import Any, Dict, Mapping, Union
try:
import hiredis # noqa
# Only support Hiredis >= 3.0:
hiredis_version = hiredis.__version__.split(".")
HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or (
int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2
)
if not HIREDIS_AVAILABLE:
raise ImportError("hiredis package should be >= 3.2.0")
# Only support Hiredis >= 1.0:
HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.")
HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command")
except ImportError:
HIREDIS_AVAILABLE = False
HIREDIS_PACK_AVAILABLE = False
try:
import ssl # noqa
@@ -36,7 +28,10 @@ try:
except ImportError:
CRYPTOGRAPHY_AVAILABLE = False
from importlib import metadata
if sys.version_info >= (3, 8):
from importlib import metadata
else:
import importlib_metadata as metadata
def from_url(url, **kwargs):
@@ -131,74 +126,6 @@ def deprecated_function(reason="", version="", name=None):
return decorator
def warn_deprecated_arg_usage(
arg_name: Union[list, str],
function_name: str,
reason: str = "",
version: str = "",
stacklevel: int = 2,
):
import warnings
msg = (
f"Call to '{function_name}' function with deprecated"
f" usage of input argument/s '{arg_name}'."
)
if reason:
msg += f" ({reason})"
if version:
msg += f" -- Deprecated since version {version}."
warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
C = TypeVar("C", bound=Callable)
def deprecated_args(
args_to_warn: list = ["*"],
allowed_args: list = [],
reason: str = "",
version: str = "",
) -> Callable[[C], C]:
"""
Decorator to mark specified args of a function as deprecated.
If '*' is in args_to_warn, all arguments will be marked as deprecated.
"""
def decorator(func: C) -> C:
@wraps(func)
def wrapper(*args, **kwargs):
# Get function argument names
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
provided_args = dict(zip(arg_names, args))
provided_args.update(kwargs)
provided_args.pop("self", None)
for allowed_arg in allowed_args:
provided_args.pop(allowed_arg, None)
for arg in args_to_warn:
if arg == "*" and len(provided_args) > 0:
warn_deprecated_arg_usage(
list(provided_args.keys()),
func.__name__,
reason,
version,
stacklevel=3,
)
elif arg in provided_args:
warn_deprecated_arg_usage(
arg, func.__name__, reason, version, stacklevel=3
)
return func(*args, **kwargs)
return wrapper
return decorator
def _set_info_logger():
"""
Set up a logger that log info logs to stdout.
@@ -218,97 +145,3 @@ def get_lib_version():
except metadata.PackageNotFoundError:
libver = "99.99.99"
return libver
def format_error_message(host_error: str, exception: BaseException) -> str:
if not exception.args:
return f"Error connecting to {host_error}."
elif len(exception.args) == 1:
return f"Error {exception.args[0]} connecting to {host_error}."
else:
return (
f"Error {exception.args[0]} connecting to {host_error}. "
f"{exception.args[1]}."
)
def compare_versions(version1: str, version2: str) -> int:
"""
Compare two versions.
:return: -1 if version1 > version2
0 if both versions are equal
1 if version1 < version2
"""
num_versions1 = list(map(int, version1.split(".")))
num_versions2 = list(map(int, version2.split(".")))
if len(num_versions1) > len(num_versions2):
diff = len(num_versions1) - len(num_versions2)
for _ in range(diff):
num_versions2.append(0)
elif len(num_versions1) < len(num_versions2):
diff = len(num_versions2) - len(num_versions1)
for _ in range(diff):
num_versions1.append(0)
for i, ver in enumerate(num_versions1):
if num_versions1[i] > num_versions2[i]:
return -1
elif num_versions1[i] < num_versions2[i]:
return 1
return 0
def ensure_string(key):
if isinstance(key, bytes):
return key.decode("utf-8")
elif isinstance(key, str):
return key
else:
raise TypeError("Key must be either a string or bytes")
def extract_expire_flags(
ex: Optional[ExpiryT] = None,
px: Optional[ExpiryT] = None,
exat: Optional[AbsExpiryT] = None,
pxat: Optional[AbsExpiryT] = None,
) -> List[EncodableT]:
exp_options: list[EncodableT] = []
if ex is not None:
exp_options.append("EX")
if isinstance(ex, datetime.timedelta):
exp_options.append(int(ex.total_seconds()))
elif isinstance(ex, int):
exp_options.append(ex)
elif isinstance(ex, str) and ex.isdigit():
exp_options.append(int(ex))
else:
raise DataError("ex must be datetime.timedelta or int")
elif px is not None:
exp_options.append("PX")
if isinstance(px, datetime.timedelta):
exp_options.append(int(px.total_seconds() * 1000))
elif isinstance(px, int):
exp_options.append(px)
else:
raise DataError("px must be datetime.timedelta or int")
elif exat is not None:
if isinstance(exat, datetime.datetime):
exat = int(exat.timestamp())
exp_options.extend(["EXAT", exat])
elif pxat is not None:
if isinstance(pxat, datetime.datetime):
pxat = int(pxat.timestamp() * 1000)
exp_options.extend(["PXAT", pxat])
return exp_options
def truncate_text(txt, max_length=100):
return textwrap.shorten(
text=txt, width=max_length, placeholder="...", break_long_words=True
)