This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import sys
|
||||
|
||||
from redis import asyncio # noqa
|
||||
from redis.backoff import default_backoff
|
||||
from redis.client import Redis, StrictRedis
|
||||
@@ -16,15 +18,11 @@ from redis.exceptions import (
|
||||
BusyLoadingError,
|
||||
ChildDeadlockedError,
|
||||
ConnectionError,
|
||||
CrossSlotTransactionError,
|
||||
DataError,
|
||||
InvalidPipelineStack,
|
||||
InvalidResponse,
|
||||
MaxConnectionsError,
|
||||
OutOfMemoryError,
|
||||
PubSubError,
|
||||
ReadOnlyError,
|
||||
RedisClusterException,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
@@ -38,6 +36,11 @@ from redis.sentinel import (
|
||||
)
|
||||
from redis.utils import from_url
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from importlib import metadata
|
||||
else:
|
||||
import importlib_metadata as metadata
|
||||
|
||||
|
||||
def int_or_str(value):
|
||||
try:
|
||||
@@ -46,10 +49,17 @@ def int_or_str(value):
|
||||
return value
|
||||
|
||||
|
||||
__version__ = "6.4.0"
|
||||
VERSION = tuple(map(int_or_str, __version__.split(".")))
|
||||
try:
|
||||
__version__ = metadata.version("redis")
|
||||
except metadata.PackageNotFoundError:
|
||||
__version__ = "99.99.99"
|
||||
|
||||
|
||||
try:
|
||||
VERSION = tuple(map(int_or_str, __version__.split(".")))
|
||||
except AttributeError:
|
||||
VERSION = tuple([99, 99, 99])
|
||||
|
||||
__all__ = [
|
||||
"AuthenticationError",
|
||||
"AuthenticationWrongNumberOfArgsError",
|
||||
@@ -60,19 +70,15 @@ __all__ = [
|
||||
"ConnectionError",
|
||||
"ConnectionPool",
|
||||
"CredentialProvider",
|
||||
"CrossSlotTransactionError",
|
||||
"DataError",
|
||||
"from_url",
|
||||
"default_backoff",
|
||||
"InvalidPipelineStack",
|
||||
"InvalidResponse",
|
||||
"MaxConnectionsError",
|
||||
"OutOfMemoryError",
|
||||
"PubSubError",
|
||||
"ReadOnlyError",
|
||||
"Redis",
|
||||
"RedisCluster",
|
||||
"RedisClusterException",
|
||||
"RedisError",
|
||||
"ResponseError",
|
||||
"Sentinel",
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
from .base import (
|
||||
AsyncPushNotificationsParser,
|
||||
BaseParser,
|
||||
PushNotificationsParser,
|
||||
_AsyncRESPBase,
|
||||
)
|
||||
from .base import BaseParser, _AsyncRESPBase
|
||||
from .commands import AsyncCommandsParser, CommandsParser
|
||||
from .encoders import Encoder
|
||||
from .hiredis import _AsyncHiredisParser, _HiredisParser
|
||||
@@ -16,12 +11,10 @@ __all__ = [
|
||||
"_AsyncRESPBase",
|
||||
"_AsyncRESP2Parser",
|
||||
"_AsyncRESP3Parser",
|
||||
"AsyncPushNotificationsParser",
|
||||
"CommandsParser",
|
||||
"Encoder",
|
||||
"BaseParser",
|
||||
"_HiredisParser",
|
||||
"_RESP2Parser",
|
||||
"_RESP3Parser",
|
||||
"PushNotificationsParser",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
from abc import ABC
|
||||
from asyncio import IncompleteReadError, StreamReader, TimeoutError
|
||||
from typing import Callable, List, Optional, Protocol, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
|
||||
from asyncio import timeout as async_timeout
|
||||
@@ -9,32 +9,26 @@ else:
|
||||
from async_timeout import timeout as async_timeout
|
||||
|
||||
from ..exceptions import (
|
||||
AskError,
|
||||
AuthenticationError,
|
||||
AuthenticationWrongNumberOfArgsError,
|
||||
BusyLoadingError,
|
||||
ClusterCrossSlotError,
|
||||
ClusterDownError,
|
||||
ConnectionError,
|
||||
ExecAbortError,
|
||||
MasterDownError,
|
||||
ModuleError,
|
||||
MovedError,
|
||||
NoPermissionError,
|
||||
NoScriptError,
|
||||
OutOfMemoryError,
|
||||
ReadOnlyError,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
TryAgainError,
|
||||
)
|
||||
from ..typing import EncodableT
|
||||
from .encoders import Encoder
|
||||
from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
|
||||
|
||||
MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
|
||||
MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs."
|
||||
NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
|
||||
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
|
||||
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible."
|
||||
MODULE_EXPORTS_DATA_TYPES_ERROR = (
|
||||
"Error unloading module: the module "
|
||||
"exports one or more module-side data "
|
||||
@@ -78,12 +72,6 @@ class BaseParser(ABC):
|
||||
"READONLY": ReadOnlyError,
|
||||
"NOAUTH": AuthenticationError,
|
||||
"NOPERM": NoPermissionError,
|
||||
"ASK": AskError,
|
||||
"TRYAGAIN": TryAgainError,
|
||||
"MOVED": MovedError,
|
||||
"CLUSTERDOWN": ClusterDownError,
|
||||
"CROSSSLOT": ClusterCrossSlotError,
|
||||
"MASTERDOWN": MasterDownError,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -158,58 +146,6 @@ class AsyncBaseParser(BaseParser):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
|
||||
|
||||
|
||||
class PushNotificationsParser(Protocol):
|
||||
"""Protocol defining RESP3-specific parsing functionality"""
|
||||
|
||||
pubsub_push_handler_func: Callable
|
||||
invalidation_push_handler_func: Optional[Callable] = None
|
||||
|
||||
def handle_pubsub_push_response(self, response):
|
||||
"""Handle pubsub push responses"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def handle_push_response(self, response, **kwargs):
|
||||
if response[0] not in _INVALIDATION_MESSAGE:
|
||||
return self.pubsub_push_handler_func(response)
|
||||
if self.invalidation_push_handler_func:
|
||||
return self.invalidation_push_handler_func(response)
|
||||
|
||||
def set_pubsub_push_handler(self, pubsub_push_handler_func):
|
||||
self.pubsub_push_handler_func = pubsub_push_handler_func
|
||||
|
||||
def set_invalidation_push_handler(self, invalidation_push_handler_func):
|
||||
self.invalidation_push_handler_func = invalidation_push_handler_func
|
||||
|
||||
|
||||
class AsyncPushNotificationsParser(Protocol):
|
||||
"""Protocol defining async RESP3-specific parsing functionality"""
|
||||
|
||||
pubsub_push_handler_func: Callable
|
||||
invalidation_push_handler_func: Optional[Callable] = None
|
||||
|
||||
async def handle_pubsub_push_response(self, response):
|
||||
"""Handle pubsub push responses asynchronously"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def handle_push_response(self, response, **kwargs):
|
||||
"""Handle push responses asynchronously"""
|
||||
if response[0] not in _INVALIDATION_MESSAGE:
|
||||
return await self.pubsub_push_handler_func(response)
|
||||
if self.invalidation_push_handler_func:
|
||||
return await self.invalidation_push_handler_func(response)
|
||||
|
||||
def set_pubsub_push_handler(self, pubsub_push_handler_func):
|
||||
"""Set the pubsub push handler function"""
|
||||
self.pubsub_push_handler_func = pubsub_push_handler_func
|
||||
|
||||
def set_invalidation_push_handler(self, invalidation_push_handler_func):
|
||||
"""Set the invalidation push handler function"""
|
||||
self.invalidation_push_handler_func = invalidation_push_handler_func
|
||||
|
||||
|
||||
class _AsyncRESPBase(AsyncBaseParser):
|
||||
"""Base class for async resp parsing"""
|
||||
|
||||
@@ -246,7 +182,7 @@ class _AsyncRESPBase(AsyncBaseParser):
|
||||
return True
|
||||
try:
|
||||
async with async_timeout(0):
|
||||
return self._stream.at_eof()
|
||||
return await self._stream.read(1)
|
||||
except TimeoutError:
|
||||
return False
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ def parse_info(response):
|
||||
response = str_if_bytes(response)
|
||||
|
||||
def get_value(value):
|
||||
if "," not in value and "=" not in value:
|
||||
if "," not in value or "=" not in value:
|
||||
try:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
@@ -46,18 +46,11 @@ def parse_info(response):
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
elif "=" not in value:
|
||||
return [get_value(v) for v in value.split(",") if v]
|
||||
else:
|
||||
sub_dict = {}
|
||||
for item in value.split(","):
|
||||
if not item:
|
||||
continue
|
||||
if "=" in item:
|
||||
k, v = item.rsplit("=", 1)
|
||||
sub_dict[k] = get_value(v)
|
||||
else:
|
||||
sub_dict[item] = True
|
||||
k, v = item.rsplit("=", 1)
|
||||
sub_dict[k] = get_value(v)
|
||||
return sub_dict
|
||||
|
||||
for line in response.splitlines():
|
||||
@@ -87,7 +80,7 @@ def parse_memory_stats(response, **kwargs):
|
||||
"""Parse the results of MEMORY STATS"""
|
||||
stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True)
|
||||
for key, value in stats.items():
|
||||
if key.startswith("db.") and isinstance(value, list):
|
||||
if key.startswith("db."):
|
||||
stats[key] = pairs_to_dict(
|
||||
value, decode_keys=True, decode_string_values=True
|
||||
)
|
||||
@@ -275,22 +268,17 @@ def parse_xinfo_stream(response, **options):
|
||||
data = {str_if_bytes(k): v for k, v in response.items()}
|
||||
if not options.get("full", False):
|
||||
first = data.get("first-entry")
|
||||
if first is not None and first[0] is not None:
|
||||
if first is not None:
|
||||
data["first-entry"] = (first[0], pairs_to_dict(first[1]))
|
||||
last = data["last-entry"]
|
||||
if last is not None and last[0] is not None:
|
||||
if last is not None:
|
||||
data["last-entry"] = (last[0], pairs_to_dict(last[1]))
|
||||
else:
|
||||
data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]}
|
||||
if len(data["groups"]) > 0 and isinstance(data["groups"][0], list):
|
||||
if isinstance(data["groups"][0], list):
|
||||
data["groups"] = [
|
||||
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
|
||||
]
|
||||
for g in data["groups"]:
|
||||
if g["consumers"] and g["consumers"][0] is not None:
|
||||
g["consumers"] = [
|
||||
pairs_to_dict(c, decode_keys=True) for c in g["consumers"]
|
||||
]
|
||||
else:
|
||||
data["groups"] = [
|
||||
{str_if_bytes(k): v for k, v in group.items()}
|
||||
@@ -334,7 +322,7 @@ def float_or_none(response):
|
||||
return float(response)
|
||||
|
||||
|
||||
def bool_ok(response, **options):
|
||||
def bool_ok(response):
|
||||
return str_if_bytes(response) == "OK"
|
||||
|
||||
|
||||
@@ -366,12 +354,7 @@ def parse_scan(response, **options):
|
||||
|
||||
def parse_hscan(response, **options):
|
||||
cursor, r = response
|
||||
no_values = options.get("no_values", False)
|
||||
if no_values:
|
||||
payload = r or []
|
||||
else:
|
||||
payload = r and pairs_to_dict(r) or {}
|
||||
return int(cursor), payload
|
||||
return int(cursor), r and pairs_to_dict(r) or {}
|
||||
|
||||
|
||||
def parse_zscan(response, **options):
|
||||
@@ -396,20 +379,13 @@ def parse_slowlog_get(response, **options):
|
||||
# an O(N) complexity) instead of the command.
|
||||
if isinstance(item[3], list):
|
||||
result["command"] = space.join(item[3])
|
||||
|
||||
# These fields are optional, depends on environment.
|
||||
if len(item) >= 6:
|
||||
result["client_address"] = item[4]
|
||||
result["client_name"] = item[5]
|
||||
result["client_address"] = item[4]
|
||||
result["client_name"] = item[5]
|
||||
else:
|
||||
result["complexity"] = item[3]
|
||||
result["command"] = space.join(item[4])
|
||||
|
||||
# These fields are optional, depends on environment.
|
||||
if len(item) >= 7:
|
||||
result["client_address"] = item[5]
|
||||
result["client_name"] = item[6]
|
||||
|
||||
result["client_address"] = item[5]
|
||||
result["client_name"] = item[6]
|
||||
return result
|
||||
|
||||
return [parse_item(item) for item in response]
|
||||
@@ -452,11 +428,9 @@ def parse_cluster_info(response, **options):
|
||||
def _parse_node_line(line):
|
||||
line_items = line.split(" ")
|
||||
node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8]
|
||||
ip = addr.split("@")[0]
|
||||
hostname = addr.split("@")[1].split(",")[1] if "@" in addr and "," in addr else ""
|
||||
addr = addr.split("@")[0]
|
||||
node_dict = {
|
||||
"node_id": node_id,
|
||||
"hostname": hostname,
|
||||
"flags": flags,
|
||||
"master_id": master_id,
|
||||
"last_ping_sent": ping,
|
||||
@@ -469,7 +443,7 @@ def _parse_node_line(line):
|
||||
if len(line_items) >= 9:
|
||||
slots, migrations = _parse_slots(line_items[8:])
|
||||
node_dict["slots"], node_dict["migrations"] = slots, migrations
|
||||
return ip, node_dict
|
||||
return addr, node_dict
|
||||
|
||||
|
||||
def _parse_slots(slot_ranges):
|
||||
@@ -516,7 +490,7 @@ def parse_geosearch_generic(response, **options):
|
||||
except KeyError: # it means the command was sent via execute_command
|
||||
return response
|
||||
|
||||
if not isinstance(response, list):
|
||||
if type(response) != list:
|
||||
response_list = [response]
|
||||
else:
|
||||
response_list = response
|
||||
@@ -676,8 +650,7 @@ def parse_client_info(value):
|
||||
"omem",
|
||||
"tot-mem",
|
||||
}:
|
||||
if int_key in client_info:
|
||||
client_info[int_key] = int(client_info[int_key])
|
||||
client_info[int_key] = int(client_info[int_key])
|
||||
return client_info
|
||||
|
||||
|
||||
@@ -840,28 +813,24 @@ _RedisCallbacksRESP2 = {
|
||||
|
||||
|
||||
_RedisCallbacksRESP3 = {
|
||||
**string_keys_to_dict(
|
||||
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
|
||||
),
|
||||
**string_keys_to_dict(
|
||||
"ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE "
|
||||
"ZUNION HGETALL XREADGROUP",
|
||||
lambda r, **kwargs: r,
|
||||
),
|
||||
**string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3),
|
||||
"ACL LOG": lambda r: (
|
||||
[
|
||||
{str_if_bytes(key): str_if_bytes(value) for key, value in x.items()}
|
||||
for x in r
|
||||
]
|
||||
if isinstance(r, list)
|
||||
else bool_ok(r)
|
||||
),
|
||||
"ACL LOG": lambda r: [
|
||||
{str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} for x in r
|
||||
]
|
||||
if isinstance(r, list)
|
||||
else bool_ok(r),
|
||||
"COMMAND": parse_command_resp3,
|
||||
"CONFIG GET": lambda r: {
|
||||
str_if_bytes(key) if key is not None else None: (
|
||||
str_if_bytes(value) if value is not None else None
|
||||
)
|
||||
str_if_bytes(key)
|
||||
if key is not None
|
||||
else None: str_if_bytes(value)
|
||||
if value is not None
|
||||
else None
|
||||
for key, value in r.items()
|
||||
},
|
||||
"MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()},
|
||||
@@ -869,11 +838,11 @@ _RedisCallbacksRESP3 = {
|
||||
"SENTINEL MASTERS": parse_sentinel_masters_resp3,
|
||||
"SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3,
|
||||
"SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3,
|
||||
"STRALGO": lambda r, **options: (
|
||||
{str_if_bytes(key): str_if_bytes(value) for key, value in r.items()}
|
||||
if isinstance(r, dict)
|
||||
else str_if_bytes(r)
|
||||
),
|
||||
"STRALGO": lambda r, **options: {
|
||||
str_if_bytes(key): str_if_bytes(value) for key, value in r.items()
|
||||
}
|
||||
if isinstance(r, dict)
|
||||
else str_if_bytes(r),
|
||||
"XINFO CONSUMERS": lambda r: [
|
||||
{str_if_bytes(key): value for key, value in x.items()} for x in r
|
||||
],
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
import asyncio
|
||||
import socket
|
||||
import sys
|
||||
from logging import getLogger
|
||||
from typing import Callable, List, Optional, TypedDict, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
|
||||
from asyncio import timeout as async_timeout
|
||||
else:
|
||||
from async_timeout import timeout as async_timeout
|
||||
|
||||
from redis.compat import TypedDict
|
||||
|
||||
from ..exceptions import ConnectionError, InvalidResponse, RedisError
|
||||
from ..typing import EncodableT
|
||||
from ..utils import HIREDIS_AVAILABLE
|
||||
from .base import (
|
||||
AsyncBaseParser,
|
||||
AsyncPushNotificationsParser,
|
||||
BaseParser,
|
||||
PushNotificationsParser,
|
||||
)
|
||||
from .base import AsyncBaseParser, BaseParser
|
||||
from .socket import (
|
||||
NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
|
||||
NONBLOCKING_EXCEPTIONS,
|
||||
@@ -25,11 +21,6 @@ from .socket import (
|
||||
SERVER_CLOSED_CONNECTION_ERROR,
|
||||
)
|
||||
|
||||
# Used to signal that hiredis-py does not have enough data to parse.
|
||||
# Using `False` or `None` is not reliable, given that the parser can
|
||||
# return `False` or `None` for legitimate reasons from RESP payloads.
|
||||
NOT_ENOUGH_DATA = object()
|
||||
|
||||
|
||||
class _HiredisReaderArgs(TypedDict, total=False):
|
||||
protocolError: Callable[[str], Exception]
|
||||
@@ -38,7 +29,7 @@ class _HiredisReaderArgs(TypedDict, total=False):
|
||||
errors: Optional[str]
|
||||
|
||||
|
||||
class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
class _HiredisParser(BaseParser):
|
||||
"Parser class for connections using Hiredis"
|
||||
|
||||
def __init__(self, socket_read_size):
|
||||
@@ -46,9 +37,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
raise RedisError("Hiredis is not installed")
|
||||
self.socket_read_size = socket_read_size
|
||||
self._buffer = bytearray(socket_read_size)
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
@@ -56,11 +44,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def handle_pubsub_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
def on_connect(self, connection, **kwargs):
|
||||
import hiredis
|
||||
|
||||
@@ -70,32 +53,25 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
"protocolError": InvalidResponse,
|
||||
"replyError": self.parse_error,
|
||||
"errors": connection.encoder.encoding_errors,
|
||||
"notEnoughData": NOT_ENOUGH_DATA,
|
||||
}
|
||||
|
||||
if connection.encoder.decode_responses:
|
||||
kwargs["encoding"] = connection.encoder.encoding
|
||||
self._reader = hiredis.Reader(**kwargs)
|
||||
self._next_response = NOT_ENOUGH_DATA
|
||||
|
||||
try:
|
||||
self._hiredis_PushNotificationType = hiredis.PushNotification
|
||||
except AttributeError:
|
||||
# hiredis < 3.2
|
||||
self._hiredis_PushNotificationType = None
|
||||
self._next_response = False
|
||||
|
||||
def on_disconnect(self):
|
||||
self._sock = None
|
||||
self._reader = None
|
||||
self._next_response = NOT_ENOUGH_DATA
|
||||
self._next_response = False
|
||||
|
||||
def can_read(self, timeout):
|
||||
if not self._reader:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
|
||||
if self._next_response is NOT_ENOUGH_DATA:
|
||||
if self._next_response is False:
|
||||
self._next_response = self._reader.gets()
|
||||
if self._next_response is NOT_ENOUGH_DATA:
|
||||
if self._next_response is False:
|
||||
return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
|
||||
return True
|
||||
|
||||
@@ -129,24 +105,14 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
if custom_timeout:
|
||||
sock.settimeout(self._socket_timeout)
|
||||
|
||||
def read_response(self, disable_decoding=False, push_request=False):
|
||||
def read_response(self, disable_decoding=False):
|
||||
if not self._reader:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
|
||||
# _next_response might be cached from a can_read() call
|
||||
if self._next_response is not NOT_ENOUGH_DATA:
|
||||
if self._next_response is not False:
|
||||
response = self._next_response
|
||||
self._next_response = NOT_ENOUGH_DATA
|
||||
if self._hiredis_PushNotificationType is not None and isinstance(
|
||||
response, self._hiredis_PushNotificationType
|
||||
):
|
||||
response = self.handle_push_response(response)
|
||||
if not push_request:
|
||||
return self.read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
self._next_response = False
|
||||
return response
|
||||
|
||||
if disable_decoding:
|
||||
@@ -154,7 +120,7 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
|
||||
while response is NOT_ENOUGH_DATA:
|
||||
while response is False:
|
||||
self.read_from_socket()
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
@@ -165,16 +131,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
# happened
|
||||
if isinstance(response, ConnectionError):
|
||||
raise response
|
||||
elif self._hiredis_PushNotificationType is not None and isinstance(
|
||||
response, self._hiredis_PushNotificationType
|
||||
):
|
||||
response = self.handle_push_response(response)
|
||||
if not push_request:
|
||||
return self.read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
elif (
|
||||
isinstance(response, list)
|
||||
and response
|
||||
@@ -184,7 +140,7 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
return response
|
||||
|
||||
|
||||
class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
class _AsyncHiredisParser(AsyncBaseParser):
|
||||
"""Async implementation of parser class for connections using Hiredis"""
|
||||
|
||||
__slots__ = ("_reader",)
|
||||
@@ -194,14 +150,6 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
raise RedisError("Hiredis is not available.")
|
||||
super().__init__(socket_read_size=socket_read_size)
|
||||
self._reader = None
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
async def handle_pubsub_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
def on_connect(self, connection):
|
||||
import hiredis
|
||||
@@ -210,7 +158,6 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
kwargs: _HiredisReaderArgs = {
|
||||
"protocolError": InvalidResponse,
|
||||
"replyError": self.parse_error,
|
||||
"notEnoughData": NOT_ENOUGH_DATA,
|
||||
}
|
||||
if connection.encoder.decode_responses:
|
||||
kwargs["encoding"] = connection.encoder.encoding
|
||||
@@ -219,21 +166,13 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
self._reader = hiredis.Reader(**kwargs)
|
||||
self._connected = True
|
||||
|
||||
try:
|
||||
self._hiredis_PushNotificationType = getattr(
|
||||
hiredis, "PushNotification", None
|
||||
)
|
||||
except AttributeError:
|
||||
# hiredis < 3.2
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
def on_disconnect(self):
|
||||
self._connected = False
|
||||
|
||||
async def can_read_destructive(self):
|
||||
if not self._connected:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
if self._reader.gets() is not NOT_ENOUGH_DATA:
|
||||
if self._reader.gets():
|
||||
return True
|
||||
try:
|
||||
async with async_timeout(0):
|
||||
@@ -251,7 +190,7 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
return True
|
||||
|
||||
async def read_response(
|
||||
self, disable_decoding: bool = False, push_request: bool = False
|
||||
self, disable_decoding: bool = False
|
||||
) -> Union[EncodableT, List[EncodableT]]:
|
||||
# If `on_disconnect()` has been called, prohibit any more reads
|
||||
# even if they could happen because data might be present.
|
||||
@@ -259,33 +198,16 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
if not self._connected:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
|
||||
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
|
||||
while response is NOT_ENOUGH_DATA:
|
||||
response = self._reader.gets()
|
||||
while response is False:
|
||||
await self.read_from_socket()
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
response = self._reader.gets()
|
||||
|
||||
# if the response is a ConnectionError or the response is a list and
|
||||
# the first item is a ConnectionError, raise it as something bad
|
||||
# happened
|
||||
if isinstance(response, ConnectionError):
|
||||
raise response
|
||||
elif self._hiredis_PushNotificationType is not None and isinstance(
|
||||
response, self._hiredis_PushNotificationType
|
||||
):
|
||||
response = await self.handle_push_response(response)
|
||||
if not push_request:
|
||||
return await self.read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
elif (
|
||||
isinstance(response, list)
|
||||
and response
|
||||
|
||||
@@ -3,26 +3,20 @@ from typing import Any, Union
|
||||
|
||||
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
|
||||
from ..typing import EncodableT
|
||||
from .base import (
|
||||
AsyncPushNotificationsParser,
|
||||
PushNotificationsParser,
|
||||
_AsyncRESPBase,
|
||||
_RESPBase,
|
||||
)
|
||||
from .base import _AsyncRESPBase, _RESPBase
|
||||
from .socket import SERVER_CLOSED_CONNECTION_ERROR
|
||||
|
||||
|
||||
class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
||||
class _RESP3Parser(_RESPBase):
|
||||
"""RESP3 protocol implementation"""
|
||||
|
||||
def __init__(self, socket_read_size):
|
||||
super().__init__(socket_read_size)
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
self.push_handler_func = self.handle_push_response
|
||||
|
||||
def handle_pubsub_push_response(self, response):
|
||||
def handle_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
logger.info("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
def read_response(self, disable_decoding=False, push_request=False):
|
||||
@@ -91,16 +85,19 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
||||
# set response
|
||||
elif byte == b"~":
|
||||
# redis can return unhashable types (like dict) in a set,
|
||||
# so we return sets as list, all the time, for predictability
|
||||
# so we need to first convert to a list, and then try to convert it to a set
|
||||
response = [
|
||||
self._read_response(disable_decoding=disable_decoding)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
try:
|
||||
response = set(response)
|
||||
except TypeError:
|
||||
pass
|
||||
# map response
|
||||
elif byte == b"%":
|
||||
# We cannot use a dict-comprehension to parse stream.
|
||||
# Evaluation order of key:val expression in dict comprehension only
|
||||
# became defined to be left-right in version 3.8
|
||||
# we use this approach and not dict comprehension here
|
||||
# because this dict comprehension fails in python 3.7
|
||||
resp_dict = {}
|
||||
for _ in range(int(response)):
|
||||
key = self._read_response(disable_decoding=disable_decoding)
|
||||
@@ -116,13 +113,13 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
||||
)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
response = self.handle_push_response(response)
|
||||
res = self.push_handler_func(response)
|
||||
if not push_request:
|
||||
return self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
return res
|
||||
else:
|
||||
raise InvalidResponse(f"Protocol Error: {raw!r}")
|
||||
|
||||
@@ -130,16 +127,18 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
||||
response = self.encoder.decode(response)
|
||||
return response
|
||||
|
||||
def set_push_handler(self, push_handler_func):
|
||||
self.push_handler_func = push_handler_func
|
||||
|
||||
class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
|
||||
|
||||
class _AsyncRESP3Parser(_AsyncRESPBase):
|
||||
def __init__(self, socket_read_size):
|
||||
super().__init__(socket_read_size)
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
self.push_handler_func = self.handle_push_response
|
||||
|
||||
async def handle_pubsub_push_response(self, response):
|
||||
def handle_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
logger.info("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
async def read_response(
|
||||
@@ -215,23 +214,23 @@ class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
|
||||
# set response
|
||||
elif byte == b"~":
|
||||
# redis can return unhashable types (like dict) in a set,
|
||||
# so we always convert to a list, to have predictable return types
|
||||
# so we need to first convert to a list, and then try to convert it to a set
|
||||
response = [
|
||||
(await self._read_response(disable_decoding=disable_decoding))
|
||||
for _ in range(int(response))
|
||||
]
|
||||
try:
|
||||
response = set(response)
|
||||
except TypeError:
|
||||
pass
|
||||
# map response
|
||||
elif byte == b"%":
|
||||
# We cannot use a dict-comprehension to parse stream.
|
||||
# Evaluation order of key:val expression in dict comprehension only
|
||||
# became defined to be left-right in version 3.8
|
||||
resp_dict = {}
|
||||
for _ in range(int(response)):
|
||||
key = await self._read_response(disable_decoding=disable_decoding)
|
||||
resp_dict[key] = await self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
response = {
|
||||
(await self._read_response(disable_decoding=disable_decoding)): (
|
||||
await self._read_response(disable_decoding=disable_decoding)
|
||||
)
|
||||
response = resp_dict
|
||||
for _ in range(int(response))
|
||||
}
|
||||
# push response
|
||||
elif byte == b">":
|
||||
response = [
|
||||
@@ -242,16 +241,19 @@ class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
|
||||
)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
response = await self.handle_push_response(response)
|
||||
res = self.push_handler_func(response)
|
||||
if not push_request:
|
||||
return await self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
return res
|
||||
else:
|
||||
raise InvalidResponse(f"Protocol Error: {raw!r}")
|
||||
|
||||
if isinstance(response, bytes) and disable_decoding is False:
|
||||
response = self.encoder.decode(response)
|
||||
return response
|
||||
|
||||
def set_push_handler(self, push_handler_func):
|
||||
self.push_handler_func = push_handler_func
|
||||
|
||||
@@ -15,11 +15,9 @@ from typing import (
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@@ -39,7 +37,6 @@ from redis.asyncio.connection import (
|
||||
)
|
||||
from redis.asyncio.lock import Lock
|
||||
from redis.asyncio.retry import Retry
|
||||
from redis.backoff import ExponentialWithJitterBackoff
|
||||
from redis.client import (
|
||||
EMPTY_RESPONSE,
|
||||
NEVER_DECODE,
|
||||
@@ -52,40 +49,27 @@ from redis.commands import (
|
||||
AsyncSentinelCommands,
|
||||
list_or_args,
|
||||
)
|
||||
from redis.compat import Protocol, TypedDict
|
||||
from redis.credentials import CredentialProvider
|
||||
from redis.event import (
|
||||
AfterPooledConnectionsInstantiationEvent,
|
||||
AfterPubSubConnectionInstantiationEvent,
|
||||
AfterSingleConnectionInstantiationEvent,
|
||||
ClientType,
|
||||
EventDispatcher,
|
||||
)
|
||||
from redis.exceptions import (
|
||||
ConnectionError,
|
||||
ExecAbortError,
|
||||
PubSubError,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
WatchError,
|
||||
)
|
||||
from redis.typing import ChannelT, EncodableT, KeyT
|
||||
from redis.utils import (
|
||||
SSL_AVAILABLE,
|
||||
HIREDIS_AVAILABLE,
|
||||
_set_info_logger,
|
||||
deprecated_args,
|
||||
deprecated_function,
|
||||
get_lib_version,
|
||||
safe_str,
|
||||
str_if_bytes,
|
||||
truncate_text,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING and SSL_AVAILABLE:
|
||||
from ssl import TLSVersion, VerifyMode
|
||||
else:
|
||||
TLSVersion = None
|
||||
VerifyMode = None
|
||||
|
||||
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
|
||||
_KeyT = TypeVar("_KeyT", bound=KeyT)
|
||||
_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
|
||||
@@ -96,11 +80,13 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ResponseCallbackProtocol(Protocol):
|
||||
def __call__(self, response: Any, **kwargs): ...
|
||||
def __call__(self, response: Any, **kwargs):
|
||||
...
|
||||
|
||||
|
||||
class AsyncResponseCallbackProtocol(Protocol):
|
||||
async def __call__(self, response: Any, **kwargs): ...
|
||||
async def __call__(self, response: Any, **kwargs):
|
||||
...
|
||||
|
||||
|
||||
ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
|
||||
@@ -182,7 +168,7 @@ class Redis(
|
||||
warnings.warn(
|
||||
DeprecationWarning(
|
||||
'"auto_close_connection_pool" is deprecated '
|
||||
"since version 5.0.1. "
|
||||
"since version 5.0.0. "
|
||||
"Please create a ConnectionPool explicitly and "
|
||||
"provide to the Redis() constructor instead."
|
||||
)
|
||||
@@ -208,11 +194,6 @@ class Redis(
|
||||
client.auto_close_connection_pool = True
|
||||
return client
|
||||
|
||||
@deprecated_args(
|
||||
args_to_warn=["retry_on_timeout"],
|
||||
reason="TimeoutError is included by default.",
|
||||
version="6.0.0",
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -230,19 +211,14 @@ class Redis(
|
||||
encoding_errors: str = "strict",
|
||||
decode_responses: bool = False,
|
||||
retry_on_timeout: bool = False,
|
||||
retry: Retry = Retry(
|
||||
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
|
||||
),
|
||||
retry_on_error: Optional[list] = None,
|
||||
ssl: bool = False,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_cert_reqs: Union[str, VerifyMode] = "required",
|
||||
ssl_cert_reqs: str = "required",
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
ssl_ca_data: Optional[str] = None,
|
||||
ssl_check_hostname: bool = True,
|
||||
ssl_min_version: Optional[TLSVersion] = None,
|
||||
ssl_ciphers: Optional[str] = None,
|
||||
ssl_check_hostname: bool = False,
|
||||
max_connections: Optional[int] = None,
|
||||
single_connection_client: bool = False,
|
||||
health_check_interval: int = 0,
|
||||
@@ -250,38 +226,20 @@ class Redis(
|
||||
lib_name: Optional[str] = "redis-py",
|
||||
lib_version: Optional[str] = get_lib_version(),
|
||||
username: Optional[str] = None,
|
||||
retry: Optional[Retry] = None,
|
||||
auto_close_connection_pool: Optional[bool] = None,
|
||||
redis_connect_func=None,
|
||||
credential_provider: Optional[CredentialProvider] = None,
|
||||
protocol: Optional[int] = 2,
|
||||
event_dispatcher: Optional[EventDispatcher] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new Redis client.
|
||||
|
||||
To specify a retry policy for specific errors, you have two options:
|
||||
|
||||
1. Set the `retry_on_error` to a list of the error/s to retry on, and
|
||||
you can also set `retry` to a valid `Retry` object(in case the default
|
||||
one is not appropriate) - with this approach the retries will be triggered
|
||||
on the default errors specified in the Retry object enriched with the
|
||||
errors specified in `retry_on_error`.
|
||||
|
||||
2. Define a `Retry` object with configured 'supported_errors' and set
|
||||
it to the `retry` parameter - with this approach you completely redefine
|
||||
the errors on which retries will happen.
|
||||
|
||||
`retry_on_timeout` is deprecated - please include the TimeoutError
|
||||
either in the Retry object or in the `retry_on_error` list.
|
||||
|
||||
When 'connection_pool' is provided - the retry configuration of the
|
||||
provided pool will be used.
|
||||
To specify a retry policy for specific errors, first set
|
||||
`retry_on_error` to a list of the error/s to retry on, then set
|
||||
`retry` to a valid `Retry` object.
|
||||
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
|
||||
"""
|
||||
kwargs: Dict[str, Any]
|
||||
if event_dispatcher is None:
|
||||
self._event_dispatcher = EventDispatcher()
|
||||
else:
|
||||
self._event_dispatcher = event_dispatcher
|
||||
# auto_close_connection_pool only has an effect if connection_pool is
|
||||
# None. It is assumed that if connection_pool is not None, the user
|
||||
# wants to manage the connection pool themselves.
|
||||
@@ -289,7 +247,7 @@ class Redis(
|
||||
warnings.warn(
|
||||
DeprecationWarning(
|
||||
'"auto_close_connection_pool" is deprecated '
|
||||
"since version 5.0.1. "
|
||||
"since version 5.0.0. "
|
||||
"Please create a ConnectionPool explicitly and "
|
||||
"provide to the Redis() constructor instead."
|
||||
)
|
||||
@@ -301,6 +259,8 @@ class Redis(
|
||||
# Create internal connection pool, expected to be closed by Redis instance
|
||||
if not retry_on_error:
|
||||
retry_on_error = []
|
||||
if retry_on_timeout is True:
|
||||
retry_on_error.append(TimeoutError)
|
||||
kwargs = {
|
||||
"db": db,
|
||||
"username": username,
|
||||
@@ -310,6 +270,7 @@ class Redis(
|
||||
"encoding": encoding,
|
||||
"encoding_errors": encoding_errors,
|
||||
"decode_responses": decode_responses,
|
||||
"retry_on_timeout": retry_on_timeout,
|
||||
"retry_on_error": retry_on_error,
|
||||
"retry": copy.deepcopy(retry),
|
||||
"max_connections": max_connections,
|
||||
@@ -350,26 +311,14 @@ class Redis(
|
||||
"ssl_ca_certs": ssl_ca_certs,
|
||||
"ssl_ca_data": ssl_ca_data,
|
||||
"ssl_check_hostname": ssl_check_hostname,
|
||||
"ssl_min_version": ssl_min_version,
|
||||
"ssl_ciphers": ssl_ciphers,
|
||||
}
|
||||
)
|
||||
# This arg only used if no pool is passed in
|
||||
self.auto_close_connection_pool = auto_close_connection_pool
|
||||
connection_pool = ConnectionPool(**kwargs)
|
||||
self._event_dispatcher.dispatch(
|
||||
AfterPooledConnectionsInstantiationEvent(
|
||||
[connection_pool], ClientType.ASYNC, credential_provider
|
||||
)
|
||||
)
|
||||
else:
|
||||
# If a pool is passed in, do not close it
|
||||
self.auto_close_connection_pool = False
|
||||
self._event_dispatcher.dispatch(
|
||||
AfterPooledConnectionsInstantiationEvent(
|
||||
[connection_pool], ClientType.ASYNC, credential_provider
|
||||
)
|
||||
)
|
||||
|
||||
self.connection_pool = connection_pool
|
||||
self.single_connection_client = single_connection_client
|
||||
@@ -388,10 +337,7 @@ class Redis(
|
||||
self._single_conn_lock = asyncio.Lock()
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
f"({self.connection_pool!r})>"
|
||||
)
|
||||
return f"{self.__class__.__name__}<{self.connection_pool!r}>"
|
||||
|
||||
def __await__(self):
|
||||
return self.initialize().__await__()
|
||||
@@ -400,13 +346,7 @@ class Redis(
|
||||
if self.single_connection_client:
|
||||
async with self._single_conn_lock:
|
||||
if self.connection is None:
|
||||
self.connection = await self.connection_pool.get_connection()
|
||||
|
||||
self._event_dispatcher.dispatch(
|
||||
AfterSingleConnectionInstantiationEvent(
|
||||
self.connection, ClientType.ASYNC, self._single_conn_lock
|
||||
)
|
||||
)
|
||||
self.connection = await self.connection_pool.get_connection("_")
|
||||
return self
|
||||
|
||||
def set_response_callback(self, command: str, callback: ResponseCallbackT):
|
||||
@@ -421,10 +361,10 @@ class Redis(
|
||||
"""Get the connection's key-word arguments"""
|
||||
return self.connection_pool.connection_kwargs
|
||||
|
||||
def get_retry(self) -> Optional[Retry]:
|
||||
def get_retry(self) -> Optional["Retry"]:
|
||||
return self.get_connection_kwargs().get("retry")
|
||||
|
||||
def set_retry(self, retry: Retry) -> None:
|
||||
def set_retry(self, retry: "Retry") -> None:
|
||||
self.get_connection_kwargs().update({"retry": retry})
|
||||
self.connection_pool.set_retry(retry)
|
||||
|
||||
@@ -503,7 +443,6 @@ class Redis(
|
||||
blocking_timeout: Optional[float] = None,
|
||||
lock_class: Optional[Type[Lock]] = None,
|
||||
thread_local: bool = True,
|
||||
raise_on_release_error: bool = True,
|
||||
) -> Lock:
|
||||
"""
|
||||
Return a new Lock object using key ``name`` that mimics
|
||||
@@ -550,11 +489,6 @@ class Redis(
|
||||
thread-1 would see the token value as "xyz" and would be
|
||||
able to successfully release the thread-2's lock.
|
||||
|
||||
``raise_on_release_error`` indicates whether to raise an exception when
|
||||
the lock is no longer owned when exiting the context manager. By default,
|
||||
this is True, meaning an exception will be raised. If False, the warning
|
||||
will be logged and the exception will be suppressed.
|
||||
|
||||
In some use cases it's necessary to disable thread local storage. For
|
||||
example, if you have code where one thread acquires a lock and passes
|
||||
that lock instance to a worker thread to release later. If thread
|
||||
@@ -572,7 +506,6 @@ class Redis(
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
thread_local=thread_local,
|
||||
raise_on_release_error=raise_on_release_error,
|
||||
)
|
||||
|
||||
def pubsub(self, **kwargs) -> "PubSub":
|
||||
@@ -581,9 +514,7 @@ class Redis(
|
||||
subscribe to channels and listen for messages that get published to
|
||||
them.
|
||||
"""
|
||||
return PubSub(
|
||||
self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
|
||||
)
|
||||
return PubSub(self.connection_pool, **kwargs)
|
||||
|
||||
def monitor(self) -> "Monitor":
|
||||
return Monitor(self.connection_pool)
|
||||
@@ -615,18 +546,15 @@ class Redis(
|
||||
_grl().call_exception_handler(context)
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.connection._close()
|
||||
|
||||
async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
|
||||
"""
|
||||
Closes Redis client connection
|
||||
|
||||
Args:
|
||||
close_connection_pool:
|
||||
decides whether to close the connection pool used by this Redis client,
|
||||
overriding Redis.auto_close_connection_pool.
|
||||
By default, let Redis.auto_close_connection_pool decide
|
||||
whether to close the connection pool.
|
||||
:param close_connection_pool: decides whether to close the connection pool used
|
||||
by this Redis client, overriding Redis.auto_close_connection_pool. By default,
|
||||
let Redis.auto_close_connection_pool decide whether to close the connection
|
||||
pool.
|
||||
"""
|
||||
conn = self.connection
|
||||
if conn:
|
||||
@@ -637,7 +565,7 @@ class Redis(
|
||||
):
|
||||
await self.connection_pool.disconnect()
|
||||
|
||||
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
|
||||
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
|
||||
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
|
||||
"""
|
||||
Alias for aclose(), for backwards compatibility
|
||||
@@ -651,17 +579,18 @@ class Redis(
|
||||
await conn.send_command(*args)
|
||||
return await self.parse_response(conn, command_name, **options)
|
||||
|
||||
async def _close_connection(self, conn: Connection):
|
||||
async def _disconnect_raise(self, conn: Connection, error: Exception):
|
||||
"""
|
||||
Close the connection before retrying.
|
||||
|
||||
The supported exceptions are already checked in the
|
||||
retry object so we don't need to do it here.
|
||||
|
||||
After we disconnect the connection, it will try to reconnect and
|
||||
do a health check as part of the send_command logic(on connection level).
|
||||
Close the connection and raise an exception
|
||||
if retry_on_error is not set or the error
|
||||
is not one of the specified error types
|
||||
"""
|
||||
await conn.disconnect()
|
||||
if (
|
||||
conn.retry_on_error is None
|
||||
or isinstance(error, tuple(conn.retry_on_error)) is False
|
||||
):
|
||||
raise error
|
||||
|
||||
# COMMAND EXECUTION AND PROTOCOL PARSING
|
||||
async def execute_command(self, *args, **options):
|
||||
@@ -669,7 +598,7 @@ class Redis(
|
||||
await self.initialize()
|
||||
pool = self.connection_pool
|
||||
command_name = args[0]
|
||||
conn = self.connection or await pool.get_connection()
|
||||
conn = self.connection or await pool.get_connection(command_name, **options)
|
||||
|
||||
if self.single_connection_client:
|
||||
await self._single_conn_lock.acquire()
|
||||
@@ -678,7 +607,7 @@ class Redis(
|
||||
lambda: self._send_command_parse_response(
|
||||
conn, command_name, *args, **options
|
||||
),
|
||||
lambda _: self._close_connection(conn),
|
||||
lambda error: self._disconnect_raise(conn, error),
|
||||
)
|
||||
finally:
|
||||
if self.single_connection_client:
|
||||
@@ -704,9 +633,6 @@ class Redis(
|
||||
if EMPTY_RESPONSE in options:
|
||||
options.pop(EMPTY_RESPONSE)
|
||||
|
||||
# Remove keys entry, it needs only for cache.
|
||||
options.pop("keys", None)
|
||||
|
||||
if command_name in self.response_callbacks:
|
||||
# Mypy bug: https://github.com/python/mypy/issues/10977
|
||||
command_name = cast(str, command_name)
|
||||
@@ -743,7 +669,7 @@ class Monitor:
|
||||
|
||||
async def connect(self):
|
||||
if self.connection is None:
|
||||
self.connection = await self.connection_pool.get_connection()
|
||||
self.connection = await self.connection_pool.get_connection("MONITOR")
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.connect()
|
||||
@@ -820,12 +746,7 @@ class PubSub:
|
||||
ignore_subscribe_messages: bool = False,
|
||||
encoder=None,
|
||||
push_handler_func: Optional[Callable] = None,
|
||||
event_dispatcher: Optional["EventDispatcher"] = None,
|
||||
):
|
||||
if event_dispatcher is None:
|
||||
self._event_dispatcher = EventDispatcher()
|
||||
else:
|
||||
self._event_dispatcher = event_dispatcher
|
||||
self.connection_pool = connection_pool
|
||||
self.shard_hint = shard_hint
|
||||
self.ignore_subscribe_messages = ignore_subscribe_messages
|
||||
@@ -862,7 +783,7 @@ class PubSub:
|
||||
|
||||
def __del__(self):
|
||||
if self.connection:
|
||||
self.connection.deregister_connect_callback(self.on_connect)
|
||||
self.connection._deregister_connect_callback(self.on_connect)
|
||||
|
||||
async def aclose(self):
|
||||
# In case a connection property does not yet exist
|
||||
@@ -873,7 +794,7 @@ class PubSub:
|
||||
async with self._lock:
|
||||
if self.connection:
|
||||
await self.connection.disconnect()
|
||||
self.connection.deregister_connect_callback(self.on_connect)
|
||||
self.connection._deregister_connect_callback(self.on_connect)
|
||||
await self.connection_pool.release(self.connection)
|
||||
self.connection = None
|
||||
self.channels = {}
|
||||
@@ -881,12 +802,12 @@ class PubSub:
|
||||
self.patterns = {}
|
||||
self.pending_unsubscribe_patterns = set()
|
||||
|
||||
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
|
||||
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close")
|
||||
async def close(self) -> None:
|
||||
"""Alias for aclose(), for backwards compatibility"""
|
||||
await self.aclose()
|
||||
|
||||
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset")
|
||||
@deprecated_function(version="5.0.0", reason="Use aclose() instead", name="reset")
|
||||
async def reset(self) -> None:
|
||||
"""Alias for aclose(), for backwards compatibility"""
|
||||
await self.aclose()
|
||||
@@ -931,26 +852,26 @@ class PubSub:
|
||||
Ensure that the PubSub is connected
|
||||
"""
|
||||
if self.connection is None:
|
||||
self.connection = await self.connection_pool.get_connection()
|
||||
self.connection = await self.connection_pool.get_connection(
|
||||
"pubsub", self.shard_hint
|
||||
)
|
||||
# register a callback that re-subscribes to any channels we
|
||||
# were listening to when we were disconnected
|
||||
self.connection.register_connect_callback(self.on_connect)
|
||||
self.connection._register_connect_callback(self.on_connect)
|
||||
else:
|
||||
await self.connection.connect()
|
||||
if self.push_handler_func is not None:
|
||||
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
|
||||
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
|
||||
self.connection._parser.set_push_handler(self.push_handler_func)
|
||||
|
||||
self._event_dispatcher.dispatch(
|
||||
AfterPubSubConnectionInstantiationEvent(
|
||||
self.connection, self.connection_pool, ClientType.ASYNC, self._lock
|
||||
)
|
||||
)
|
||||
|
||||
async def _reconnect(self, conn):
|
||||
async def _disconnect_raise_connect(self, conn, error):
|
||||
"""
|
||||
Try to reconnect
|
||||
Close the connection and raise an exception
|
||||
if retry_on_timeout is not set or the error
|
||||
is not a TimeoutError. Otherwise, try to reconnect
|
||||
"""
|
||||
await conn.disconnect()
|
||||
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
|
||||
raise error
|
||||
await conn.connect()
|
||||
|
||||
async def _execute(self, conn, command, *args, **kwargs):
|
||||
@@ -963,7 +884,7 @@ class PubSub:
|
||||
"""
|
||||
return await conn.retry.call_with_retry(
|
||||
lambda: command(*args, **kwargs),
|
||||
lambda _: self._reconnect(conn),
|
||||
lambda error: self._disconnect_raise_connect(conn, error),
|
||||
)
|
||||
|
||||
async def parse_response(self, block: bool = True, timeout: float = 0):
|
||||
@@ -1232,11 +1153,13 @@ class PubSub:
|
||||
|
||||
|
||||
class PubsubWorkerExceptionHandler(Protocol):
|
||||
def __call__(self, e: BaseException, pubsub: PubSub): ...
|
||||
def __call__(self, e: BaseException, pubsub: PubSub):
|
||||
...
|
||||
|
||||
|
||||
class AsyncPubsubWorkerExceptionHandler(Protocol):
|
||||
async def __call__(self, e: BaseException, pubsub: PubSub): ...
|
||||
async def __call__(self, e: BaseException, pubsub: PubSub):
|
||||
...
|
||||
|
||||
|
||||
PSWorkerThreadExcHandlerT = Union[
|
||||
@@ -1254,8 +1177,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
|
||||
in one transmission. This is convenient for batch processing, such as
|
||||
saving all the values in a list to Redis.
|
||||
|
||||
All commands executed within a pipeline(when running in transactional mode,
|
||||
which is the default behavior) are wrapped with MULTI and EXEC
|
||||
All commands executed within a pipeline are wrapped with MULTI and EXEC
|
||||
calls. This guarantees all commands executed in the pipeline will be
|
||||
executed atomically.
|
||||
|
||||
@@ -1284,7 +1206,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
|
||||
self.shard_hint = shard_hint
|
||||
self.watching = False
|
||||
self.command_stack: CommandStackT = []
|
||||
self.scripts: Set[Script] = set()
|
||||
self.scripts: Set["Script"] = set()
|
||||
self.explicit_transaction = False
|
||||
|
||||
async def __aenter__(self: _RedisT) -> _RedisT:
|
||||
@@ -1356,50 +1278,49 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
|
||||
return self.immediate_execute_command(*args, **kwargs)
|
||||
return self.pipeline_execute_command(*args, **kwargs)
|
||||
|
||||
async def _disconnect_reset_raise_on_watching(
|
||||
self,
|
||||
conn: Connection,
|
||||
error: Exception,
|
||||
):
|
||||
async def _disconnect_reset_raise(self, conn, error):
|
||||
"""
|
||||
Close the connection reset watching state and
|
||||
raise an exception if we were watching.
|
||||
|
||||
The supported exceptions are already checked in the
|
||||
retry object so we don't need to do it here.
|
||||
|
||||
After we disconnect the connection, it will try to reconnect and
|
||||
do a health check as part of the send_command logic(on connection level).
|
||||
Close the connection, reset watching state and
|
||||
raise an exception if we were watching,
|
||||
retry_on_timeout is not set,
|
||||
or the error is not a TimeoutError
|
||||
"""
|
||||
await conn.disconnect()
|
||||
# if we were already watching a variable, the watch is no longer
|
||||
# valid since this connection has died. raise a WatchError, which
|
||||
# indicates the user should retry this transaction.
|
||||
if self.watching:
|
||||
await self.reset()
|
||||
await self.aclose()
|
||||
raise WatchError(
|
||||
f"A {type(error).__name__} occurred while watching one or more keys"
|
||||
"A ConnectionError occurred on while watching one or more keys"
|
||||
)
|
||||
# if retry_on_timeout is not set, or the error is not
|
||||
# a TimeoutError, raise it
|
||||
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
|
||||
await self.aclose()
|
||||
raise
|
||||
|
||||
async def immediate_execute_command(self, *args, **options):
|
||||
"""
|
||||
Execute a command immediately, but don't auto-retry on the supported
|
||||
errors for retry if we're already WATCHing a variable.
|
||||
Used when issuing WATCH or subsequent commands retrieving their values but before
|
||||
Execute a command immediately, but don't auto-retry on a
|
||||
ConnectionError if we're already WATCHing a variable. Used when
|
||||
issuing WATCH or subsequent commands retrieving their values but before
|
||||
MULTI is called.
|
||||
"""
|
||||
command_name = args[0]
|
||||
conn = self.connection
|
||||
# if this is the first call, we need a connection
|
||||
if not conn:
|
||||
conn = await self.connection_pool.get_connection()
|
||||
conn = await self.connection_pool.get_connection(
|
||||
command_name, self.shard_hint
|
||||
)
|
||||
self.connection = conn
|
||||
|
||||
return await conn.retry.call_with_retry(
|
||||
lambda: self._send_command_parse_response(
|
||||
conn, command_name, *args, **options
|
||||
),
|
||||
lambda error: self._disconnect_reset_raise_on_watching(conn, error),
|
||||
lambda error: self._disconnect_reset_raise(conn, error),
|
||||
)
|
||||
|
||||
def pipeline_execute_command(self, *args, **options):
|
||||
@@ -1484,10 +1405,6 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
|
||||
if not isinstance(r, Exception):
|
||||
args, options = cmd
|
||||
command_name = args[0]
|
||||
|
||||
# Remove keys entry, it needs only for cache.
|
||||
options.pop("keys", None)
|
||||
|
||||
if command_name in self.response_callbacks:
|
||||
r = self.response_callbacks[command_name](r, **options)
|
||||
if inspect.isawaitable(r):
|
||||
@@ -1525,10 +1442,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
|
||||
self, exception: Exception, number: int, command: Iterable[object]
|
||||
) -> None:
|
||||
cmd = " ".join(map(safe_str, command))
|
||||
msg = (
|
||||
f"Command # {number} ({truncate_text(cmd)}) "
|
||||
"of pipeline caused error: {exception.args}"
|
||||
)
|
||||
msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}"
|
||||
exception.args = (msg,) + exception.args[1:]
|
||||
|
||||
async def parse_response(
|
||||
@@ -1554,15 +1468,11 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
|
||||
if not exist:
|
||||
s.sha = await immediate("SCRIPT LOAD", s.script)
|
||||
|
||||
async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception):
|
||||
async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
|
||||
"""
|
||||
Close the connection, raise an exception if we were watching.
|
||||
|
||||
The supported exceptions are already checked in the
|
||||
retry object so we don't need to do it here.
|
||||
|
||||
After we disconnect the connection, it will try to reconnect and
|
||||
do a health check as part of the send_command logic(on connection level).
|
||||
Close the connection, raise an exception if we were watching,
|
||||
and raise an exception if retry_on_timeout is not set,
|
||||
or the error is not a TimeoutError
|
||||
"""
|
||||
await conn.disconnect()
|
||||
# if we were watching a variable, the watch is no longer valid
|
||||
@@ -1570,10 +1480,15 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
|
||||
# indicates the user should retry this transaction.
|
||||
if self.watching:
|
||||
raise WatchError(
|
||||
f"A {type(error).__name__} occurred while watching one or more keys"
|
||||
"A ConnectionError occurred on while watching one or more keys"
|
||||
)
|
||||
# if retry_on_timeout is not set, or the error is not
|
||||
# a TimeoutError, raise it
|
||||
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
|
||||
await self.reset()
|
||||
raise
|
||||
|
||||
async def execute(self, raise_on_error: bool = True) -> List[Any]:
|
||||
async def execute(self, raise_on_error: bool = True):
|
||||
"""Execute all the commands in the current pipeline"""
|
||||
stack = self.command_stack
|
||||
if not stack and not self.watching:
|
||||
@@ -1587,7 +1502,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
|
||||
|
||||
conn = self.connection
|
||||
if not conn:
|
||||
conn = await self.connection_pool.get_connection()
|
||||
conn = await self.connection_pool.get_connection("MULTI", self.shard_hint)
|
||||
# assign to self.connection so reset() releases the connection
|
||||
# back to the pool after we're done
|
||||
self.connection = conn
|
||||
@@ -1596,7 +1511,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass]
|
||||
try:
|
||||
return await conn.retry.call_with_retry(
|
||||
lambda: execute(conn, stack, raise_on_error),
|
||||
lambda error: self._disconnect_raise_on_watching(conn, error),
|
||||
lambda error: self._disconnect_raise_reset(conn, error),
|
||||
)
|
||||
finally:
|
||||
await self.reset()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,8 +3,8 @@ import copy
|
||||
import enum
|
||||
import inspect
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import abstractmethod
|
||||
from itertools import chain
|
||||
@@ -16,30 +16,14 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import ParseResult, parse_qs, unquote, urlparse
|
||||
|
||||
from ..utils import SSL_AVAILABLE
|
||||
|
||||
if SSL_AVAILABLE:
|
||||
import ssl
|
||||
from ssl import SSLContext, TLSVersion
|
||||
else:
|
||||
ssl = None
|
||||
TLSVersion = None
|
||||
SSLContext = None
|
||||
|
||||
from ..auth.token import TokenInterface
|
||||
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
|
||||
from ..utils import deprecated_args, format_error_message
|
||||
|
||||
# the functionality is available in 3.11.x but has a major issue before
|
||||
# 3.11.3. See https://github.com/redis/redis-py/issues/2633
|
||||
if sys.version_info >= (3, 11, 3):
|
||||
@@ -49,6 +33,7 @@ else:
|
||||
|
||||
from redis.asyncio.retry import Retry
|
||||
from redis.backoff import NoBackoff
|
||||
from redis.compat import Protocol, TypedDict
|
||||
from redis.connection import DEFAULT_RESP_VERSION
|
||||
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
|
||||
from redis.exceptions import (
|
||||
@@ -93,11 +78,13 @@ else:
|
||||
|
||||
|
||||
class ConnectCallbackProtocol(Protocol):
|
||||
def __call__(self, connection: "AbstractConnection"): ...
|
||||
def __call__(self, connection: "AbstractConnection"):
|
||||
...
|
||||
|
||||
|
||||
class AsyncConnectCallbackProtocol(Protocol):
|
||||
async def __call__(self, connection: "AbstractConnection"): ...
|
||||
async def __call__(self, connection: "AbstractConnection"):
|
||||
...
|
||||
|
||||
|
||||
ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
|
||||
@@ -159,7 +146,6 @@ class AbstractConnection:
|
||||
encoder_class: Type[Encoder] = Encoder,
|
||||
credential_provider: Optional[CredentialProvider] = None,
|
||||
protocol: Optional[int] = 2,
|
||||
event_dispatcher: Optional[EventDispatcher] = None,
|
||||
):
|
||||
if (username or password) and credential_provider is not None:
|
||||
raise DataError(
|
||||
@@ -168,10 +154,6 @@ class AbstractConnection:
|
||||
"1. 'password' and (optional) 'username'\n"
|
||||
"2. 'credential_provider'"
|
||||
)
|
||||
if event_dispatcher is None:
|
||||
self._event_dispatcher = EventDispatcher()
|
||||
else:
|
||||
self._event_dispatcher = event_dispatcher
|
||||
self.db = db
|
||||
self.client_name = client_name
|
||||
self.lib_name = lib_name
|
||||
@@ -211,8 +193,6 @@ class AbstractConnection:
|
||||
self.set_parser(parser_class)
|
||||
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
|
||||
self._buffer_cutoff = 6000
|
||||
self._re_auth_token: Optional[TokenInterface] = None
|
||||
|
||||
try:
|
||||
p = int(protocol)
|
||||
except TypeError:
|
||||
@@ -224,33 +204,9 @@ class AbstractConnection:
|
||||
raise ConnectionError("protocol must be either 2 or 3")
|
||||
self.protocol = protocol
|
||||
|
||||
def __del__(self, _warnings: Any = warnings):
|
||||
# For some reason, the individual streams don't get properly garbage
|
||||
# collected and therefore produce no resource warnings. We add one
|
||||
# here, in the same style as those from the stdlib.
|
||||
if getattr(self, "_writer", None):
|
||||
_warnings.warn(
|
||||
f"unclosed Connection {self!r}", ResourceWarning, source=self
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
self._close()
|
||||
except RuntimeError:
|
||||
# No actions been taken if pool already closed.
|
||||
pass
|
||||
|
||||
def _close(self):
|
||||
"""
|
||||
Internal method to silently close the connection without waiting
|
||||
"""
|
||||
if self._writer:
|
||||
self._writer.close()
|
||||
self._writer = self._reader = None
|
||||
|
||||
def __repr__(self):
|
||||
repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
|
||||
return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>"
|
||||
return f"{self.__class__.__name__}<{repr_args}>"
|
||||
|
||||
@abstractmethod
|
||||
def repr_pieces(self):
|
||||
@@ -260,24 +216,12 @@ class AbstractConnection:
|
||||
def is_connected(self):
|
||||
return self._reader is not None and self._writer is not None
|
||||
|
||||
def register_connect_callback(self, callback):
|
||||
"""
|
||||
Register a callback to be called when the connection is established either
|
||||
initially or reconnected. This allows listeners to issue commands that
|
||||
are ephemeral to the connection, for example pub/sub subscription or
|
||||
key tracking. The callback must be a _method_ and will be kept as
|
||||
a weak reference.
|
||||
"""
|
||||
def _register_connect_callback(self, callback):
|
||||
wm = weakref.WeakMethod(callback)
|
||||
if wm not in self._connect_callbacks:
|
||||
self._connect_callbacks.append(wm)
|
||||
|
||||
def deregister_connect_callback(self, callback):
|
||||
"""
|
||||
De-register a previously registered callback. It will no-longer receive
|
||||
notifications on connection events. Calling this is not required when the
|
||||
listener goes away, since the callbacks are kept as weak methods.
|
||||
"""
|
||||
def _deregister_connect_callback(self, callback):
|
||||
try:
|
||||
self._connect_callbacks.remove(weakref.WeakMethod(callback))
|
||||
except ValueError:
|
||||
@@ -293,20 +237,12 @@ class AbstractConnection:
|
||||
|
||||
async def connect(self):
|
||||
"""Connects to the Redis server if not already connected"""
|
||||
await self.connect_check_health(check_health=True)
|
||||
|
||||
async def connect_check_health(
|
||||
self, check_health: bool = True, retry_socket_connect: bool = True
|
||||
):
|
||||
if self.is_connected:
|
||||
return
|
||||
try:
|
||||
if retry_socket_connect:
|
||||
await self.retry.call_with_retry(
|
||||
lambda: self._connect(), lambda error: self.disconnect()
|
||||
)
|
||||
else:
|
||||
await self._connect()
|
||||
await self.retry.call_with_retry(
|
||||
lambda: self._connect(), lambda error: self.disconnect()
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise # in 3.7 and earlier, this is an Exception, not BaseException
|
||||
except (socket.timeout, asyncio.TimeoutError):
|
||||
@@ -319,14 +255,12 @@ class AbstractConnection:
|
||||
try:
|
||||
if not self.redis_connect_func:
|
||||
# Use the default on_connect function
|
||||
await self.on_connect_check_health(check_health=check_health)
|
||||
await self.on_connect()
|
||||
else:
|
||||
# Use the passed function redis_connect_func
|
||||
(
|
||||
await self.redis_connect_func(self)
|
||||
if asyncio.iscoroutinefunction(self.redis_connect_func)
|
||||
else self.redis_connect_func(self)
|
||||
)
|
||||
await self.redis_connect_func(self) if asyncio.iscoroutinefunction(
|
||||
self.redis_connect_func
|
||||
) else self.redis_connect_func(self)
|
||||
except RedisError:
|
||||
# clean up after any error in on_connect
|
||||
await self.disconnect()
|
||||
@@ -350,17 +284,12 @@ class AbstractConnection:
|
||||
def _host_error(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _error_message(self, exception: BaseException) -> str:
|
||||
return format_error_message(self._host_error(), exception)
|
||||
|
||||
def get_protocol(self):
|
||||
return self.protocol
|
||||
pass
|
||||
|
||||
async def on_connect(self) -> None:
|
||||
"""Initialize the connection, authenticate and select a database"""
|
||||
await self.on_connect_check_health(check_health=True)
|
||||
|
||||
async def on_connect_check_health(self, check_health: bool = True) -> None:
|
||||
self._parser.on_connect(self)
|
||||
parser = self._parser
|
||||
|
||||
@@ -371,8 +300,7 @@ class AbstractConnection:
|
||||
self.credential_provider
|
||||
or UsernamePasswordCredentialProvider(self.username, self.password)
|
||||
)
|
||||
auth_args = await cred_provider.get_credentials_async()
|
||||
|
||||
auth_args = cred_provider.get_credentials()
|
||||
# if resp version is specified and we have auth args,
|
||||
# we need to send them via HELLO
|
||||
if auth_args and self.protocol not in [2, "2"]:
|
||||
@@ -383,11 +311,7 @@ class AbstractConnection:
|
||||
self._parser.on_connect(self)
|
||||
if len(auth_args) == 1:
|
||||
auth_args = ["default", auth_args[0]]
|
||||
# avoid checking health here -- PING will fail if we try
|
||||
# to check the health prior to the AUTH
|
||||
await self.send_command(
|
||||
"HELLO", self.protocol, "AUTH", *auth_args, check_health=False
|
||||
)
|
||||
await self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
|
||||
response = await self.read_response()
|
||||
if response.get(b"proto") != int(self.protocol) and response.get(
|
||||
"proto"
|
||||
@@ -418,7 +342,7 @@ class AbstractConnection:
|
||||
# update cluster exception classes
|
||||
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
|
||||
self._parser.on_connect(self)
|
||||
await self.send_command("HELLO", self.protocol, check_health=check_health)
|
||||
await self.send_command("HELLO", self.protocol)
|
||||
response = await self.read_response()
|
||||
# if response.get(b"proto") != self.protocol and response.get(
|
||||
# "proto"
|
||||
@@ -427,35 +351,18 @@ class AbstractConnection:
|
||||
|
||||
# if a client_name is given, set it
|
||||
if self.client_name:
|
||||
await self.send_command(
|
||||
"CLIENT",
|
||||
"SETNAME",
|
||||
self.client_name,
|
||||
check_health=check_health,
|
||||
)
|
||||
await self.send_command("CLIENT", "SETNAME", self.client_name)
|
||||
if str_if_bytes(await self.read_response()) != "OK":
|
||||
raise ConnectionError("Error setting client name")
|
||||
|
||||
# set the library name and version, pipeline for lower startup latency
|
||||
if self.lib_name:
|
||||
await self.send_command(
|
||||
"CLIENT",
|
||||
"SETINFO",
|
||||
"LIB-NAME",
|
||||
self.lib_name,
|
||||
check_health=check_health,
|
||||
)
|
||||
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
|
||||
if self.lib_version:
|
||||
await self.send_command(
|
||||
"CLIENT",
|
||||
"SETINFO",
|
||||
"LIB-VER",
|
||||
self.lib_version,
|
||||
check_health=check_health,
|
||||
)
|
||||
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
|
||||
# if a database is specified, switch to it. Also pipeline this
|
||||
if self.db:
|
||||
await self.send_command("SELECT", self.db, check_health=check_health)
|
||||
await self.send_command("SELECT", self.db)
|
||||
|
||||
# read responses from pipeline
|
||||
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
|
||||
@@ -517,8 +424,8 @@ class AbstractConnection:
|
||||
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
|
||||
) -> None:
|
||||
if not self.is_connected:
|
||||
await self.connect_check_health(check_health=False)
|
||||
if check_health:
|
||||
await self.connect()
|
||||
elif check_health:
|
||||
await self.check_health()
|
||||
|
||||
try:
|
||||
@@ -581,7 +488,11 @@ class AbstractConnection:
|
||||
read_timeout = timeout if timeout is not None else self.socket_timeout
|
||||
host_error = self._host_error()
|
||||
try:
|
||||
if read_timeout is not None and self.protocol in ["3", 3]:
|
||||
if (
|
||||
read_timeout is not None
|
||||
and self.protocol in ["3", 3]
|
||||
and not HIREDIS_AVAILABLE
|
||||
):
|
||||
async with async_timeout(read_timeout):
|
||||
response = await self._parser.read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
@@ -591,7 +502,7 @@ class AbstractConnection:
|
||||
response = await self._parser.read_response(
|
||||
disable_decoding=disable_decoding
|
||||
)
|
||||
elif self.protocol in ["3", 3]:
|
||||
elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE:
|
||||
response = await self._parser.read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
@@ -703,27 +614,6 @@ class AbstractConnection:
|
||||
output.append(SYM_EMPTY.join(pieces))
|
||||
return output
|
||||
|
||||
def _socket_is_empty(self):
|
||||
"""Check if the socket is empty"""
|
||||
return len(self._reader._buffer) == 0
|
||||
|
||||
async def process_invalidation_messages(self):
|
||||
while not self._socket_is_empty():
|
||||
await self.read_response(push_request=True)
|
||||
|
||||
def set_re_auth_token(self, token: TokenInterface):
|
||||
self._re_auth_token = token
|
||||
|
||||
async def re_auth(self):
|
||||
if self._re_auth_token is not None:
|
||||
await self.send_command(
|
||||
"AUTH",
|
||||
self._re_auth_token.try_get("oid"),
|
||||
self._re_auth_token.get_value(),
|
||||
)
|
||||
await self.read_response()
|
||||
self._re_auth_token = None
|
||||
|
||||
|
||||
class Connection(AbstractConnection):
|
||||
"Manages TCP communication to and from a Redis server"
|
||||
@@ -781,6 +671,27 @@ class Connection(AbstractConnection):
|
||||
def _host_error(self) -> str:
|
||||
return f"{self.host}:{self.port}"
|
||||
|
||||
def _error_message(self, exception: BaseException) -> str:
|
||||
# args for socket.error can either be (errno, "message")
|
||||
# or just "message"
|
||||
|
||||
host_error = self._host_error()
|
||||
|
||||
if not exception.args:
|
||||
# asyncio has a bug where on Connection reset by peer, the
|
||||
# exception is not instanciated, so args is empty. This is the
|
||||
# workaround.
|
||||
# See: https://github.com/redis/redis-py/issues/2237
|
||||
# See: https://github.com/python/cpython/issues/94061
|
||||
return f"Error connecting to {host_error}. Connection reset by peer"
|
||||
elif len(exception.args) == 1:
|
||||
return f"Error connecting to {host_error}. {exception.args[0]}."
|
||||
else:
|
||||
return (
|
||||
f"Error {exception.args[0]} connecting to {host_error}. "
|
||||
f"{exception.args[0]}."
|
||||
)
|
||||
|
||||
|
||||
class SSLConnection(Connection):
|
||||
"""Manages SSL connections to and from the Redis server(s).
|
||||
@@ -792,17 +703,12 @@ class SSLConnection(Connection):
|
||||
self,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
|
||||
ssl_cert_reqs: str = "required",
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
ssl_ca_data: Optional[str] = None,
|
||||
ssl_check_hostname: bool = True,
|
||||
ssl_min_version: Optional[TLSVersion] = None,
|
||||
ssl_ciphers: Optional[str] = None,
|
||||
ssl_check_hostname: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if not SSL_AVAILABLE:
|
||||
raise RedisError("Python wasn't built with SSL support")
|
||||
|
||||
self.ssl_context: RedisSSLContext = RedisSSLContext(
|
||||
keyfile=ssl_keyfile,
|
||||
certfile=ssl_certfile,
|
||||
@@ -810,8 +716,6 @@ class SSLConnection(Connection):
|
||||
ca_certs=ssl_ca_certs,
|
||||
ca_data=ssl_ca_data,
|
||||
check_hostname=ssl_check_hostname,
|
||||
min_version=ssl_min_version,
|
||||
ciphers=ssl_ciphers,
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -844,10 +748,6 @@ class SSLConnection(Connection):
|
||||
def check_hostname(self):
|
||||
return self.ssl_context.check_hostname
|
||||
|
||||
@property
|
||||
def min_version(self):
|
||||
return self.ssl_context.min_version
|
||||
|
||||
|
||||
class RedisSSLContext:
|
||||
__slots__ = (
|
||||
@@ -858,30 +758,23 @@ class RedisSSLContext:
|
||||
"ca_data",
|
||||
"context",
|
||||
"check_hostname",
|
||||
"min_version",
|
||||
"ciphers",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keyfile: Optional[str] = None,
|
||||
certfile: Optional[str] = None,
|
||||
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
|
||||
cert_reqs: Optional[str] = None,
|
||||
ca_certs: Optional[str] = None,
|
||||
ca_data: Optional[str] = None,
|
||||
check_hostname: bool = False,
|
||||
min_version: Optional[TLSVersion] = None,
|
||||
ciphers: Optional[str] = None,
|
||||
):
|
||||
if not SSL_AVAILABLE:
|
||||
raise RedisError("Python wasn't built with SSL support")
|
||||
|
||||
self.keyfile = keyfile
|
||||
self.certfile = certfile
|
||||
if cert_reqs is None:
|
||||
cert_reqs = ssl.CERT_NONE
|
||||
self.cert_reqs = ssl.CERT_NONE
|
||||
elif isinstance(cert_reqs, str):
|
||||
CERT_REQS = { # noqa: N806
|
||||
CERT_REQS = {
|
||||
"none": ssl.CERT_NONE,
|
||||
"optional": ssl.CERT_OPTIONAL,
|
||||
"required": ssl.CERT_REQUIRED,
|
||||
@@ -890,18 +783,13 @@ class RedisSSLContext:
|
||||
raise RedisError(
|
||||
f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
|
||||
)
|
||||
cert_reqs = CERT_REQS[cert_reqs]
|
||||
self.cert_reqs = cert_reqs
|
||||
self.cert_reqs = CERT_REQS[cert_reqs]
|
||||
self.ca_certs = ca_certs
|
||||
self.ca_data = ca_data
|
||||
self.check_hostname = (
|
||||
check_hostname if self.cert_reqs != ssl.CERT_NONE else False
|
||||
)
|
||||
self.min_version = min_version
|
||||
self.ciphers = ciphers
|
||||
self.context: Optional[SSLContext] = None
|
||||
self.check_hostname = check_hostname
|
||||
self.context: Optional[ssl.SSLContext] = None
|
||||
|
||||
def get(self) -> SSLContext:
|
||||
def get(self) -> ssl.SSLContext:
|
||||
if not self.context:
|
||||
context = ssl.create_default_context()
|
||||
context.check_hostname = self.check_hostname
|
||||
@@ -910,10 +798,6 @@ class RedisSSLContext:
|
||||
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
|
||||
if self.ca_certs or self.ca_data:
|
||||
context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data)
|
||||
if self.min_version is not None:
|
||||
context.minimum_version = self.min_version
|
||||
if self.ciphers is not None:
|
||||
context.set_ciphers(self.ciphers)
|
||||
self.context = context
|
||||
return self.context
|
||||
|
||||
@@ -941,6 +825,20 @@ class UnixDomainSocketConnection(AbstractConnection):
|
||||
def _host_error(self) -> str:
|
||||
return self.path
|
||||
|
||||
def _error_message(self, exception: BaseException) -> str:
|
||||
# args for socket.error can either be (errno, "message")
|
||||
# or just "message"
|
||||
host_error = self._host_error()
|
||||
if len(exception.args) == 1:
|
||||
return (
|
||||
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"Error {exception.args[0]} connecting to unix socket: "
|
||||
f"{host_error}. {exception.args[1]}."
|
||||
)
|
||||
|
||||
|
||||
FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
|
||||
|
||||
@@ -963,7 +861,6 @@ URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyTy
|
||||
"max_connections": int,
|
||||
"health_check_interval": int,
|
||||
"ssl_check_hostname": to_bool,
|
||||
"timeout": float,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -990,7 +887,7 @@ def parse_url(url: str) -> ConnectKwargs:
|
||||
try:
|
||||
kwargs[name] = parser(value)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError(f"Invalid value for '{name}' in connection URL.")
|
||||
raise ValueError(f"Invalid value for `{name}` in connection URL.")
|
||||
else:
|
||||
kwargs[name] = value
|
||||
|
||||
@@ -1042,7 +939,6 @@ class ConnectionPool:
|
||||
By default, TCP connections are created unless ``connection_class``
|
||||
is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
|
||||
unix sockets.
|
||||
:py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
|
||||
|
||||
Any additional keyword arguments are passed to the constructor of
|
||||
``connection_class``.
|
||||
@@ -1112,22 +1008,16 @@ class ConnectionPool:
|
||||
self._available_connections: List[AbstractConnection] = []
|
||||
self._in_use_connections: Set[AbstractConnection] = set()
|
||||
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
|
||||
self._lock = asyncio.Lock()
|
||||
self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
|
||||
if self._event_dispatcher is None:
|
||||
self._event_dispatcher = EventDispatcher()
|
||||
|
||||
def __repr__(self):
|
||||
conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()])
|
||||
return (
|
||||
f"<{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
f"(<{self.connection_class.__module__}.{self.connection_class.__name__}"
|
||||
f"({conn_kwargs})>)>"
|
||||
f"{self.__class__.__name__}"
|
||||
f"<{self.connection_class(**self.connection_kwargs)!r}>"
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self._available_connections = []
|
||||
self._in_use_connections = weakref.WeakSet()
|
||||
self._in_use_connections = set()
|
||||
|
||||
def can_get_connection(self) -> bool:
|
||||
"""Return True if a connection can be retrieved from the pool."""
|
||||
@@ -1136,25 +1026,8 @@ class ConnectionPool:
|
||||
or len(self._in_use_connections) < self.max_connections
|
||||
)
|
||||
|
||||
@deprecated_args(
|
||||
args_to_warn=["*"],
|
||||
reason="Use get_connection() without args instead",
|
||||
version="5.3.0",
|
||||
)
|
||||
async def get_connection(self, command_name=None, *keys, **options):
|
||||
async with self._lock:
|
||||
"""Get a connected connection from the pool"""
|
||||
connection = self.get_available_connection()
|
||||
try:
|
||||
await self.ensure_connection(connection)
|
||||
except BaseException:
|
||||
await self.release(connection)
|
||||
raise
|
||||
|
||||
return connection
|
||||
|
||||
def get_available_connection(self):
|
||||
"""Get a connection from the pool, without making sure it is connected"""
|
||||
async def get_connection(self, command_name, *keys, **options):
|
||||
"""Get a connection from the pool"""
|
||||
try:
|
||||
connection = self._available_connections.pop()
|
||||
except IndexError:
|
||||
@@ -1162,6 +1035,13 @@ class ConnectionPool:
|
||||
raise ConnectionError("Too many connections") from None
|
||||
connection = self.make_connection()
|
||||
self._in_use_connections.add(connection)
|
||||
|
||||
try:
|
||||
await self.ensure_connection(connection)
|
||||
except BaseException:
|
||||
await self.release(connection)
|
||||
raise
|
||||
|
||||
return connection
|
||||
|
||||
def get_encoder(self):
|
||||
@@ -1187,7 +1067,7 @@ class ConnectionPool:
|
||||
try:
|
||||
if await connection.can_read_destructive():
|
||||
raise ConnectionError("Connection has data") from None
|
||||
except (ConnectionError, TimeoutError, OSError):
|
||||
except (ConnectionError, OSError):
|
||||
await connection.disconnect()
|
||||
await connection.connect()
|
||||
if await connection.can_read_destructive():
|
||||
@@ -1199,9 +1079,6 @@ class ConnectionPool:
|
||||
# not doing so is an error that will cause an exception here.
|
||||
self._in_use_connections.remove(connection)
|
||||
self._available_connections.append(connection)
|
||||
await self._event_dispatcher.dispatch_async(
|
||||
AsyncAfterConnectionReleasedEvent(connection)
|
||||
)
|
||||
|
||||
async def disconnect(self, inuse_connections: bool = True):
|
||||
"""
|
||||
@@ -1235,29 +1112,6 @@ class ConnectionPool:
|
||||
for conn in self._in_use_connections:
|
||||
conn.retry = retry
|
||||
|
||||
async def re_auth_callback(self, token: TokenInterface):
|
||||
async with self._lock:
|
||||
for conn in self._available_connections:
|
||||
await conn.retry.call_with_retry(
|
||||
lambda: conn.send_command(
|
||||
"AUTH", token.try_get("oid"), token.get_value()
|
||||
),
|
||||
lambda error: self._mock(error),
|
||||
)
|
||||
await conn.retry.call_with_retry(
|
||||
lambda: conn.read_response(), lambda error: self._mock(error)
|
||||
)
|
||||
for conn in self._in_use_connections:
|
||||
conn.set_re_auth_token(token)
|
||||
|
||||
async def _mock(self, error: RedisError):
|
||||
"""
|
||||
Dummy functions, needs to be passed as error callback to retry object.
|
||||
:param error:
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BlockingConnectionPool(ConnectionPool):
|
||||
"""
|
||||
@@ -1275,7 +1129,7 @@ class BlockingConnectionPool(ConnectionPool):
|
||||
connection from the pool when all of connections are in use, rather than
|
||||
raising a :py:class:`~redis.ConnectionError` (as the default
|
||||
:py:class:`~redis.asyncio.ConnectionPool` implementation does), it
|
||||
blocks the current `Task` for a specified number of seconds until
|
||||
makes blocks the current `Task` for a specified number of seconds until
|
||||
a connection becomes available.
|
||||
|
||||
Use ``max_connections`` to increase / decrease the pool size::
|
||||
@@ -1309,29 +1163,16 @@ class BlockingConnectionPool(ConnectionPool):
|
||||
self._condition = asyncio.Condition()
|
||||
self.timeout = timeout
|
||||
|
||||
@deprecated_args(
|
||||
args_to_warn=["*"],
|
||||
reason="Use get_connection() without args instead",
|
||||
version="5.3.0",
|
||||
)
|
||||
async def get_connection(self, command_name=None, *keys, **options):
|
||||
async def get_connection(self, command_name, *keys, **options):
|
||||
"""Gets a connection from the pool, blocking until one is available"""
|
||||
try:
|
||||
async with self._condition:
|
||||
async with async_timeout(self.timeout):
|
||||
async with async_timeout(self.timeout):
|
||||
async with self._condition:
|
||||
await self._condition.wait_for(self.can_get_connection)
|
||||
connection = super().get_available_connection()
|
||||
return await super().get_connection(command_name, *keys, **options)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise ConnectionError("No connection available.") from err
|
||||
|
||||
# We now perform the connection check outside of the lock.
|
||||
try:
|
||||
await self.ensure_connection(connection)
|
||||
return connection
|
||||
except BaseException:
|
||||
await self.release(connection)
|
||||
raise
|
||||
|
||||
async def release(self, connection: AbstractConnection):
|
||||
"""Releases the connection back to the pool."""
|
||||
async with self._condition:
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Awaitable, Optional, Union
|
||||
|
||||
from redis.exceptions import LockError, LockNotOwnedError
|
||||
from redis.typing import Number
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Lock:
|
||||
"""
|
||||
@@ -86,9 +82,8 @@ class Lock:
|
||||
timeout: Optional[float] = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
blocking_timeout: Optional[float] = None,
|
||||
thread_local: bool = True,
|
||||
raise_on_release_error: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a new Lock instance named ``name`` using the Redis client
|
||||
@@ -132,11 +127,6 @@ class Lock:
|
||||
thread-1 would see the token value as "xyz" and would be
|
||||
able to successfully release the thread-2's lock.
|
||||
|
||||
``raise_on_release_error`` indicates whether to raise an exception when
|
||||
the lock is no longer owned when exiting the context manager. By default,
|
||||
this is True, meaning an exception will be raised. If False, the warning
|
||||
will be logged and the exception will be suppressed.
|
||||
|
||||
In some use cases it's necessary to disable thread local storage. For
|
||||
example, if you have code where one thread acquires a lock and passes
|
||||
that lock instance to a worker thread to release later. If thread
|
||||
@@ -153,7 +143,6 @@ class Lock:
|
||||
self.blocking_timeout = blocking_timeout
|
||||
self.thread_local = bool(thread_local)
|
||||
self.local = threading.local() if self.thread_local else SimpleNamespace()
|
||||
self.raise_on_release_error = raise_on_release_error
|
||||
self.local.token = None
|
||||
self.register_scripts()
|
||||
|
||||
@@ -173,19 +162,12 @@ class Lock:
|
||||
raise LockError("Unable to acquire lock within the time specified")
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
try:
|
||||
await self.release()
|
||||
except LockError:
|
||||
if self.raise_on_release_error:
|
||||
raise
|
||||
logger.warning(
|
||||
"Lock was unlocked or no longer owned when exiting context manager."
|
||||
)
|
||||
await self.release()
|
||||
|
||||
async def acquire(
|
||||
self,
|
||||
blocking: Optional[bool] = None,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
blocking_timeout: Optional[float] = None,
|
||||
token: Optional[Union[str, bytes]] = None,
|
||||
):
|
||||
"""
|
||||
@@ -267,10 +249,7 @@ class Lock:
|
||||
"""Releases the already acquired lock"""
|
||||
expected_token = self.local.token
|
||||
if expected_token is None:
|
||||
raise LockError(
|
||||
"Cannot release a lock that's not owned or is already unlocked.",
|
||||
lock_name=self.name,
|
||||
)
|
||||
raise LockError("Cannot release an unlocked lock")
|
||||
self.local.token = None
|
||||
return self.do_release(expected_token)
|
||||
|
||||
@@ -283,7 +262,7 @@ class Lock:
|
||||
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
|
||||
|
||||
def extend(
|
||||
self, additional_time: Number, replace_ttl: bool = False
|
||||
self, additional_time: float, replace_ttl: bool = False
|
||||
) -> Awaitable[bool]:
|
||||
"""
|
||||
Adds more time to an already acquired lock.
|
||||
|
||||
@@ -2,16 +2,18 @@ from asyncio import sleep
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
|
||||
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
from redis.retry import AbstractRetry
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.backoff import AbstractBackoff
|
||||
|
||||
|
||||
class Retry(AbstractRetry[RedisError]):
|
||||
__hash__ = AbstractRetry.__hash__
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Retry:
|
||||
"""Retry a specific number of times after a failure"""
|
||||
|
||||
__slots__ = "_backoff", "_retries", "_supported_errors"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -22,16 +24,23 @@ class Retry(AbstractRetry[RedisError]):
|
||||
TimeoutError,
|
||||
),
|
||||
):
|
||||
super().__init__(backoff, retries, supported_errors)
|
||||
"""
|
||||
Initialize a `Retry` object with a `Backoff` object
|
||||
that retries a maximum of `retries` times.
|
||||
`retries` can be negative to retry forever.
|
||||
You can specify the types of supported errors which trigger
|
||||
a retry with the `supported_errors` parameter.
|
||||
"""
|
||||
self._backoff = backoff
|
||||
self._retries = retries
|
||||
self._supported_errors = supported_errors
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Retry):
|
||||
return NotImplemented
|
||||
|
||||
return (
|
||||
self._backoff == other._backoff
|
||||
and self._retries == other._retries
|
||||
and set(self._supported_errors) == set(other._supported_errors)
|
||||
def update_supported_errors(self, specified_errors: list):
|
||||
"""
|
||||
Updates the supported errors with the specified error types
|
||||
"""
|
||||
self._supported_errors = tuple(
|
||||
set(self._supported_errors + tuple(specified_errors))
|
||||
)
|
||||
|
||||
async def call_with_retry(
|
||||
|
||||
@@ -11,12 +11,8 @@ from redis.asyncio.connection import (
|
||||
SSLConnection,
|
||||
)
|
||||
from redis.commands import AsyncSentinelCommands
|
||||
from redis.exceptions import (
|
||||
ConnectionError,
|
||||
ReadOnlyError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
)
|
||||
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
|
||||
from redis.utils import str_if_bytes
|
||||
|
||||
|
||||
class MasterNotFoundError(ConnectionError):
|
||||
@@ -33,18 +29,20 @@ class SentinelManagedConnection(Connection):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
s = f"<{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
pool = self.connection_pool
|
||||
s = f"{self.__class__.__name__}<service={pool.service_name}"
|
||||
if self.host:
|
||||
host_info = f",host={self.host},port={self.port}"
|
||||
s += host_info
|
||||
return s + ")>"
|
||||
return s + ">"
|
||||
|
||||
async def connect_to(self, address):
|
||||
self.host, self.port = address
|
||||
await self.connect_check_health(
|
||||
check_health=self.connection_pool.check_connection,
|
||||
retry_socket_connect=False,
|
||||
)
|
||||
await super().connect()
|
||||
if self.connection_pool.check_connection:
|
||||
await self.send_command("PING")
|
||||
if str_if_bytes(await self.read_response()) != "PONG":
|
||||
raise ConnectionError("PING failed")
|
||||
|
||||
async def _connect_retry(self):
|
||||
if self._reader:
|
||||
@@ -107,11 +105,9 @@ class SentinelConnectionPool(ConnectionPool):
|
||||
def __init__(self, service_name, sentinel_manager, **kwargs):
|
||||
kwargs["connection_class"] = kwargs.get(
|
||||
"connection_class",
|
||||
(
|
||||
SentinelManagedSSLConnection
|
||||
if kwargs.pop("ssl", False)
|
||||
else SentinelManagedConnection
|
||||
),
|
||||
SentinelManagedSSLConnection
|
||||
if kwargs.pop("ssl", False)
|
||||
else SentinelManagedConnection,
|
||||
)
|
||||
self.is_master = kwargs.pop("is_master", True)
|
||||
self.check_connection = kwargs.pop("check_connection", False)
|
||||
@@ -124,8 +120,8 @@ class SentinelConnectionPool(ConnectionPool):
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>"
|
||||
f"{self.__class__.__name__}"
|
||||
f"<service={self.service_name}({self.is_master and 'master' or 'slave'})>"
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
@@ -201,7 +197,6 @@ class Sentinel(AsyncSentinelCommands):
|
||||
sentinels,
|
||||
min_other_sentinels=0,
|
||||
sentinel_kwargs=None,
|
||||
force_master_ip=None,
|
||||
**connection_kwargs,
|
||||
):
|
||||
# if sentinel_kwargs isn't defined, use the socket_* options from
|
||||
@@ -218,7 +213,6 @@ class Sentinel(AsyncSentinelCommands):
|
||||
]
|
||||
self.min_other_sentinels = min_other_sentinels
|
||||
self.connection_kwargs = connection_kwargs
|
||||
self._force_master_ip = force_master_ip
|
||||
|
||||
async def execute_command(self, *args, **kwargs):
|
||||
"""
|
||||
@@ -226,31 +220,19 @@ class Sentinel(AsyncSentinelCommands):
|
||||
once - If set to True, then execute the resulting command on a single
|
||||
node at random, rather than across the entire sentinel cluster.
|
||||
"""
|
||||
once = bool(kwargs.pop("once", False))
|
||||
|
||||
# Check if command is supposed to return the original
|
||||
# responses instead of boolean value.
|
||||
return_responses = bool(kwargs.pop("return_responses", False))
|
||||
once = bool(kwargs.get("once", False))
|
||||
if "once" in kwargs.keys():
|
||||
kwargs.pop("once")
|
||||
|
||||
if once:
|
||||
response = await random.choice(self.sentinels).execute_command(
|
||||
*args, **kwargs
|
||||
)
|
||||
if return_responses:
|
||||
return [response]
|
||||
else:
|
||||
return True if response else False
|
||||
|
||||
tasks = [
|
||||
asyncio.Task(sentinel.execute_command(*args, **kwargs))
|
||||
for sentinel in self.sentinels
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
if return_responses:
|
||||
return responses
|
||||
|
||||
return all(responses)
|
||||
await random.choice(self.sentinels).execute_command(*args, **kwargs)
|
||||
else:
|
||||
tasks = [
|
||||
asyncio.Task(sentinel.execute_command(*args, **kwargs))
|
||||
for sentinel in self.sentinels
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
sentinel_addresses = []
|
||||
@@ -259,10 +241,7 @@ class Sentinel(AsyncSentinelCommands):
|
||||
f"{sentinel.connection_pool.connection_kwargs['host']}:"
|
||||
f"{sentinel.connection_pool.connection_kwargs['port']}"
|
||||
)
|
||||
return (
|
||||
f"<{self.__class__}.{self.__class__.__name__}"
|
||||
f"(sentinels=[{','.join(sentinel_addresses)}])>"
|
||||
)
|
||||
return f"{self.__class__.__name__}<sentinels=[{','.join(sentinel_addresses)}]>"
|
||||
|
||||
def check_master_state(self, state: dict, service_name: str) -> bool:
|
||||
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
|
||||
@@ -294,13 +273,7 @@ class Sentinel(AsyncSentinelCommands):
|
||||
sentinel,
|
||||
self.sentinels[0],
|
||||
)
|
||||
|
||||
ip = (
|
||||
self._force_master_ip
|
||||
if self._force_master_ip is not None
|
||||
else state["ip"]
|
||||
)
|
||||
return ip, state["port"]
|
||||
return state["ip"], state["port"]
|
||||
|
||||
error_info = ""
|
||||
if len(collected_errors) > 0:
|
||||
@@ -341,8 +314,6 @@ class Sentinel(AsyncSentinelCommands):
|
||||
):
|
||||
"""
|
||||
Returns a redis client instance for the ``service_name`` master.
|
||||
Sentinel client will detect failover and reconnect Redis clients
|
||||
automatically.
|
||||
|
||||
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
|
||||
used to retrieve the master's address before establishing a new
|
||||
|
||||
@@ -16,7 +16,7 @@ def from_url(url, **kwargs):
|
||||
return Redis.from_url(url, **kwargs)
|
||||
|
||||
|
||||
class pipeline: # noqa: N801
|
||||
class pipeline:
|
||||
def __init__(self, redis_obj: "Redis"):
|
||||
self.p: "Pipeline" = redis_obj.pipeline()
|
||||
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class RequestTokenErr(Exception):
|
||||
"""
|
||||
Represents an exception during token request.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class InvalidTokenSchemaErr(Exception):
|
||||
"""
|
||||
Represents an exception related to invalid token schema.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_fields: Iterable[str] = []):
|
||||
super().__init__(
|
||||
"Unexpected token schema. Following fields are missing: "
|
||||
+ ", ".join(missing_fields)
|
||||
)
|
||||
|
||||
|
||||
class TokenRenewalErr(Exception):
|
||||
"""
|
||||
Represents an exception during token renewal process.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
@@ -1,28 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from redis.auth.token import TokenInterface
|
||||
|
||||
"""
|
||||
This interface is the facade of an identity provider
|
||||
"""
|
||||
|
||||
|
||||
class IdentityProviderInterface(ABC):
|
||||
"""
|
||||
Receive a token from the identity provider.
|
||||
Receiving a token only works when being authenticated.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def request_token(self, force_refresh=False) -> TokenInterface:
|
||||
pass
|
||||
|
||||
|
||||
class IdentityProviderConfigInterface(ABC):
|
||||
"""
|
||||
Configuration class that provides a configured identity provider.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_provider(self) -> IdentityProviderInterface:
|
||||
pass
|
||||
@@ -1,130 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from redis.auth.err import InvalidTokenSchemaErr
|
||||
|
||||
|
||||
class TokenInterface(ABC):
|
||||
@abstractmethod
|
||||
def is_expired(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def ttl(self) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def try_get(self, key: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_value(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_expires_at_ms(self) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_received_at_ms(self) -> float:
|
||||
pass
|
||||
|
||||
|
||||
class TokenResponse:
|
||||
def __init__(self, token: TokenInterface):
|
||||
self._token = token
|
||||
|
||||
def get_token(self) -> TokenInterface:
|
||||
return self._token
|
||||
|
||||
def get_ttl_ms(self) -> float:
|
||||
return self._token.get_expires_at_ms() - self._token.get_received_at_ms()
|
||||
|
||||
|
||||
class SimpleToken(TokenInterface):
|
||||
def __init__(
|
||||
self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict
|
||||
) -> None:
|
||||
self.value = value
|
||||
self.expires_at = expires_at_ms
|
||||
self.received_at = received_at_ms
|
||||
self.claims = claims
|
||||
|
||||
def ttl(self) -> float:
|
||||
if self.expires_at == -1:
|
||||
return -1
|
||||
|
||||
return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
if self.expires_at == -1:
|
||||
return False
|
||||
|
||||
return self.ttl() <= 0
|
||||
|
||||
def try_get(self, key: str) -> str:
|
||||
return self.claims.get(key)
|
||||
|
||||
def get_value(self) -> str:
|
||||
return self.value
|
||||
|
||||
def get_expires_at_ms(self) -> float:
|
||||
return self.expires_at
|
||||
|
||||
def get_received_at_ms(self) -> float:
|
||||
return self.received_at
|
||||
|
||||
|
||||
class JWToken(TokenInterface):
|
||||
REQUIRED_FIELDS = {"exp"}
|
||||
|
||||
def __init__(self, token: str):
|
||||
try:
|
||||
import jwt
|
||||
except ImportError as ie:
|
||||
raise ImportError(
|
||||
f"The PyJWT library is required for {self.__class__.__name__}.",
|
||||
) from ie
|
||||
self._value = token
|
||||
self._decoded = jwt.decode(
|
||||
self._value,
|
||||
options={"verify_signature": False},
|
||||
algorithms=[jwt.get_unverified_header(self._value).get("alg")],
|
||||
)
|
||||
self._validate_token()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
exp = self._decoded["exp"]
|
||||
if exp == -1:
|
||||
return False
|
||||
|
||||
return (
|
||||
self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000
|
||||
)
|
||||
|
||||
def ttl(self) -> float:
|
||||
exp = self._decoded["exp"]
|
||||
if exp == -1:
|
||||
return -1
|
||||
|
||||
return (
|
||||
self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000
|
||||
)
|
||||
|
||||
def try_get(self, key: str) -> str:
|
||||
return self._decoded.get(key)
|
||||
|
||||
def get_value(self) -> str:
|
||||
return self._value
|
||||
|
||||
def get_expires_at_ms(self) -> float:
|
||||
return float(self._decoded["exp"] * 1000)
|
||||
|
||||
def get_received_at_ms(self) -> float:
|
||||
return datetime.now(timezone.utc).timestamp() * 1000
|
||||
|
||||
def _validate_token(self):
|
||||
actual_fields = {x for x in self._decoded.keys()}
|
||||
|
||||
if len(self.REQUIRED_FIELDS - actual_fields) != 0:
|
||||
raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields)
|
||||
@@ -1,370 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from time import sleep
|
||||
from typing import Any, Awaitable, Callable, Union
|
||||
|
||||
from redis.auth.err import RequestTokenErr, TokenRenewalErr
|
||||
from redis.auth.idp import IdentityProviderInterface
|
||||
from redis.auth.token import TokenResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialsListener:
|
||||
"""
|
||||
Listeners that will be notified on events related to credentials.
|
||||
Accepts callbacks and awaitable callbacks.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._on_next = None
|
||||
self._on_error = None
|
||||
|
||||
@property
|
||||
def on_next(self) -> Union[Callable[[Any], None], Awaitable]:
|
||||
return self._on_next
|
||||
|
||||
@on_next.setter
|
||||
def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None:
|
||||
self._on_next = callback
|
||||
|
||||
@property
|
||||
def on_error(self) -> Union[Callable[[Exception], None], Awaitable]:
|
||||
return self._on_error
|
||||
|
||||
@on_error.setter
|
||||
def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None:
|
||||
self._on_error = callback
|
||||
|
||||
|
||||
class RetryPolicy:
|
||||
def __init__(self, max_attempts: int, delay_in_ms: float):
|
||||
self.max_attempts = max_attempts
|
||||
self.delay_in_ms = delay_in_ms
|
||||
|
||||
def get_max_attempts(self) -> int:
|
||||
"""
|
||||
Retry attempts before exception will be thrown.
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return self.max_attempts
|
||||
|
||||
def get_delay_in_ms(self) -> float:
|
||||
"""
|
||||
Delay between retries in seconds.
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return self.delay_in_ms
|
||||
|
||||
|
||||
class TokenManagerConfig:
|
||||
def __init__(
|
||||
self,
|
||||
expiration_refresh_ratio: float,
|
||||
lower_refresh_bound_millis: int,
|
||||
token_request_execution_timeout_in_ms: int,
|
||||
retry_policy: RetryPolicy,
|
||||
):
|
||||
self._expiration_refresh_ratio = expiration_refresh_ratio
|
||||
self._lower_refresh_bound_millis = lower_refresh_bound_millis
|
||||
self._token_request_execution_timeout_in_ms = (
|
||||
token_request_execution_timeout_in_ms
|
||||
)
|
||||
self._retry_policy = retry_policy
|
||||
|
||||
def get_expiration_refresh_ratio(self) -> float:
|
||||
"""
|
||||
Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501
|
||||
For example, a value of 0.75 means the token should be refreshed
|
||||
when 75% of its lifetime has elapsed (or when 25% of its lifetime remains).
|
||||
|
||||
:return: float
|
||||
"""
|
||||
|
||||
return self._expiration_refresh_ratio
|
||||
|
||||
def get_lower_refresh_bound_millis(self) -> int:
|
||||
"""
|
||||
Represents the minimum time in milliseconds before token expiration
|
||||
to trigger a refresh, in milliseconds.
|
||||
This value sets a fixed lower bound for when a token refresh should occur,
|
||||
regardless of the token's total lifetime.
|
||||
If set to 0 there will be no lower bound and the refresh will be triggered
|
||||
based on the expirationRefreshRatio only.
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return self._lower_refresh_bound_millis
|
||||
|
||||
def get_token_request_execution_timeout_in_ms(self) -> int:
|
||||
"""
|
||||
Represents the maximum time in milliseconds to wait
|
||||
for a token request to complete.
|
||||
|
||||
:return: int
|
||||
"""
|
||||
return self._token_request_execution_timeout_in_ms
|
||||
|
||||
def get_retry_policy(self) -> RetryPolicy:
|
||||
"""
|
||||
Represents the retry policy for token requests.
|
||||
|
||||
:return: RetryPolicy
|
||||
"""
|
||||
return self._retry_policy
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(
|
||||
self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig
|
||||
):
|
||||
self._idp = identity_provider
|
||||
self._config = config
|
||||
self._next_timer = None
|
||||
self._listener = None
|
||||
self._init_timer = None
|
||||
self._retries = 0
|
||||
|
||||
def __del__(self):
|
||||
logger.info("Token manager are disposed")
|
||||
self.stop()
|
||||
|
||||
def start(
|
||||
self,
|
||||
listener: CredentialsListener,
|
||||
skip_initial: bool = False,
|
||||
) -> Callable[[], None]:
|
||||
self._listener = listener
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# Run loop in a separate thread to unblock main thread.
|
||||
loop = asyncio.new_event_loop()
|
||||
thread = threading.Thread(
|
||||
target=_start_event_loop_in_thread, args=(loop,), daemon=True
|
||||
)
|
||||
thread.start()
|
||||
|
||||
# Event to block for initial execution.
|
||||
init_event = asyncio.Event()
|
||||
self._init_timer = loop.call_later(
|
||||
0, self._renew_token, skip_initial, init_event
|
||||
)
|
||||
logger.info("Token manager started")
|
||||
|
||||
# Blocks in thread-safe manner.
|
||||
asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result()
|
||||
return self.stop
|
||||
|
||||
async def start_async(
|
||||
self,
|
||||
listener: CredentialsListener,
|
||||
block_for_initial: bool = False,
|
||||
initial_delay_in_ms: float = 0,
|
||||
skip_initial: bool = False,
|
||||
) -> Callable[[], None]:
|
||||
self._listener = listener
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
init_event = asyncio.Event()
|
||||
|
||||
# Wraps the async callback with async wrapper to schedule with loop.call_later()
|
||||
wrapped = _async_to_sync_wrapper(
|
||||
loop, self._renew_token_async, skip_initial, init_event
|
||||
)
|
||||
self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped)
|
||||
logger.info("Token manager started")
|
||||
|
||||
if block_for_initial:
|
||||
await init_event.wait()
|
||||
|
||||
return self.stop
|
||||
|
||||
def stop(self):
|
||||
if self._init_timer is not None:
|
||||
self._init_timer.cancel()
|
||||
if self._next_timer is not None:
|
||||
self._next_timer.cancel()
|
||||
|
||||
def acquire_token(self, force_refresh=False) -> TokenResponse:
|
||||
try:
|
||||
token = self._idp.request_token(force_refresh)
|
||||
except RequestTokenErr as e:
|
||||
if self._retries < self._config.get_retry_policy().get_max_attempts():
|
||||
self._retries += 1
|
||||
sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000)
|
||||
return self.acquire_token(force_refresh)
|
||||
else:
|
||||
raise e
|
||||
|
||||
self._retries = 0
|
||||
return TokenResponse(token)
|
||||
|
||||
async def acquire_token_async(self, force_refresh=False) -> TokenResponse:
|
||||
try:
|
||||
token = self._idp.request_token(force_refresh)
|
||||
except RequestTokenErr as e:
|
||||
if self._retries < self._config.get_retry_policy().get_max_attempts():
|
||||
self._retries += 1
|
||||
await asyncio.sleep(
|
||||
self._config.get_retry_policy().get_delay_in_ms() / 1000
|
||||
)
|
||||
return await self.acquire_token_async(force_refresh)
|
||||
else:
|
||||
raise e
|
||||
|
||||
self._retries = 0
|
||||
return TokenResponse(token)
|
||||
|
||||
def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float:
|
||||
delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date)
|
||||
delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date)
|
||||
delay = min(delay_for_ratio_refresh, delay_for_lower_refresh)
|
||||
|
||||
return 0 if delay < 0 else delay / 1000
|
||||
|
||||
def _delay_for_lower_refresh(self, expire_date: float):
|
||||
return (
|
||||
expire_date
|
||||
- self._config.get_lower_refresh_bound_millis()
|
||||
- (datetime.now(timezone.utc).timestamp() * 1000)
|
||||
)
|
||||
|
||||
def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float):
|
||||
token_ttl = expire_date - issue_date
|
||||
refresh_before = token_ttl - (
|
||||
token_ttl * self._config.get_expiration_refresh_ratio()
|
||||
)
|
||||
|
||||
return (
|
||||
expire_date
|
||||
- refresh_before
|
||||
- (datetime.now(timezone.utc).timestamp() * 1000)
|
||||
)
|
||||
|
||||
def _renew_token(
|
||||
self, skip_initial: bool = False, init_event: asyncio.Event = None
|
||||
):
|
||||
"""
|
||||
Task to renew token from identity provider.
|
||||
Schedules renewal tasks based on token TTL.
|
||||
"""
|
||||
|
||||
try:
|
||||
token_res = self.acquire_token(force_refresh=True)
|
||||
delay = self._calculate_renewal_delay(
|
||||
token_res.get_token().get_expires_at_ms(),
|
||||
token_res.get_token().get_received_at_ms(),
|
||||
)
|
||||
|
||||
if token_res.get_token().is_expired():
|
||||
raise TokenRenewalErr("Requested token is expired")
|
||||
|
||||
if self._listener.on_next is None:
|
||||
logger.warning(
|
||||
"No registered callback for token renewal task. Renewal cancelled"
|
||||
)
|
||||
return
|
||||
|
||||
if not skip_initial:
|
||||
try:
|
||||
self._listener.on_next(token_res.get_token())
|
||||
except Exception as e:
|
||||
raise TokenRenewalErr(e)
|
||||
|
||||
if delay <= 0:
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
self._next_timer = loop.call_later(delay, self._renew_token)
|
||||
logger.info(f"Next token renewal scheduled in {delay} seconds")
|
||||
return token_res
|
||||
except Exception as e:
|
||||
if self._listener.on_error is None:
|
||||
raise e
|
||||
|
||||
self._listener.on_error(e)
|
||||
finally:
|
||||
if init_event:
|
||||
init_event.set()
|
||||
|
||||
async def _renew_token_async(
|
||||
self, skip_initial: bool = False, init_event: asyncio.Event = None
|
||||
):
|
||||
"""
|
||||
Async task to renew tokens from identity provider.
|
||||
Schedules renewal tasks based on token TTL.
|
||||
"""
|
||||
|
||||
try:
|
||||
token_res = await self.acquire_token_async(force_refresh=True)
|
||||
delay = self._calculate_renewal_delay(
|
||||
token_res.get_token().get_expires_at_ms(),
|
||||
token_res.get_token().get_received_at_ms(),
|
||||
)
|
||||
|
||||
if token_res.get_token().is_expired():
|
||||
raise TokenRenewalErr("Requested token is expired")
|
||||
|
||||
if self._listener.on_next is None:
|
||||
logger.warning(
|
||||
"No registered callback for token renewal task. Renewal cancelled"
|
||||
)
|
||||
return
|
||||
|
||||
if not skip_initial:
|
||||
try:
|
||||
await self._listener.on_next(token_res.get_token())
|
||||
except Exception as e:
|
||||
raise TokenRenewalErr(e)
|
||||
|
||||
if delay <= 0:
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
wrapped = _async_to_sync_wrapper(loop, self._renew_token_async)
|
||||
logger.info(f"Next token renewal scheduled in {delay} seconds")
|
||||
loop.call_later(delay, wrapped)
|
||||
except Exception as e:
|
||||
if self._listener.on_error is None:
|
||||
raise e
|
||||
|
||||
await self._listener.on_error(e)
|
||||
finally:
|
||||
if init_event:
|
||||
init_event.set()
|
||||
|
||||
|
||||
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):
|
||||
"""
|
||||
Wraps an asynchronous function so it can be used with loop.call_later.
|
||||
|
||||
:param loop: The event loop in which the coroutine will be executed.
|
||||
:param coro_func: The coroutine function to wrap.
|
||||
:param args: Positional arguments to pass to the coroutine function.
|
||||
:param kwargs: Keyword arguments to pass to the coroutine function.
|
||||
:return: A regular function suitable for loop.call_later.
|
||||
"""
|
||||
|
||||
def wrapped():
|
||||
# Schedule the coroutine in the event loop
|
||||
asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop):
|
||||
"""
|
||||
Starts event loop in a thread.
|
||||
Used to be able to schedule tasks using loop.call_later.
|
||||
|
||||
:param event_loop:
|
||||
:return:
|
||||
"""
|
||||
asyncio.set_event_loop(event_loop)
|
||||
event_loop.run_forever()
|
||||
@@ -19,7 +19,7 @@ class AbstractBackoff(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, failures: int) -> float:
|
||||
def compute(self, failures):
|
||||
"""Compute backoff in seconds upon failure"""
|
||||
pass
|
||||
|
||||
@@ -27,34 +27,25 @@ class AbstractBackoff(ABC):
|
||||
class ConstantBackoff(AbstractBackoff):
|
||||
"""Constant backoff upon failure"""
|
||||
|
||||
def __init__(self, backoff: float) -> None:
|
||||
def __init__(self, backoff):
|
||||
"""`backoff`: backoff time in seconds"""
|
||||
self._backoff = backoff
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._backoff,))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ConstantBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._backoff == other._backoff
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
def compute(self, failures):
|
||||
return self._backoff
|
||||
|
||||
|
||||
class NoBackoff(ConstantBackoff):
|
||||
"""No backoff upon failure"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
super().__init__(0)
|
||||
|
||||
|
||||
class ExponentialBackoff(AbstractBackoff):
|
||||
"""Exponential backoff upon failure"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE):
|
||||
def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE):
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
@@ -62,23 +53,14 @@ class ExponentialBackoff(AbstractBackoff):
|
||||
self._cap = cap
|
||||
self._base = base
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ExponentialBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
def compute(self, failures):
|
||||
return min(self._cap, self._base * 2**failures)
|
||||
|
||||
|
||||
class FullJitterBackoff(AbstractBackoff):
|
||||
"""Full jitter backoff upon failure"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
|
||||
def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE):
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
@@ -86,23 +68,14 @@ class FullJitterBackoff(AbstractBackoff):
|
||||
self._cap = cap
|
||||
self._base = base
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, FullJitterBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
def compute(self, failures):
|
||||
return random.uniform(0, min(self._cap, self._base * 2**failures))
|
||||
|
||||
|
||||
class EqualJitterBackoff(AbstractBackoff):
|
||||
"""Equal jitter backoff upon failure"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
|
||||
def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE):
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
@@ -110,16 +83,7 @@ class EqualJitterBackoff(AbstractBackoff):
|
||||
self._cap = cap
|
||||
self._base = base
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, EqualJitterBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
def compute(self, failures):
|
||||
temp = min(self._cap, self._base * 2**failures) / 2
|
||||
return temp + random.uniform(0, temp)
|
||||
|
||||
@@ -127,7 +91,7 @@ class EqualJitterBackoff(AbstractBackoff):
|
||||
class DecorrelatedJitterBackoff(AbstractBackoff):
|
||||
"""Decorrelated jitter backoff upon failure"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
|
||||
def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE):
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
@@ -136,48 +100,15 @@ class DecorrelatedJitterBackoff(AbstractBackoff):
|
||||
self._base = base
|
||||
self._previous_backoff = 0
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, DecorrelatedJitterBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def reset(self) -> None:
|
||||
def reset(self):
|
||||
self._previous_backoff = 0
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
def compute(self, failures):
|
||||
max_backoff = max(self._base, self._previous_backoff * 3)
|
||||
temp = random.uniform(self._base, max_backoff)
|
||||
self._previous_backoff = min(self._cap, temp)
|
||||
return self._previous_backoff
|
||||
|
||||
|
||||
class ExponentialWithJitterBackoff(AbstractBackoff):
|
||||
"""Exponential backoff upon failure, with jitter"""
|
||||
|
||||
def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None:
|
||||
"""
|
||||
`cap`: maximum backoff time in seconds
|
||||
`base`: base backoff time in seconds
|
||||
"""
|
||||
self._cap = cap
|
||||
self._base = base
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._base, self._cap))
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ExponentialWithJitterBackoff):
|
||||
return NotImplemented
|
||||
|
||||
return self._base == other._base and self._cap == other._cap
|
||||
|
||||
def compute(self, failures: int) -> float:
|
||||
return min(self._cap, random.random() * self._base * 2**failures)
|
||||
|
||||
|
||||
def default_backoff():
|
||||
return EqualJitterBackoff()
|
||||
|
||||
@@ -1,401 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
|
||||
class CacheEntryStatus(Enum):
|
||||
VALID = "VALID"
|
||||
IN_PROGRESS = "IN_PROGRESS"
|
||||
|
||||
|
||||
class EvictionPolicyType(Enum):
|
||||
time_based = "time_based"
|
||||
frequency_based = "frequency_based"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheKey:
|
||||
command: str
|
||||
redis_keys: tuple
|
||||
|
||||
|
||||
class CacheEntry:
|
||||
def __init__(
|
||||
self,
|
||||
cache_key: CacheKey,
|
||||
cache_value: bytes,
|
||||
status: CacheEntryStatus,
|
||||
connection_ref,
|
||||
):
|
||||
self.cache_key = cache_key
|
||||
self.cache_value = cache_value
|
||||
self.status = status
|
||||
self.connection_ref = connection_ref
|
||||
|
||||
def __hash__(self):
|
||||
return hash(
|
||||
(self.cache_key, self.cache_value, self.status, self.connection_ref)
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return hash(self) == hash(other)
|
||||
|
||||
|
||||
class EvictionPolicyInterface(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def cache(self):
|
||||
pass
|
||||
|
||||
@cache.setter
|
||||
def cache(self, value):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> EvictionPolicyType:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict_next(self) -> CacheKey:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict_many(self, count: int) -> List[CacheKey]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def touch(self, cache_key: CacheKey) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class CacheConfigurationInterface(ABC):
|
||||
@abstractmethod
|
||||
def get_cache_class(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_max_size(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_eviction_policy(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_exceeds_max_size(self, count: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_allowed_to_cache(self, command: str) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class CacheInterface(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def collection(self) -> OrderedDict:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> CacheConfigurationInterface:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def eviction_policy(self) -> EvictionPolicyInterface:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def size(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, entry: CacheEntry) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def flush(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_cachable(self, key: CacheKey) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultCache(CacheInterface):
|
||||
def __init__(
|
||||
self,
|
||||
cache_config: CacheConfigurationInterface,
|
||||
) -> None:
|
||||
self._cache = OrderedDict()
|
||||
self._cache_config = cache_config
|
||||
self._eviction_policy = self._cache_config.get_eviction_policy().value()
|
||||
self._eviction_policy.cache = self
|
||||
|
||||
@property
|
||||
def collection(self) -> OrderedDict:
|
||||
return self._cache
|
||||
|
||||
@property
|
||||
def config(self) -> CacheConfigurationInterface:
|
||||
return self._cache_config
|
||||
|
||||
@property
|
||||
def eviction_policy(self) -> EvictionPolicyInterface:
|
||||
return self._eviction_policy
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self._cache)
|
||||
|
||||
def set(self, entry: CacheEntry) -> bool:
|
||||
if not self.is_cachable(entry.cache_key):
|
||||
return False
|
||||
|
||||
self._cache[entry.cache_key] = entry
|
||||
self._eviction_policy.touch(entry.cache_key)
|
||||
|
||||
if self._cache_config.is_exceeds_max_size(len(self._cache)):
|
||||
self._eviction_policy.evict_next()
|
||||
|
||||
return True
|
||||
|
||||
def get(self, key: CacheKey) -> Union[CacheEntry, None]:
|
||||
entry = self._cache.get(key, None)
|
||||
|
||||
if entry is None:
|
||||
return None
|
||||
|
||||
self._eviction_policy.touch(key)
|
||||
return entry
|
||||
|
||||
def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
|
||||
response = []
|
||||
|
||||
for key in cache_keys:
|
||||
if self.get(key) is not None:
|
||||
self._cache.pop(key)
|
||||
response.append(True)
|
||||
else:
|
||||
response.append(False)
|
||||
|
||||
return response
|
||||
|
||||
def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
|
||||
response = []
|
||||
keys_to_delete = []
|
||||
|
||||
for redis_key in redis_keys:
|
||||
if isinstance(redis_key, bytes):
|
||||
redis_key = redis_key.decode()
|
||||
for cache_key in self._cache:
|
||||
if redis_key in cache_key.redis_keys:
|
||||
keys_to_delete.append(cache_key)
|
||||
response.append(True)
|
||||
|
||||
for key in keys_to_delete:
|
||||
self._cache.pop(key)
|
||||
|
||||
return response
|
||||
|
||||
def flush(self) -> int:
|
||||
elem_count = len(self._cache)
|
||||
self._cache.clear()
|
||||
return elem_count
|
||||
|
||||
def is_cachable(self, key: CacheKey) -> bool:
|
||||
return self._cache_config.is_allowed_to_cache(key.command)
|
||||
|
||||
|
||||
class LRUPolicy(EvictionPolicyInterface):
|
||||
def __init__(self):
|
||||
self.cache = None
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._cache
|
||||
|
||||
@cache.setter
|
||||
def cache(self, cache: CacheInterface):
|
||||
self._cache = cache
|
||||
|
||||
@property
|
||||
def type(self) -> EvictionPolicyType:
|
||||
return EvictionPolicyType.time_based
|
||||
|
||||
def evict_next(self) -> CacheKey:
|
||||
self._assert_cache()
|
||||
popped_entry = self._cache.collection.popitem(last=False)
|
||||
return popped_entry[0]
|
||||
|
||||
def evict_many(self, count: int) -> List[CacheKey]:
|
||||
self._assert_cache()
|
||||
if count > len(self._cache.collection):
|
||||
raise ValueError("Evictions count is above cache size")
|
||||
|
||||
popped_keys = []
|
||||
|
||||
for _ in range(count):
|
||||
popped_entry = self._cache.collection.popitem(last=False)
|
||||
popped_keys.append(popped_entry[0])
|
||||
|
||||
return popped_keys
|
||||
|
||||
def touch(self, cache_key: CacheKey) -> None:
|
||||
self._assert_cache()
|
||||
|
||||
if self._cache.collection.get(cache_key) is None:
|
||||
raise ValueError("Given entry does not belong to the cache")
|
||||
|
||||
self._cache.collection.move_to_end(cache_key)
|
||||
|
||||
def _assert_cache(self):
|
||||
if self.cache is None or not isinstance(self.cache, CacheInterface):
|
||||
raise ValueError("Eviction policy should be associated with valid cache.")
|
||||
|
||||
|
||||
class EvictionPolicy(Enum):
|
||||
LRU = LRUPolicy
|
||||
|
||||
|
||||
class CacheConfig(CacheConfigurationInterface):
|
||||
DEFAULT_CACHE_CLASS = DefaultCache
|
||||
DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU
|
||||
DEFAULT_MAX_SIZE = 10000
|
||||
|
||||
DEFAULT_ALLOW_LIST = [
|
||||
"BITCOUNT",
|
||||
"BITFIELD_RO",
|
||||
"BITPOS",
|
||||
"EXISTS",
|
||||
"GEODIST",
|
||||
"GEOHASH",
|
||||
"GEOPOS",
|
||||
"GEORADIUSBYMEMBER_RO",
|
||||
"GEORADIUS_RO",
|
||||
"GEOSEARCH",
|
||||
"GET",
|
||||
"GETBIT",
|
||||
"GETRANGE",
|
||||
"HEXISTS",
|
||||
"HGET",
|
||||
"HGETALL",
|
||||
"HKEYS",
|
||||
"HLEN",
|
||||
"HMGET",
|
||||
"HSTRLEN",
|
||||
"HVALS",
|
||||
"JSON.ARRINDEX",
|
||||
"JSON.ARRLEN",
|
||||
"JSON.GET",
|
||||
"JSON.MGET",
|
||||
"JSON.OBJKEYS",
|
||||
"JSON.OBJLEN",
|
||||
"JSON.RESP",
|
||||
"JSON.STRLEN",
|
||||
"JSON.TYPE",
|
||||
"LCS",
|
||||
"LINDEX",
|
||||
"LLEN",
|
||||
"LPOS",
|
||||
"LRANGE",
|
||||
"MGET",
|
||||
"SCARD",
|
||||
"SDIFF",
|
||||
"SINTER",
|
||||
"SINTERCARD",
|
||||
"SISMEMBER",
|
||||
"SMEMBERS",
|
||||
"SMISMEMBER",
|
||||
"SORT_RO",
|
||||
"STRLEN",
|
||||
"SUBSTR",
|
||||
"SUNION",
|
||||
"TS.GET",
|
||||
"TS.INFO",
|
||||
"TS.RANGE",
|
||||
"TS.REVRANGE",
|
||||
"TYPE",
|
||||
"XLEN",
|
||||
"XPENDING",
|
||||
"XRANGE",
|
||||
"XREAD",
|
||||
"XREVRANGE",
|
||||
"ZCARD",
|
||||
"ZCOUNT",
|
||||
"ZDIFF",
|
||||
"ZINTER",
|
||||
"ZINTERCARD",
|
||||
"ZLEXCOUNT",
|
||||
"ZMSCORE",
|
||||
"ZRANGE",
|
||||
"ZRANGEBYLEX",
|
||||
"ZRANGEBYSCORE",
|
||||
"ZRANK",
|
||||
"ZREVRANGE",
|
||||
"ZREVRANGEBYLEX",
|
||||
"ZREVRANGEBYSCORE",
|
||||
"ZREVRANK",
|
||||
"ZSCORE",
|
||||
"ZUNION",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = DEFAULT_MAX_SIZE,
|
||||
cache_class: Any = DEFAULT_CACHE_CLASS,
|
||||
eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
|
||||
):
|
||||
self._cache_class = cache_class
|
||||
self._max_size = max_size
|
||||
self._eviction_policy = eviction_policy
|
||||
|
||||
def get_cache_class(self):
|
||||
return self._cache_class
|
||||
|
||||
def get_max_size(self) -> int:
|
||||
return self._max_size
|
||||
|
||||
def get_eviction_policy(self) -> EvictionPolicy:
|
||||
return self._eviction_policy
|
||||
|
||||
def is_exceeds_max_size(self, count: int) -> bool:
|
||||
return count > self._max_size
|
||||
|
||||
def is_allowed_to_cache(self, command: str) -> bool:
|
||||
return command in self.DEFAULT_ALLOW_LIST
|
||||
|
||||
|
||||
class CacheFactoryInterface(ABC):
|
||||
@abstractmethod
|
||||
def get_cache(self) -> CacheInterface:
|
||||
pass
|
||||
|
||||
|
||||
class CacheFactory(CacheFactoryInterface):
|
||||
def __init__(self, cache_config: Optional[CacheConfig] = None):
|
||||
self._config = cache_config
|
||||
|
||||
if self._config is None:
|
||||
self._config = CacheConfig()
|
||||
|
||||
def get_cache(self) -> CacheInterface:
|
||||
cache_class = self._config.get_cache_class()
|
||||
return cache_class(cache_config=self._config)
|
||||
421
venv/lib/python3.12/site-packages/redis/client.py
Executable file → Normal file
421
venv/lib/python3.12/site-packages/redis/client.py
Executable file → Normal file
@@ -2,19 +2,9 @@ import copy
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
from redis._parsers.encoders import Encoder
|
||||
from redis._parsers.helpers import (
|
||||
@@ -23,54 +13,33 @@ from redis._parsers.helpers import (
|
||||
_RedisCallbacksRESP3,
|
||||
bool_ok,
|
||||
)
|
||||
from redis.backoff import ExponentialWithJitterBackoff
|
||||
from redis.cache import CacheConfig, CacheInterface
|
||||
from redis.commands import (
|
||||
CoreCommands,
|
||||
RedisModuleCommands,
|
||||
SentinelCommands,
|
||||
list_or_args,
|
||||
)
|
||||
from redis.commands.core import Script
|
||||
from redis.connection import (
|
||||
AbstractConnection,
|
||||
Connection,
|
||||
ConnectionPool,
|
||||
SSLConnection,
|
||||
UnixDomainSocketConnection,
|
||||
)
|
||||
from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection
|
||||
from redis.credentials import CredentialProvider
|
||||
from redis.event import (
|
||||
AfterPooledConnectionsInstantiationEvent,
|
||||
AfterPubSubConnectionInstantiationEvent,
|
||||
AfterSingleConnectionInstantiationEvent,
|
||||
ClientType,
|
||||
EventDispatcher,
|
||||
)
|
||||
from redis.exceptions import (
|
||||
ConnectionError,
|
||||
ExecAbortError,
|
||||
PubSubError,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
WatchError,
|
||||
)
|
||||
from redis.lock import Lock
|
||||
from redis.retry import Retry
|
||||
from redis.utils import (
|
||||
HIREDIS_AVAILABLE,
|
||||
_set_info_logger,
|
||||
deprecated_args,
|
||||
get_lib_version,
|
||||
safe_str,
|
||||
str_if_bytes,
|
||||
truncate_text,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ssl
|
||||
|
||||
import OpenSSL
|
||||
|
||||
SYM_EMPTY = b""
|
||||
EMPTY_RESPONSE = "EMPTY_RESPONSE"
|
||||
|
||||
@@ -125,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs) -> "Redis":
|
||||
def from_url(cls, url: str, **kwargs) -> None:
|
||||
"""
|
||||
Return a Redis client object configured from the given URL
|
||||
|
||||
@@ -191,80 +160,56 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
client.auto_close_connection_pool = True
|
||||
return client
|
||||
|
||||
@deprecated_args(
|
||||
args_to_warn=["retry_on_timeout"],
|
||||
reason="TimeoutError is included by default.",
|
||||
version="6.0.0",
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: Optional[str] = None,
|
||||
socket_timeout: Optional[float] = None,
|
||||
socket_connect_timeout: Optional[float] = None,
|
||||
socket_keepalive: Optional[bool] = None,
|
||||
socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
|
||||
connection_pool: Optional[ConnectionPool] = None,
|
||||
unix_socket_path: Optional[str] = None,
|
||||
encoding: str = "utf-8",
|
||||
encoding_errors: str = "strict",
|
||||
decode_responses: bool = False,
|
||||
retry_on_timeout: bool = False,
|
||||
retry: Retry = Retry(
|
||||
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
|
||||
),
|
||||
retry_on_error: Optional[List[Type[Exception]]] = None,
|
||||
ssl: bool = False,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required",
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
ssl_ca_path: Optional[str] = None,
|
||||
ssl_ca_data: Optional[str] = None,
|
||||
ssl_check_hostname: bool = True,
|
||||
ssl_password: Optional[str] = None,
|
||||
ssl_validate_ocsp: bool = False,
|
||||
ssl_validate_ocsp_stapled: bool = False,
|
||||
ssl_ocsp_context: Optional["OpenSSL.SSL.Context"] = None,
|
||||
ssl_ocsp_expected_cert: Optional[str] = None,
|
||||
ssl_min_version: Optional["ssl.TLSVersion"] = None,
|
||||
ssl_ciphers: Optional[str] = None,
|
||||
max_connections: Optional[int] = None,
|
||||
single_connection_client: bool = False,
|
||||
health_check_interval: int = 0,
|
||||
client_name: Optional[str] = None,
|
||||
lib_name: Optional[str] = "redis-py",
|
||||
lib_version: Optional[str] = get_lib_version(),
|
||||
username: Optional[str] = None,
|
||||
redis_connect_func: Optional[Callable[[], None]] = None,
|
||||
host="localhost",
|
||||
port=6379,
|
||||
db=0,
|
||||
password=None,
|
||||
socket_timeout=None,
|
||||
socket_connect_timeout=None,
|
||||
socket_keepalive=None,
|
||||
socket_keepalive_options=None,
|
||||
connection_pool=None,
|
||||
unix_socket_path=None,
|
||||
encoding="utf-8",
|
||||
encoding_errors="strict",
|
||||
charset=None,
|
||||
errors=None,
|
||||
decode_responses=False,
|
||||
retry_on_timeout=False,
|
||||
retry_on_error=None,
|
||||
ssl=False,
|
||||
ssl_keyfile=None,
|
||||
ssl_certfile=None,
|
||||
ssl_cert_reqs="required",
|
||||
ssl_ca_certs=None,
|
||||
ssl_ca_path=None,
|
||||
ssl_ca_data=None,
|
||||
ssl_check_hostname=False,
|
||||
ssl_password=None,
|
||||
ssl_validate_ocsp=False,
|
||||
ssl_validate_ocsp_stapled=False,
|
||||
ssl_ocsp_context=None,
|
||||
ssl_ocsp_expected_cert=None,
|
||||
max_connections=None,
|
||||
single_connection_client=False,
|
||||
health_check_interval=0,
|
||||
client_name=None,
|
||||
lib_name="redis-py",
|
||||
lib_version=get_lib_version(),
|
||||
username=None,
|
||||
retry=None,
|
||||
redis_connect_func=None,
|
||||
credential_provider: Optional[CredentialProvider] = None,
|
||||
protocol: Optional[int] = 2,
|
||||
cache: Optional[CacheInterface] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
event_dispatcher: Optional[EventDispatcher] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a new Redis client.
|
||||
|
||||
To specify a retry policy for specific errors, you have two options:
|
||||
|
||||
1. Set the `retry_on_error` to a list of the error/s to retry on, and
|
||||
you can also set `retry` to a valid `Retry` object(in case the default
|
||||
one is not appropriate) - with this approach the retries will be triggered
|
||||
on the default errors specified in the Retry object enriched with the
|
||||
errors specified in `retry_on_error`.
|
||||
|
||||
2. Define a `Retry` object with configured 'supported_errors' and set
|
||||
it to the `retry` parameter - with this approach you completely redefine
|
||||
the errors on which retries will happen.
|
||||
|
||||
`retry_on_timeout` is deprecated - please include the TimeoutError
|
||||
either in the Retry object or in the `retry_on_error` list.
|
||||
|
||||
When 'connection_pool' is provided - the retry configuration of the
|
||||
provided pool will be used.
|
||||
To specify a retry policy for specific errors, first set
|
||||
`retry_on_error` to a list of the error/s to retry on, then set
|
||||
`retry` to a valid `Retry` object.
|
||||
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -272,13 +217,25 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
if `True`, connection pool is not used. In that case `Redis`
|
||||
instance use is not thread safe.
|
||||
"""
|
||||
if event_dispatcher is None:
|
||||
self._event_dispatcher = EventDispatcher()
|
||||
else:
|
||||
self._event_dispatcher = event_dispatcher
|
||||
if not connection_pool:
|
||||
if charset is not None:
|
||||
warnings.warn(
|
||||
DeprecationWarning(
|
||||
'"charset" is deprecated. Use "encoding" instead'
|
||||
)
|
||||
)
|
||||
encoding = charset
|
||||
if errors is not None:
|
||||
warnings.warn(
|
||||
DeprecationWarning(
|
||||
'"errors" is deprecated. Use "encoding_errors" instead'
|
||||
)
|
||||
)
|
||||
encoding_errors = errors
|
||||
if not retry_on_error:
|
||||
retry_on_error = []
|
||||
if retry_on_timeout is True:
|
||||
retry_on_error.append(TimeoutError)
|
||||
kwargs = {
|
||||
"db": db,
|
||||
"username": username,
|
||||
@@ -334,50 +291,17 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
"ssl_validate_ocsp": ssl_validate_ocsp,
|
||||
"ssl_ocsp_context": ssl_ocsp_context,
|
||||
"ssl_ocsp_expected_cert": ssl_ocsp_expected_cert,
|
||||
"ssl_min_version": ssl_min_version,
|
||||
"ssl_ciphers": ssl_ciphers,
|
||||
}
|
||||
)
|
||||
if (cache_config or cache) and protocol in [3, "3"]:
|
||||
kwargs.update(
|
||||
{
|
||||
"cache": cache,
|
||||
"cache_config": cache_config,
|
||||
}
|
||||
)
|
||||
connection_pool = ConnectionPool(**kwargs)
|
||||
self._event_dispatcher.dispatch(
|
||||
AfterPooledConnectionsInstantiationEvent(
|
||||
[connection_pool], ClientType.SYNC, credential_provider
|
||||
)
|
||||
)
|
||||
self.auto_close_connection_pool = True
|
||||
else:
|
||||
self.auto_close_connection_pool = False
|
||||
self._event_dispatcher.dispatch(
|
||||
AfterPooledConnectionsInstantiationEvent(
|
||||
[connection_pool], ClientType.SYNC, credential_provider
|
||||
)
|
||||
)
|
||||
|
||||
self.connection_pool = connection_pool
|
||||
|
||||
if (cache_config or cache) and self.connection_pool.get_protocol() not in [
|
||||
3,
|
||||
"3",
|
||||
]:
|
||||
raise RedisError("Client caching is only supported with RESP version 3")
|
||||
|
||||
self.single_connection_lock = threading.RLock()
|
||||
self.connection = None
|
||||
self._single_connection_client = single_connection_client
|
||||
if self._single_connection_client:
|
||||
self.connection = self.connection_pool.get_connection()
|
||||
self._event_dispatcher.dispatch(
|
||||
AfterSingleConnectionInstantiationEvent(
|
||||
self.connection, ClientType.SYNC, self.single_connection_lock
|
||||
)
|
||||
)
|
||||
if single_connection_client:
|
||||
self.connection = self.connection_pool.get_connection("_")
|
||||
|
||||
self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks)
|
||||
|
||||
@@ -387,10 +311,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
self.response_callbacks.update(_RedisCallbacksRESP2)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<{type(self).__module__}.{type(self).__name__}"
|
||||
f"({repr(self.connection_pool)})>"
|
||||
)
|
||||
return f"{type(self).__name__}<{repr(self.connection_pool)}>"
|
||||
|
||||
def get_encoder(self) -> "Encoder":
|
||||
"""Get the connection pool's encoder"""
|
||||
@@ -400,10 +321,10 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
"""Get the connection's key-word arguments"""
|
||||
return self.connection_pool.connection_kwargs
|
||||
|
||||
def get_retry(self) -> Optional[Retry]:
|
||||
def get_retry(self) -> Optional["Retry"]:
|
||||
return self.get_connection_kwargs().get("retry")
|
||||
|
||||
def set_retry(self, retry: Retry) -> None:
|
||||
def set_retry(self, retry: "Retry") -> None:
|
||||
self.get_connection_kwargs().update({"retry": retry})
|
||||
self.connection_pool.set_retry(retry)
|
||||
|
||||
@@ -448,7 +369,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
|
||||
def transaction(
|
||||
self, func: Callable[["Pipeline"], None], *watches, **kwargs
|
||||
) -> Union[List[Any], Any, None]:
|
||||
) -> None:
|
||||
"""
|
||||
Convenience method for executing the callable `func` as a transaction
|
||||
while watching all keys specified in `watches`. The 'func' callable
|
||||
@@ -479,7 +400,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
blocking_timeout: Optional[float] = None,
|
||||
lock_class: Union[None, Any] = None,
|
||||
thread_local: bool = True,
|
||||
raise_on_release_error: bool = True,
|
||||
):
|
||||
"""
|
||||
Return a new Lock object using key ``name`` that mimics
|
||||
@@ -526,11 +446,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
thread-1 would see the token value as "xyz" and would be
|
||||
able to successfully release the thread-2's lock.
|
||||
|
||||
``raise_on_release_error`` indicates whether to raise an exception when
|
||||
the lock is no longer owned when exiting the context manager. By default,
|
||||
this is True, meaning an exception will be raised. If False, the warning
|
||||
will be logged and the exception will be suppressed.
|
||||
|
||||
In some use cases it's necessary to disable thread local storage. For
|
||||
example, if you have code where one thread acquires a lock and passes
|
||||
that lock instance to a worker thread to release later. If thread
|
||||
@@ -548,7 +463,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
thread_local=thread_local,
|
||||
raise_on_release_error=raise_on_release_error,
|
||||
)
|
||||
|
||||
def pubsub(self, **kwargs):
|
||||
@@ -557,9 +471,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
subscribe to channels and listen for messages that get published to
|
||||
them.
|
||||
"""
|
||||
return PubSub(
|
||||
self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
|
||||
)
|
||||
return PubSub(self.connection_pool, **kwargs)
|
||||
|
||||
def monitor(self):
|
||||
return Monitor(self.connection_pool)
|
||||
@@ -576,12 +488,9 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self):
|
||||
# In case a connection property does not yet exist
|
||||
# (due to a crash earlier in the Redis() constructor), return
|
||||
# immediately as there is nothing to clean-up.
|
||||
@@ -600,44 +509,37 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
"""
|
||||
Send a command and parse the response
|
||||
"""
|
||||
conn.send_command(*args, **options)
|
||||
conn.send_command(*args)
|
||||
return self.parse_response(conn, command_name, **options)
|
||||
|
||||
def _close_connection(self, conn) -> None:
|
||||
def _disconnect_raise(self, conn, error):
|
||||
"""
|
||||
Close the connection before retrying.
|
||||
|
||||
The supported exceptions are already checked in the
|
||||
retry object so we don't need to do it here.
|
||||
|
||||
After we disconnect the connection, it will try to reconnect and
|
||||
do a health check as part of the send_command logic(on connection level).
|
||||
Close the connection and raise an exception
|
||||
if retry_on_error is not set or the error
|
||||
is not one of the specified error types
|
||||
"""
|
||||
|
||||
conn.disconnect()
|
||||
if (
|
||||
conn.retry_on_error is None
|
||||
or isinstance(error, tuple(conn.retry_on_error)) is False
|
||||
):
|
||||
raise error
|
||||
|
||||
# COMMAND EXECUTION AND PROTOCOL PARSING
|
||||
def execute_command(self, *args, **options):
|
||||
return self._execute_command(*args, **options)
|
||||
|
||||
def _execute_command(self, *args, **options):
|
||||
"""Execute a command and return a parsed response"""
|
||||
pool = self.connection_pool
|
||||
command_name = args[0]
|
||||
conn = self.connection or pool.get_connection()
|
||||
conn = self.connection or pool.get_connection(command_name, **options)
|
||||
|
||||
if self._single_connection_client:
|
||||
self.single_connection_lock.acquire()
|
||||
try:
|
||||
return conn.retry.call_with_retry(
|
||||
lambda: self._send_command_parse_response(
|
||||
conn, command_name, *args, **options
|
||||
),
|
||||
lambda _: self._close_connection(conn),
|
||||
lambda error: self._disconnect_raise(conn, error),
|
||||
)
|
||||
finally:
|
||||
if self._single_connection_client:
|
||||
self.single_connection_lock.release()
|
||||
if not self.connection:
|
||||
pool.release(conn)
|
||||
|
||||
@@ -657,16 +559,10 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
|
||||
if EMPTY_RESPONSE in options:
|
||||
options.pop(EMPTY_RESPONSE)
|
||||
|
||||
# Remove keys entry, it needs only for cache.
|
||||
options.pop("keys", None)
|
||||
|
||||
if command_name in self.response_callbacks:
|
||||
return self.response_callbacks[command_name](response, **options)
|
||||
return response
|
||||
|
||||
def get_cache(self) -> Optional[CacheInterface]:
|
||||
return self.connection_pool.cache
|
||||
|
||||
|
||||
StrictRedis = Redis
|
||||
|
||||
@@ -683,7 +579,7 @@ class Monitor:
|
||||
|
||||
def __init__(self, connection_pool):
|
||||
self.connection_pool = connection_pool
|
||||
self.connection = self.connection_pool.get_connection()
|
||||
self.connection = self.connection_pool.get_connection("MONITOR")
|
||||
|
||||
def __enter__(self):
|
||||
self.connection.send_command("MONITOR")
|
||||
@@ -758,7 +654,6 @@ class PubSub:
|
||||
ignore_subscribe_messages: bool = False,
|
||||
encoder: Optional["Encoder"] = None,
|
||||
push_handler_func: Union[None, Callable[[str], None]] = None,
|
||||
event_dispatcher: Optional["EventDispatcher"] = None,
|
||||
):
|
||||
self.connection_pool = connection_pool
|
||||
self.shard_hint = shard_hint
|
||||
@@ -769,12 +664,6 @@ class PubSub:
|
||||
# to lookup channel and pattern names for callback handlers.
|
||||
self.encoder = encoder
|
||||
self.push_handler_func = push_handler_func
|
||||
if event_dispatcher is None:
|
||||
self._event_dispatcher = EventDispatcher()
|
||||
else:
|
||||
self._event_dispatcher = event_dispatcher
|
||||
|
||||
self._lock = threading.RLock()
|
||||
if self.encoder is None:
|
||||
self.encoder = self.connection_pool.get_encoder()
|
||||
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
|
||||
@@ -804,7 +693,7 @@ class PubSub:
|
||||
def reset(self) -> None:
|
||||
if self.connection:
|
||||
self.connection.disconnect()
|
||||
self.connection.deregister_connect_callback(self.on_connect)
|
||||
self.connection._deregister_connect_callback(self.on_connect)
|
||||
self.connection_pool.release(self.connection)
|
||||
self.connection = None
|
||||
self.health_check_response_counter = 0
|
||||
@@ -857,23 +746,19 @@ class PubSub:
|
||||
# subscribed to one or more channels
|
||||
|
||||
if self.connection is None:
|
||||
self.connection = self.connection_pool.get_connection()
|
||||
self.connection = self.connection_pool.get_connection(
|
||||
"pubsub", self.shard_hint
|
||||
)
|
||||
# register a callback that re-subscribes to any channels we
|
||||
# were listening to when we were disconnected
|
||||
self.connection.register_connect_callback(self.on_connect)
|
||||
if self.push_handler_func is not None:
|
||||
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
|
||||
self._event_dispatcher.dispatch(
|
||||
AfterPubSubConnectionInstantiationEvent(
|
||||
self.connection, self.connection_pool, ClientType.SYNC, self._lock
|
||||
)
|
||||
)
|
||||
self.connection._register_connect_callback(self.on_connect)
|
||||
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
|
||||
self.connection._parser.set_push_handler(self.push_handler_func)
|
||||
connection = self.connection
|
||||
kwargs = {"check_health": not self.subscribed}
|
||||
if not self.subscribed:
|
||||
self.clean_health_check_responses()
|
||||
with self._lock:
|
||||
self._execute(connection, connection.send_command, *args, **kwargs)
|
||||
self._execute(connection, connection.send_command, *args, **kwargs)
|
||||
|
||||
def clean_health_check_responses(self) -> None:
|
||||
"""
|
||||
@@ -889,18 +774,19 @@ class PubSub:
|
||||
else:
|
||||
raise PubSubError(
|
||||
"A non health check response was cleaned by "
|
||||
"execute_command: {}".format(response)
|
||||
"execute_command: {0}".format(response)
|
||||
)
|
||||
ttl -= 1
|
||||
|
||||
def _reconnect(self, conn) -> None:
|
||||
def _disconnect_raise_connect(self, conn, error) -> None:
|
||||
"""
|
||||
The supported exceptions are already checked in the
|
||||
retry object so we don't need to do it here.
|
||||
|
||||
In this error handler we are trying to reconnect to the server.
|
||||
Close the connection and raise an exception
|
||||
if retry_on_timeout is not set or the error
|
||||
is not a TimeoutError. Otherwise, try to reconnect
|
||||
"""
|
||||
conn.disconnect()
|
||||
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
|
||||
raise error
|
||||
conn.connect()
|
||||
|
||||
def _execute(self, conn, command, *args, **kwargs):
|
||||
@@ -913,7 +799,7 @@ class PubSub:
|
||||
"""
|
||||
return conn.retry.call_with_retry(
|
||||
lambda: command(*args, **kwargs),
|
||||
lambda _: self._reconnect(conn),
|
||||
lambda error: self._disconnect_raise_connect(conn, error),
|
||||
)
|
||||
|
||||
def parse_response(self, block=True, timeout=0):
|
||||
@@ -962,7 +848,7 @@ class PubSub:
|
||||
"did you forget to call subscribe() or psubscribe()?"
|
||||
)
|
||||
|
||||
if conn.health_check_interval and time.monotonic() > conn.next_health_check:
|
||||
if conn.health_check_interval and time.time() > conn.next_health_check:
|
||||
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
|
||||
self.health_check_response_counter += 1
|
||||
|
||||
@@ -1112,12 +998,12 @@ class PubSub:
|
||||
"""
|
||||
if not self.subscribed:
|
||||
# Wait for subscription
|
||||
start_time = time.monotonic()
|
||||
start_time = time.time()
|
||||
if self.subscribed_event.wait(timeout) is True:
|
||||
# The connection was subscribed during the timeout time frame.
|
||||
# The timeout should be adjusted based on the time spent
|
||||
# waiting for the subscription
|
||||
time_spent = time.monotonic() - start_time
|
||||
time_spent = time.time() - start_time
|
||||
timeout = max(0.0, timeout - time_spent)
|
||||
else:
|
||||
# The connection isn't subscribed to any channels or patterns,
|
||||
@@ -1214,7 +1100,7 @@ class PubSub:
|
||||
|
||||
def run_in_thread(
|
||||
self,
|
||||
sleep_time: float = 0.0,
|
||||
sleep_time: int = 0,
|
||||
daemon: bool = False,
|
||||
exception_handler: Optional[Callable] = None,
|
||||
) -> "PubSubWorkerThread":
|
||||
@@ -1282,8 +1168,7 @@ class Pipeline(Redis):
|
||||
in one transmission. This is convenient for batch processing, such as
|
||||
saving all the values in a list to Redis.
|
||||
|
||||
All commands executed within a pipeline(when running in transactional mode,
|
||||
which is the default behavior) are wrapped with MULTI and EXEC
|
||||
All commands executed within a pipeline are wrapped with MULTI and EXEC
|
||||
calls. This guarantees all commands executed in the pipeline will be
|
||||
executed atomically.
|
||||
|
||||
@@ -1298,22 +1183,15 @@ class Pipeline(Redis):
|
||||
|
||||
UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_pool: ConnectionPool,
|
||||
response_callbacks,
|
||||
transaction,
|
||||
shard_hint,
|
||||
):
|
||||
def __init__(self, connection_pool, response_callbacks, transaction, shard_hint):
|
||||
self.connection_pool = connection_pool
|
||||
self.connection: Optional[Connection] = None
|
||||
self.connection = None
|
||||
self.response_callbacks = response_callbacks
|
||||
self.transaction = transaction
|
||||
self.shard_hint = shard_hint
|
||||
|
||||
self.watching = False
|
||||
self.command_stack = []
|
||||
self.scripts: Set[Script] = set()
|
||||
self.explicit_transaction = False
|
||||
self.reset()
|
||||
|
||||
def __enter__(self) -> "Pipeline":
|
||||
return self
|
||||
@@ -1379,51 +1257,47 @@ class Pipeline(Redis):
|
||||
return self.immediate_execute_command(*args, **kwargs)
|
||||
return self.pipeline_execute_command(*args, **kwargs)
|
||||
|
||||
def _disconnect_reset_raise_on_watching(
|
||||
self,
|
||||
conn: AbstractConnection,
|
||||
error: Exception,
|
||||
) -> None:
|
||||
def _disconnect_reset_raise(self, conn, error) -> None:
|
||||
"""
|
||||
Close the connection reset watching state and
|
||||
raise an exception if we were watching.
|
||||
|
||||
The supported exceptions are already checked in the
|
||||
retry object so we don't need to do it here.
|
||||
|
||||
After we disconnect the connection, it will try to reconnect and
|
||||
do a health check as part of the send_command logic(on connection level).
|
||||
Close the connection, reset watching state and
|
||||
raise an exception if we were watching,
|
||||
retry_on_timeout is not set,
|
||||
or the error is not a TimeoutError
|
||||
"""
|
||||
conn.disconnect()
|
||||
|
||||
# if we were already watching a variable, the watch is no longer
|
||||
# valid since this connection has died. raise a WatchError, which
|
||||
# indicates the user should retry this transaction.
|
||||
if self.watching:
|
||||
self.reset()
|
||||
raise WatchError(
|
||||
f"A {type(error).__name__} occurred while watching one or more keys"
|
||||
"A ConnectionError occurred on while watching one or more keys"
|
||||
)
|
||||
# if retry_on_timeout is not set, or the error is not
|
||||
# a TimeoutError, raise it
|
||||
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
|
||||
self.reset()
|
||||
raise
|
||||
|
||||
def immediate_execute_command(self, *args, **options):
|
||||
"""
|
||||
Execute a command immediately, but don't auto-retry on the supported
|
||||
errors for retry if we're already WATCHing a variable.
|
||||
Used when issuing WATCH or subsequent commands retrieving their values but before
|
||||
Execute a command immediately, but don't auto-retry on a
|
||||
ConnectionError if we're already WATCHing a variable. Used when
|
||||
issuing WATCH or subsequent commands retrieving their values but before
|
||||
MULTI is called.
|
||||
"""
|
||||
command_name = args[0]
|
||||
conn = self.connection
|
||||
# if this is the first call, we need a connection
|
||||
if not conn:
|
||||
conn = self.connection_pool.get_connection()
|
||||
conn = self.connection_pool.get_connection(command_name, self.shard_hint)
|
||||
self.connection = conn
|
||||
|
||||
return conn.retry.call_with_retry(
|
||||
lambda: self._send_command_parse_response(
|
||||
conn, command_name, *args, **options
|
||||
),
|
||||
lambda error: self._disconnect_reset_raise_on_watching(conn, error),
|
||||
lambda error: self._disconnect_reset_raise(conn, error),
|
||||
)
|
||||
|
||||
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
|
||||
@@ -1441,9 +1315,7 @@ class Pipeline(Redis):
|
||||
self.command_stack.append((args, options))
|
||||
return self
|
||||
|
||||
def _execute_transaction(
|
||||
self, connection: Connection, commands, raise_on_error
|
||||
) -> List:
|
||||
def _execute_transaction(self, connection, commands, raise_on_error) -> List:
|
||||
cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})])
|
||||
all_cmds = connection.pack_commands(
|
||||
[args for args, options in cmds if EMPTY_RESPONSE not in options]
|
||||
@@ -1504,8 +1376,6 @@ class Pipeline(Redis):
|
||||
for r, cmd in zip(response, commands):
|
||||
if not isinstance(r, Exception):
|
||||
args, options = cmd
|
||||
# Remove keys entry, it needs only for cache.
|
||||
options.pop("keys", None)
|
||||
command_name = args[0]
|
||||
if command_name in self.response_callbacks:
|
||||
r = self.response_callbacks[command_name](r, **options)
|
||||
@@ -1537,7 +1407,7 @@ class Pipeline(Redis):
|
||||
def annotate_exception(self, exception, number, command):
|
||||
cmd = " ".join(map(safe_str, command))
|
||||
msg = (
|
||||
f"Command # {number} ({truncate_text(cmd)}) of pipeline "
|
||||
f"Command # {number} ({cmd}) of pipeline "
|
||||
f"caused error: {exception.args[0]}"
|
||||
)
|
||||
exception.args = (msg,) + exception.args[1:]
|
||||
@@ -1563,19 +1433,11 @@ class Pipeline(Redis):
|
||||
if not exist:
|
||||
s.sha = immediate("SCRIPT LOAD", s.script)
|
||||
|
||||
def _disconnect_raise_on_watching(
|
||||
self,
|
||||
conn: AbstractConnection,
|
||||
error: Exception,
|
||||
) -> None:
|
||||
def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None:
|
||||
"""
|
||||
Close the connection, raise an exception if we were watching.
|
||||
|
||||
The supported exceptions are already checked in the
|
||||
retry object so we don't need to do it here.
|
||||
|
||||
After we disconnect the connection, it will try to reconnect and
|
||||
do a health check as part of the send_command logic(on connection level).
|
||||
Close the connection, raise an exception if we were watching,
|
||||
and raise an exception if TimeoutError is not part of retry_on_error,
|
||||
or the error is not a TimeoutError
|
||||
"""
|
||||
conn.disconnect()
|
||||
# if we were watching a variable, the watch is no longer valid
|
||||
@@ -1583,10 +1445,17 @@ class Pipeline(Redis):
|
||||
# indicates the user should retry this transaction.
|
||||
if self.watching:
|
||||
raise WatchError(
|
||||
f"A {type(error).__name__} occurred while watching one or more keys"
|
||||
"A ConnectionError occurred on while watching one or more keys"
|
||||
)
|
||||
# if TimeoutError is not part of retry_on_error, or the error
|
||||
# is not a TimeoutError, raise it
|
||||
if not (
|
||||
TimeoutError in conn.retry_on_error and isinstance(error, TimeoutError)
|
||||
):
|
||||
self.reset()
|
||||
raise error
|
||||
|
||||
def execute(self, raise_on_error: bool = True) -> List[Any]:
|
||||
def execute(self, raise_on_error=True):
|
||||
"""Execute all the commands in the current pipeline"""
|
||||
stack = self.command_stack
|
||||
if not stack and not self.watching:
|
||||
@@ -1600,7 +1469,7 @@ class Pipeline(Redis):
|
||||
|
||||
conn = self.connection
|
||||
if not conn:
|
||||
conn = self.connection_pool.get_connection()
|
||||
conn = self.connection_pool.get_connection("MULTI", self.shard_hint)
|
||||
# assign to self.connection so reset() releases the connection
|
||||
# back to the pool after we're done
|
||||
self.connection = conn
|
||||
@@ -1608,7 +1477,7 @@ class Pipeline(Redis):
|
||||
try:
|
||||
return conn.retry.call_with_retry(
|
||||
lambda: execute(conn, stack, raise_on_error),
|
||||
lambda error: self._disconnect_raise_on_watching(conn, error),
|
||||
lambda error: self._disconnect_raise_reset(conn, error),
|
||||
)
|
||||
finally:
|
||||
self.reset()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@ from .commands import * # noqa
|
||||
from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo
|
||||
|
||||
|
||||
class AbstractBloom:
|
||||
class AbstractBloom(object):
|
||||
"""
|
||||
The client allows to interact with RedisBloom and use all of
|
||||
it's functionality.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from redis.client import NEVER_DECODE
|
||||
from redis.utils import deprecated_function
|
||||
from redis.exceptions import ModuleError
|
||||
from redis.utils import HIREDIS_AVAILABLE, deprecated_function
|
||||
|
||||
BF_RESERVE = "BF.RESERVE"
|
||||
BF_ADD = "BF.ADD"
|
||||
@@ -138,6 +139,9 @@ class BFCommands:
|
||||
This command will return successive (iter, data) pairs until (0, NULL) to indicate completion.
|
||||
For more information see `BF.SCANDUMP <https://redis.io/commands/bf.scandump>`_.
|
||||
""" # noqa
|
||||
if HIREDIS_AVAILABLE:
|
||||
raise ModuleError("This command cannot be used when hiredis is available.")
|
||||
|
||||
params = [key, iter]
|
||||
options = {}
|
||||
options[NEVER_DECODE] = []
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from ..helpers import nativestr
|
||||
|
||||
|
||||
class BFInfo:
|
||||
class BFInfo(object):
|
||||
capacity = None
|
||||
size = None
|
||||
filterNum = None
|
||||
@@ -26,7 +26,7 @@ class BFInfo:
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class CFInfo:
|
||||
class CFInfo(object):
|
||||
size = None
|
||||
bucketNum = None
|
||||
filterNum = None
|
||||
@@ -57,7 +57,7 @@ class CFInfo:
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class CMSInfo:
|
||||
class CMSInfo(object):
|
||||
width = None
|
||||
depth = None
|
||||
count = None
|
||||
@@ -72,7 +72,7 @@ class CMSInfo:
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class TopKInfo:
|
||||
class TopKInfo(object):
|
||||
k = None
|
||||
width = None
|
||||
depth = None
|
||||
@@ -89,7 +89,7 @@ class TopKInfo:
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class TDigestInfo:
|
||||
class TDigestInfo(object):
|
||||
compression = None
|
||||
capacity = None
|
||||
merged_nodes = None
|
||||
|
||||
@@ -7,13 +7,13 @@ from typing import (
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from redis.compat import Literal
|
||||
from redis.crc import key_slot
|
||||
from redis.exceptions import RedisClusterException, RedisError
|
||||
from redis.typing import (
|
||||
@@ -23,7 +23,6 @@ from redis.typing import (
|
||||
KeysT,
|
||||
KeyT,
|
||||
PatternT,
|
||||
ResponseT,
|
||||
)
|
||||
|
||||
from .core import (
|
||||
@@ -31,18 +30,21 @@ from .core import (
|
||||
AsyncACLCommands,
|
||||
AsyncDataAccessCommands,
|
||||
AsyncFunctionCommands,
|
||||
AsyncGearsCommands,
|
||||
AsyncManagementCommands,
|
||||
AsyncModuleCommands,
|
||||
AsyncScriptCommands,
|
||||
DataAccessCommands,
|
||||
FunctionCommands,
|
||||
GearsCommands,
|
||||
ManagementCommands,
|
||||
ModuleCommands,
|
||||
PubSubCommands,
|
||||
ResponseT,
|
||||
ScriptCommands,
|
||||
)
|
||||
from .helpers import list_or_args
|
||||
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
|
||||
from .redismodules import RedisModuleCommands
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio.cluster import TargetNodesT
|
||||
@@ -223,7 +225,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
|
||||
The keys are first split up into slots
|
||||
and then an DEL command is sent for every slot
|
||||
|
||||
Non-existent keys are ignored.
|
||||
Non-existant keys are ignored.
|
||||
Returns the number of keys that were deleted.
|
||||
|
||||
For more information see https://redis.io/commands/del
|
||||
@@ -238,7 +240,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
|
||||
The keys are first split up into slots
|
||||
and then an TOUCH command is sent for every slot
|
||||
|
||||
Non-existent keys are ignored.
|
||||
Non-existant keys are ignored.
|
||||
Returns the number of keys that were touched.
|
||||
|
||||
For more information see https://redis.io/commands/touch
|
||||
@@ -252,7 +254,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
|
||||
The keys are first split up into slots
|
||||
and then an TOUCH command is sent for every slot
|
||||
|
||||
Non-existent keys are ignored.
|
||||
Non-existant keys are ignored.
|
||||
Returns the number of keys that were unlinked.
|
||||
|
||||
For more information see https://redis.io/commands/unlink
|
||||
@@ -593,7 +595,7 @@ class ClusterManagementCommands(ManagementCommands):
|
||||
"CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node
|
||||
)
|
||||
elif state.upper() == "STABLE":
|
||||
raise RedisError('For "stable" state please use cluster_setslot_stable')
|
||||
raise RedisError('For "stable" state please use ' "cluster_setslot_stable")
|
||||
else:
|
||||
raise RedisError(f"Invalid slot state: {state}")
|
||||
|
||||
@@ -691,6 +693,12 @@ class ClusterManagementCommands(ManagementCommands):
|
||||
self.read_from_replicas = False
|
||||
return self.execute_command("READWRITE", target_nodes=target_nodes)
|
||||
|
||||
def gears_refresh_cluster(self, **kwargs) -> ResponseT:
|
||||
"""
|
||||
On an OSS cluster, before executing any gears function, you must call this command. # noqa
|
||||
"""
|
||||
return self.execute_command("REDISGEARS_2.REFRESHCLUSTER", **kwargs)
|
||||
|
||||
|
||||
class AsyncClusterManagementCommands(
|
||||
ClusterManagementCommands, AsyncManagementCommands
|
||||
@@ -866,6 +874,7 @@ class RedisClusterCommands(
|
||||
ClusterDataAccessCommands,
|
||||
ScriptCommands,
|
||||
FunctionCommands,
|
||||
GearsCommands,
|
||||
ModuleCommands,
|
||||
RedisModuleCommands,
|
||||
):
|
||||
@@ -896,8 +905,8 @@ class AsyncRedisClusterCommands(
|
||||
AsyncClusterDataAccessCommands,
|
||||
AsyncScriptCommands,
|
||||
AsyncFunctionCommands,
|
||||
AsyncGearsCommands,
|
||||
AsyncModuleCommands,
|
||||
AsyncRedisModuleCommands,
|
||||
):
|
||||
"""
|
||||
A class for all Redis Cluster commands
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,263 @@
|
||||
import warnings
|
||||
|
||||
from ..helpers import quote_string, random_string, stringify_param_value
|
||||
from .commands import AsyncGraphCommands, GraphCommands
|
||||
from .edge import Edge # noqa
|
||||
from .node import Node # noqa
|
||||
from .path import Path # noqa
|
||||
|
||||
DB_LABELS = "DB.LABELS"
|
||||
DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES"
|
||||
DB_PROPERTYKEYS = "DB.PROPERTYKEYS"
|
||||
|
||||
|
||||
class Graph(GraphCommands):
|
||||
"""
|
||||
Graph, collection of nodes and edges.
|
||||
"""
|
||||
|
||||
def __init__(self, client, name=random_string()):
|
||||
"""
|
||||
Create a new graph.
|
||||
"""
|
||||
warnings.warn(
|
||||
DeprecationWarning(
|
||||
"RedisGraph support is deprecated as of Redis Stack 7.2 \
|
||||
(https://redis.com/blog/redisgraph-eol/)"
|
||||
)
|
||||
)
|
||||
self.NAME = name # Graph key
|
||||
self.client = client
|
||||
self.execute_command = client.execute_command
|
||||
|
||||
self.nodes = {}
|
||||
self.edges = []
|
||||
self._labels = [] # List of node labels.
|
||||
self._properties = [] # List of properties.
|
||||
self._relationship_types = [] # List of relation types.
|
||||
self.version = 0 # Graph version
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.NAME
|
||||
|
||||
def _clear_schema(self):
|
||||
self._labels = []
|
||||
self._properties = []
|
||||
self._relationship_types = []
|
||||
|
||||
def _refresh_schema(self):
|
||||
self._clear_schema()
|
||||
self._refresh_labels()
|
||||
self._refresh_relations()
|
||||
self._refresh_attributes()
|
||||
|
||||
def _refresh_labels(self):
|
||||
lbls = self.labels()
|
||||
|
||||
# Unpack data.
|
||||
self._labels = [l[0] for _, l in enumerate(lbls)]
|
||||
|
||||
def _refresh_relations(self):
|
||||
rels = self.relationship_types()
|
||||
|
||||
# Unpack data.
|
||||
self._relationship_types = [r[0] for _, r in enumerate(rels)]
|
||||
|
||||
def _refresh_attributes(self):
|
||||
props = self.property_keys()
|
||||
|
||||
# Unpack data.
|
||||
self._properties = [p[0] for _, p in enumerate(props)]
|
||||
|
||||
def get_label(self, idx):
|
||||
"""
|
||||
Returns a label by it's index
|
||||
|
||||
Args:
|
||||
|
||||
idx:
|
||||
The index of the label
|
||||
"""
|
||||
try:
|
||||
label = self._labels[idx]
|
||||
except IndexError:
|
||||
# Refresh labels.
|
||||
self._refresh_labels()
|
||||
label = self._labels[idx]
|
||||
return label
|
||||
|
||||
def get_relation(self, idx):
|
||||
"""
|
||||
Returns a relationship type by it's index
|
||||
|
||||
Args:
|
||||
|
||||
idx:
|
||||
The index of the relation
|
||||
"""
|
||||
try:
|
||||
relationship_type = self._relationship_types[idx]
|
||||
except IndexError:
|
||||
# Refresh relationship types.
|
||||
self._refresh_relations()
|
||||
relationship_type = self._relationship_types[idx]
|
||||
return relationship_type
|
||||
|
||||
def get_property(self, idx):
|
||||
"""
|
||||
Returns a property by it's index
|
||||
|
||||
Args:
|
||||
|
||||
idx:
|
||||
The index of the property
|
||||
"""
|
||||
try:
|
||||
p = self._properties[idx]
|
||||
except IndexError:
|
||||
# Refresh properties.
|
||||
self._refresh_attributes()
|
||||
p = self._properties[idx]
|
||||
return p
|
||||
|
||||
def add_node(self, node):
|
||||
"""
|
||||
Adds a node to the graph.
|
||||
"""
|
||||
if node.alias is None:
|
||||
node.alias = random_string()
|
||||
self.nodes[node.alias] = node
|
||||
|
||||
def add_edge(self, edge):
|
||||
"""
|
||||
Adds an edge to the graph.
|
||||
"""
|
||||
if not (self.nodes[edge.src_node.alias] and self.nodes[edge.dest_node.alias]):
|
||||
raise AssertionError("Both edge's end must be in the graph")
|
||||
|
||||
self.edges.append(edge)
|
||||
|
||||
def _build_params_header(self, params):
|
||||
if params is None:
|
||||
return ""
|
||||
if not isinstance(params, dict):
|
||||
raise TypeError("'params' must be a dict")
|
||||
# Header starts with "CYPHER"
|
||||
params_header = "CYPHER "
|
||||
for key, value in params.items():
|
||||
params_header += str(key) + "=" + stringify_param_value(value) + " "
|
||||
return params_header
|
||||
|
||||
# Procedures.
|
||||
def call_procedure(self, procedure, *args, read_only=False, **kwagrs):
|
||||
args = [quote_string(arg) for arg in args]
|
||||
q = f"CALL {procedure}({','.join(args)})"
|
||||
|
||||
y = kwagrs.get("y", None)
|
||||
if y is not None:
|
||||
q += f"YIELD {','.join(y)}"
|
||||
|
||||
return self.query(q, read_only=read_only)
|
||||
|
||||
def labels(self):
|
||||
return self.call_procedure(DB_LABELS, read_only=True).result_set
|
||||
|
||||
def relationship_types(self):
|
||||
return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set
|
||||
|
||||
def property_keys(self):
|
||||
return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set
|
||||
|
||||
|
||||
class AsyncGraph(Graph, AsyncGraphCommands):
|
||||
"""Async version for Graph"""
|
||||
|
||||
async def _refresh_labels(self):
|
||||
lbls = await self.labels()
|
||||
|
||||
# Unpack data.
|
||||
self._labels = [l[0] for _, l in enumerate(lbls)]
|
||||
|
||||
async def _refresh_attributes(self):
|
||||
props = await self.property_keys()
|
||||
|
||||
# Unpack data.
|
||||
self._properties = [p[0] for _, p in enumerate(props)]
|
||||
|
||||
async def _refresh_relations(self):
|
||||
rels = await self.relationship_types()
|
||||
|
||||
# Unpack data.
|
||||
self._relationship_types = [r[0] for _, r in enumerate(rels)]
|
||||
|
||||
async def get_label(self, idx):
|
||||
"""
|
||||
Returns a label by it's index
|
||||
|
||||
Args:
|
||||
|
||||
idx:
|
||||
The index of the label
|
||||
"""
|
||||
try:
|
||||
label = self._labels[idx]
|
||||
except IndexError:
|
||||
# Refresh labels.
|
||||
await self._refresh_labels()
|
||||
label = self._labels[idx]
|
||||
return label
|
||||
|
||||
async def get_property(self, idx):
|
||||
"""
|
||||
Returns a property by it's index
|
||||
|
||||
Args:
|
||||
|
||||
idx:
|
||||
The index of the property
|
||||
"""
|
||||
try:
|
||||
p = self._properties[idx]
|
||||
except IndexError:
|
||||
# Refresh properties.
|
||||
await self._refresh_attributes()
|
||||
p = self._properties[idx]
|
||||
return p
|
||||
|
||||
async def get_relation(self, idx):
|
||||
"""
|
||||
Returns a relationship type by it's index
|
||||
|
||||
Args:
|
||||
|
||||
idx:
|
||||
The index of the relation
|
||||
"""
|
||||
try:
|
||||
relationship_type = self._relationship_types[idx]
|
||||
except IndexError:
|
||||
# Refresh relationship types.
|
||||
await self._refresh_relations()
|
||||
relationship_type = self._relationship_types[idx]
|
||||
return relationship_type
|
||||
|
||||
async def call_procedure(self, procedure, *args, read_only=False, **kwagrs):
|
||||
args = [quote_string(arg) for arg in args]
|
||||
q = f"CALL {procedure}({','.join(args)})"
|
||||
|
||||
y = kwagrs.get("y", None)
|
||||
if y is not None:
|
||||
f"YIELD {','.join(y)}"
|
||||
return await self.query(q, read_only=read_only)
|
||||
|
||||
async def labels(self):
|
||||
return ((await self.call_procedure(DB_LABELS, read_only=True))).result_set
|
||||
|
||||
async def property_keys(self):
|
||||
return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set
|
||||
|
||||
async def relationship_types(self):
|
||||
return (
|
||||
await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True)
|
||||
).result_set
|
||||
@@ -0,0 +1,313 @@
|
||||
from redis import DataError
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from .exceptions import VersionMismatchException
|
||||
from .execution_plan import ExecutionPlan
|
||||
from .query_result import AsyncQueryResult, QueryResult
|
||||
|
||||
PROFILE_CMD = "GRAPH.PROFILE"
|
||||
RO_QUERY_CMD = "GRAPH.RO_QUERY"
|
||||
QUERY_CMD = "GRAPH.QUERY"
|
||||
DELETE_CMD = "GRAPH.DELETE"
|
||||
SLOWLOG_CMD = "GRAPH.SLOWLOG"
|
||||
CONFIG_CMD = "GRAPH.CONFIG"
|
||||
LIST_CMD = "GRAPH.LIST"
|
||||
EXPLAIN_CMD = "GRAPH.EXPLAIN"
|
||||
|
||||
|
||||
class GraphCommands:
|
||||
"""RedisGraph Commands"""
|
||||
|
||||
def commit(self):
|
||||
"""
|
||||
Create entire graph.
|
||||
"""
|
||||
if len(self.nodes) == 0 and len(self.edges) == 0:
|
||||
return None
|
||||
|
||||
query = "CREATE "
|
||||
for _, node in self.nodes.items():
|
||||
query += str(node) + ","
|
||||
|
||||
query += ",".join([str(edge) for edge in self.edges])
|
||||
|
||||
# Discard leading comma.
|
||||
if query[-1] == ",":
|
||||
query = query[:-1]
|
||||
|
||||
return self.query(query)
|
||||
|
||||
def query(self, q, params=None, timeout=None, read_only=False, profile=False):
|
||||
"""
|
||||
Executes a query against the graph.
|
||||
For more information see `GRAPH.QUERY <https://redis.io/commands/graph.query>`_. # noqa
|
||||
|
||||
Args:
|
||||
|
||||
q : str
|
||||
The query.
|
||||
params : dict
|
||||
Query parameters.
|
||||
timeout : int
|
||||
Maximum runtime for read queries in milliseconds.
|
||||
read_only : bool
|
||||
Executes a readonly query if set to True.
|
||||
profile : bool
|
||||
Return details on results produced by and time
|
||||
spent in each operation.
|
||||
"""
|
||||
|
||||
# maintain original 'q'
|
||||
query = q
|
||||
|
||||
# handle query parameters
|
||||
query = self._build_params_header(params) + query
|
||||
|
||||
# construct query command
|
||||
# ask for compact result-set format
|
||||
# specify known graph version
|
||||
if profile:
|
||||
cmd = PROFILE_CMD
|
||||
else:
|
||||
cmd = RO_QUERY_CMD if read_only else QUERY_CMD
|
||||
command = [cmd, self.name, query, "--compact"]
|
||||
|
||||
# include timeout is specified
|
||||
if isinstance(timeout, int):
|
||||
command.extend(["timeout", timeout])
|
||||
elif timeout is not None:
|
||||
raise Exception("Timeout argument must be a positive integer")
|
||||
|
||||
# issue query
|
||||
try:
|
||||
response = self.execute_command(*command)
|
||||
return QueryResult(self, response, profile)
|
||||
except ResponseError as e:
|
||||
if "unknown command" in str(e) and read_only:
|
||||
# `GRAPH.RO_QUERY` is unavailable in older versions.
|
||||
return self.query(q, params, timeout, read_only=False)
|
||||
raise e
|
||||
except VersionMismatchException as e:
|
||||
# client view over the graph schema is out of sync
|
||||
# set client version and refresh local schema
|
||||
self.version = e.version
|
||||
self._refresh_schema()
|
||||
# re-issue query
|
||||
return self.query(q, params, timeout, read_only)
|
||||
|
||||
def merge(self, pattern):
|
||||
"""
|
||||
Merge pattern.
|
||||
"""
|
||||
query = "MERGE "
|
||||
query += str(pattern)
|
||||
|
||||
return self.query(query)
|
||||
|
||||
def delete(self):
|
||||
"""
|
||||
Deletes graph.
|
||||
For more information see `DELETE <https://redis.io/commands/graph.delete>`_. # noqa
|
||||
"""
|
||||
self._clear_schema()
|
||||
return self.execute_command(DELETE_CMD, self.name)
|
||||
|
||||
# declared here, to override the built in redis.db.flush()
|
||||
def flush(self):
|
||||
"""
|
||||
Commit the graph and reset the edges and the nodes to zero length.
|
||||
"""
|
||||
self.commit()
|
||||
self.nodes = {}
|
||||
self.edges = []
|
||||
|
||||
def bulk(self, **kwargs):
|
||||
"""Internal only. Not supported."""
|
||||
raise NotImplementedError(
|
||||
"GRAPH.BULK is internal only. "
|
||||
"Use https://github.com/redisgraph/redisgraph-bulk-loader."
|
||||
)
|
||||
|
||||
def profile(self, query):
|
||||
"""
|
||||
Execute a query and produce an execution plan augmented with metrics
|
||||
for each operation's execution. Return a string representation of a
|
||||
query execution plan, with details on results produced by and time
|
||||
spent in each operation.
|
||||
For more information see `GRAPH.PROFILE <https://redis.io/commands/graph.profile>`_. # noqa
|
||||
"""
|
||||
return self.query(query, profile=True)
|
||||
|
||||
def slowlog(self):
|
||||
"""
|
||||
Get a list containing up to 10 of the slowest queries issued
|
||||
against the given graph ID.
|
||||
For more information see `GRAPH.SLOWLOG <https://redis.io/commands/graph.slowlog>`_. # noqa
|
||||
|
||||
Each item in the list has the following structure:
|
||||
1. A unix timestamp at which the log entry was processed.
|
||||
2. The issued command.
|
||||
3. The issued query.
|
||||
4. The amount of time needed for its execution, in milliseconds.
|
||||
"""
|
||||
return self.execute_command(SLOWLOG_CMD, self.name)
|
||||
|
||||
def config(self, name, value=None, set=False):
|
||||
"""
|
||||
Retrieve or update a RedisGraph configuration.
|
||||
For more information see `https://redis.io/commands/graph.config-get/>`_. # noqa
|
||||
|
||||
Args:
|
||||
|
||||
name : str
|
||||
The name of the configuration
|
||||
value :
|
||||
The value we want to set (can be used only when `set` is on)
|
||||
set : bool
|
||||
Turn on to set a configuration. Default behavior is get.
|
||||
"""
|
||||
params = ["SET" if set else "GET", name]
|
||||
if value is not None:
|
||||
if set:
|
||||
params.append(value)
|
||||
else:
|
||||
raise DataError(
|
||||
"``value`` can be provided only when ``set`` is True"
|
||||
) # noqa
|
||||
return self.execute_command(CONFIG_CMD, *params)
|
||||
|
||||
def list_keys(self):
|
||||
"""
|
||||
Lists all graph keys in the keyspace.
|
||||
For more information see `GRAPH.LIST <https://redis.io/commands/graph.list>`_. # noqa
|
||||
"""
|
||||
return self.execute_command(LIST_CMD)
|
||||
|
||||
def execution_plan(self, query, params=None):
|
||||
"""
|
||||
Get the execution plan for given query,
|
||||
GRAPH.EXPLAIN returns an array of operations.
|
||||
|
||||
Args:
|
||||
query: the query that will be executed
|
||||
params: query parameters
|
||||
"""
|
||||
query = self._build_params_header(params) + query
|
||||
|
||||
plan = self.execute_command(EXPLAIN_CMD, self.name, query)
|
||||
if isinstance(plan[0], bytes):
|
||||
plan = [b.decode() for b in plan]
|
||||
return "\n".join(plan)
|
||||
|
||||
def explain(self, query, params=None):
|
||||
"""
|
||||
Get the execution plan for given query,
|
||||
GRAPH.EXPLAIN returns ExecutionPlan object.
|
||||
For more information see `GRAPH.EXPLAIN <https://redis.io/commands/graph.explain>`_. # noqa
|
||||
|
||||
Args:
|
||||
query: the query that will be executed
|
||||
params: query parameters
|
||||
"""
|
||||
query = self._build_params_header(params) + query
|
||||
|
||||
plan = self.execute_command(EXPLAIN_CMD, self.name, query)
|
||||
return ExecutionPlan(plan)
|
||||
|
||||
|
||||
class AsyncGraphCommands(GraphCommands):
|
||||
async def query(self, q, params=None, timeout=None, read_only=False, profile=False):
|
||||
"""
|
||||
Executes a query against the graph.
|
||||
For more information see `GRAPH.QUERY <https://oss.redis.com/redisgraph/master/commands/#graphquery>`_. # noqa
|
||||
|
||||
Args:
|
||||
|
||||
q : str
|
||||
The query.
|
||||
params : dict
|
||||
Query parameters.
|
||||
timeout : int
|
||||
Maximum runtime for read queries in milliseconds.
|
||||
read_only : bool
|
||||
Executes a readonly query if set to True.
|
||||
profile : bool
|
||||
Return details on results produced by and time
|
||||
spent in each operation.
|
||||
"""
|
||||
|
||||
# maintain original 'q'
|
||||
query = q
|
||||
|
||||
# handle query parameters
|
||||
query = self._build_params_header(params) + query
|
||||
|
||||
# construct query command
|
||||
# ask for compact result-set format
|
||||
# specify known graph version
|
||||
if profile:
|
||||
cmd = PROFILE_CMD
|
||||
else:
|
||||
cmd = RO_QUERY_CMD if read_only else QUERY_CMD
|
||||
command = [cmd, self.name, query, "--compact"]
|
||||
|
||||
# include timeout is specified
|
||||
if isinstance(timeout, int):
|
||||
command.extend(["timeout", timeout])
|
||||
elif timeout is not None:
|
||||
raise Exception("Timeout argument must be a positive integer")
|
||||
|
||||
# issue query
|
||||
try:
|
||||
response = await self.execute_command(*command)
|
||||
return await AsyncQueryResult().initialize(self, response, profile)
|
||||
except ResponseError as e:
|
||||
if "unknown command" in str(e) and read_only:
|
||||
# `GRAPH.RO_QUERY` is unavailable in older versions.
|
||||
return await self.query(q, params, timeout, read_only=False)
|
||||
raise e
|
||||
except VersionMismatchException as e:
|
||||
# client view over the graph schema is out of sync
|
||||
# set client version and refresh local schema
|
||||
self.version = e.version
|
||||
self._refresh_schema()
|
||||
# re-issue query
|
||||
return await self.query(q, params, timeout, read_only)
|
||||
|
||||
async def execution_plan(self, query, params=None):
|
||||
"""
|
||||
Get the execution plan for given query,
|
||||
GRAPH.EXPLAIN returns an array of operations.
|
||||
|
||||
Args:
|
||||
query: the query that will be executed
|
||||
params: query parameters
|
||||
"""
|
||||
query = self._build_params_header(params) + query
|
||||
|
||||
plan = await self.execute_command(EXPLAIN_CMD, self.name, query)
|
||||
if isinstance(plan[0], bytes):
|
||||
plan = [b.decode() for b in plan]
|
||||
return "\n".join(plan)
|
||||
|
||||
async def explain(self, query, params=None):
|
||||
"""
|
||||
Get the execution plan for given query,
|
||||
GRAPH.EXPLAIN returns ExecutionPlan object.
|
||||
|
||||
Args:
|
||||
query: the query that will be executed
|
||||
params: query parameters
|
||||
"""
|
||||
query = self._build_params_header(params) + query
|
||||
|
||||
plan = await self.execute_command(EXPLAIN_CMD, self.name, query)
|
||||
return ExecutionPlan(plan)
|
||||
|
||||
async def flush(self):
|
||||
"""
|
||||
Commit the graph and reset the edges and the nodes to zero length.
|
||||
"""
|
||||
await self.commit()
|
||||
self.nodes = {}
|
||||
self.edges = []
|
||||
@@ -0,0 +1,91 @@
|
||||
from ..helpers import quote_string
|
||||
from .node import Node
|
||||
|
||||
|
||||
class Edge:
|
||||
"""
|
||||
An edge connecting two nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, src_node, relation, dest_node, edge_id=None, properties=None):
|
||||
"""
|
||||
Create a new edge.
|
||||
"""
|
||||
if src_node is None or dest_node is None:
|
||||
# NOTE(bors-42): It makes sense to change AssertionError to
|
||||
# ValueError here
|
||||
raise AssertionError("Both src_node & dest_node must be provided")
|
||||
|
||||
self.id = edge_id
|
||||
self.relation = relation or ""
|
||||
self.properties = properties or {}
|
||||
self.src_node = src_node
|
||||
self.dest_node = dest_node
|
||||
|
||||
def to_string(self):
|
||||
res = ""
|
||||
if self.properties:
|
||||
props = ",".join(
|
||||
key + ":" + str(quote_string(val))
|
||||
for key, val in sorted(self.properties.items())
|
||||
)
|
||||
res += "{" + props + "}"
|
||||
|
||||
return res
|
||||
|
||||
def __str__(self):
|
||||
# Source node.
|
||||
if isinstance(self.src_node, Node):
|
||||
res = str(self.src_node)
|
||||
else:
|
||||
res = "()"
|
||||
|
||||
# Edge
|
||||
res += "-["
|
||||
if self.relation:
|
||||
res += ":" + self.relation
|
||||
if self.properties:
|
||||
props = ",".join(
|
||||
key + ":" + str(quote_string(val))
|
||||
for key, val in sorted(self.properties.items())
|
||||
)
|
||||
res += "{" + props + "}"
|
||||
res += "]->"
|
||||
|
||||
# Dest node.
|
||||
if isinstance(self.dest_node, Node):
|
||||
res += str(self.dest_node)
|
||||
else:
|
||||
res += "()"
|
||||
|
||||
return res
|
||||
|
||||
def __eq__(self, rhs):
|
||||
# Type checking
|
||||
if not isinstance(rhs, Edge):
|
||||
return False
|
||||
|
||||
# Quick positive check, if both IDs are set.
|
||||
if self.id is not None and rhs.id is not None and self.id == rhs.id:
|
||||
return True
|
||||
|
||||
# Source and destination nodes should match.
|
||||
if self.src_node != rhs.src_node:
|
||||
return False
|
||||
|
||||
if self.dest_node != rhs.dest_node:
|
||||
return False
|
||||
|
||||
# Relation should match.
|
||||
if self.relation != rhs.relation:
|
||||
return False
|
||||
|
||||
# Quick check for number of properties.
|
||||
if len(self.properties) != len(rhs.properties):
|
||||
return False
|
||||
|
||||
# Compare properties.
|
||||
if self.properties != rhs.properties:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,3 @@
|
||||
class VersionMismatchException(Exception):
|
||||
def __init__(self, version):
|
||||
self.version = version
|
||||
@@ -0,0 +1,211 @@
|
||||
import re
|
||||
|
||||
|
||||
class ProfileStats:
|
||||
"""
|
||||
ProfileStats, runtime execution statistics of operation.
|
||||
"""
|
||||
|
||||
def __init__(self, records_produced, execution_time):
|
||||
self.records_produced = records_produced
|
||||
self.execution_time = execution_time
|
||||
|
||||
|
||||
class Operation:
|
||||
"""
|
||||
Operation, single operation within execution plan.
|
||||
"""
|
||||
|
||||
def __init__(self, name, args=None, profile_stats=None):
|
||||
"""
|
||||
Create a new operation.
|
||||
|
||||
Args:
|
||||
name: string that represents the name of the operation
|
||||
args: operation arguments
|
||||
profile_stats: profile statistics
|
||||
"""
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.profile_stats = profile_stats
|
||||
self.children = []
|
||||
|
||||
def append_child(self, child):
|
||||
if not isinstance(child, Operation) or self is child:
|
||||
raise Exception("child must be Operation")
|
||||
|
||||
self.children.append(child)
|
||||
return self
|
||||
|
||||
def child_count(self):
|
||||
return len(self.children)
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if not isinstance(o, Operation):
|
||||
return False
|
||||
|
||||
return self.name == o.name and self.args == o.args
|
||||
|
||||
def __str__(self) -> str:
|
||||
args_str = "" if self.args is None else " | " + self.args
|
||||
return f"{self.name}{args_str}"
|
||||
|
||||
|
||||
class ExecutionPlan:
|
||||
"""
|
||||
ExecutionPlan, collection of operations.
|
||||
"""
|
||||
|
||||
def __init__(self, plan):
|
||||
"""
|
||||
Create a new execution plan.
|
||||
|
||||
Args:
|
||||
plan: array of strings that represents the collection operations
|
||||
the output from GRAPH.EXPLAIN
|
||||
"""
|
||||
if not isinstance(plan, list):
|
||||
raise Exception("plan must be an array")
|
||||
|
||||
if isinstance(plan[0], bytes):
|
||||
plan = [b.decode() for b in plan]
|
||||
|
||||
self.plan = plan
|
||||
self.structured_plan = self._operation_tree()
|
||||
|
||||
def _compare_operations(self, root_a, root_b):
|
||||
"""
|
||||
Compare execution plan operation tree
|
||||
|
||||
Return: True if operation trees are equal, False otherwise
|
||||
"""
|
||||
|
||||
# compare current root
|
||||
if root_a != root_b:
|
||||
return False
|
||||
|
||||
# make sure root have the same number of children
|
||||
if root_a.child_count() != root_b.child_count():
|
||||
return False
|
||||
|
||||
# recursively compare children
|
||||
for i in range(root_a.child_count()):
|
||||
if not self._compare_operations(root_a.children[i], root_b.children[i]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __str__(self) -> str:
|
||||
def aggraget_str(str_children):
|
||||
return "\n".join(
|
||||
[
|
||||
" " + line
|
||||
for str_child in str_children
|
||||
for line in str_child.splitlines()
|
||||
]
|
||||
)
|
||||
|
||||
def combine_str(x, y):
|
||||
return f"{x}\n{y}"
|
||||
|
||||
return self._operation_traverse(
|
||||
self.structured_plan, str, aggraget_str, combine_str
|
||||
)
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
"""Compares two execution plans
|
||||
|
||||
Return: True if the two plans are equal False otherwise
|
||||
"""
|
||||
# make sure 'o' is an execution-plan
|
||||
if not isinstance(o, ExecutionPlan):
|
||||
return False
|
||||
|
||||
# get root for both plans
|
||||
root_a = self.structured_plan
|
||||
root_b = o.structured_plan
|
||||
|
||||
# compare execution trees
|
||||
return self._compare_operations(root_a, root_b)
|
||||
|
||||
def _operation_traverse(self, op, op_f, aggregate_f, combine_f):
|
||||
"""
|
||||
Traverse operation tree recursively applying functions
|
||||
|
||||
Args:
|
||||
op: operation to traverse
|
||||
op_f: function applied for each operation
|
||||
aggregate_f: aggregation function applied for all children of a single operation
|
||||
combine_f: combine function applied for the operation result and the children result
|
||||
""" # noqa
|
||||
# apply op_f for each operation
|
||||
op_res = op_f(op)
|
||||
if len(op.children) == 0:
|
||||
return op_res # no children return
|
||||
else:
|
||||
# apply _operation_traverse recursively
|
||||
children = [
|
||||
self._operation_traverse(child, op_f, aggregate_f, combine_f)
|
||||
for child in op.children
|
||||
]
|
||||
# combine the operation result with the children aggregated result
|
||||
return combine_f(op_res, aggregate_f(children))
|
||||
|
||||
def _operation_tree(self):
|
||||
"""Build the operation tree from the string representation"""
|
||||
|
||||
# initial state
|
||||
i = 0
|
||||
level = 0
|
||||
stack = []
|
||||
current = None
|
||||
|
||||
def _create_operation(args):
|
||||
profile_stats = None
|
||||
name = args[0].strip()
|
||||
args.pop(0)
|
||||
if len(args) > 0 and "Records produced" in args[-1]:
|
||||
records_produced = int(
|
||||
re.search("Records produced: (\\d+)", args[-1]).group(1)
|
||||
)
|
||||
execution_time = float(
|
||||
re.search("Execution time: (\\d+.\\d+) ms", args[-1]).group(1)
|
||||
)
|
||||
profile_stats = ProfileStats(records_produced, execution_time)
|
||||
args.pop(-1)
|
||||
return Operation(
|
||||
name, None if len(args) == 0 else args[0].strip(), profile_stats
|
||||
)
|
||||
|
||||
# iterate plan operations
|
||||
while i < len(self.plan):
|
||||
current_op = self.plan[i]
|
||||
op_level = current_op.count(" ")
|
||||
if op_level == level:
|
||||
# if the operation level equal to the current level
|
||||
# set the current operation and move next
|
||||
child = _create_operation(current_op.split("|"))
|
||||
if current:
|
||||
current = stack.pop()
|
||||
current.append_child(child)
|
||||
current = child
|
||||
i += 1
|
||||
elif op_level == level + 1:
|
||||
# if the operation is child of the current operation
|
||||
# add it as child and set as current operation
|
||||
child = _create_operation(current_op.split("|"))
|
||||
current.append_child(child)
|
||||
stack.append(current)
|
||||
current = child
|
||||
level += 1
|
||||
i += 1
|
||||
elif op_level < level:
|
||||
# if the operation is not child of current operation
|
||||
# go back to it's parent operation
|
||||
levels_back = level - op_level + 1
|
||||
for _ in range(levels_back):
|
||||
current = stack.pop()
|
||||
level -= levels_back
|
||||
else:
|
||||
raise Exception("corrupted plan")
|
||||
return stack[0]
|
||||
@@ -0,0 +1,88 @@
|
||||
from ..helpers import quote_string
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
A node within the graph.
|
||||
"""
|
||||
|
||||
def __init__(self, node_id=None, alias=None, label=None, properties=None):
|
||||
"""
|
||||
Create a new node.
|
||||
"""
|
||||
self.id = node_id
|
||||
self.alias = alias
|
||||
if isinstance(label, list):
|
||||
label = [inner_label for inner_label in label if inner_label != ""]
|
||||
|
||||
if (
|
||||
label is None
|
||||
or label == ""
|
||||
or (isinstance(label, list) and len(label) == 0)
|
||||
):
|
||||
self.label = None
|
||||
self.labels = None
|
||||
elif isinstance(label, str):
|
||||
self.label = label
|
||||
self.labels = [label]
|
||||
elif isinstance(label, list) and all(
|
||||
[isinstance(inner_label, str) for inner_label in label]
|
||||
):
|
||||
self.label = label[0]
|
||||
self.labels = label
|
||||
else:
|
||||
raise AssertionError(
|
||||
"label should be either None, string or a list of strings"
|
||||
)
|
||||
|
||||
self.properties = properties or {}
|
||||
|
||||
def to_string(self):
|
||||
res = ""
|
||||
if self.properties:
|
||||
props = ",".join(
|
||||
key + ":" + str(quote_string(val))
|
||||
for key, val in sorted(self.properties.items())
|
||||
)
|
||||
res += "{" + props + "}"
|
||||
|
||||
return res
|
||||
|
||||
def __str__(self):
|
||||
res = "("
|
||||
if self.alias:
|
||||
res += self.alias
|
||||
if self.labels:
|
||||
res += ":" + ":".join(self.labels)
|
||||
if self.properties:
|
||||
props = ",".join(
|
||||
key + ":" + str(quote_string(val))
|
||||
for key, val in sorted(self.properties.items())
|
||||
)
|
||||
res += "{" + props + "}"
|
||||
res += ")"
|
||||
|
||||
return res
|
||||
|
||||
def __eq__(self, rhs):
|
||||
# Type checking
|
||||
if not isinstance(rhs, Node):
|
||||
return False
|
||||
|
||||
# Quick positive check, if both IDs are set.
|
||||
if self.id is not None and rhs.id is not None and self.id != rhs.id:
|
||||
return False
|
||||
|
||||
# Label should match.
|
||||
if self.label != rhs.label:
|
||||
return False
|
||||
|
||||
# Quick check for number of properties.
|
||||
if len(self.properties) != len(rhs.properties):
|
||||
return False
|
||||
|
||||
# Compare properties.
|
||||
if self.properties != rhs.properties:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,78 @@
|
||||
from .edge import Edge
|
||||
from .node import Node
|
||||
|
||||
|
||||
class Path:
|
||||
def __init__(self, nodes, edges):
|
||||
if not (isinstance(nodes, list) and isinstance(edges, list)):
|
||||
raise TypeError("nodes and edges must be list")
|
||||
|
||||
self._nodes = nodes
|
||||
self._edges = edges
|
||||
self.append_type = Node
|
||||
|
||||
@classmethod
|
||||
def new_empty_path(cls):
|
||||
return cls([], [])
|
||||
|
||||
def nodes(self):
|
||||
return self._nodes
|
||||
|
||||
def edges(self):
|
||||
return self._edges
|
||||
|
||||
def get_node(self, index):
|
||||
return self._nodes[index]
|
||||
|
||||
def get_relationship(self, index):
|
||||
return self._edges[index]
|
||||
|
||||
def first_node(self):
|
||||
return self._nodes[0]
|
||||
|
||||
def last_node(self):
|
||||
return self._nodes[-1]
|
||||
|
||||
def edge_count(self):
|
||||
return len(self._edges)
|
||||
|
||||
def nodes_count(self):
|
||||
return len(self._nodes)
|
||||
|
||||
def add_node(self, node):
|
||||
if not isinstance(node, self.append_type):
|
||||
raise AssertionError("Add Edge before adding Node")
|
||||
self._nodes.append(node)
|
||||
self.append_type = Edge
|
||||
return self
|
||||
|
||||
def add_edge(self, edge):
|
||||
if not isinstance(edge, self.append_type):
|
||||
raise AssertionError("Add Node before adding Edge")
|
||||
self._edges.append(edge)
|
||||
self.append_type = Node
|
||||
return self
|
||||
|
||||
def __eq__(self, other):
|
||||
# Type checking
|
||||
if not isinstance(other, Path):
|
||||
return False
|
||||
|
||||
return self.nodes() == other.nodes() and self.edges() == other.edges()
|
||||
|
||||
def __str__(self):
|
||||
res = "<"
|
||||
edge_count = self.edge_count()
|
||||
for i in range(0, edge_count):
|
||||
node_id = self.get_node(i).id
|
||||
res += "(" + str(node_id) + ")"
|
||||
edge = self.get_relationship(i)
|
||||
res += (
|
||||
"-[" + str(int(edge.id)) + "]->"
|
||||
if edge.src_node == node_id
|
||||
else "<-[" + str(int(edge.id)) + "]-"
|
||||
)
|
||||
node_id = self.get_node(edge_count).id
|
||||
res += "(" + str(node_id) + ")"
|
||||
res += ">"
|
||||
return res
|
||||
@@ -0,0 +1,573 @@
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from distutils.util import strtobool
|
||||
|
||||
# from prettytable import PrettyTable
|
||||
from redis import ResponseError
|
||||
|
||||
from .edge import Edge
|
||||
from .exceptions import VersionMismatchException
|
||||
from .node import Node
|
||||
from .path import Path
|
||||
|
||||
LABELS_ADDED = "Labels added"
|
||||
LABELS_REMOVED = "Labels removed"
|
||||
NODES_CREATED = "Nodes created"
|
||||
NODES_DELETED = "Nodes deleted"
|
||||
RELATIONSHIPS_DELETED = "Relationships deleted"
|
||||
PROPERTIES_SET = "Properties set"
|
||||
PROPERTIES_REMOVED = "Properties removed"
|
||||
RELATIONSHIPS_CREATED = "Relationships created"
|
||||
INDICES_CREATED = "Indices created"
|
||||
INDICES_DELETED = "Indices deleted"
|
||||
CACHED_EXECUTION = "Cached execution"
|
||||
INTERNAL_EXECUTION_TIME = "internal execution time"
|
||||
|
||||
STATS = [
|
||||
LABELS_ADDED,
|
||||
LABELS_REMOVED,
|
||||
NODES_CREATED,
|
||||
PROPERTIES_SET,
|
||||
PROPERTIES_REMOVED,
|
||||
RELATIONSHIPS_CREATED,
|
||||
NODES_DELETED,
|
||||
RELATIONSHIPS_DELETED,
|
||||
INDICES_CREATED,
|
||||
INDICES_DELETED,
|
||||
CACHED_EXECUTION,
|
||||
INTERNAL_EXECUTION_TIME,
|
||||
]
|
||||
|
||||
|
||||
class ResultSetColumnTypes:
|
||||
COLUMN_UNKNOWN = 0
|
||||
COLUMN_SCALAR = 1
|
||||
COLUMN_NODE = 2 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
|
||||
COLUMN_RELATION = 3 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
|
||||
|
||||
|
||||
class ResultSetScalarTypes:
|
||||
VALUE_UNKNOWN = 0
|
||||
VALUE_NULL = 1
|
||||
VALUE_STRING = 2
|
||||
VALUE_INTEGER = 3
|
||||
VALUE_BOOLEAN = 4
|
||||
VALUE_DOUBLE = 5
|
||||
VALUE_ARRAY = 6
|
||||
VALUE_EDGE = 7
|
||||
VALUE_NODE = 8
|
||||
VALUE_PATH = 9
|
||||
VALUE_MAP = 10
|
||||
VALUE_POINT = 11
|
||||
|
||||
|
||||
class QueryResult:
|
||||
def __init__(self, graph, response, profile=False):
|
||||
"""
|
||||
A class that represents a result of the query operation.
|
||||
|
||||
Args:
|
||||
|
||||
graph:
|
||||
The graph on which the query was executed.
|
||||
response:
|
||||
The response from the server.
|
||||
profile:
|
||||
A boolean indicating if the query command was "GRAPH.PROFILE"
|
||||
"""
|
||||
self.graph = graph
|
||||
self.header = []
|
||||
self.result_set = []
|
||||
|
||||
# in case of an error an exception will be raised
|
||||
self._check_for_errors(response)
|
||||
|
||||
if len(response) == 1:
|
||||
self.parse_statistics(response[0])
|
||||
elif profile:
|
||||
self.parse_profile(response)
|
||||
else:
|
||||
# start by parsing statistics, matches the one we have
|
||||
self.parse_statistics(response[-1]) # Last element.
|
||||
self.parse_results(response)
|
||||
|
||||
def _check_for_errors(self, response):
|
||||
"""
|
||||
Check if the response contains an error.
|
||||
"""
|
||||
if isinstance(response[0], ResponseError):
|
||||
error = response[0]
|
||||
if str(error) == "version mismatch":
|
||||
version = response[1]
|
||||
error = VersionMismatchException(version)
|
||||
raise error
|
||||
|
||||
# If we encountered a run-time error, the last response
|
||||
# element will be an exception
|
||||
if isinstance(response[-1], ResponseError):
|
||||
raise response[-1]
|
||||
|
||||
def parse_results(self, raw_result_set):
|
||||
"""
|
||||
Parse the query execution result returned from the server.
|
||||
"""
|
||||
self.header = self.parse_header(raw_result_set)
|
||||
|
||||
# Empty header.
|
||||
if len(self.header) == 0:
|
||||
return
|
||||
|
||||
self.result_set = self.parse_records(raw_result_set)
|
||||
|
||||
def parse_statistics(self, raw_statistics):
|
||||
"""
|
||||
Parse the statistics returned in the response.
|
||||
"""
|
||||
self.statistics = {}
|
||||
|
||||
# decode statistics
|
||||
for idx, stat in enumerate(raw_statistics):
|
||||
if isinstance(stat, bytes):
|
||||
raw_statistics[idx] = stat.decode()
|
||||
|
||||
for s in STATS:
|
||||
v = self._get_value(s, raw_statistics)
|
||||
if v is not None:
|
||||
self.statistics[s] = v
|
||||
|
||||
def parse_header(self, raw_result_set):
|
||||
"""
|
||||
Parse the header of the result.
|
||||
"""
|
||||
# An array of column name/column type pairs.
|
||||
header = raw_result_set[0]
|
||||
return header
|
||||
|
||||
def parse_records(self, raw_result_set):
|
||||
"""
|
||||
Parses the result set and returns a list of records.
|
||||
"""
|
||||
records = [
|
||||
[
|
||||
self.parse_record_types[self.header[idx][0]](cell)
|
||||
for idx, cell in enumerate(row)
|
||||
]
|
||||
for row in raw_result_set[1]
|
||||
]
|
||||
|
||||
return records
|
||||
|
||||
def parse_entity_properties(self, props):
|
||||
"""
|
||||
Parse node / edge properties.
|
||||
"""
|
||||
# [[name, value type, value] X N]
|
||||
properties = {}
|
||||
for prop in props:
|
||||
prop_name = self.graph.get_property(prop[0])
|
||||
prop_value = self.parse_scalar(prop[1:])
|
||||
properties[prop_name] = prop_value
|
||||
|
||||
return properties
|
||||
|
||||
def parse_string(self, cell):
|
||||
"""
|
||||
Parse the cell as a string.
|
||||
"""
|
||||
if isinstance(cell, bytes):
|
||||
return cell.decode()
|
||||
elif not isinstance(cell, str):
|
||||
return str(cell)
|
||||
else:
|
||||
return cell
|
||||
|
||||
def parse_node(self, cell):
|
||||
"""
|
||||
Parse the cell to a node.
|
||||
"""
|
||||
# Node ID (integer),
|
||||
# [label string offset (integer)],
|
||||
# [[name, value type, value] X N]
|
||||
|
||||
node_id = int(cell[0])
|
||||
labels = None
|
||||
if len(cell[1]) > 0:
|
||||
labels = []
|
||||
for inner_label in cell[1]:
|
||||
labels.append(self.graph.get_label(inner_label))
|
||||
properties = self.parse_entity_properties(cell[2])
|
||||
return Node(node_id=node_id, label=labels, properties=properties)
|
||||
|
||||
def parse_edge(self, cell):
|
||||
"""
|
||||
Parse the cell to an edge.
|
||||
"""
|
||||
# Edge ID (integer),
|
||||
# reltype string offset (integer),
|
||||
# src node ID offset (integer),
|
||||
# dest node ID offset (integer),
|
||||
# [[name, value, value type] X N]
|
||||
|
||||
edge_id = int(cell[0])
|
||||
relation = self.graph.get_relation(cell[1])
|
||||
src_node_id = int(cell[2])
|
||||
dest_node_id = int(cell[3])
|
||||
properties = self.parse_entity_properties(cell[4])
|
||||
return Edge(
|
||||
src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties
|
||||
)
|
||||
|
||||
def parse_path(self, cell):
|
||||
"""
|
||||
Parse the cell to a path.
|
||||
"""
|
||||
nodes = self.parse_scalar(cell[0])
|
||||
edges = self.parse_scalar(cell[1])
|
||||
return Path(nodes, edges)
|
||||
|
||||
def parse_map(self, cell):
|
||||
"""
|
||||
Parse the cell as a map.
|
||||
"""
|
||||
m = OrderedDict()
|
||||
n_entries = len(cell)
|
||||
|
||||
# A map is an array of key value pairs.
|
||||
# 1. key (string)
|
||||
# 2. array: (value type, value)
|
||||
for i in range(0, n_entries, 2):
|
||||
key = self.parse_string(cell[i])
|
||||
m[key] = self.parse_scalar(cell[i + 1])
|
||||
|
||||
return m
|
||||
|
||||
def parse_point(self, cell):
|
||||
"""
|
||||
Parse the cell to point.
|
||||
"""
|
||||
p = {}
|
||||
# A point is received an array of the form: [latitude, longitude]
|
||||
# It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa
|
||||
p["latitude"] = float(cell[0])
|
||||
p["longitude"] = float(cell[1])
|
||||
return p
|
||||
|
||||
def parse_null(self, cell):
|
||||
"""
|
||||
Parse a null value.
|
||||
"""
|
||||
return None
|
||||
|
||||
def parse_integer(self, cell):
|
||||
"""
|
||||
Parse the integer value from the cell.
|
||||
"""
|
||||
return int(cell)
|
||||
|
||||
def parse_boolean(self, value):
|
||||
"""
|
||||
Parse the cell value as a boolean.
|
||||
"""
|
||||
value = value.decode() if isinstance(value, bytes) else value
|
||||
try:
|
||||
scalar = True if strtobool(value) else False
|
||||
except ValueError:
|
||||
sys.stderr.write("unknown boolean type\n")
|
||||
scalar = None
|
||||
return scalar
|
||||
|
||||
def parse_double(self, cell):
|
||||
"""
|
||||
Parse the cell as a double.
|
||||
"""
|
||||
return float(cell)
|
||||
|
||||
def parse_array(self, value):
|
||||
"""
|
||||
Parse an array of values.
|
||||
"""
|
||||
scalar = [self.parse_scalar(value[i]) for i in range(len(value))]
|
||||
return scalar
|
||||
|
||||
def parse_unknown(self, cell):
|
||||
"""
|
||||
Parse a cell of unknown type.
|
||||
"""
|
||||
sys.stderr.write("Unknown type\n")
|
||||
return None
|
||||
|
||||
def parse_scalar(self, cell):
|
||||
"""
|
||||
Parse a scalar value from a cell in the result set.
|
||||
"""
|
||||
scalar_type = int(cell[0])
|
||||
value = cell[1]
|
||||
scalar = self.parse_scalar_types[scalar_type](value)
|
||||
|
||||
return scalar
|
||||
|
||||
def parse_profile(self, response):
|
||||
self.result_set = [x[0 : x.index(",")].strip() for x in response]
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.result_set) == 0
|
||||
|
||||
@staticmethod
|
||||
def _get_value(prop, statistics):
|
||||
for stat in statistics:
|
||||
if prop in stat:
|
||||
return float(stat.split(": ")[1].split(" ")[0])
|
||||
|
||||
return None
|
||||
|
||||
def _get_stat(self, stat):
|
||||
return self.statistics[stat] if stat in self.statistics else 0
|
||||
|
||||
@property
|
||||
def labels_added(self):
|
||||
"""Returns the number of labels added in the query"""
|
||||
return self._get_stat(LABELS_ADDED)
|
||||
|
||||
@property
|
||||
def labels_removed(self):
|
||||
"""Returns the number of labels removed in the query"""
|
||||
return self._get_stat(LABELS_REMOVED)
|
||||
|
||||
@property
|
||||
def nodes_created(self):
|
||||
"""Returns the number of nodes created in the query"""
|
||||
return self._get_stat(NODES_CREATED)
|
||||
|
||||
@property
|
||||
def nodes_deleted(self):
|
||||
"""Returns the number of nodes deleted in the query"""
|
||||
return self._get_stat(NODES_DELETED)
|
||||
|
||||
@property
|
||||
def properties_set(self):
|
||||
"""Returns the number of properties set in the query"""
|
||||
return self._get_stat(PROPERTIES_SET)
|
||||
|
||||
@property
|
||||
def properties_removed(self):
|
||||
"""Returns the number of properties removed in the query"""
|
||||
return self._get_stat(PROPERTIES_REMOVED)
|
||||
|
||||
@property
|
||||
def relationships_created(self):
|
||||
"""Returns the number of relationships created in the query"""
|
||||
return self._get_stat(RELATIONSHIPS_CREATED)
|
||||
|
||||
@property
|
||||
def relationships_deleted(self):
|
||||
"""Returns the number of relationships deleted in the query"""
|
||||
return self._get_stat(RELATIONSHIPS_DELETED)
|
||||
|
||||
@property
|
||||
def indices_created(self):
|
||||
"""Returns the number of indices created in the query"""
|
||||
return self._get_stat(INDICES_CREATED)
|
||||
|
||||
@property
|
||||
def indices_deleted(self):
|
||||
"""Returns the number of indices deleted in the query"""
|
||||
return self._get_stat(INDICES_DELETED)
|
||||
|
||||
@property
|
||||
def cached_execution(self):
|
||||
"""Returns whether or not the query execution plan was cached"""
|
||||
return self._get_stat(CACHED_EXECUTION) == 1
|
||||
|
||||
@property
|
||||
def run_time_ms(self):
|
||||
"""Returns the server execution time of the query"""
|
||||
return self._get_stat(INTERNAL_EXECUTION_TIME)
|
||||
|
||||
@property
|
||||
def parse_scalar_types(self):
|
||||
return {
|
||||
ResultSetScalarTypes.VALUE_NULL: self.parse_null,
|
||||
ResultSetScalarTypes.VALUE_STRING: self.parse_string,
|
||||
ResultSetScalarTypes.VALUE_INTEGER: self.parse_integer,
|
||||
ResultSetScalarTypes.VALUE_BOOLEAN: self.parse_boolean,
|
||||
ResultSetScalarTypes.VALUE_DOUBLE: self.parse_double,
|
||||
ResultSetScalarTypes.VALUE_ARRAY: self.parse_array,
|
||||
ResultSetScalarTypes.VALUE_NODE: self.parse_node,
|
||||
ResultSetScalarTypes.VALUE_EDGE: self.parse_edge,
|
||||
ResultSetScalarTypes.VALUE_PATH: self.parse_path,
|
||||
ResultSetScalarTypes.VALUE_MAP: self.parse_map,
|
||||
ResultSetScalarTypes.VALUE_POINT: self.parse_point,
|
||||
ResultSetScalarTypes.VALUE_UNKNOWN: self.parse_unknown,
|
||||
}
|
||||
|
||||
@property
|
||||
def parse_record_types(self):
|
||||
return {
|
||||
ResultSetColumnTypes.COLUMN_SCALAR: self.parse_scalar,
|
||||
ResultSetColumnTypes.COLUMN_NODE: self.parse_node,
|
||||
ResultSetColumnTypes.COLUMN_RELATION: self.parse_edge,
|
||||
ResultSetColumnTypes.COLUMN_UNKNOWN: self.parse_unknown,
|
||||
}
|
||||
|
||||
|
||||
class AsyncQueryResult(QueryResult):
|
||||
"""
|
||||
Async version for the QueryResult class - a class that
|
||||
represents a result of the query operation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
To init the class you must call self.initialize()
|
||||
"""
|
||||
pass
|
||||
|
||||
async def initialize(self, graph, response, profile=False):
|
||||
"""
|
||||
Initializes the class.
|
||||
Args:
|
||||
|
||||
graph:
|
||||
The graph on which the query was executed.
|
||||
response:
|
||||
The response from the server.
|
||||
profile:
|
||||
A boolean indicating if the query command was "GRAPH.PROFILE"
|
||||
"""
|
||||
self.graph = graph
|
||||
self.header = []
|
||||
self.result_set = []
|
||||
|
||||
# in case of an error an exception will be raised
|
||||
self._check_for_errors(response)
|
||||
|
||||
if len(response) == 1:
|
||||
self.parse_statistics(response[0])
|
||||
elif profile:
|
||||
self.parse_profile(response)
|
||||
else:
|
||||
# start by parsing statistics, matches the one we have
|
||||
self.parse_statistics(response[-1]) # Last element.
|
||||
await self.parse_results(response)
|
||||
|
||||
return self
|
||||
|
||||
async def parse_node(self, cell):
|
||||
"""
|
||||
Parses a node from the cell.
|
||||
"""
|
||||
# Node ID (integer),
|
||||
# [label string offset (integer)],
|
||||
# [[name, value type, value] X N]
|
||||
|
||||
labels = None
|
||||
if len(cell[1]) > 0:
|
||||
labels = []
|
||||
for inner_label in cell[1]:
|
||||
labels.append(await self.graph.get_label(inner_label))
|
||||
properties = await self.parse_entity_properties(cell[2])
|
||||
node_id = int(cell[0])
|
||||
return Node(node_id=node_id, label=labels, properties=properties)
|
||||
|
||||
async def parse_scalar(self, cell):
|
||||
"""
|
||||
Parses a scalar value from the server response.
|
||||
"""
|
||||
scalar_type = int(cell[0])
|
||||
value = cell[1]
|
||||
try:
|
||||
scalar = await self.parse_scalar_types[scalar_type](value)
|
||||
except TypeError:
|
||||
# Not all of the functions are async
|
||||
scalar = self.parse_scalar_types[scalar_type](value)
|
||||
|
||||
return scalar
|
||||
|
||||
async def parse_records(self, raw_result_set):
|
||||
"""
|
||||
Parses the result set and returns a list of records.
|
||||
"""
|
||||
records = []
|
||||
for row in raw_result_set[1]:
|
||||
record = [
|
||||
await self.parse_record_types[self.header[idx][0]](cell)
|
||||
for idx, cell in enumerate(row)
|
||||
]
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
async def parse_results(self, raw_result_set):
|
||||
"""
|
||||
Parse the query execution result returned from the server.
|
||||
"""
|
||||
self.header = self.parse_header(raw_result_set)
|
||||
|
||||
# Empty header.
|
||||
if len(self.header) == 0:
|
||||
return
|
||||
|
||||
self.result_set = await self.parse_records(raw_result_set)
|
||||
|
||||
async def parse_entity_properties(self, props):
|
||||
"""
|
||||
Parse node / edge properties.
|
||||
"""
|
||||
# [[name, value type, value] X N]
|
||||
properties = {}
|
||||
for prop in props:
|
||||
prop_name = await self.graph.get_property(prop[0])
|
||||
prop_value = await self.parse_scalar(prop[1:])
|
||||
properties[prop_name] = prop_value
|
||||
|
||||
return properties
|
||||
|
||||
async def parse_edge(self, cell):
|
||||
"""
|
||||
Parse the cell to an edge.
|
||||
"""
|
||||
# Edge ID (integer),
|
||||
# reltype string offset (integer),
|
||||
# src node ID offset (integer),
|
||||
# dest node ID offset (integer),
|
||||
# [[name, value, value type] X N]
|
||||
|
||||
edge_id = int(cell[0])
|
||||
relation = await self.graph.get_relation(cell[1])
|
||||
src_node_id = int(cell[2])
|
||||
dest_node_id = int(cell[3])
|
||||
properties = await self.parse_entity_properties(cell[4])
|
||||
return Edge(
|
||||
src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties
|
||||
)
|
||||
|
||||
async def parse_path(self, cell):
|
||||
"""
|
||||
Parse the cell to a path.
|
||||
"""
|
||||
nodes = await self.parse_scalar(cell[0])
|
||||
edges = await self.parse_scalar(cell[1])
|
||||
return Path(nodes, edges)
|
||||
|
||||
async def parse_map(self, cell):
|
||||
"""
|
||||
Parse the cell to a map.
|
||||
"""
|
||||
m = OrderedDict()
|
||||
n_entries = len(cell)
|
||||
|
||||
# A map is an array of key value pairs.
|
||||
# 1. key (string)
|
||||
# 2. array: (value type, value)
|
||||
for i in range(0, n_entries, 2):
|
||||
key = self.parse_string(cell[i])
|
||||
m[key] = await self.parse_scalar(cell[i + 1])
|
||||
|
||||
return m
|
||||
|
||||
async def parse_array(self, value):
|
||||
"""
|
||||
Parse array value.
|
||||
"""
|
||||
scalar = [await self.parse_scalar(value[i]) for i in range(len(value))]
|
||||
return scalar
|
||||
@@ -43,32 +43,19 @@ def parse_to_list(response):
|
||||
"""Optimistically parse the response to a list."""
|
||||
res = []
|
||||
|
||||
special_values = {"infinity", "nan", "-infinity"}
|
||||
|
||||
if response is None:
|
||||
return res
|
||||
|
||||
for item in response:
|
||||
if item is None:
|
||||
res.append(None)
|
||||
continue
|
||||
try:
|
||||
item_str = nativestr(item)
|
||||
res.append(int(item))
|
||||
except ValueError:
|
||||
try:
|
||||
res.append(float(item))
|
||||
except ValueError:
|
||||
res.append(nativestr(item))
|
||||
except TypeError:
|
||||
res.append(None)
|
||||
continue
|
||||
|
||||
if isinstance(item_str, str) and item_str.lower() in special_values:
|
||||
res.append(item_str) # Keep as string
|
||||
else:
|
||||
try:
|
||||
res.append(int(item))
|
||||
except ValueError:
|
||||
try:
|
||||
res.append(float(item))
|
||||
except ValueError:
|
||||
res.append(item_str)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -77,11 +64,6 @@ def parse_list_to_dict(response):
|
||||
for i in range(0, len(response), 2):
|
||||
if isinstance(response[i], list):
|
||||
res["Child iterators"].append(parse_list_to_dict(response[i]))
|
||||
try:
|
||||
if isinstance(response[i + 1], list):
|
||||
res["Child iterators"].append(parse_list_to_dict(response[i + 1]))
|
||||
except IndexError:
|
||||
pass
|
||||
elif isinstance(response[i + 1], list):
|
||||
res["Child iterators"] = [parse_list_to_dict(response[i + 1])]
|
||||
else:
|
||||
@@ -92,6 +74,25 @@ def parse_list_to_dict(response):
|
||||
return res
|
||||
|
||||
|
||||
def parse_to_dict(response):
|
||||
if response is None:
|
||||
return {}
|
||||
|
||||
res = {}
|
||||
for det in response:
|
||||
if isinstance(det[1], list):
|
||||
res[det[0]] = parse_list_to_dict(det[1])
|
||||
else:
|
||||
try: # try to set the attribute. may be provided without value
|
||||
try: # try to convert the value to float
|
||||
res[det[0]] = float(det[1])
|
||||
except (TypeError, ValueError):
|
||||
res[det[0]] = det[1]
|
||||
except IndexError:
|
||||
pass
|
||||
return res
|
||||
|
||||
|
||||
def random_string(length=10):
|
||||
"""
|
||||
Returns a random N character long string.
|
||||
@@ -101,6 +102,26 @@ def random_string(length=10):
|
||||
)
|
||||
|
||||
|
||||
def quote_string(v):
|
||||
"""
|
||||
RedisGraph strings must be quoted,
|
||||
quote_string wraps given v with quotes incase
|
||||
v is a string.
|
||||
"""
|
||||
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode()
|
||||
elif not isinstance(v, str):
|
||||
return v
|
||||
if len(v) == 0:
|
||||
return '""'
|
||||
|
||||
v = v.replace("\\", "\\\\")
|
||||
v = v.replace('"', '\\"')
|
||||
|
||||
return f'"{v}"'
|
||||
|
||||
|
||||
def decode_dict_keys(obj):
|
||||
"""Decode the keys of the given dictionary with utf-8."""
|
||||
newobj = copy.copy(obj)
|
||||
@@ -111,6 +132,33 @@ def decode_dict_keys(obj):
|
||||
return newobj
|
||||
|
||||
|
||||
def stringify_param_value(value):
|
||||
"""
|
||||
Turn a parameter value into a string suitable for the params header of
|
||||
a Cypher command.
|
||||
You may pass any value that would be accepted by `json.dumps()`.
|
||||
|
||||
Ways in which output differs from that of `str()`:
|
||||
* Strings are quoted.
|
||||
* None --> "null".
|
||||
* In dictionaries, keys are _not_ quoted.
|
||||
|
||||
:param value: The parameter value to be turned into a string.
|
||||
:return: string
|
||||
"""
|
||||
|
||||
if isinstance(value, str):
|
||||
return quote_string(value)
|
||||
elif value is None:
|
||||
return "null"
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return f'[{",".join(map(stringify_param_value, value))}]'
|
||||
elif isinstance(value, dict):
|
||||
return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' # noqa
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
|
||||
def get_protocol_version(client):
|
||||
if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis):
|
||||
return client.connection_pool.connection_kwargs.get("protocol")
|
||||
|
||||
@@ -120,7 +120,7 @@ class JSON(JSONCommands):
|
||||
startup_nodes=self.client.nodes_manager.startup_nodes,
|
||||
result_callbacks=self.client.result_callbacks,
|
||||
cluster_response_callbacks=self.client.cluster_response_callbacks,
|
||||
cluster_error_retry_attempts=self.client.retry.get_retries(),
|
||||
cluster_error_retry_attempts=self.client.cluster_error_retry_attempts,
|
||||
read_from_replicas=self.client.read_from_replicas,
|
||||
reinitialize_steps=self.client.reinitialize_steps,
|
||||
lock=self.client._lock,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List, Mapping, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
JsonType = Union[
|
||||
str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"]
|
||||
]
|
||||
JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
|
||||
|
||||
@@ -15,7 +15,7 @@ class JSONCommands:
|
||||
|
||||
def arrappend(
|
||||
self, name: str, path: Optional[str] = Path.root_path(), *args: List[JsonType]
|
||||
) -> List[Optional[int]]:
|
||||
) -> List[Union[int, None]]:
|
||||
"""Append the objects ``args`` to the array under the
|
||||
``path` in key ``name``.
|
||||
|
||||
@@ -33,7 +33,7 @@ class JSONCommands:
|
||||
scalar: int,
|
||||
start: Optional[int] = None,
|
||||
stop: Optional[int] = None,
|
||||
) -> List[Optional[int]]:
|
||||
) -> List[Union[int, None]]:
|
||||
"""
|
||||
Return the index of ``scalar`` in the JSON array under ``path`` at key
|
||||
``name``.
|
||||
@@ -49,11 +49,11 @@ class JSONCommands:
|
||||
if stop is not None:
|
||||
pieces.append(stop)
|
||||
|
||||
return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name])
|
||||
return self.execute_command("JSON.ARRINDEX", *pieces)
|
||||
|
||||
def arrinsert(
|
||||
self, name: str, path: str, index: int, *args: List[JsonType]
|
||||
) -> List[Optional[int]]:
|
||||
) -> List[Union[int, None]]:
|
||||
"""Insert the objects ``args`` to the array at index ``index``
|
||||
under the ``path` in key ``name``.
|
||||
|
||||
@@ -66,20 +66,20 @@ class JSONCommands:
|
||||
|
||||
def arrlen(
|
||||
self, name: str, path: Optional[str] = Path.root_path()
|
||||
) -> List[Optional[int]]:
|
||||
) -> List[Union[int, None]]:
|
||||
"""Return the length of the array JSON value under ``path``
|
||||
at key``name``.
|
||||
|
||||
For more information see `JSON.ARRLEN <https://redis.io/commands/json.arrlen>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name])
|
||||
return self.execute_command("JSON.ARRLEN", name, str(path))
|
||||
|
||||
def arrpop(
|
||||
self,
|
||||
name: str,
|
||||
path: Optional[str] = Path.root_path(),
|
||||
index: Optional[int] = -1,
|
||||
) -> List[Optional[str]]:
|
||||
) -> List[Union[str, None]]:
|
||||
"""Pop the element at ``index`` in the array JSON value under
|
||||
``path`` at key ``name``.
|
||||
|
||||
@@ -89,7 +89,7 @@ class JSONCommands:
|
||||
|
||||
def arrtrim(
|
||||
self, name: str, path: str, start: int, stop: int
|
||||
) -> List[Optional[int]]:
|
||||
) -> List[Union[int, None]]:
|
||||
"""Trim the array JSON value under ``path`` at key ``name`` to the
|
||||
inclusive range given by ``start`` and ``stop``.
|
||||
|
||||
@@ -102,34 +102,32 @@ class JSONCommands:
|
||||
|
||||
For more information see `JSON.TYPE <https://redis.io/commands/json.type>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.TYPE", name, str(path), keys=[name])
|
||||
return self.execute_command("JSON.TYPE", name, str(path))
|
||||
|
||||
def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List:
|
||||
"""Return the JSON value under ``path`` at key ``name``.
|
||||
|
||||
For more information see `JSON.RESP <https://redis.io/commands/json.resp>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.RESP", name, str(path), keys=[name])
|
||||
return self.execute_command("JSON.RESP", name, str(path))
|
||||
|
||||
def objkeys(
|
||||
self, name: str, path: Optional[str] = Path.root_path()
|
||||
) -> List[Optional[List[str]]]:
|
||||
) -> List[Union[List[str], None]]:
|
||||
"""Return the key names in the dictionary JSON value under ``path`` at
|
||||
key ``name``.
|
||||
|
||||
For more information see `JSON.OBJKEYS <https://redis.io/commands/json.objkeys>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name])
|
||||
return self.execute_command("JSON.OBJKEYS", name, str(path))
|
||||
|
||||
def objlen(
|
||||
self, name: str, path: Optional[str] = Path.root_path()
|
||||
) -> List[Optional[int]]:
|
||||
def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int:
|
||||
"""Return the length of the dictionary JSON value under ``path`` at key
|
||||
``name``.
|
||||
|
||||
For more information see `JSON.OBJLEN <https://redis.io/commands/json.objlen>`_.
|
||||
""" # noqa
|
||||
return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name])
|
||||
return self.execute_command("JSON.OBJLEN", name, str(path))
|
||||
|
||||
def numincrby(self, name: str, path: str, number: int) -> str:
|
||||
"""Increment the numeric (integer or floating point) JSON value under
|
||||
@@ -175,7 +173,7 @@ class JSONCommands:
|
||||
|
||||
def get(
|
||||
self, name: str, *args, no_escape: Optional[bool] = False
|
||||
) -> Optional[List[JsonType]]:
|
||||
) -> List[JsonType]:
|
||||
"""
|
||||
Get the object stored as a JSON value at key ``name``.
|
||||
|
||||
@@ -199,7 +197,7 @@ class JSONCommands:
|
||||
# Handle case where key doesn't exist. The JSONDecoder would raise a
|
||||
# TypeError exception since it can't decode None
|
||||
try:
|
||||
return self.execute_command("JSON.GET", *pieces, keys=[name])
|
||||
return self.execute_command("JSON.GET", *pieces)
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
@@ -213,7 +211,7 @@ class JSONCommands:
|
||||
pieces = []
|
||||
pieces += keys
|
||||
pieces.append(str(path))
|
||||
return self.execute_command("JSON.MGET", *pieces, keys=keys)
|
||||
return self.execute_command("JSON.MGET", *pieces)
|
||||
|
||||
def set(
|
||||
self,
|
||||
@@ -314,7 +312,7 @@ class JSONCommands:
|
||||
|
||||
"""
|
||||
|
||||
with open(file_name) as fp:
|
||||
with open(file_name, "r") as fp:
|
||||
file_content = loads(fp.read())
|
||||
|
||||
return self.set(name, path, file_content, nx=nx, xx=xx, decode_keys=decode_keys)
|
||||
@@ -326,7 +324,7 @@ class JSONCommands:
|
||||
nx: Optional[bool] = False,
|
||||
xx: Optional[bool] = False,
|
||||
decode_keys: Optional[bool] = False,
|
||||
) -> Dict[str, bool]:
|
||||
) -> List[Dict[str, bool]]:
|
||||
"""
|
||||
Iterate over ``root_folder`` and set each JSON file to a value
|
||||
under ``json_path`` with the file name as the key.
|
||||
@@ -357,7 +355,7 @@ class JSONCommands:
|
||||
|
||||
return set_files_result
|
||||
|
||||
def strlen(self, name: str, path: Optional[str] = None) -> List[Optional[int]]:
|
||||
def strlen(self, name: str, path: Optional[str] = None) -> List[Union[int, None]]:
|
||||
"""Return the length of the string JSON value under ``path`` at key
|
||||
``name``.
|
||||
|
||||
@@ -366,7 +364,7 @@ class JSONCommands:
|
||||
pieces = [name]
|
||||
if path is not None:
|
||||
pieces.append(str(path))
|
||||
return self.execute_command("JSON.STRLEN", *pieces, keys=[name])
|
||||
return self.execute_command("JSON.STRLEN", *pieces)
|
||||
|
||||
def toggle(
|
||||
self, name: str, path: Optional[str] = Path.root_path()
|
||||
@@ -379,7 +377,7 @@ class JSONCommands:
|
||||
return self.execute_command("JSON.TOGGLE", name, str(path))
|
||||
|
||||
def strappend(
|
||||
self, name: str, value: str, path: Optional[str] = Path.root_path()
|
||||
self, name: str, value: str, path: Optional[int] = Path.root_path()
|
||||
) -> Union[int, List[Optional[int]]]:
|
||||
"""Append to the string JSON value. If two options are specified after
|
||||
the key name, the path is determined to be the first. If a single
|
||||
|
||||
@@ -1,14 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from json import JSONDecoder, JSONEncoder
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bf import BFBloom, CFBloom, CMSBloom, TDigestBloom, TOPKBloom
|
||||
from .json import JSON
|
||||
from .search import AsyncSearch, Search
|
||||
from .timeseries import TimeSeries
|
||||
from .vectorset import VectorSet
|
||||
|
||||
|
||||
class RedisModuleCommands:
|
||||
@@ -16,7 +6,7 @@ class RedisModuleCommands:
|
||||
modules into the command namespace.
|
||||
"""
|
||||
|
||||
def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()) -> JSON:
|
||||
def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()):
|
||||
"""Access the json namespace, providing support for redis json."""
|
||||
|
||||
from .json import JSON
|
||||
@@ -24,7 +14,7 @@ class RedisModuleCommands:
|
||||
jj = JSON(client=self, encoder=encoder, decoder=decoder)
|
||||
return jj
|
||||
|
||||
def ft(self, index_name="idx") -> Search:
|
||||
def ft(self, index_name="idx"):
|
||||
"""Access the search namespace, providing support for redis search."""
|
||||
|
||||
from .search import Search
|
||||
@@ -32,7 +22,7 @@ class RedisModuleCommands:
|
||||
s = Search(client=self, index_name=index_name)
|
||||
return s
|
||||
|
||||
def ts(self) -> TimeSeries:
|
||||
def ts(self):
|
||||
"""Access the timeseries namespace, providing support for
|
||||
redis timeseries data.
|
||||
"""
|
||||
@@ -42,7 +32,7 @@ class RedisModuleCommands:
|
||||
s = TimeSeries(client=self)
|
||||
return s
|
||||
|
||||
def bf(self) -> BFBloom:
|
||||
def bf(self):
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import BFBloom
|
||||
@@ -50,7 +40,7 @@ class RedisModuleCommands:
|
||||
bf = BFBloom(client=self)
|
||||
return bf
|
||||
|
||||
def cf(self) -> CFBloom:
|
||||
def cf(self):
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import CFBloom
|
||||
@@ -58,7 +48,7 @@ class RedisModuleCommands:
|
||||
cf = CFBloom(client=self)
|
||||
return cf
|
||||
|
||||
def cms(self) -> CMSBloom:
|
||||
def cms(self):
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import CMSBloom
|
||||
@@ -66,7 +56,7 @@ class RedisModuleCommands:
|
||||
cms = CMSBloom(client=self)
|
||||
return cms
|
||||
|
||||
def topk(self) -> TOPKBloom:
|
||||
def topk(self):
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import TOPKBloom
|
||||
@@ -74,7 +64,7 @@ class RedisModuleCommands:
|
||||
topk = TOPKBloom(client=self)
|
||||
return topk
|
||||
|
||||
def tdigest(self) -> TDigestBloom:
|
||||
def tdigest(self):
|
||||
"""Access the bloom namespace."""
|
||||
|
||||
from .bf import TDigestBloom
|
||||
@@ -82,20 +72,32 @@ class RedisModuleCommands:
|
||||
tdigest = TDigestBloom(client=self)
|
||||
return tdigest
|
||||
|
||||
def vset(self) -> VectorSet:
|
||||
"""Access the VectorSet commands namespace."""
|
||||
def graph(self, index_name="idx"):
|
||||
"""Access the graph namespace, providing support for
|
||||
redis graph data.
|
||||
"""
|
||||
|
||||
from .vectorset import VectorSet
|
||||
from .graph import Graph
|
||||
|
||||
vset = VectorSet(client=self)
|
||||
return vset
|
||||
g = Graph(client=self, name=index_name)
|
||||
return g
|
||||
|
||||
|
||||
class AsyncRedisModuleCommands(RedisModuleCommands):
|
||||
def ft(self, index_name="idx") -> AsyncSearch:
|
||||
def ft(self, index_name="idx"):
|
||||
"""Access the search namespace, providing support for redis search."""
|
||||
|
||||
from .search import AsyncSearch
|
||||
|
||||
s = AsyncSearch(client=self, index_name=index_name)
|
||||
return s
|
||||
|
||||
def graph(self, index_name="idx"):
|
||||
"""Access the graph namespace, providing support for
|
||||
redis graph data.
|
||||
"""
|
||||
|
||||
from .graph import AsyncGraph
|
||||
|
||||
g = AsyncGraph(client=self, name=index_name)
|
||||
return g
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
def to_string(s, encoding: str = "utf-8"):
|
||||
def to_string(s):
|
||||
if isinstance(s, str):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
return s.decode(encoding, "ignore")
|
||||
return s.decode("utf-8", "ignore")
|
||||
else:
|
||||
return s # Not a string we care about
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import List, Union
|
||||
|
||||
from redis.commands.search.dialect import DEFAULT_DIALECT
|
||||
|
||||
FIELDNAME = object()
|
||||
|
||||
|
||||
@@ -26,7 +24,7 @@ class Reducer:
|
||||
|
||||
NAME = None
|
||||
|
||||
def __init__(self, *args: str) -> None:
|
||||
def __init__(self, *args: List[str]) -> None:
|
||||
self._args = args
|
||||
self._field = None
|
||||
self._alias = None
|
||||
@@ -112,11 +110,9 @@ class AggregateRequest:
|
||||
self._with_schema = False
|
||||
self._verbatim = False
|
||||
self._cursor = []
|
||||
self._dialect = DEFAULT_DIALECT
|
||||
self._add_scores = False
|
||||
self._scorer = "TFIDF"
|
||||
self._dialect = None
|
||||
|
||||
def load(self, *fields: str) -> "AggregateRequest":
|
||||
def load(self, *fields: List[str]) -> "AggregateRequest":
|
||||
"""
|
||||
Indicate the fields to be returned in the response. These fields are
|
||||
returned in addition to any others implicitly specified.
|
||||
@@ -223,7 +219,7 @@ class AggregateRequest:
|
||||
self._aggregateplan.extend(_limit.build_args())
|
||||
return self
|
||||
|
||||
def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
|
||||
def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
|
||||
"""
|
||||
Indicate how the results should be sorted. This can also be used for
|
||||
*top-N* style queries
|
||||
@@ -296,24 +292,6 @@ class AggregateRequest:
|
||||
self._with_schema = True
|
||||
return self
|
||||
|
||||
def add_scores(self) -> "AggregateRequest":
|
||||
"""
|
||||
If set, includes the score as an ordinary field of the row.
|
||||
"""
|
||||
self._add_scores = True
|
||||
return self
|
||||
|
||||
def scorer(self, scorer: str) -> "AggregateRequest":
|
||||
"""
|
||||
Use a different scoring function to evaluate document relevance.
|
||||
Default is `TFIDF`.
|
||||
|
||||
:param scorer: The scoring function to use
|
||||
(e.g. `TFIDF.DOCNORM` or `BM25`)
|
||||
"""
|
||||
self._scorer = scorer
|
||||
return self
|
||||
|
||||
def verbatim(self) -> "AggregateRequest":
|
||||
self._verbatim = True
|
||||
return self
|
||||
@@ -337,19 +315,12 @@ class AggregateRequest:
|
||||
if self._verbatim:
|
||||
ret.append("VERBATIM")
|
||||
|
||||
if self._scorer:
|
||||
ret.extend(["SCORER", self._scorer])
|
||||
|
||||
if self._add_scores:
|
||||
ret.append("ADDSCORES")
|
||||
|
||||
if self._cursor:
|
||||
ret += self._cursor
|
||||
|
||||
if self._loadall:
|
||||
ret.append("LOAD")
|
||||
ret.append("*")
|
||||
|
||||
elif self._loadfields:
|
||||
ret.append("LOAD")
|
||||
ret.append(str(len(self._loadfields)))
|
||||
|
||||
@@ -2,16 +2,13 @@ import itertools
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from redis.client import NEVER_DECODE, Pipeline
|
||||
from redis.client import Pipeline
|
||||
from redis.utils import deprecated_function
|
||||
|
||||
from ..helpers import get_protocol_version
|
||||
from ..helpers import get_protocol_version, parse_to_dict
|
||||
from ._util import to_string
|
||||
from .aggregation import AggregateRequest, AggregateResult, Cursor
|
||||
from .document import Document
|
||||
from .field import Field
|
||||
from .index_definition import IndexDefinition
|
||||
from .profile_information import ProfileInformation
|
||||
from .query import Query
|
||||
from .result import Result
|
||||
from .suggestion import SuggestionParser
|
||||
@@ -23,6 +20,7 @@ ALTER_CMD = "FT.ALTER"
|
||||
SEARCH_CMD = "FT.SEARCH"
|
||||
ADD_CMD = "FT.ADD"
|
||||
ADDHASH_CMD = "FT.ADDHASH"
|
||||
DROP_CMD = "FT.DROP"
|
||||
DROPINDEX_CMD = "FT.DROPINDEX"
|
||||
EXPLAIN_CMD = "FT.EXPLAIN"
|
||||
EXPLAINCLI_CMD = "FT.EXPLAINCLI"
|
||||
@@ -34,6 +32,7 @@ SPELLCHECK_CMD = "FT.SPELLCHECK"
|
||||
DICT_ADD_CMD = "FT.DICTADD"
|
||||
DICT_DEL_CMD = "FT.DICTDEL"
|
||||
DICT_DUMP_CMD = "FT.DICTDUMP"
|
||||
GET_CMD = "FT.GET"
|
||||
MGET_CMD = "FT.MGET"
|
||||
CONFIG_CMD = "FT.CONFIG"
|
||||
TAGVALS_CMD = "FT.TAGVALS"
|
||||
@@ -66,7 +65,7 @@ class SearchCommands:
|
||||
|
||||
def _parse_results(self, cmd, res, **kwargs):
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
return ProfileInformation(res) if cmd == "FT.PROFILE" else res
|
||||
return res
|
||||
else:
|
||||
return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs)
|
||||
|
||||
@@ -81,7 +80,6 @@ class SearchCommands:
|
||||
duration=kwargs["duration"],
|
||||
has_payload=kwargs["query"]._with_payloads,
|
||||
with_scores=kwargs["query"]._with_scores,
|
||||
field_encodings=kwargs["query"]._return_fields_decode_as,
|
||||
)
|
||||
|
||||
def _parse_aggregate(self, res, **kwargs):
|
||||
@@ -100,7 +98,7 @@ class SearchCommands:
|
||||
with_scores=query._with_scores,
|
||||
)
|
||||
|
||||
return result, ProfileInformation(res[1])
|
||||
return result, parse_to_dict(res[1])
|
||||
|
||||
def _parse_spellcheck(self, res, **kwargs):
|
||||
corrections = {}
|
||||
@@ -153,43 +151,44 @@ class SearchCommands:
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
fields: List[Field],
|
||||
no_term_offsets: bool = False,
|
||||
no_field_flags: bool = False,
|
||||
stopwords: Optional[List[str]] = None,
|
||||
definition: Optional[IndexDefinition] = None,
|
||||
fields,
|
||||
no_term_offsets=False,
|
||||
no_field_flags=False,
|
||||
stopwords=None,
|
||||
definition=None,
|
||||
max_text_fields=False,
|
||||
temporary=None,
|
||||
no_highlight: bool = False,
|
||||
no_term_frequencies: bool = False,
|
||||
skip_initial_scan: bool = False,
|
||||
no_highlight=False,
|
||||
no_term_frequencies=False,
|
||||
skip_initial_scan=False,
|
||||
):
|
||||
"""
|
||||
Creates the search index. The index must not already exist.
|
||||
Create the search index. The index must not already exist.
|
||||
|
||||
For more information, see https://redis.io/commands/ft.create/
|
||||
### Parameters:
|
||||
|
||||
Args:
|
||||
fields: A list of Field objects.
|
||||
no_term_offsets: If `true`, term offsets will not be saved in the index.
|
||||
no_field_flags: If true, field flags that allow searching in specific fields
|
||||
will not be saved.
|
||||
stopwords: If provided, the index will be created with this custom stopword
|
||||
list. The list can be empty.
|
||||
definition: If provided, the index will be created with this custom index
|
||||
definition.
|
||||
max_text_fields: If true, indexes will be encoded as if there were more than
|
||||
32 text fields, allowing for additional fields beyond 32.
|
||||
temporary: Creates a lightweight temporary index which will expire after the
|
||||
specified period of inactivity. The internal idle timer is reset
|
||||
whenever the index is searched or added to.
|
||||
no_highlight: If true, disables highlighting support. Also implied by
|
||||
`no_term_offsets`.
|
||||
no_term_frequencies: If true, term frequencies will not be saved in the
|
||||
index.
|
||||
skip_initial_scan: If true, the initial scan and indexing will be skipped.
|
||||
- **fields**: a list of TextField or NumericField objects
|
||||
- **no_term_offsets**: If true, we will not save term offsets in
|
||||
the index
|
||||
- **no_field_flags**: If true, we will not save field flags that
|
||||
allow searching in specific fields
|
||||
- **stopwords**: If not None, we create the index with this custom
|
||||
stopword list. The list can be empty
|
||||
- **max_text_fields**: If true, we will encode indexes as if there
|
||||
were more than 32 text fields which allows you to add additional
|
||||
fields (beyond 32).
|
||||
- **temporary**: Create a lightweight temporary index which will
|
||||
expire after the specified period of inactivity (in seconds). The
|
||||
internal idle timer is reset whenever the index is searched or added to.
|
||||
- **no_highlight**: If true, disabling highlighting support.
|
||||
Also implied by no_term_offsets.
|
||||
- **no_term_frequencies**: If true, we avoid saving the term frequencies
|
||||
in the index.
|
||||
- **skip_initial_scan**: If true, we do not scan and index.
|
||||
|
||||
For more information see `FT.CREATE <https://redis.io/commands/ft.create>`_.
|
||||
""" # noqa
|
||||
|
||||
"""
|
||||
args = [CREATE_CMD, self.index_name]
|
||||
if definition is not None:
|
||||
args += definition.args
|
||||
@@ -253,18 +252,8 @@ class SearchCommands:
|
||||
|
||||
For more information see `FT.DROPINDEX <https://redis.io/commands/ft.dropindex>`_.
|
||||
""" # noqa
|
||||
args = [DROPINDEX_CMD, self.index_name]
|
||||
|
||||
delete_str = (
|
||||
"DD"
|
||||
if isinstance(delete_documents, bool) and delete_documents is True
|
||||
else ""
|
||||
)
|
||||
|
||||
if delete_str:
|
||||
args.append(delete_str)
|
||||
|
||||
return self.execute_command(*args)
|
||||
delete_str = "DD" if delete_documents else ""
|
||||
return self.execute_command(DROPINDEX_CMD, self.index_name, delete_str)
|
||||
|
||||
def _add_document(
|
||||
self,
|
||||
@@ -346,30 +335,30 @@ class SearchCommands:
|
||||
"""
|
||||
Add a single document to the index.
|
||||
|
||||
Args:
|
||||
### Parameters
|
||||
|
||||
doc_id: the id of the saved document.
|
||||
nosave: if set to true, we just index the document, and don't
|
||||
- **doc_id**: the id of the saved document.
|
||||
- **nosave**: if set to true, we just index the document, and don't
|
||||
save a copy of it. This means that searches will just
|
||||
return ids.
|
||||
score: the document ranking, between 0.0 and 1.0
|
||||
payload: optional inner-index payload we can save for fast
|
||||
access in scoring functions
|
||||
replace: if True, and the document already is in the index,
|
||||
we perform an update and reindex the document
|
||||
partial: if True, the fields specified will be added to the
|
||||
- **score**: the document ranking, between 0.0 and 1.0
|
||||
- **payload**: optional inner-index payload we can save for fast
|
||||
i access in scoring functions
|
||||
- **replace**: if True, and the document already is in the index,
|
||||
we perform an update and reindex the document
|
||||
- **partial**: if True, the fields specified will be added to the
|
||||
existing document.
|
||||
This has the added benefit that any fields specified
|
||||
with `no_index`
|
||||
will not be reindexed again. Implies `replace`
|
||||
language: Specify the language used for document tokenization.
|
||||
no_create: if True, the document is only updated and reindexed
|
||||
- **language**: Specify the language used for document tokenization.
|
||||
- **no_create**: if True, the document is only updated and reindexed
|
||||
if it already exists.
|
||||
If the document does not exist, an error will be
|
||||
returned. Implies `replace`
|
||||
fields: kwargs dictionary of the document fields to be saved
|
||||
and/or indexed.
|
||||
NOTE: Geo points shoule be encoded as strings of "lon,lat"
|
||||
- **fields** kwargs dictionary of the document fields to be saved
|
||||
and/or indexed.
|
||||
NOTE: Geo points shoule be encoded as strings of "lon,lat"
|
||||
""" # noqa
|
||||
return self._add_document(
|
||||
doc_id,
|
||||
@@ -404,7 +393,6 @@ class SearchCommands:
|
||||
doc_id, conn=None, score=score, language=language, replace=replace
|
||||
)
|
||||
|
||||
@deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0")
|
||||
def delete_document(self, doc_id, conn=None, delete_actual_document=False):
|
||||
"""
|
||||
Delete a document from index
|
||||
@@ -439,7 +427,6 @@ class SearchCommands:
|
||||
|
||||
return Document(id=id, **fields)
|
||||
|
||||
@deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0")
|
||||
def get(self, *ids):
|
||||
"""
|
||||
Returns the full contents of multiple documents.
|
||||
@@ -510,19 +497,14 @@ class SearchCommands:
|
||||
For more information see `FT.SEARCH <https://redis.io/commands/ft.search>`_.
|
||||
""" # noqa
|
||||
args, query = self._mk_query_args(query, query_params=query_params)
|
||||
st = time.monotonic()
|
||||
|
||||
options = {}
|
||||
if get_protocol_version(self.client) not in ["3", 3]:
|
||||
options[NEVER_DECODE] = True
|
||||
|
||||
res = self.execute_command(SEARCH_CMD, *args, **options)
|
||||
st = time.time()
|
||||
res = self.execute_command(SEARCH_CMD, *args)
|
||||
|
||||
if isinstance(res, Pipeline):
|
||||
return res
|
||||
|
||||
return self._parse_results(
|
||||
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
|
||||
SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0
|
||||
)
|
||||
|
||||
def explain(
|
||||
@@ -542,7 +524,7 @@ class SearchCommands:
|
||||
|
||||
def aggregate(
|
||||
self,
|
||||
query: Union[AggregateRequest, Cursor],
|
||||
query: Union[str, Query],
|
||||
query_params: Dict[str, Union[str, int, float]] = None,
|
||||
):
|
||||
"""
|
||||
@@ -573,7 +555,7 @@ class SearchCommands:
|
||||
)
|
||||
|
||||
def _get_aggregate_result(
|
||||
self, raw: List, query: Union[AggregateRequest, Cursor], has_cursor: bool
|
||||
self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool
|
||||
):
|
||||
if has_cursor:
|
||||
if isinstance(query, Cursor):
|
||||
@@ -596,7 +578,7 @@ class SearchCommands:
|
||||
|
||||
def profile(
|
||||
self,
|
||||
query: Union[Query, AggregateRequest],
|
||||
query: Union[str, Query, AggregateRequest],
|
||||
limited: bool = False,
|
||||
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
|
||||
):
|
||||
@@ -606,13 +588,13 @@ class SearchCommands:
|
||||
|
||||
### Parameters
|
||||
|
||||
**query**: This can be either an `AggregateRequest` or `Query`.
|
||||
**query**: This can be either an `AggregateRequest`, `Query` or string.
|
||||
**limited**: If set to True, removes details of reader iterator.
|
||||
**query_params**: Define one or more value parameters.
|
||||
Each parameter has a name and a value.
|
||||
|
||||
"""
|
||||
st = time.monotonic()
|
||||
st = time.time()
|
||||
cmd = [PROFILE_CMD, self.index_name, ""]
|
||||
if limited:
|
||||
cmd.append("LIMITED")
|
||||
@@ -631,20 +613,20 @@ class SearchCommands:
|
||||
res = self.execute_command(*cmd)
|
||||
|
||||
return self._parse_results(
|
||||
PROFILE_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
|
||||
PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0
|
||||
)
|
||||
|
||||
def spellcheck(self, query, distance=None, include=None, exclude=None):
|
||||
"""
|
||||
Issue a spellcheck query
|
||||
|
||||
Args:
|
||||
### Parameters
|
||||
|
||||
query: search query.
|
||||
distance: the maximal Levenshtein distance for spelling
|
||||
**query**: search query.
|
||||
**distance***: the maximal Levenshtein distance for spelling
|
||||
suggestions (default: 1, max: 4).
|
||||
include: specifies an inclusion custom dictionary.
|
||||
exclude: specifies an exclusion custom dictionary.
|
||||
**include**: specifies an inclusion custom dictionary.
|
||||
**exclude**: specifies an exclusion custom dictionary.
|
||||
|
||||
For more information see `FT.SPELLCHECK <https://redis.io/commands/ft.spellcheck>`_.
|
||||
""" # noqa
|
||||
@@ -702,10 +684,6 @@ class SearchCommands:
|
||||
cmd = [DICT_DUMP_CMD, name]
|
||||
return self.execute_command(*cmd)
|
||||
|
||||
@deprecated_function(
|
||||
version="8.0.0",
|
||||
reason="deprecated since Redis 8.0, call config_set from core module instead",
|
||||
)
|
||||
def config_set(self, option: str, value: str) -> bool:
|
||||
"""Set runtime configuration option.
|
||||
|
||||
@@ -720,10 +698,6 @@ class SearchCommands:
|
||||
raw = self.execute_command(*cmd)
|
||||
return raw == "OK"
|
||||
|
||||
@deprecated_function(
|
||||
version="8.0.0",
|
||||
reason="deprecated since Redis 8.0, call config_get from core module instead",
|
||||
)
|
||||
def config_get(self, option: str) -> str:
|
||||
"""Get runtime configuration option value.
|
||||
|
||||
@@ -950,24 +924,19 @@ class AsyncSearchCommands(SearchCommands):
|
||||
For more information see `FT.SEARCH <https://redis.io/commands/ft.search>`_.
|
||||
""" # noqa
|
||||
args, query = self._mk_query_args(query, query_params=query_params)
|
||||
st = time.monotonic()
|
||||
|
||||
options = {}
|
||||
if get_protocol_version(self.client) not in ["3", 3]:
|
||||
options[NEVER_DECODE] = True
|
||||
|
||||
res = await self.execute_command(SEARCH_CMD, *args, **options)
|
||||
st = time.time()
|
||||
res = await self.execute_command(SEARCH_CMD, *args)
|
||||
|
||||
if isinstance(res, Pipeline):
|
||||
return res
|
||||
|
||||
return self._parse_results(
|
||||
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
|
||||
SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0
|
||||
)
|
||||
|
||||
async def aggregate(
|
||||
self,
|
||||
query: Union[AggregateResult, Cursor],
|
||||
query: Union[str, Query],
|
||||
query_params: Dict[str, Union[str, int, float]] = None,
|
||||
):
|
||||
"""
|
||||
@@ -1025,10 +994,6 @@ class AsyncSearchCommands(SearchCommands):
|
||||
|
||||
return self._parse_results(SPELLCHECK_CMD, res)
|
||||
|
||||
@deprecated_function(
|
||||
version="8.0.0",
|
||||
reason="deprecated since Redis 8.0, call config_set from core module instead",
|
||||
)
|
||||
async def config_set(self, option: str, value: str) -> bool:
|
||||
"""Set runtime configuration option.
|
||||
|
||||
@@ -1043,10 +1008,6 @@ class AsyncSearchCommands(SearchCommands):
|
||||
raw = await self.execute_command(*cmd)
|
||||
return raw == "OK"
|
||||
|
||||
@deprecated_function(
|
||||
version="8.0.0",
|
||||
reason="deprecated since Redis 8.0, call config_get from core module instead",
|
||||
)
|
||||
async def config_get(self, option: str) -> str:
|
||||
"""Get runtime configuration option value.
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# Value for the default dialect to be used as a part of
|
||||
# Search or Aggregate query.
|
||||
DEFAULT_DIALECT = 2
|
||||
@@ -4,10 +4,6 @@ from redis import DataError
|
||||
|
||||
|
||||
class Field:
|
||||
"""
|
||||
A class representing a field in a document.
|
||||
"""
|
||||
|
||||
NUMERIC = "NUMERIC"
|
||||
TEXT = "TEXT"
|
||||
WEIGHT = "WEIGHT"
|
||||
@@ -17,9 +13,6 @@ class Field:
|
||||
SORTABLE = "SORTABLE"
|
||||
NOINDEX = "NOINDEX"
|
||||
AS = "AS"
|
||||
GEOSHAPE = "GEOSHAPE"
|
||||
INDEX_MISSING = "INDEXMISSING"
|
||||
INDEX_EMPTY = "INDEXEMPTY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -27,24 +20,8 @@ class Field:
|
||||
args: List[str] = None,
|
||||
sortable: bool = False,
|
||||
no_index: bool = False,
|
||||
index_missing: bool = False,
|
||||
index_empty: bool = False,
|
||||
as_name: str = None,
|
||||
):
|
||||
"""
|
||||
Create a new field object.
|
||||
|
||||
Args:
|
||||
name: The name of the field.
|
||||
args:
|
||||
sortable: If `True`, the field will be sortable.
|
||||
no_index: If `True`, the field will not be indexed.
|
||||
index_missing: If `True`, it will be possible to search for documents that
|
||||
have this field missing.
|
||||
index_empty: If `True`, it will be possible to search for documents that
|
||||
have this field empty.
|
||||
as_name: If provided, this alias will be used for the field.
|
||||
"""
|
||||
if args is None:
|
||||
args = []
|
||||
self.name = name
|
||||
@@ -56,10 +33,6 @@ class Field:
|
||||
self.args_suffix.append(Field.SORTABLE)
|
||||
if no_index:
|
||||
self.args_suffix.append(Field.NOINDEX)
|
||||
if index_missing:
|
||||
self.args_suffix.append(Field.INDEX_MISSING)
|
||||
if index_empty:
|
||||
self.args_suffix.append(Field.INDEX_EMPTY)
|
||||
|
||||
if no_index and not sortable:
|
||||
raise ValueError("Non-Sortable non-Indexable fields are ignored")
|
||||
@@ -118,21 +91,6 @@ class NumericField(Field):
|
||||
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
|
||||
|
||||
|
||||
class GeoShapeField(Field):
|
||||
"""
|
||||
GeoShapeField is used to enable within/contain indexing/searching
|
||||
"""
|
||||
|
||||
SPHERICAL = "SPHERICAL"
|
||||
FLAT = "FLAT"
|
||||
|
||||
def __init__(self, name: str, coord_system=None, **kwargs):
|
||||
args = [Field.GEOSHAPE]
|
||||
if coord_system:
|
||||
args.append(coord_system)
|
||||
Field.__init__(self, name, args=args, **kwargs)
|
||||
|
||||
|
||||
class GeoField(Field):
|
||||
"""
|
||||
GeoField is used to define a geo-indexing field in a schema definition
|
||||
@@ -181,7 +139,7 @@ class VectorField(Field):
|
||||
|
||||
``name`` is the name of the field.
|
||||
|
||||
``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA".
|
||||
``algorithm`` can be "FLAT" or "HNSW".
|
||||
|
||||
``attributes`` each algorithm can have specific attributes. Some of them
|
||||
are mandatory and some of them are optional. See
|
||||
@@ -194,10 +152,10 @@ class VectorField(Field):
|
||||
if sort or noindex:
|
||||
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
|
||||
|
||||
if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]:
|
||||
if algorithm.upper() not in ["FLAT", "HNSW"]:
|
||||
raise DataError(
|
||||
"Realtime vector indexing supporting 3 Indexing Methods:"
|
||||
"'FLAT', 'HNSW', and 'SVS-VAMANA'."
|
||||
"Realtime vector indexing supporting 2 Indexing Methods:"
|
||||
"'FLAT' and 'HNSW'."
|
||||
)
|
||||
|
||||
attr_li = []
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ProfileInformation:
|
||||
"""
|
||||
Wrapper around FT.PROFILE response
|
||||
"""
|
||||
|
||||
def __init__(self, info: Any) -> None:
|
||||
self._info: Any = info
|
||||
|
||||
@property
|
||||
def info(self) -> Any:
|
||||
return self._info
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from redis.commands.search.dialect import DEFAULT_DIALECT
|
||||
|
||||
|
||||
class Query:
|
||||
"""
|
||||
@@ -37,12 +35,11 @@ class Query:
|
||||
self._in_order: bool = False
|
||||
self._sortby: Optional[SortbyField] = None
|
||||
self._return_fields: List = []
|
||||
self._return_fields_decode_as: dict = {}
|
||||
self._summarize_fields: List = []
|
||||
self._highlight_fields: List = []
|
||||
self._language: Optional[str] = None
|
||||
self._expander: Optional[str] = None
|
||||
self._dialect: int = DEFAULT_DIALECT
|
||||
self._dialect: Optional[int] = None
|
||||
|
||||
def query_string(self) -> str:
|
||||
"""Return the query string of this query only."""
|
||||
@@ -56,27 +53,13 @@ class Query:
|
||||
|
||||
def return_fields(self, *fields) -> "Query":
|
||||
"""Add fields to return fields."""
|
||||
for field in fields:
|
||||
self.return_field(field)
|
||||
self._return_fields += fields
|
||||
return self
|
||||
|
||||
def return_field(
|
||||
self,
|
||||
field: str,
|
||||
as_field: Optional[str] = None,
|
||||
decode_field: Optional[bool] = True,
|
||||
encoding: Optional[str] = "utf8",
|
||||
) -> "Query":
|
||||
"""
|
||||
Add a field to the list of fields to return.
|
||||
|
||||
- **field**: The field to include in query results
|
||||
- **as_field**: The alias for the field
|
||||
- **decode_field**: Whether to decode the field from bytes to string
|
||||
- **encoding**: The encoding to use when decoding the field
|
||||
"""
|
||||
def return_field(self, field: str, as_field: Optional[str] = None) -> "Query":
|
||||
"""Add field to return fields (Optional: add 'AS' name
|
||||
to the field)."""
|
||||
self._return_fields.append(field)
|
||||
self._return_fields_decode_as[field] = encoding if decode_field else None
|
||||
if as_field is not None:
|
||||
self._return_fields += ("AS", as_field)
|
||||
return self
|
||||
@@ -179,8 +162,6 @@ class Query:
|
||||
Use a different scoring function to evaluate document relevance.
|
||||
Default is `TFIDF`.
|
||||
|
||||
Since Redis 8.0 default was changed to BM25STD.
|
||||
|
||||
:param scorer: The scoring function to use
|
||||
(e.g. `TFIDF.DOCNORM` or `BM25`)
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from ._util import to_string
|
||||
from .document import Document
|
||||
|
||||
@@ -11,19 +9,11 @@ class Result:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
res,
|
||||
hascontent,
|
||||
duration=0,
|
||||
has_payload=False,
|
||||
with_scores=False,
|
||||
field_encodings: Optional[dict] = None,
|
||||
self, res, hascontent, duration=0, has_payload=False, with_scores=False
|
||||
):
|
||||
"""
|
||||
- duration: the execution time of the query
|
||||
- has_payload: whether the query has payloads
|
||||
- with_scores: whether the query has scores
|
||||
- field_encodings: a dictionary of field encodings if any is provided
|
||||
- **snippets**: An optional dictionary of the form
|
||||
{field: snippet_size} for snippet formatting
|
||||
"""
|
||||
|
||||
self.total = res[0]
|
||||
@@ -49,22 +39,18 @@ class Result:
|
||||
|
||||
fields = {}
|
||||
if hascontent and res[i + fields_offset] is not None:
|
||||
keys = map(to_string, res[i + fields_offset][::2])
|
||||
values = res[i + fields_offset][1::2]
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
if field_encodings is None or key not in field_encodings:
|
||||
fields[key] = to_string(value)
|
||||
continue
|
||||
|
||||
encoding = field_encodings[key]
|
||||
|
||||
# If the encoding is None, we don't need to decode the value
|
||||
if encoding is None:
|
||||
fields[key] = value
|
||||
else:
|
||||
fields[key] = to_string(value, encoding=encoding)
|
||||
|
||||
fields = (
|
||||
dict(
|
||||
dict(
|
||||
zip(
|
||||
map(to_string, res[i + fields_offset][::2]),
|
||||
map(to_string, res[i + fields_offset][1::2]),
|
||||
)
|
||||
)
|
||||
)
|
||||
if hascontent
|
||||
else {}
|
||||
)
|
||||
try:
|
||||
del fields["id"]
|
||||
except KeyError:
|
||||
|
||||
@@ -11,35 +11,16 @@ class SentinelCommands:
|
||||
"""Redis Sentinel's SENTINEL command."""
|
||||
warnings.warn(DeprecationWarning("Use the individual sentinel_* methods"))
|
||||
|
||||
def sentinel_get_master_addr_by_name(self, service_name, return_responses=False):
|
||||
"""
|
||||
Returns a (host, port) pair for the given ``service_name`` when return_responses is True,
|
||||
otherwise returns a boolean value that indicates if the command was successful.
|
||||
"""
|
||||
return self.execute_command(
|
||||
"SENTINEL GET-MASTER-ADDR-BY-NAME",
|
||||
service_name,
|
||||
once=True,
|
||||
return_responses=return_responses,
|
||||
)
|
||||
def sentinel_get_master_addr_by_name(self, service_name):
|
||||
"""Returns a (host, port) pair for the given ``service_name``"""
|
||||
return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name)
|
||||
|
||||
def sentinel_master(self, service_name, return_responses=False):
|
||||
"""
|
||||
Returns a dictionary containing the specified masters state, when return_responses is True,
|
||||
otherwise returns a boolean value that indicates if the command was successful.
|
||||
"""
|
||||
return self.execute_command(
|
||||
"SENTINEL MASTER", service_name, return_responses=return_responses
|
||||
)
|
||||
def sentinel_master(self, service_name):
|
||||
"""Returns a dictionary containing the specified masters state."""
|
||||
return self.execute_command("SENTINEL MASTER", service_name)
|
||||
|
||||
def sentinel_masters(self):
|
||||
"""
|
||||
Returns a list of dictionaries containing each master's state.
|
||||
|
||||
Important: This function is called by the Sentinel implementation and is
|
||||
called directly on the Redis standalone client for sentinels,
|
||||
so it doesn't support the "once" and "return_responses" options.
|
||||
"""
|
||||
"""Returns a list of dictionaries containing each master's state."""
|
||||
return self.execute_command("SENTINEL MASTERS")
|
||||
|
||||
def sentinel_monitor(self, name, ip, port, quorum):
|
||||
@@ -50,27 +31,16 @@ class SentinelCommands:
|
||||
"""Remove a master from Sentinel's monitoring"""
|
||||
return self.execute_command("SENTINEL REMOVE", name)
|
||||
|
||||
def sentinel_sentinels(self, service_name, return_responses=False):
|
||||
"""
|
||||
Returns a list of sentinels for ``service_name``, when return_responses is True,
|
||||
otherwise returns a boolean value that indicates if the command was successful.
|
||||
"""
|
||||
return self.execute_command(
|
||||
"SENTINEL SENTINELS", service_name, return_responses=return_responses
|
||||
)
|
||||
def sentinel_sentinels(self, service_name):
|
||||
"""Returns a list of sentinels for ``service_name``"""
|
||||
return self.execute_command("SENTINEL SENTINELS", service_name)
|
||||
|
||||
def sentinel_set(self, name, option, value):
|
||||
"""Set Sentinel monitoring parameters for a given master"""
|
||||
return self.execute_command("SENTINEL SET", name, option, value)
|
||||
|
||||
def sentinel_slaves(self, service_name):
|
||||
"""
|
||||
Returns a list of slaves for ``service_name``
|
||||
|
||||
Important: This function is called by the Sentinel implementation and is
|
||||
called directly on the Redis standalone client for sentinels,
|
||||
so it doesn't support the "once" and "return_responses" options.
|
||||
"""
|
||||
"""Returns a list of slaves for ``service_name``"""
|
||||
return self.execute_command("SENTINEL SLAVES", service_name)
|
||||
|
||||
def sentinel_reset(self, pattern):
|
||||
|
||||
@@ -84,7 +84,7 @@ class TimeSeries(TimeSeriesCommands):
|
||||
startup_nodes=self.client.nodes_manager.startup_nodes,
|
||||
result_callbacks=self.client.result_callbacks,
|
||||
cluster_response_callbacks=self.client.cluster_response_callbacks,
|
||||
cluster_error_retry_attempts=self.client.retry.get_retries(),
|
||||
cluster_error_retry_attempts=self.client.cluster_error_retry_attempts,
|
||||
read_from_replicas=self.client.read_from_replicas,
|
||||
reinitialize_steps=self.client.reinitialize_steps,
|
||||
lock=self.client._lock,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,7 @@ class TSInfo:
|
||||
"""
|
||||
Hold information and statistics on the time-series.
|
||||
Can be created using ``tsinfo`` command
|
||||
https://redis.io/docs/latest/commands/ts.info/
|
||||
https://oss.redis.com/redistimeseries/commands/#tsinfo.
|
||||
"""
|
||||
|
||||
rules = []
|
||||
@@ -57,7 +57,7 @@ class TSInfo:
|
||||
Policy that will define handling of duplicate samples.
|
||||
|
||||
Can read more about on
|
||||
https://redis.io/docs/latest/develop/data-types/timeseries/configuration/#duplicate_policy
|
||||
https://oss.redis.com/redistimeseries/configuration/#duplicate_policy
|
||||
"""
|
||||
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
|
||||
self.rules = response.get("rules")
|
||||
@@ -78,7 +78,7 @@ class TSInfo:
|
||||
self.chunk_size = response["chunkSize"]
|
||||
if "duplicatePolicy" in response:
|
||||
self.duplicate_policy = response["duplicatePolicy"]
|
||||
if isinstance(self.duplicate_policy, bytes):
|
||||
if type(self.duplicate_policy) == bytes:
|
||||
self.duplicate_policy = self.duplicate_policy.decode()
|
||||
|
||||
def get(self, item):
|
||||
|
||||
@@ -5,7 +5,7 @@ def list_to_dict(aList):
|
||||
return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))}
|
||||
|
||||
|
||||
def parse_range(response, **kwargs):
|
||||
def parse_range(response):
|
||||
"""Parse range response. Used by TS.RANGE and TS.REVRANGE."""
|
||||
return [tuple((r[0], float(r[1]))) for r in response]
|
||||
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import json
|
||||
|
||||
from redis._parsers.helpers import pairs_to_dict
|
||||
from redis.commands.vectorset.utils import (
|
||||
parse_vemb_result,
|
||||
parse_vlinks_result,
|
||||
parse_vsim_result,
|
||||
)
|
||||
|
||||
from ..helpers import get_protocol_version
|
||||
from .commands import (
|
||||
VEMB_CMD,
|
||||
VGETATTR_CMD,
|
||||
VINFO_CMD,
|
||||
VLINKS_CMD,
|
||||
VSIM_CMD,
|
||||
VectorSetCommands,
|
||||
)
|
||||
|
||||
|
||||
class VectorSet(VectorSetCommands):
|
||||
def __init__(self, client, **kwargs):
|
||||
"""Create a new VectorSet client."""
|
||||
# Set the module commands' callbacks
|
||||
self._MODULE_CALLBACKS = {
|
||||
VEMB_CMD: parse_vemb_result,
|
||||
VGETATTR_CMD: lambda r: r and json.loads(r) or None,
|
||||
}
|
||||
|
||||
self._RESP2_MODULE_CALLBACKS = {
|
||||
VINFO_CMD: lambda r: r and pairs_to_dict(r) or None,
|
||||
VSIM_CMD: parse_vsim_result,
|
||||
VLINKS_CMD: parse_vlinks_result,
|
||||
}
|
||||
self._RESP3_MODULE_CALLBACKS = {}
|
||||
|
||||
self.client = client
|
||||
self.execute_command = client.execute_command
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
self._MODULE_CALLBACKS.update(self._RESP3_MODULE_CALLBACKS)
|
||||
else:
|
||||
self._MODULE_CALLBACKS.update(self._RESP2_MODULE_CALLBACKS)
|
||||
|
||||
for k, v in self._MODULE_CALLBACKS.items():
|
||||
self.client.set_response_callback(k, v)
|
||||
@@ -1,374 +0,0 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Dict, List, Optional, Union
|
||||
|
||||
from redis.client import NEVER_DECODE
|
||||
from redis.commands.helpers import get_protocol_version
|
||||
from redis.exceptions import DataError
|
||||
from redis.typing import CommandsProtocol, EncodableT, KeyT, Number
|
||||
|
||||
VADD_CMD = "VADD"
|
||||
VSIM_CMD = "VSIM"
|
||||
VREM_CMD = "VREM"
|
||||
VDIM_CMD = "VDIM"
|
||||
VCARD_CMD = "VCARD"
|
||||
VEMB_CMD = "VEMB"
|
||||
VLINKS_CMD = "VLINKS"
|
||||
VINFO_CMD = "VINFO"
|
||||
VSETATTR_CMD = "VSETATTR"
|
||||
VGETATTR_CMD = "VGETATTR"
|
||||
VRANDMEMBER_CMD = "VRANDMEMBER"
|
||||
|
||||
|
||||
class QuantizationOptions(Enum):
|
||||
"""Quantization options for the VADD command."""
|
||||
|
||||
NOQUANT = "NOQUANT"
|
||||
BIN = "BIN"
|
||||
Q8 = "Q8"
|
||||
|
||||
|
||||
class CallbacksOptions(Enum):
|
||||
"""Options that can be set for the commands callbacks"""
|
||||
|
||||
RAW = "RAW"
|
||||
WITHSCORES = "WITHSCORES"
|
||||
ALLOW_DECODING = "ALLOW_DECODING"
|
||||
RESP3 = "RESP3"
|
||||
|
||||
|
||||
class VectorSetCommands(CommandsProtocol):
|
||||
"""Redis VectorSet commands"""
|
||||
|
||||
def vadd(
|
||||
self,
|
||||
key: KeyT,
|
||||
vector: Union[List[float], bytes],
|
||||
element: str,
|
||||
reduce_dim: Optional[int] = None,
|
||||
cas: Optional[bool] = False,
|
||||
quantization: Optional[QuantizationOptions] = None,
|
||||
ef: Optional[Number] = None,
|
||||
attributes: Optional[Union[dict, str]] = None,
|
||||
numlinks: Optional[int] = None,
|
||||
) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Add vector ``vector`` for element ``element`` to a vector set ``key``.
|
||||
|
||||
``reduce_dim`` sets the dimensions to reduce the vector to.
|
||||
If not provided, the vector is not reduced.
|
||||
|
||||
``cas`` is a boolean flag that indicates whether to use CAS (check-and-set style)
|
||||
when adding the vector. If not provided, CAS is not used.
|
||||
|
||||
``quantization`` sets the quantization type to use.
|
||||
If not provided, int8 quantization is used.
|
||||
The options are:
|
||||
- NOQUANT: No quantization
|
||||
- BIN: Binary quantization
|
||||
- Q8: Signed 8-bit quantization
|
||||
|
||||
``ef`` sets the exploration factor to use.
|
||||
If not provided, the default exploration factor is used.
|
||||
|
||||
``attributes`` is a dictionary or json string that contains the attributes to set for the vector.
|
||||
If not provided, no attributes are set.
|
||||
|
||||
``numlinks`` sets the number of links to create for the vector.
|
||||
If not provided, the default number of links is used.
|
||||
|
||||
For more information see https://redis.io/commands/vadd
|
||||
"""
|
||||
if not vector or not element:
|
||||
raise DataError("Both vector and element must be provided")
|
||||
|
||||
pieces = []
|
||||
if reduce_dim:
|
||||
pieces.extend(["REDUCE", reduce_dim])
|
||||
|
||||
values_pieces = []
|
||||
if isinstance(vector, bytes):
|
||||
values_pieces.extend(["FP32", vector])
|
||||
else:
|
||||
values_pieces.extend(["VALUES", len(vector)])
|
||||
values_pieces.extend(vector)
|
||||
pieces.extend(values_pieces)
|
||||
|
||||
pieces.append(element)
|
||||
|
||||
if cas:
|
||||
pieces.append("CAS")
|
||||
|
||||
if quantization:
|
||||
pieces.append(quantization.value)
|
||||
|
||||
if ef:
|
||||
pieces.extend(["EF", ef])
|
||||
|
||||
if attributes:
|
||||
if isinstance(attributes, dict):
|
||||
# transform attributes to json string
|
||||
attributes_json = json.dumps(attributes)
|
||||
else:
|
||||
attributes_json = attributes
|
||||
pieces.extend(["SETATTR", attributes_json])
|
||||
|
||||
if numlinks:
|
||||
pieces.extend(["M", numlinks])
|
||||
|
||||
return self.execute_command(VADD_CMD, key, *pieces)
|
||||
|
||||
def vsim(
|
||||
self,
|
||||
key: KeyT,
|
||||
input: Union[List[float], bytes, str],
|
||||
with_scores: Optional[bool] = False,
|
||||
count: Optional[int] = None,
|
||||
ef: Optional[Number] = None,
|
||||
filter: Optional[str] = None,
|
||||
filter_ef: Optional[str] = None,
|
||||
truth: Optional[bool] = False,
|
||||
no_thread: Optional[bool] = False,
|
||||
epsilon: Optional[Number] = None,
|
||||
) -> Union[
|
||||
Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]],
|
||||
Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]],
|
||||
]:
|
||||
"""
|
||||
Compare a vector or element ``input`` with the other vectors in a vector set ``key``.
|
||||
|
||||
``with_scores`` sets if the results should be returned with the
|
||||
similarity scores of the elements in the result.
|
||||
|
||||
``count`` sets the number of results to return.
|
||||
|
||||
``ef`` sets the exploration factor.
|
||||
|
||||
``filter`` sets filter that should be applied for the search.
|
||||
|
||||
``filter_ef`` sets the max filtering effort.
|
||||
|
||||
``truth`` when enabled forces the command to perform linear scan.
|
||||
|
||||
``no_thread`` when enabled forces the command to execute the search
|
||||
on the data structure in the main thread.
|
||||
|
||||
``epsilon`` floating point between 0 and 1, if specified will return
|
||||
only elements with distance no further than the specified one.
|
||||
|
||||
For more information see https://redis.io/commands/vsim
|
||||
"""
|
||||
|
||||
if not input:
|
||||
raise DataError("'input' should be provided")
|
||||
|
||||
pieces = []
|
||||
options = {}
|
||||
|
||||
if isinstance(input, bytes):
|
||||
pieces.extend(["FP32", input])
|
||||
elif isinstance(input, list):
|
||||
pieces.extend(["VALUES", len(input)])
|
||||
pieces.extend(input)
|
||||
else:
|
||||
pieces.extend(["ELE", input])
|
||||
|
||||
if with_scores:
|
||||
pieces.append("WITHSCORES")
|
||||
options[CallbacksOptions.WITHSCORES.value] = True
|
||||
|
||||
if count:
|
||||
pieces.extend(["COUNT", count])
|
||||
|
||||
if epsilon:
|
||||
pieces.extend(["EPSILON", epsilon])
|
||||
|
||||
if ef:
|
||||
pieces.extend(["EF", ef])
|
||||
|
||||
if filter:
|
||||
pieces.extend(["FILTER", filter])
|
||||
|
||||
if filter_ef:
|
||||
pieces.extend(["FILTER-EF", filter_ef])
|
||||
|
||||
if truth:
|
||||
pieces.append("TRUTH")
|
||||
|
||||
if no_thread:
|
||||
pieces.append("NOTHREAD")
|
||||
|
||||
return self.execute_command(VSIM_CMD, key, *pieces, **options)
|
||||
|
||||
def vdim(self, key: KeyT) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Get the dimension of a vector set.
|
||||
|
||||
In the case of vectors that were populated using the `REDUCE`
|
||||
option, for random projection, the vector set will report the size of
|
||||
the projected (reduced) dimension.
|
||||
|
||||
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
|
||||
|
||||
For more information see https://redis.io/commands/vdim
|
||||
"""
|
||||
return self.execute_command(VDIM_CMD, key)
|
||||
|
||||
def vcard(self, key: KeyT) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Get the cardinality(the number of elements) of a vector set with key ``key``.
|
||||
|
||||
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
|
||||
|
||||
For more information see https://redis.io/commands/vcard
|
||||
"""
|
||||
return self.execute_command(VCARD_CMD, key)
|
||||
|
||||
def vrem(self, key: KeyT, element: str) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Remove an element from a vector set.
|
||||
|
||||
For more information see https://redis.io/commands/vrem
|
||||
"""
|
||||
return self.execute_command(VREM_CMD, key, element)
|
||||
|
||||
def vemb(
|
||||
self, key: KeyT, element: str, raw: Optional[bool] = False
|
||||
) -> Union[
|
||||
Awaitable[Optional[Union[List[EncodableT], Dict[str, EncodableT]]]],
|
||||
Optional[Union[List[EncodableT], Dict[str, EncodableT]]],
|
||||
]:
|
||||
"""
|
||||
Get the approximated vector of an element ``element`` from vector set ``key``.
|
||||
|
||||
``raw`` is a boolean flag that indicates whether to return the
|
||||
interal representation used by the vector.
|
||||
|
||||
|
||||
For more information see https://redis.io/commands/vembed
|
||||
"""
|
||||
options = {}
|
||||
pieces = []
|
||||
pieces.extend([key, element])
|
||||
|
||||
if get_protocol_version(self.client) in ["3", 3]:
|
||||
options[CallbacksOptions.RESP3.value] = True
|
||||
|
||||
if raw:
|
||||
pieces.append("RAW")
|
||||
|
||||
options[NEVER_DECODE] = True
|
||||
if (
|
||||
hasattr(self.client, "connection_pool")
|
||||
and self.client.connection_pool.connection_kwargs["decode_responses"]
|
||||
) or (
|
||||
hasattr(self.client, "nodes_manager")
|
||||
and self.client.nodes_manager.connection_kwargs["decode_responses"]
|
||||
):
|
||||
# allow decoding in the postprocessing callback
|
||||
# if the user set decode_responses=True
|
||||
# in the connection pool
|
||||
options[CallbacksOptions.ALLOW_DECODING.value] = True
|
||||
|
||||
options[CallbacksOptions.RAW.value] = True
|
||||
|
||||
return self.execute_command(VEMB_CMD, *pieces, **options)
|
||||
|
||||
def vlinks(
|
||||
self, key: KeyT, element: str, with_scores: Optional[bool] = False
|
||||
) -> Union[
|
||||
Awaitable[
|
||||
Optional[
|
||||
List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]
|
||||
]
|
||||
],
|
||||
Optional[List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]],
|
||||
]:
|
||||
"""
|
||||
Returns the neighbors for each level the element ``element`` exists in the vector set ``key``.
|
||||
|
||||
The result is a list of lists, where each list contains the neighbors for one level.
|
||||
If the element does not exist, or if the vector set does not exist, None is returned.
|
||||
|
||||
If the ``WITHSCORES`` option is provided, the result is a list of dicts,
|
||||
where each dict contains the neighbors for one level, with the scores as values.
|
||||
|
||||
For more information see https://redis.io/commands/vlinks
|
||||
"""
|
||||
options = {}
|
||||
pieces = []
|
||||
pieces.extend([key, element])
|
||||
|
||||
if with_scores:
|
||||
pieces.append("WITHSCORES")
|
||||
options[CallbacksOptions.WITHSCORES.value] = True
|
||||
|
||||
return self.execute_command(VLINKS_CMD, *pieces, **options)
|
||||
|
||||
def vinfo(self, key: KeyT) -> Union[Awaitable[dict], dict]:
|
||||
"""
|
||||
Get information about a vector set.
|
||||
|
||||
For more information see https://redis.io/commands/vinfo
|
||||
"""
|
||||
return self.execute_command(VINFO_CMD, key)
|
||||
|
||||
def vsetattr(
|
||||
self, key: KeyT, element: str, attributes: Optional[Union[dict, str]] = None
|
||||
) -> Union[Awaitable[int], int]:
|
||||
"""
|
||||
Associate or remove JSON attributes ``attributes`` of element ``element``
|
||||
for vector set ``key``.
|
||||
|
||||
For more information see https://redis.io/commands/vsetattr
|
||||
"""
|
||||
if attributes is None:
|
||||
attributes_json = "{}"
|
||||
elif isinstance(attributes, dict):
|
||||
# transform attributes to json string
|
||||
attributes_json = json.dumps(attributes)
|
||||
else:
|
||||
attributes_json = attributes
|
||||
|
||||
return self.execute_command(VSETATTR_CMD, key, element, attributes_json)
|
||||
|
||||
def vgetattr(
|
||||
self, key: KeyT, element: str
|
||||
) -> Union[Optional[Awaitable[dict]], Optional[dict]]:
|
||||
"""
|
||||
Retrieve the JSON attributes of an element ``elemet`` for vector set ``key``.
|
||||
|
||||
If the element does not exist, or if the vector set does not exist, None is
|
||||
returned.
|
||||
|
||||
For more information see https://redis.io/commands/vgetattr
|
||||
"""
|
||||
return self.execute_command(VGETATTR_CMD, key, element)
|
||||
|
||||
def vrandmember(
|
||||
self, key: KeyT, count: Optional[int] = None
|
||||
) -> Union[
|
||||
Awaitable[Optional[Union[List[str], str]]], Optional[Union[List[str], str]]
|
||||
]:
|
||||
"""
|
||||
Returns random elements from a vector set ``key``.
|
||||
|
||||
``count`` is the number of elements to return.
|
||||
If ``count`` is not provided, a single element is returned as a single string.
|
||||
If ``count`` is positive(smaller than the number of elements
|
||||
in the vector set), the command returns a list with up to ``count``
|
||||
distinct elements from the vector set
|
||||
If ``count`` is negative, the command returns a list with ``count`` random elements,
|
||||
potentially with duplicates.
|
||||
If ``count`` is greater than the number of elements in the vector set,
|
||||
only the entire set is returned as a list.
|
||||
|
||||
If the vector set does not exist, ``None`` is returned.
|
||||
|
||||
For more information see https://redis.io/commands/vrandmember
|
||||
"""
|
||||
pieces = []
|
||||
pieces.append(key)
|
||||
if count is not None:
|
||||
pieces.append(count)
|
||||
return self.execute_command(VRANDMEMBER_CMD, *pieces)
|
||||
@@ -1,94 +0,0 @@
|
||||
from redis._parsers.helpers import pairs_to_dict
|
||||
from redis.commands.vectorset.commands import CallbacksOptions
|
||||
|
||||
|
||||
def parse_vemb_result(response, **options):
|
||||
"""
|
||||
Handle VEMB result since the command can returning different result
|
||||
structures depending on input options and on quantization type of the vector set.
|
||||
|
||||
Parsing VEMB result into:
|
||||
- List[Union[bytes, Union[int, float]]]
|
||||
- Dict[str, Union[bytes, str, float]]
|
||||
"""
|
||||
if response is None:
|
||||
return response
|
||||
|
||||
if options.get(CallbacksOptions.RAW.value):
|
||||
result = {}
|
||||
result["quantization"] = (
|
||||
response[0].decode("utf-8")
|
||||
if options.get(CallbacksOptions.ALLOW_DECODING.value)
|
||||
else response[0]
|
||||
)
|
||||
result["raw"] = response[1]
|
||||
result["l2"] = float(response[2])
|
||||
if len(response) > 3:
|
||||
result["range"] = float(response[3])
|
||||
return result
|
||||
else:
|
||||
if options.get(CallbacksOptions.RESP3.value):
|
||||
return response
|
||||
|
||||
result = []
|
||||
for i in range(len(response)):
|
||||
try:
|
||||
result.append(int(response[i]))
|
||||
except ValueError:
|
||||
# if the value is not an integer, it should be a float
|
||||
result.append(float(response[i]))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_vlinks_result(response, **options):
|
||||
"""
|
||||
Handle VLINKS result since the command can be returning different result
|
||||
structures depending on input options.
|
||||
Parsing VLINKS result into:
|
||||
- List[List[str]]
|
||||
- List[Dict[str, Number]]
|
||||
"""
|
||||
if response is None:
|
||||
return response
|
||||
|
||||
if options.get(CallbacksOptions.WITHSCORES.value):
|
||||
result = []
|
||||
# Redis will return a list of list of strings.
|
||||
# This list have to be transformed to list of dicts
|
||||
for level_item in response:
|
||||
level_data_dict = {}
|
||||
for key, value in pairs_to_dict(level_item).items():
|
||||
value = float(value)
|
||||
level_data_dict[key] = value
|
||||
result.append(level_data_dict)
|
||||
return result
|
||||
else:
|
||||
# return the list of elements for each level
|
||||
# list of lists
|
||||
return response
|
||||
|
||||
|
||||
def parse_vsim_result(response, **options):
|
||||
"""
|
||||
Handle VSIM result since the command can be returning different result
|
||||
structures depending on input options.
|
||||
Parsing VSIM result into:
|
||||
- List[List[str]]
|
||||
- List[Dict[str, Number]]
|
||||
"""
|
||||
if response is None:
|
||||
return response
|
||||
|
||||
if options.get(CallbacksOptions.WITHSCORES.value):
|
||||
# Redis will return a list of list of pairs.
|
||||
# This list have to be transformed to dict
|
||||
result_dict = {}
|
||||
for key, value in pairs_to_dict(response).items():
|
||||
value = float(value)
|
||||
result_dict[key] = value
|
||||
return result_dict
|
||||
else:
|
||||
# return the list of elements for each level
|
||||
# list of lists
|
||||
return response
|
||||
6
venv/lib/python3.12/site-packages/redis/compat.py
Normal file
6
venv/lib/python3.12/site-packages/redis/compat.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# flake8: noqa
|
||||
try:
|
||||
from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import]
|
||||
except ImportError:
|
||||
from typing_extensions import Literal # lgtm [py/unused-import]
|
||||
from typing_extensions import Protocol, TypedDict
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,4 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
class CredentialProvider:
|
||||
@@ -13,38 +9,6 @@ class CredentialProvider:
|
||||
def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
|
||||
raise NotImplementedError("get_credentials must be implemented")
|
||||
|
||||
async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
|
||||
logger.warning(
|
||||
"This method is added for backward compatability. "
|
||||
"Please override it in your implementation."
|
||||
)
|
||||
return self.get_credentials()
|
||||
|
||||
|
||||
class StreamingCredentialProvider(CredentialProvider, ABC):
|
||||
"""
|
||||
Credential provider that streams credentials in the background.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def on_next(self, callback: Callable[[Any], None]):
|
||||
"""
|
||||
Specifies the callback that should be invoked
|
||||
when the next credentials will be retrieved.
|
||||
|
||||
:param callback: Callback with
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_error(self, callback: Callable[[Exception], None]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_streaming(self) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class UsernamePasswordCredentialProvider(CredentialProvider):
|
||||
"""
|
||||
@@ -60,6 +24,3 @@ class UsernamePasswordCredentialProvider(CredentialProvider):
|
||||
if self.username:
|
||||
return self.username, self.password
|
||||
return (self.password,)
|
||||
|
||||
async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]:
|
||||
return self.get_credentials()
|
||||
|
||||
@@ -1,394 +0,0 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from redis.auth.token import TokenInterface
|
||||
from redis.credentials import CredentialProvider, StreamingCredentialProvider
|
||||
|
||||
|
||||
class EventListenerInterface(ABC):
|
||||
"""
|
||||
Represents a listener for given event object.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def listen(self, event: object):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncEventListenerInterface(ABC):
|
||||
"""
|
||||
Represents an async listener for given event object.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def listen(self, event: object):
|
||||
pass
|
||||
|
||||
|
||||
class EventDispatcherInterface(ABC):
|
||||
"""
|
||||
Represents a dispatcher that dispatches events to listeners
|
||||
associated with given event.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dispatch(self, event: object):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def dispatch_async(self, event: object):
|
||||
pass
|
||||
|
||||
|
||||
class EventException(Exception):
|
||||
"""
|
||||
Exception wrapper that adds an event object into exception context.
|
||||
"""
|
||||
|
||||
def __init__(self, exception: Exception, event: object):
|
||||
self.exception = exception
|
||||
self.event = event
|
||||
super().__init__(exception)
|
||||
|
||||
|
||||
class EventDispatcher(EventDispatcherInterface):
|
||||
# TODO: Make dispatcher to accept external mappings.
|
||||
def __init__(self):
|
||||
"""
|
||||
Mapping should be extended for any new events or listeners to be added.
|
||||
"""
|
||||
self._event_listeners_mapping = {
|
||||
AfterConnectionReleasedEvent: [
|
||||
ReAuthConnectionListener(),
|
||||
],
|
||||
AfterPooledConnectionsInstantiationEvent: [
|
||||
RegisterReAuthForPooledConnections()
|
||||
],
|
||||
AfterSingleConnectionInstantiationEvent: [
|
||||
RegisterReAuthForSingleConnection()
|
||||
],
|
||||
AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()],
|
||||
AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()],
|
||||
AsyncAfterConnectionReleasedEvent: [
|
||||
AsyncReAuthConnectionListener(),
|
||||
],
|
||||
}
|
||||
|
||||
def dispatch(self, event: object):
|
||||
listeners = self._event_listeners_mapping.get(type(event))
|
||||
|
||||
for listener in listeners:
|
||||
listener.listen(event)
|
||||
|
||||
async def dispatch_async(self, event: object):
|
||||
listeners = self._event_listeners_mapping.get(type(event))
|
||||
|
||||
for listener in listeners:
|
||||
await listener.listen(event)
|
||||
|
||||
|
||||
class AfterConnectionReleasedEvent:
|
||||
"""
|
||||
Event that will be fired before each command execution.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
self._connection = connection
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return self._connection
|
||||
|
||||
|
||||
class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ClientType(Enum):
|
||||
SYNC = ("sync",)
|
||||
ASYNC = ("async",)
|
||||
|
||||
|
||||
class AfterPooledConnectionsInstantiationEvent:
|
||||
"""
|
||||
Event that will be fired after pooled connection instances was created.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_pools: List,
|
||||
client_type: ClientType,
|
||||
credential_provider: Optional[CredentialProvider] = None,
|
||||
):
|
||||
self._connection_pools = connection_pools
|
||||
self._client_type = client_type
|
||||
self._credential_provider = credential_provider
|
||||
|
||||
@property
|
||||
def connection_pools(self):
|
||||
return self._connection_pools
|
||||
|
||||
@property
|
||||
def client_type(self) -> ClientType:
|
||||
return self._client_type
|
||||
|
||||
@property
|
||||
def credential_provider(self) -> Union[CredentialProvider, None]:
|
||||
return self._credential_provider
|
||||
|
||||
|
||||
class AfterSingleConnectionInstantiationEvent:
|
||||
"""
|
||||
Event that will be fired after single connection instances was created.
|
||||
|
||||
:param connection_lock: For sync client thread-lock should be provided,
|
||||
for async asyncio.Lock
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection,
|
||||
client_type: ClientType,
|
||||
connection_lock: Union[threading.RLock, asyncio.Lock],
|
||||
):
|
||||
self._connection = connection
|
||||
self._client_type = client_type
|
||||
self._connection_lock = connection_lock
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def client_type(self) -> ClientType:
|
||||
return self._client_type
|
||||
|
||||
@property
|
||||
def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
|
||||
return self._connection_lock
|
||||
|
||||
|
||||
class AfterPubSubConnectionInstantiationEvent:
|
||||
def __init__(
|
||||
self,
|
||||
pubsub_connection,
|
||||
connection_pool,
|
||||
client_type: ClientType,
|
||||
connection_lock: Union[threading.RLock, asyncio.Lock],
|
||||
):
|
||||
self._pubsub_connection = pubsub_connection
|
||||
self._connection_pool = connection_pool
|
||||
self._client_type = client_type
|
||||
self._connection_lock = connection_lock
|
||||
|
||||
@property
|
||||
def pubsub_connection(self):
|
||||
return self._pubsub_connection
|
||||
|
||||
@property
|
||||
def connection_pool(self):
|
||||
return self._connection_pool
|
||||
|
||||
@property
|
||||
def client_type(self) -> ClientType:
|
||||
return self._client_type
|
||||
|
||||
@property
|
||||
def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]:
|
||||
return self._connection_lock
|
||||
|
||||
|
||||
class AfterAsyncClusterInstantiationEvent:
|
||||
"""
|
||||
Event that will be fired after async cluster instance was created.
|
||||
|
||||
Async cluster doesn't use connection pools,
|
||||
instead ClusterNode object manages connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: dict,
|
||||
credential_provider: Optional[CredentialProvider] = None,
|
||||
):
|
||||
self._nodes = nodes
|
||||
self._credential_provider = credential_provider
|
||||
|
||||
@property
|
||||
def nodes(self) -> dict:
|
||||
return self._nodes
|
||||
|
||||
@property
|
||||
def credential_provider(self) -> Union[CredentialProvider, None]:
|
||||
return self._credential_provider
|
||||
|
||||
|
||||
class ReAuthConnectionListener(EventListenerInterface):
|
||||
"""
|
||||
Listener that performs re-authentication of given connection.
|
||||
"""
|
||||
|
||||
def listen(self, event: AfterConnectionReleasedEvent):
|
||||
event.connection.re_auth()
|
||||
|
||||
|
||||
class AsyncReAuthConnectionListener(AsyncEventListenerInterface):
|
||||
"""
|
||||
Async listener that performs re-authentication of given connection.
|
||||
"""
|
||||
|
||||
async def listen(self, event: AsyncAfterConnectionReleasedEvent):
|
||||
await event.connection.re_auth()
|
||||
|
||||
|
||||
class RegisterReAuthForPooledConnections(EventListenerInterface):
|
||||
"""
|
||||
Listener that registers a re-authentication callback for pooled connections.
|
||||
Required by :class:`StreamingCredentialProvider`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._event = None
|
||||
|
||||
def listen(self, event: AfterPooledConnectionsInstantiationEvent):
|
||||
if isinstance(event.credential_provider, StreamingCredentialProvider):
|
||||
self._event = event
|
||||
|
||||
if event.client_type == ClientType.SYNC:
|
||||
event.credential_provider.on_next(self._re_auth)
|
||||
event.credential_provider.on_error(self._raise_on_error)
|
||||
else:
|
||||
event.credential_provider.on_next(self._re_auth_async)
|
||||
event.credential_provider.on_error(self._raise_on_error_async)
|
||||
|
||||
def _re_auth(self, token):
|
||||
for pool in self._event.connection_pools:
|
||||
pool.re_auth_callback(token)
|
||||
|
||||
async def _re_auth_async(self, token):
|
||||
for pool in self._event.connection_pools:
|
||||
await pool.re_auth_callback(token)
|
||||
|
||||
def _raise_on_error(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
async def _raise_on_error_async(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
|
||||
class RegisterReAuthForSingleConnection(EventListenerInterface):
|
||||
"""
|
||||
Listener that registers a re-authentication callback for single connection.
|
||||
Required by :class:`StreamingCredentialProvider`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._event = None
|
||||
|
||||
def listen(self, event: AfterSingleConnectionInstantiationEvent):
|
||||
if isinstance(
|
||||
event.connection.credential_provider, StreamingCredentialProvider
|
||||
):
|
||||
self._event = event
|
||||
|
||||
if event.client_type == ClientType.SYNC:
|
||||
event.connection.credential_provider.on_next(self._re_auth)
|
||||
event.connection.credential_provider.on_error(self._raise_on_error)
|
||||
else:
|
||||
event.connection.credential_provider.on_next(self._re_auth_async)
|
||||
event.connection.credential_provider.on_error(
|
||||
self._raise_on_error_async
|
||||
)
|
||||
|
||||
def _re_auth(self, token):
|
||||
with self._event.connection_lock:
|
||||
self._event.connection.send_command(
|
||||
"AUTH", token.try_get("oid"), token.get_value()
|
||||
)
|
||||
self._event.connection.read_response()
|
||||
|
||||
async def _re_auth_async(self, token):
|
||||
async with self._event.connection_lock:
|
||||
await self._event.connection.send_command(
|
||||
"AUTH", token.try_get("oid"), token.get_value()
|
||||
)
|
||||
await self._event.connection.read_response()
|
||||
|
||||
def _raise_on_error(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
async def _raise_on_error_async(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
|
||||
class RegisterReAuthForAsyncClusterNodes(EventListenerInterface):
|
||||
def __init__(self):
|
||||
self._event = None
|
||||
|
||||
def listen(self, event: AfterAsyncClusterInstantiationEvent):
|
||||
if isinstance(event.credential_provider, StreamingCredentialProvider):
|
||||
self._event = event
|
||||
event.credential_provider.on_next(self._re_auth)
|
||||
event.credential_provider.on_error(self._raise_on_error)
|
||||
|
||||
async def _re_auth(self, token: TokenInterface):
|
||||
for key in self._event.nodes:
|
||||
await self._event.nodes[key].re_auth_callback(token)
|
||||
|
||||
async def _raise_on_error(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
|
||||
class RegisterReAuthForPubSub(EventListenerInterface):
|
||||
def __init__(self):
|
||||
self._connection = None
|
||||
self._connection_pool = None
|
||||
self._client_type = None
|
||||
self._connection_lock = None
|
||||
self._event = None
|
||||
|
||||
def listen(self, event: AfterPubSubConnectionInstantiationEvent):
|
||||
if isinstance(
|
||||
event.pubsub_connection.credential_provider, StreamingCredentialProvider
|
||||
) and event.pubsub_connection.get_protocol() in [3, "3"]:
|
||||
self._event = event
|
||||
self._connection = event.pubsub_connection
|
||||
self._connection_pool = event.connection_pool
|
||||
self._client_type = event.client_type
|
||||
self._connection_lock = event.connection_lock
|
||||
|
||||
if self._client_type == ClientType.SYNC:
|
||||
self._connection.credential_provider.on_next(self._re_auth)
|
||||
self._connection.credential_provider.on_error(self._raise_on_error)
|
||||
else:
|
||||
self._connection.credential_provider.on_next(self._re_auth_async)
|
||||
self._connection.credential_provider.on_error(
|
||||
self._raise_on_error_async
|
||||
)
|
||||
|
||||
def _re_auth(self, token: TokenInterface):
|
||||
with self._connection_lock:
|
||||
self._connection.send_command(
|
||||
"AUTH", token.try_get("oid"), token.get_value()
|
||||
)
|
||||
self._connection.read_response()
|
||||
|
||||
self._connection_pool.re_auth_callback(token)
|
||||
|
||||
async def _re_auth_async(self, token: TokenInterface):
|
||||
async with self._connection_lock:
|
||||
await self._connection.send_command(
|
||||
"AUTH", token.try_get("oid"), token.get_value()
|
||||
)
|
||||
await self._connection.read_response()
|
||||
|
||||
await self._connection_pool.re_auth_callback(token)
|
||||
|
||||
def _raise_on_error(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
|
||||
async def _raise_on_error_async(self, error: Exception):
|
||||
raise EventException(error, self._event)
|
||||
@@ -79,24 +79,18 @@ class ModuleError(ResponseError):
|
||||
|
||||
class LockError(RedisError, ValueError):
|
||||
"Errors acquiring or releasing a lock"
|
||||
|
||||
# NOTE: For backwards compatibility, this class derives from ValueError.
|
||||
# This was originally chosen to behave like threading.Lock.
|
||||
|
||||
def __init__(self, message=None, lock_name=None):
|
||||
self.message = message
|
||||
self.lock_name = lock_name
|
||||
pass
|
||||
|
||||
|
||||
class LockNotOwnedError(LockError):
|
||||
"Error trying to extend or release a lock that is not owned (anymore)"
|
||||
|
||||
"Error trying to extend or release a lock that is (no longer) owned"
|
||||
pass
|
||||
|
||||
|
||||
class ChildDeadlockedError(Exception):
|
||||
"Error indicating that a child process is deadlocked after a fork()"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -221,27 +215,4 @@ class SlotNotCoveredError(RedisClusterException):
|
||||
|
||||
|
||||
class MaxConnectionsError(ConnectionError):
|
||||
"""
|
||||
Raised when a connection pool has reached its max_connections limit.
|
||||
This indicates pool exhaustion rather than an actual connection failure.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CrossSlotTransactionError(RedisClusterException):
|
||||
"""
|
||||
Raised when a transaction or watch is triggered in a pipeline
|
||||
and not all keys or all commands belong to the same slot.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidPipelineStack(RedisClusterException):
|
||||
"""
|
||||
Raised on unexpected response length on pipelines. This is
|
||||
most likely a handling error on the stack.
|
||||
"""
|
||||
|
||||
pass
|
||||
...
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import threading
|
||||
import time as mod_time
|
||||
import uuid
|
||||
@@ -8,8 +7,6 @@ from typing import Optional, Type
|
||||
from redis.exceptions import LockError, LockNotOwnedError
|
||||
from redis.typing import Number
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Lock:
|
||||
"""
|
||||
@@ -85,7 +82,6 @@ class Lock:
|
||||
blocking: bool = True,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
thread_local: bool = True,
|
||||
raise_on_release_error: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a new Lock instance named ``name`` using the Redis client
|
||||
@@ -129,11 +125,6 @@ class Lock:
|
||||
thread-1 would see the token value as "xyz" and would be
|
||||
able to successfully release the thread-2's lock.
|
||||
|
||||
``raise_on_release_error`` indicates whether to raise an exception when
|
||||
the lock is no longer owned when exiting the context manager. By default,
|
||||
this is True, meaning an exception will be raised. If False, the warning
|
||||
will be logged and the exception will be suppressed.
|
||||
|
||||
In some use cases it's necessary to disable thread local storage. For
|
||||
example, if you have code where one thread acquires a lock and passes
|
||||
that lock instance to a worker thread to release later. If thread
|
||||
@@ -149,7 +140,6 @@ class Lock:
|
||||
self.blocking = blocking
|
||||
self.blocking_timeout = blocking_timeout
|
||||
self.thread_local = bool(thread_local)
|
||||
self.raise_on_release_error = raise_on_release_error
|
||||
self.local = threading.local() if self.thread_local else SimpleNamespace()
|
||||
self.local.token = None
|
||||
self.register_scripts()
|
||||
@@ -167,10 +157,7 @@ class Lock:
|
||||
def __enter__(self) -> "Lock":
|
||||
if self.acquire():
|
||||
return self
|
||||
raise LockError(
|
||||
"Unable to acquire lock within the time specified",
|
||||
lock_name=self.name,
|
||||
)
|
||||
raise LockError("Unable to acquire lock within the time specified")
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
@@ -178,14 +165,7 @@ class Lock:
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
try:
|
||||
self.release()
|
||||
except LockError:
|
||||
if self.raise_on_release_error:
|
||||
raise
|
||||
logger.warning(
|
||||
"Lock was unlocked or no longer owned when exiting context manager."
|
||||
)
|
||||
self.release()
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
@@ -268,10 +248,7 @@ class Lock:
|
||||
"""
|
||||
expected_token = self.local.token
|
||||
if expected_token is None:
|
||||
raise LockError(
|
||||
"Cannot release a lock that's not owned or is already unlocked.",
|
||||
lock_name=self.name,
|
||||
)
|
||||
raise LockError("Cannot release an unlocked lock")
|
||||
self.local.token = None
|
||||
self.do_release(expected_token)
|
||||
|
||||
@@ -279,12 +256,9 @@ class Lock:
|
||||
if not bool(
|
||||
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
|
||||
):
|
||||
raise LockNotOwnedError(
|
||||
"Cannot release a lock that's no longer owned",
|
||||
lock_name=self.name,
|
||||
)
|
||||
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
|
||||
|
||||
def extend(self, additional_time: Number, replace_ttl: bool = False) -> bool:
|
||||
def extend(self, additional_time: int, replace_ttl: bool = False) -> bool:
|
||||
"""
|
||||
Adds more time to an already acquired lock.
|
||||
|
||||
@@ -296,12 +270,12 @@ class Lock:
|
||||
`additional_time`.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot extend an unlocked lock", lock_name=self.name)
|
||||
raise LockError("Cannot extend an unlocked lock")
|
||||
if self.timeout is None:
|
||||
raise LockError("Cannot extend a lock with no timeout", lock_name=self.name)
|
||||
raise LockError("Cannot extend a lock with no timeout")
|
||||
return self.do_extend(additional_time, replace_ttl)
|
||||
|
||||
def do_extend(self, additional_time: Number, replace_ttl: bool) -> bool:
|
||||
def do_extend(self, additional_time: int, replace_ttl: bool) -> bool:
|
||||
additional_time = int(additional_time * 1000)
|
||||
if not bool(
|
||||
self.lua_extend(
|
||||
@@ -310,10 +284,7 @@ class Lock:
|
||||
client=self.redis,
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError(
|
||||
"Cannot extend a lock that's no longer owned",
|
||||
lock_name=self.name,
|
||||
)
|
||||
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
|
||||
return True
|
||||
|
||||
def reacquire(self) -> bool:
|
||||
@@ -321,12 +292,9 @@ class Lock:
|
||||
Resets a TTL of an already acquired lock back to a timeout value.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot reacquire an unlocked lock", lock_name=self.name)
|
||||
raise LockError("Cannot reacquire an unlocked lock")
|
||||
if self.timeout is None:
|
||||
raise LockError(
|
||||
"Cannot reacquire a lock with no timeout",
|
||||
lock_name=self.name,
|
||||
)
|
||||
raise LockError("Cannot reacquire a lock with no timeout")
|
||||
return self.do_reacquire()
|
||||
|
||||
def do_reacquire(self) -> bool:
|
||||
@@ -336,8 +304,5 @@ class Lock:
|
||||
keys=[self.name], args=[self.local.token, timeout], client=self.redis
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError(
|
||||
"Cannot reacquire a lock that's no longer owned",
|
||||
lock_name=self.name,
|
||||
)
|
||||
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
|
||||
return True
|
||||
|
||||
@@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
from cryptography.hazmat.primitives.hashes import SHA1, Hash
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
||||
from cryptography.x509 import ocsp
|
||||
|
||||
from redis.exceptions import AuthorizationError, ConnectionError
|
||||
|
||||
|
||||
@@ -57,12 +56,12 @@ def _check_certificate(issuer_cert, ocsp_bytes, validate=True):
|
||||
if ocsp_response.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL:
|
||||
if ocsp_response.certificate_status != ocsp.OCSPCertStatus.GOOD:
|
||||
raise ConnectionError(
|
||||
f"Received an {str(ocsp_response.certificate_status).split('.')[1]} "
|
||||
f'Received an {str(ocsp_response.certificate_status).split(".")[1]} '
|
||||
"ocsp certificate status"
|
||||
)
|
||||
else:
|
||||
raise ConnectionError(
|
||||
"failed to retrieve a successful response from the ocsp responder"
|
||||
"failed to retrieve a sucessful response from the ocsp responder"
|
||||
)
|
||||
|
||||
if ocsp_response.this_update >= datetime.datetime.now():
|
||||
@@ -140,7 +139,7 @@ def _get_pubkey_hash(certificate):
|
||||
|
||||
|
||||
def ocsp_staple_verifier(con, ocsp_bytes, expected=None):
|
||||
"""An implementation of a function for set_ocsp_client_callback in PyOpenSSL.
|
||||
"""An implemention of a function for set_ocsp_client_callback in PyOpenSSL.
|
||||
|
||||
This function validates that the provide ocsp_bytes response is valid,
|
||||
and matches the expected, stapled responses.
|
||||
@@ -267,7 +266,7 @@ class OCSPVerifier:
|
||||
return url
|
||||
|
||||
def check_certificate(self, server, cert, issuer_url):
|
||||
"""Checks the validity of an ocsp server for an issuer"""
|
||||
"""Checks the validitity of an ocsp server for an issuer"""
|
||||
|
||||
r = requests.get(issuer_url)
|
||||
if not r.ok:
|
||||
|
||||
@@ -1,27 +1,17 @@
|
||||
import abc
|
||||
import socket
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Tuple, Type, TypeVar
|
||||
|
||||
from redis.exceptions import ConnectionError, TimeoutError
|
||||
|
||||
T = TypeVar("T")
|
||||
E = TypeVar("E", bound=Exception, covariant=True)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.backoff import AbstractBackoff
|
||||
|
||||
|
||||
class AbstractRetry(Generic[E], abc.ABC):
|
||||
class Retry:
|
||||
"""Retry a specific number of times after a failure"""
|
||||
|
||||
_supported_errors: Tuple[Type[E], ...]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backoff: "AbstractBackoff",
|
||||
retries: int,
|
||||
supported_errors: Tuple[Type[E], ...],
|
||||
backoff,
|
||||
retries,
|
||||
supported_errors=(ConnectionError, TimeoutError, socket.timeout),
|
||||
):
|
||||
"""
|
||||
Initialize a `Retry` object with a `Backoff` object
|
||||
@@ -34,14 +24,7 @@ class AbstractRetry(Generic[E], abc.ABC):
|
||||
self._retries = retries
|
||||
self._supported_errors = supported_errors
|
||||
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self._backoff, self._retries, frozenset(self._supported_errors)))
|
||||
|
||||
def update_supported_errors(self, specified_errors: Iterable[Type[E]]) -> None:
|
||||
def update_supported_errors(self, specified_errors: list):
|
||||
"""
|
||||
Updates the supported errors with the specified error types
|
||||
"""
|
||||
@@ -49,49 +32,7 @@ class AbstractRetry(Generic[E], abc.ABC):
|
||||
set(self._supported_errors + tuple(specified_errors))
|
||||
)
|
||||
|
||||
def get_retries(self) -> int:
|
||||
"""
|
||||
Get the number of retries.
|
||||
"""
|
||||
return self._retries
|
||||
|
||||
def update_retries(self, value: int) -> None:
|
||||
"""
|
||||
Set the number of retries.
|
||||
"""
|
||||
self._retries = value
|
||||
|
||||
|
||||
class Retry(AbstractRetry[Exception]):
|
||||
__hash__ = AbstractRetry.__hash__
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backoff: "AbstractBackoff",
|
||||
retries: int,
|
||||
supported_errors: Tuple[Type[Exception], ...] = (
|
||||
ConnectionError,
|
||||
TimeoutError,
|
||||
socket.timeout,
|
||||
),
|
||||
):
|
||||
super().__init__(backoff, retries, supported_errors)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Retry):
|
||||
return NotImplemented
|
||||
|
||||
return (
|
||||
self._backoff == other._backoff
|
||||
and self._retries == other._retries
|
||||
and set(self._supported_errors) == set(other._supported_errors)
|
||||
)
|
||||
|
||||
def call_with_retry(
|
||||
self,
|
||||
do: Callable[[], T],
|
||||
fail: Callable[[Exception], Any],
|
||||
) -> T:
|
||||
def call_with_retry(self, do, fail):
|
||||
"""
|
||||
Execute an operation that might fail and returns its result, or
|
||||
raise the exception that was thrown depending on the `Backoff` object.
|
||||
|
||||
@@ -5,12 +5,8 @@ from typing import Optional
|
||||
from redis.client import Redis
|
||||
from redis.commands import SentinelCommands
|
||||
from redis.connection import Connection, ConnectionPool, SSLConnection
|
||||
from redis.exceptions import (
|
||||
ConnectionError,
|
||||
ReadOnlyError,
|
||||
ResponseError,
|
||||
TimeoutError,
|
||||
)
|
||||
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
|
||||
from redis.utils import str_if_bytes
|
||||
|
||||
|
||||
class MasterNotFoundError(ConnectionError):
|
||||
@@ -28,10 +24,7 @@ class SentinelManagedConnection(Connection):
|
||||
|
||||
def __repr__(self):
|
||||
pool = self.connection_pool
|
||||
s = (
|
||||
f"<{type(self).__module__}.{type(self).__name__}"
|
||||
f"(service={pool.service_name}%s)>"
|
||||
)
|
||||
s = f"{type(self).__name__}<service={pool.service_name}%s>"
|
||||
if self.host:
|
||||
host_info = f",host={self.host},port={self.port}"
|
||||
s = s % host_info
|
||||
@@ -39,11 +32,11 @@ class SentinelManagedConnection(Connection):
|
||||
|
||||
def connect_to(self, address):
|
||||
self.host, self.port = address
|
||||
|
||||
self.connect_check_health(
|
||||
check_health=self.connection_pool.check_connection,
|
||||
retry_socket_connect=False,
|
||||
)
|
||||
super().connect()
|
||||
if self.connection_pool.check_connection:
|
||||
self.send_command("PING")
|
||||
if str_if_bytes(self.read_response()) != "PONG":
|
||||
raise ConnectionError("PING failed")
|
||||
|
||||
def _connect_retry(self):
|
||||
if self._sock:
|
||||
@@ -149,11 +142,9 @@ class SentinelConnectionPool(ConnectionPool):
|
||||
def __init__(self, service_name, sentinel_manager, **kwargs):
|
||||
kwargs["connection_class"] = kwargs.get(
|
||||
"connection_class",
|
||||
(
|
||||
SentinelManagedSSLConnection
|
||||
if kwargs.pop("ssl", False)
|
||||
else SentinelManagedConnection
|
||||
),
|
||||
SentinelManagedSSLConnection
|
||||
if kwargs.pop("ssl", False)
|
||||
else SentinelManagedConnection,
|
||||
)
|
||||
self.is_master = kwargs.pop("is_master", True)
|
||||
self.check_connection = kwargs.pop("check_connection", False)
|
||||
@@ -171,10 +162,7 @@ class SentinelConnectionPool(ConnectionPool):
|
||||
|
||||
def __repr__(self):
|
||||
role = "master" if self.is_master else "slave"
|
||||
return (
|
||||
f"<{type(self).__module__}.{type(self).__name__}"
|
||||
f"(service={self.service_name}({role}))>"
|
||||
)
|
||||
return f"{type(self).__name__}<service={self.service_name}({role})"
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
@@ -233,7 +221,6 @@ class Sentinel(SentinelCommands):
|
||||
sentinels,
|
||||
min_other_sentinels=0,
|
||||
sentinel_kwargs=None,
|
||||
force_master_ip=None,
|
||||
**connection_kwargs,
|
||||
):
|
||||
# if sentinel_kwargs isn't defined, use the socket_* options from
|
||||
@@ -250,7 +237,6 @@ class Sentinel(SentinelCommands):
|
||||
]
|
||||
self.min_other_sentinels = min_other_sentinels
|
||||
self.connection_kwargs = connection_kwargs
|
||||
self._force_master_ip = force_master_ip
|
||||
|
||||
def execute_command(self, *args, **kwargs):
|
||||
"""
|
||||
@@ -258,27 +244,16 @@ class Sentinel(SentinelCommands):
|
||||
once - If set to True, then execute the resulting command on a single
|
||||
node at random, rather than across the entire sentinel cluster.
|
||||
"""
|
||||
once = bool(kwargs.pop("once", False))
|
||||
|
||||
# Check if command is supposed to return the original
|
||||
# responses instead of boolean value.
|
||||
return_responses = bool(kwargs.pop("return_responses", False))
|
||||
once = bool(kwargs.get("once", False))
|
||||
if "once" in kwargs.keys():
|
||||
kwargs.pop("once")
|
||||
|
||||
if once:
|
||||
response = random.choice(self.sentinels).execute_command(*args, **kwargs)
|
||||
if return_responses:
|
||||
return [response]
|
||||
else:
|
||||
return True if response else False
|
||||
|
||||
responses = []
|
||||
for sentinel in self.sentinels:
|
||||
responses.append(sentinel.execute_command(*args, **kwargs))
|
||||
|
||||
if return_responses:
|
||||
return responses
|
||||
|
||||
return all(responses)
|
||||
random.choice(self.sentinels).execute_command(*args, **kwargs)
|
||||
else:
|
||||
for sentinel in self.sentinels:
|
||||
sentinel.execute_command(*args, **kwargs)
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
sentinel_addresses = []
|
||||
@@ -286,10 +261,7 @@ class Sentinel(SentinelCommands):
|
||||
sentinel_addresses.append(
|
||||
"{host}:{port}".format_map(sentinel.connection_pool.connection_kwargs)
|
||||
)
|
||||
return (
|
||||
f"<{type(self).__module__}.{type(self).__name__}"
|
||||
f"(sentinels=[{','.join(sentinel_addresses)}])>"
|
||||
)
|
||||
return f'{type(self).__name__}<sentinels=[{",".join(sentinel_addresses)}]>'
|
||||
|
||||
def check_master_state(self, state, service_name):
|
||||
if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
|
||||
@@ -321,13 +293,7 @@ class Sentinel(SentinelCommands):
|
||||
sentinel,
|
||||
self.sentinels[0],
|
||||
)
|
||||
|
||||
ip = (
|
||||
self._force_master_ip
|
||||
if self._force_master_ip is not None
|
||||
else state["ip"]
|
||||
)
|
||||
return ip, state["port"]
|
||||
return state["ip"], state["port"]
|
||||
|
||||
error_info = ""
|
||||
if len(collected_errors) > 0:
|
||||
@@ -364,8 +330,6 @@ class Sentinel(SentinelCommands):
|
||||
):
|
||||
"""
|
||||
Returns a redis client instance for the ``service_name`` master.
|
||||
Sentinel client will detect failover and reconnect Redis clients
|
||||
automatically.
|
||||
|
||||
A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
|
||||
used to retrieve the master's address before establishing a new
|
||||
|
||||
@@ -7,18 +7,21 @@ from typing import (
|
||||
Awaitable,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from redis.compat import Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis._parsers import Encoder
|
||||
from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool
|
||||
from redis.connection import ConnectionPool
|
||||
|
||||
|
||||
Number = Union[int, float]
|
||||
EncodedT = Union[bytes, bytearray, memoryview]
|
||||
EncodedT = Union[bytes, memoryview]
|
||||
DecodedT = Union[str, int, float]
|
||||
EncodableT = Union[EncodedT, DecodedT]
|
||||
AbsExpiryT = Union[int, datetime]
|
||||
@@ -30,7 +33,6 @@ KeyT = _StringLikeT # Main redis key space
|
||||
PatternT = _StringLikeT # Patterns matched against keys, fields etc
|
||||
FieldT = EncodableT # Fields within hash tables, streams and geo commands
|
||||
KeysT = Union[KeyT, Iterable[KeyT]]
|
||||
ResponseT = Union[Awaitable[Any], Any]
|
||||
ChannelT = _StringLikeT
|
||||
GroupT = _StringLikeT # Consumer group
|
||||
ConsumerT = _StringLikeT # Consumer name
|
||||
@@ -50,8 +52,14 @@ ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Except
|
||||
|
||||
|
||||
class CommandsProtocol(Protocol):
|
||||
def execute_command(self, *args, **options) -> ResponseT: ...
|
||||
connection_pool: Union["AsyncConnectionPool", "ConnectionPool"]
|
||||
|
||||
def execute_command(self, *args, **options):
|
||||
...
|
||||
|
||||
|
||||
class ClusterCommandsProtocol(CommandsProtocol):
|
||||
class ClusterCommandsProtocol(CommandsProtocol, Protocol):
|
||||
encoder: "Encoder"
|
||||
|
||||
def execute_command(self, *args, **options) -> Union[Any, Awaitable]:
|
||||
...
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
import datetime
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Callable
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union
|
||||
|
||||
from redis.exceptions import DataError
|
||||
from redis.typing import AbsExpiryT, EncodableT, ExpiryT
|
||||
from typing import Any, Dict, Mapping, Union
|
||||
|
||||
try:
|
||||
import hiredis # noqa
|
||||
|
||||
# Only support Hiredis >= 3.0:
|
||||
hiredis_version = hiredis.__version__.split(".")
|
||||
HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or (
|
||||
int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2
|
||||
)
|
||||
if not HIREDIS_AVAILABLE:
|
||||
raise ImportError("hiredis package should be >= 3.2.0")
|
||||
# Only support Hiredis >= 1.0:
|
||||
HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.")
|
||||
HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command")
|
||||
except ImportError:
|
||||
HIREDIS_AVAILABLE = False
|
||||
HIREDIS_PACK_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import ssl # noqa
|
||||
@@ -36,7 +28,10 @@ try:
|
||||
except ImportError:
|
||||
CRYPTOGRAPHY_AVAILABLE = False
|
||||
|
||||
from importlib import metadata
|
||||
if sys.version_info >= (3, 8):
|
||||
from importlib import metadata
|
||||
else:
|
||||
import importlib_metadata as metadata
|
||||
|
||||
|
||||
def from_url(url, **kwargs):
|
||||
@@ -131,74 +126,6 @@ def deprecated_function(reason="", version="", name=None):
|
||||
return decorator
|
||||
|
||||
|
||||
def warn_deprecated_arg_usage(
|
||||
arg_name: Union[list, str],
|
||||
function_name: str,
|
||||
reason: str = "",
|
||||
version: str = "",
|
||||
stacklevel: int = 2,
|
||||
):
|
||||
import warnings
|
||||
|
||||
msg = (
|
||||
f"Call to '{function_name}' function with deprecated"
|
||||
f" usage of input argument/s '{arg_name}'."
|
||||
)
|
||||
if reason:
|
||||
msg += f" ({reason})"
|
||||
if version:
|
||||
msg += f" -- Deprecated since version {version}."
|
||||
warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel)
|
||||
|
||||
|
||||
C = TypeVar("C", bound=Callable)
|
||||
|
||||
|
||||
def deprecated_args(
|
||||
args_to_warn: list = ["*"],
|
||||
allowed_args: list = [],
|
||||
reason: str = "",
|
||||
version: str = "",
|
||||
) -> Callable[[C], C]:
|
||||
"""
|
||||
Decorator to mark specified args of a function as deprecated.
|
||||
If '*' is in args_to_warn, all arguments will be marked as deprecated.
|
||||
"""
|
||||
|
||||
def decorator(func: C) -> C:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get function argument names
|
||||
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
|
||||
|
||||
provided_args = dict(zip(arg_names, args))
|
||||
provided_args.update(kwargs)
|
||||
|
||||
provided_args.pop("self", None)
|
||||
for allowed_arg in allowed_args:
|
||||
provided_args.pop(allowed_arg, None)
|
||||
|
||||
for arg in args_to_warn:
|
||||
if arg == "*" and len(provided_args) > 0:
|
||||
warn_deprecated_arg_usage(
|
||||
list(provided_args.keys()),
|
||||
func.__name__,
|
||||
reason,
|
||||
version,
|
||||
stacklevel=3,
|
||||
)
|
||||
elif arg in provided_args:
|
||||
warn_deprecated_arg_usage(
|
||||
arg, func.__name__, reason, version, stacklevel=3
|
||||
)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _set_info_logger():
|
||||
"""
|
||||
Set up a logger that log info logs to stdout.
|
||||
@@ -218,97 +145,3 @@ def get_lib_version():
|
||||
except metadata.PackageNotFoundError:
|
||||
libver = "99.99.99"
|
||||
return libver
|
||||
|
||||
|
||||
def format_error_message(host_error: str, exception: BaseException) -> str:
|
||||
if not exception.args:
|
||||
return f"Error connecting to {host_error}."
|
||||
elif len(exception.args) == 1:
|
||||
return f"Error {exception.args[0]} connecting to {host_error}."
|
||||
else:
|
||||
return (
|
||||
f"Error {exception.args[0]} connecting to {host_error}. "
|
||||
f"{exception.args[1]}."
|
||||
)
|
||||
|
||||
|
||||
def compare_versions(version1: str, version2: str) -> int:
|
||||
"""
|
||||
Compare two versions.
|
||||
|
||||
:return: -1 if version1 > version2
|
||||
0 if both versions are equal
|
||||
1 if version1 < version2
|
||||
"""
|
||||
|
||||
num_versions1 = list(map(int, version1.split(".")))
|
||||
num_versions2 = list(map(int, version2.split(".")))
|
||||
|
||||
if len(num_versions1) > len(num_versions2):
|
||||
diff = len(num_versions1) - len(num_versions2)
|
||||
for _ in range(diff):
|
||||
num_versions2.append(0)
|
||||
elif len(num_versions1) < len(num_versions2):
|
||||
diff = len(num_versions2) - len(num_versions1)
|
||||
for _ in range(diff):
|
||||
num_versions1.append(0)
|
||||
|
||||
for i, ver in enumerate(num_versions1):
|
||||
if num_versions1[i] > num_versions2[i]:
|
||||
return -1
|
||||
elif num_versions1[i] < num_versions2[i]:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def ensure_string(key):
|
||||
if isinstance(key, bytes):
|
||||
return key.decode("utf-8")
|
||||
elif isinstance(key, str):
|
||||
return key
|
||||
else:
|
||||
raise TypeError("Key must be either a string or bytes")
|
||||
|
||||
|
||||
def extract_expire_flags(
|
||||
ex: Optional[ExpiryT] = None,
|
||||
px: Optional[ExpiryT] = None,
|
||||
exat: Optional[AbsExpiryT] = None,
|
||||
pxat: Optional[AbsExpiryT] = None,
|
||||
) -> List[EncodableT]:
|
||||
exp_options: list[EncodableT] = []
|
||||
if ex is not None:
|
||||
exp_options.append("EX")
|
||||
if isinstance(ex, datetime.timedelta):
|
||||
exp_options.append(int(ex.total_seconds()))
|
||||
elif isinstance(ex, int):
|
||||
exp_options.append(ex)
|
||||
elif isinstance(ex, str) and ex.isdigit():
|
||||
exp_options.append(int(ex))
|
||||
else:
|
||||
raise DataError("ex must be datetime.timedelta or int")
|
||||
elif px is not None:
|
||||
exp_options.append("PX")
|
||||
if isinstance(px, datetime.timedelta):
|
||||
exp_options.append(int(px.total_seconds() * 1000))
|
||||
elif isinstance(px, int):
|
||||
exp_options.append(px)
|
||||
else:
|
||||
raise DataError("px must be datetime.timedelta or int")
|
||||
elif exat is not None:
|
||||
if isinstance(exat, datetime.datetime):
|
||||
exat = int(exat.timestamp())
|
||||
exp_options.extend(["EXAT", exat])
|
||||
elif pxat is not None:
|
||||
if isinstance(pxat, datetime.datetime):
|
||||
pxat = int(pxat.timestamp() * 1000)
|
||||
exp_options.extend(["PXAT", pxat])
|
||||
|
||||
return exp_options
|
||||
|
||||
|
||||
def truncate_text(txt, max_length=100):
|
||||
return textwrap.shorten(
|
||||
text=txt, width=max_length, placeholder="...", break_long_words=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user