This commit is contained in:
@@ -1,9 +1,4 @@
|
||||
from .base import (
|
||||
AsyncPushNotificationsParser,
|
||||
BaseParser,
|
||||
PushNotificationsParser,
|
||||
_AsyncRESPBase,
|
||||
)
|
||||
from .base import BaseParser, _AsyncRESPBase
|
||||
from .commands import AsyncCommandsParser, CommandsParser
|
||||
from .encoders import Encoder
|
||||
from .hiredis import _AsyncHiredisParser, _HiredisParser
|
||||
@@ -16,12 +11,10 @@ __all__ = [
|
||||
"_AsyncRESPBase",
|
||||
"_AsyncRESP2Parser",
|
||||
"_AsyncRESP3Parser",
|
||||
"AsyncPushNotificationsParser",
|
||||
"CommandsParser",
|
||||
"Encoder",
|
||||
"BaseParser",
|
||||
"_HiredisParser",
|
||||
"_RESP2Parser",
|
||||
"_RESP3Parser",
|
||||
"PushNotificationsParser",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
from abc import ABC
|
||||
from asyncio import IncompleteReadError, StreamReader, TimeoutError
|
||||
from typing import Callable, List, Optional, Protocol, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
|
||||
from asyncio import timeout as async_timeout
|
||||
@@ -9,32 +9,26 @@ else:
|
||||
from async_timeout import timeout as async_timeout
|
||||
|
||||
from ..exceptions import (
|
||||
AskError,
|
||||
AuthenticationError,
|
||||
AuthenticationWrongNumberOfArgsError,
|
||||
BusyLoadingError,
|
||||
ClusterCrossSlotError,
|
||||
ClusterDownError,
|
||||
ConnectionError,
|
||||
ExecAbortError,
|
||||
MasterDownError,
|
||||
ModuleError,
|
||||
MovedError,
|
||||
NoPermissionError,
|
||||
NoScriptError,
|
||||
OutOfMemoryError,
|
||||
ReadOnlyError,
|
||||
RedisError,
|
||||
ResponseError,
|
||||
TryAgainError,
|
||||
)
|
||||
from ..typing import EncodableT
|
||||
from .encoders import Encoder
|
||||
from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer
|
||||
|
||||
MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
|
||||
MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs."
|
||||
NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
|
||||
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
|
||||
MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible."
|
||||
MODULE_EXPORTS_DATA_TYPES_ERROR = (
|
||||
"Error unloading module: the module "
|
||||
"exports one or more module-side data "
|
||||
@@ -78,12 +72,6 @@ class BaseParser(ABC):
|
||||
"READONLY": ReadOnlyError,
|
||||
"NOAUTH": AuthenticationError,
|
||||
"NOPERM": NoPermissionError,
|
||||
"ASK": AskError,
|
||||
"TRYAGAIN": TryAgainError,
|
||||
"MOVED": MovedError,
|
||||
"CLUSTERDOWN": ClusterDownError,
|
||||
"CROSSSLOT": ClusterCrossSlotError,
|
||||
"MASTERDOWN": MasterDownError,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -158,58 +146,6 @@ class AsyncBaseParser(BaseParser):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
|
||||
|
||||
|
||||
class PushNotificationsParser(Protocol):
|
||||
"""Protocol defining RESP3-specific parsing functionality"""
|
||||
|
||||
pubsub_push_handler_func: Callable
|
||||
invalidation_push_handler_func: Optional[Callable] = None
|
||||
|
||||
def handle_pubsub_push_response(self, response):
|
||||
"""Handle pubsub push responses"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def handle_push_response(self, response, **kwargs):
|
||||
if response[0] not in _INVALIDATION_MESSAGE:
|
||||
return self.pubsub_push_handler_func(response)
|
||||
if self.invalidation_push_handler_func:
|
||||
return self.invalidation_push_handler_func(response)
|
||||
|
||||
def set_pubsub_push_handler(self, pubsub_push_handler_func):
|
||||
self.pubsub_push_handler_func = pubsub_push_handler_func
|
||||
|
||||
def set_invalidation_push_handler(self, invalidation_push_handler_func):
|
||||
self.invalidation_push_handler_func = invalidation_push_handler_func
|
||||
|
||||
|
||||
class AsyncPushNotificationsParser(Protocol):
|
||||
"""Protocol defining async RESP3-specific parsing functionality"""
|
||||
|
||||
pubsub_push_handler_func: Callable
|
||||
invalidation_push_handler_func: Optional[Callable] = None
|
||||
|
||||
async def handle_pubsub_push_response(self, response):
|
||||
"""Handle pubsub push responses asynchronously"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def handle_push_response(self, response, **kwargs):
|
||||
"""Handle push responses asynchronously"""
|
||||
if response[0] not in _INVALIDATION_MESSAGE:
|
||||
return await self.pubsub_push_handler_func(response)
|
||||
if self.invalidation_push_handler_func:
|
||||
return await self.invalidation_push_handler_func(response)
|
||||
|
||||
def set_pubsub_push_handler(self, pubsub_push_handler_func):
|
||||
"""Set the pubsub push handler function"""
|
||||
self.pubsub_push_handler_func = pubsub_push_handler_func
|
||||
|
||||
def set_invalidation_push_handler(self, invalidation_push_handler_func):
|
||||
"""Set the invalidation push handler function"""
|
||||
self.invalidation_push_handler_func = invalidation_push_handler_func
|
||||
|
||||
|
||||
class _AsyncRESPBase(AsyncBaseParser):
|
||||
"""Base class for async resp parsing"""
|
||||
|
||||
@@ -246,7 +182,7 @@ class _AsyncRESPBase(AsyncBaseParser):
|
||||
return True
|
||||
try:
|
||||
async with async_timeout(0):
|
||||
return self._stream.at_eof()
|
||||
return await self._stream.read(1)
|
||||
except TimeoutError:
|
||||
return False
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ def parse_info(response):
|
||||
response = str_if_bytes(response)
|
||||
|
||||
def get_value(value):
|
||||
if "," not in value and "=" not in value:
|
||||
if "," not in value or "=" not in value:
|
||||
try:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
@@ -46,18 +46,11 @@ def parse_info(response):
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
elif "=" not in value:
|
||||
return [get_value(v) for v in value.split(",") if v]
|
||||
else:
|
||||
sub_dict = {}
|
||||
for item in value.split(","):
|
||||
if not item:
|
||||
continue
|
||||
if "=" in item:
|
||||
k, v = item.rsplit("=", 1)
|
||||
sub_dict[k] = get_value(v)
|
||||
else:
|
||||
sub_dict[item] = True
|
||||
k, v = item.rsplit("=", 1)
|
||||
sub_dict[k] = get_value(v)
|
||||
return sub_dict
|
||||
|
||||
for line in response.splitlines():
|
||||
@@ -87,7 +80,7 @@ def parse_memory_stats(response, **kwargs):
|
||||
"""Parse the results of MEMORY STATS"""
|
||||
stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True)
|
||||
for key, value in stats.items():
|
||||
if key.startswith("db.") and isinstance(value, list):
|
||||
if key.startswith("db."):
|
||||
stats[key] = pairs_to_dict(
|
||||
value, decode_keys=True, decode_string_values=True
|
||||
)
|
||||
@@ -275,22 +268,17 @@ def parse_xinfo_stream(response, **options):
|
||||
data = {str_if_bytes(k): v for k, v in response.items()}
|
||||
if not options.get("full", False):
|
||||
first = data.get("first-entry")
|
||||
if first is not None and first[0] is not None:
|
||||
if first is not None:
|
||||
data["first-entry"] = (first[0], pairs_to_dict(first[1]))
|
||||
last = data["last-entry"]
|
||||
if last is not None and last[0] is not None:
|
||||
if last is not None:
|
||||
data["last-entry"] = (last[0], pairs_to_dict(last[1]))
|
||||
else:
|
||||
data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]}
|
||||
if len(data["groups"]) > 0 and isinstance(data["groups"][0], list):
|
||||
if isinstance(data["groups"][0], list):
|
||||
data["groups"] = [
|
||||
pairs_to_dict(group, decode_keys=True) for group in data["groups"]
|
||||
]
|
||||
for g in data["groups"]:
|
||||
if g["consumers"] and g["consumers"][0] is not None:
|
||||
g["consumers"] = [
|
||||
pairs_to_dict(c, decode_keys=True) for c in g["consumers"]
|
||||
]
|
||||
else:
|
||||
data["groups"] = [
|
||||
{str_if_bytes(k): v for k, v in group.items()}
|
||||
@@ -334,7 +322,7 @@ def float_or_none(response):
|
||||
return float(response)
|
||||
|
||||
|
||||
def bool_ok(response, **options):
|
||||
def bool_ok(response):
|
||||
return str_if_bytes(response) == "OK"
|
||||
|
||||
|
||||
@@ -366,12 +354,7 @@ def parse_scan(response, **options):
|
||||
|
||||
def parse_hscan(response, **options):
|
||||
cursor, r = response
|
||||
no_values = options.get("no_values", False)
|
||||
if no_values:
|
||||
payload = r or []
|
||||
else:
|
||||
payload = r and pairs_to_dict(r) or {}
|
||||
return int(cursor), payload
|
||||
return int(cursor), r and pairs_to_dict(r) or {}
|
||||
|
||||
|
||||
def parse_zscan(response, **options):
|
||||
@@ -396,20 +379,13 @@ def parse_slowlog_get(response, **options):
|
||||
# an O(N) complexity) instead of the command.
|
||||
if isinstance(item[3], list):
|
||||
result["command"] = space.join(item[3])
|
||||
|
||||
# These fields are optional, depends on environment.
|
||||
if len(item) >= 6:
|
||||
result["client_address"] = item[4]
|
||||
result["client_name"] = item[5]
|
||||
result["client_address"] = item[4]
|
||||
result["client_name"] = item[5]
|
||||
else:
|
||||
result["complexity"] = item[3]
|
||||
result["command"] = space.join(item[4])
|
||||
|
||||
# These fields are optional, depends on environment.
|
||||
if len(item) >= 7:
|
||||
result["client_address"] = item[5]
|
||||
result["client_name"] = item[6]
|
||||
|
||||
result["client_address"] = item[5]
|
||||
result["client_name"] = item[6]
|
||||
return result
|
||||
|
||||
return [parse_item(item) for item in response]
|
||||
@@ -452,11 +428,9 @@ def parse_cluster_info(response, **options):
|
||||
def _parse_node_line(line):
|
||||
line_items = line.split(" ")
|
||||
node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8]
|
||||
ip = addr.split("@")[0]
|
||||
hostname = addr.split("@")[1].split(",")[1] if "@" in addr and "," in addr else ""
|
||||
addr = addr.split("@")[0]
|
||||
node_dict = {
|
||||
"node_id": node_id,
|
||||
"hostname": hostname,
|
||||
"flags": flags,
|
||||
"master_id": master_id,
|
||||
"last_ping_sent": ping,
|
||||
@@ -469,7 +443,7 @@ def _parse_node_line(line):
|
||||
if len(line_items) >= 9:
|
||||
slots, migrations = _parse_slots(line_items[8:])
|
||||
node_dict["slots"], node_dict["migrations"] = slots, migrations
|
||||
return ip, node_dict
|
||||
return addr, node_dict
|
||||
|
||||
|
||||
def _parse_slots(slot_ranges):
|
||||
@@ -516,7 +490,7 @@ def parse_geosearch_generic(response, **options):
|
||||
except KeyError: # it means the command was sent via execute_command
|
||||
return response
|
||||
|
||||
if not isinstance(response, list):
|
||||
if type(response) != list:
|
||||
response_list = [response]
|
||||
else:
|
||||
response_list = response
|
||||
@@ -676,8 +650,7 @@ def parse_client_info(value):
|
||||
"omem",
|
||||
"tot-mem",
|
||||
}:
|
||||
if int_key in client_info:
|
||||
client_info[int_key] = int(client_info[int_key])
|
||||
client_info[int_key] = int(client_info[int_key])
|
||||
return client_info
|
||||
|
||||
|
||||
@@ -840,28 +813,24 @@ _RedisCallbacksRESP2 = {
|
||||
|
||||
|
||||
_RedisCallbacksRESP3 = {
|
||||
**string_keys_to_dict(
|
||||
"SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set()
|
||||
),
|
||||
**string_keys_to_dict(
|
||||
"ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE "
|
||||
"ZUNION HGETALL XREADGROUP",
|
||||
lambda r, **kwargs: r,
|
||||
),
|
||||
**string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3),
|
||||
"ACL LOG": lambda r: (
|
||||
[
|
||||
{str_if_bytes(key): str_if_bytes(value) for key, value in x.items()}
|
||||
for x in r
|
||||
]
|
||||
if isinstance(r, list)
|
||||
else bool_ok(r)
|
||||
),
|
||||
"ACL LOG": lambda r: [
|
||||
{str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} for x in r
|
||||
]
|
||||
if isinstance(r, list)
|
||||
else bool_ok(r),
|
||||
"COMMAND": parse_command_resp3,
|
||||
"CONFIG GET": lambda r: {
|
||||
str_if_bytes(key) if key is not None else None: (
|
||||
str_if_bytes(value) if value is not None else None
|
||||
)
|
||||
str_if_bytes(key)
|
||||
if key is not None
|
||||
else None: str_if_bytes(value)
|
||||
if value is not None
|
||||
else None
|
||||
for key, value in r.items()
|
||||
},
|
||||
"MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()},
|
||||
@@ -869,11 +838,11 @@ _RedisCallbacksRESP3 = {
|
||||
"SENTINEL MASTERS": parse_sentinel_masters_resp3,
|
||||
"SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3,
|
||||
"SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3,
|
||||
"STRALGO": lambda r, **options: (
|
||||
{str_if_bytes(key): str_if_bytes(value) for key, value in r.items()}
|
||||
if isinstance(r, dict)
|
||||
else str_if_bytes(r)
|
||||
),
|
||||
"STRALGO": lambda r, **options: {
|
||||
str_if_bytes(key): str_if_bytes(value) for key, value in r.items()
|
||||
}
|
||||
if isinstance(r, dict)
|
||||
else str_if_bytes(r),
|
||||
"XINFO CONSUMERS": lambda r: [
|
||||
{str_if_bytes(key): value for key, value in x.items()} for x in r
|
||||
],
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
import asyncio
|
||||
import socket
|
||||
import sys
|
||||
from logging import getLogger
|
||||
from typing import Callable, List, Optional, TypedDict, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
|
||||
from asyncio import timeout as async_timeout
|
||||
else:
|
||||
from async_timeout import timeout as async_timeout
|
||||
|
||||
from redis.compat import TypedDict
|
||||
|
||||
from ..exceptions import ConnectionError, InvalidResponse, RedisError
|
||||
from ..typing import EncodableT
|
||||
from ..utils import HIREDIS_AVAILABLE
|
||||
from .base import (
|
||||
AsyncBaseParser,
|
||||
AsyncPushNotificationsParser,
|
||||
BaseParser,
|
||||
PushNotificationsParser,
|
||||
)
|
||||
from .base import AsyncBaseParser, BaseParser
|
||||
from .socket import (
|
||||
NONBLOCKING_EXCEPTION_ERROR_NUMBERS,
|
||||
NONBLOCKING_EXCEPTIONS,
|
||||
@@ -25,11 +21,6 @@ from .socket import (
|
||||
SERVER_CLOSED_CONNECTION_ERROR,
|
||||
)
|
||||
|
||||
# Used to signal that hiredis-py does not have enough data to parse.
|
||||
# Using `False` or `None` is not reliable, given that the parser can
|
||||
# return `False` or `None` for legitimate reasons from RESP payloads.
|
||||
NOT_ENOUGH_DATA = object()
|
||||
|
||||
|
||||
class _HiredisReaderArgs(TypedDict, total=False):
|
||||
protocolError: Callable[[str], Exception]
|
||||
@@ -38,7 +29,7 @@ class _HiredisReaderArgs(TypedDict, total=False):
|
||||
errors: Optional[str]
|
||||
|
||||
|
||||
class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
class _HiredisParser(BaseParser):
|
||||
"Parser class for connections using Hiredis"
|
||||
|
||||
def __init__(self, socket_read_size):
|
||||
@@ -46,9 +37,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
raise RedisError("Hiredis is not installed")
|
||||
self.socket_read_size = socket_read_size
|
||||
self._buffer = bytearray(socket_read_size)
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
@@ -56,11 +44,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def handle_pubsub_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
def on_connect(self, connection, **kwargs):
|
||||
import hiredis
|
||||
|
||||
@@ -70,32 +53,25 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
"protocolError": InvalidResponse,
|
||||
"replyError": self.parse_error,
|
||||
"errors": connection.encoder.encoding_errors,
|
||||
"notEnoughData": NOT_ENOUGH_DATA,
|
||||
}
|
||||
|
||||
if connection.encoder.decode_responses:
|
||||
kwargs["encoding"] = connection.encoder.encoding
|
||||
self._reader = hiredis.Reader(**kwargs)
|
||||
self._next_response = NOT_ENOUGH_DATA
|
||||
|
||||
try:
|
||||
self._hiredis_PushNotificationType = hiredis.PushNotification
|
||||
except AttributeError:
|
||||
# hiredis < 3.2
|
||||
self._hiredis_PushNotificationType = None
|
||||
self._next_response = False
|
||||
|
||||
def on_disconnect(self):
|
||||
self._sock = None
|
||||
self._reader = None
|
||||
self._next_response = NOT_ENOUGH_DATA
|
||||
self._next_response = False
|
||||
|
||||
def can_read(self, timeout):
|
||||
if not self._reader:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
|
||||
if self._next_response is NOT_ENOUGH_DATA:
|
||||
if self._next_response is False:
|
||||
self._next_response = self._reader.gets()
|
||||
if self._next_response is NOT_ENOUGH_DATA:
|
||||
if self._next_response is False:
|
||||
return self.read_from_socket(timeout=timeout, raise_on_timeout=False)
|
||||
return True
|
||||
|
||||
@@ -129,24 +105,14 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
if custom_timeout:
|
||||
sock.settimeout(self._socket_timeout)
|
||||
|
||||
def read_response(self, disable_decoding=False, push_request=False):
|
||||
def read_response(self, disable_decoding=False):
|
||||
if not self._reader:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
|
||||
# _next_response might be cached from a can_read() call
|
||||
if self._next_response is not NOT_ENOUGH_DATA:
|
||||
if self._next_response is not False:
|
||||
response = self._next_response
|
||||
self._next_response = NOT_ENOUGH_DATA
|
||||
if self._hiredis_PushNotificationType is not None and isinstance(
|
||||
response, self._hiredis_PushNotificationType
|
||||
):
|
||||
response = self.handle_push_response(response)
|
||||
if not push_request:
|
||||
return self.read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
self._next_response = False
|
||||
return response
|
||||
|
||||
if disable_decoding:
|
||||
@@ -154,7 +120,7 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
|
||||
while response is NOT_ENOUGH_DATA:
|
||||
while response is False:
|
||||
self.read_from_socket()
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
@@ -165,16 +131,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
# happened
|
||||
if isinstance(response, ConnectionError):
|
||||
raise response
|
||||
elif self._hiredis_PushNotificationType is not None and isinstance(
|
||||
response, self._hiredis_PushNotificationType
|
||||
):
|
||||
response = self.handle_push_response(response)
|
||||
if not push_request:
|
||||
return self.read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
elif (
|
||||
isinstance(response, list)
|
||||
and response
|
||||
@@ -184,7 +140,7 @@ class _HiredisParser(BaseParser, PushNotificationsParser):
|
||||
return response
|
||||
|
||||
|
||||
class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
class _AsyncHiredisParser(AsyncBaseParser):
|
||||
"""Async implementation of parser class for connections using Hiredis"""
|
||||
|
||||
__slots__ = ("_reader",)
|
||||
@@ -194,14 +150,6 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
raise RedisError("Hiredis is not available.")
|
||||
super().__init__(socket_read_size=socket_read_size)
|
||||
self._reader = None
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
async def handle_pubsub_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
def on_connect(self, connection):
|
||||
import hiredis
|
||||
@@ -210,7 +158,6 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
kwargs: _HiredisReaderArgs = {
|
||||
"protocolError": InvalidResponse,
|
||||
"replyError": self.parse_error,
|
||||
"notEnoughData": NOT_ENOUGH_DATA,
|
||||
}
|
||||
if connection.encoder.decode_responses:
|
||||
kwargs["encoding"] = connection.encoder.encoding
|
||||
@@ -219,21 +166,13 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
self._reader = hiredis.Reader(**kwargs)
|
||||
self._connected = True
|
||||
|
||||
try:
|
||||
self._hiredis_PushNotificationType = getattr(
|
||||
hiredis, "PushNotification", None
|
||||
)
|
||||
except AttributeError:
|
||||
# hiredis < 3.2
|
||||
self._hiredis_PushNotificationType = None
|
||||
|
||||
def on_disconnect(self):
|
||||
self._connected = False
|
||||
|
||||
async def can_read_destructive(self):
|
||||
if not self._connected:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
|
||||
if self._reader.gets() is not NOT_ENOUGH_DATA:
|
||||
if self._reader.gets():
|
||||
return True
|
||||
try:
|
||||
async with async_timeout(0):
|
||||
@@ -251,7 +190,7 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
return True
|
||||
|
||||
async def read_response(
|
||||
self, disable_decoding: bool = False, push_request: bool = False
|
||||
self, disable_decoding: bool = False
|
||||
) -> Union[EncodableT, List[EncodableT]]:
|
||||
# If `on_disconnect()` has been called, prohibit any more reads
|
||||
# even if they could happen because data might be present.
|
||||
@@ -259,33 +198,16 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser):
|
||||
if not self._connected:
|
||||
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
|
||||
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
|
||||
while response is NOT_ENOUGH_DATA:
|
||||
response = self._reader.gets()
|
||||
while response is False:
|
||||
await self.read_from_socket()
|
||||
if disable_decoding:
|
||||
response = self._reader.gets(False)
|
||||
else:
|
||||
response = self._reader.gets()
|
||||
response = self._reader.gets()
|
||||
|
||||
# if the response is a ConnectionError or the response is a list and
|
||||
# the first item is a ConnectionError, raise it as something bad
|
||||
# happened
|
||||
if isinstance(response, ConnectionError):
|
||||
raise response
|
||||
elif self._hiredis_PushNotificationType is not None and isinstance(
|
||||
response, self._hiredis_PushNotificationType
|
||||
):
|
||||
response = await self.handle_push_response(response)
|
||||
if not push_request:
|
||||
return await self.read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
elif (
|
||||
isinstance(response, list)
|
||||
and response
|
||||
|
||||
@@ -3,26 +3,20 @@ from typing import Any, Union
|
||||
|
||||
from ..exceptions import ConnectionError, InvalidResponse, ResponseError
|
||||
from ..typing import EncodableT
|
||||
from .base import (
|
||||
AsyncPushNotificationsParser,
|
||||
PushNotificationsParser,
|
||||
_AsyncRESPBase,
|
||||
_RESPBase,
|
||||
)
|
||||
from .base import _AsyncRESPBase, _RESPBase
|
||||
from .socket import SERVER_CLOSED_CONNECTION_ERROR
|
||||
|
||||
|
||||
class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
||||
class _RESP3Parser(_RESPBase):
|
||||
"""RESP3 protocol implementation"""
|
||||
|
||||
def __init__(self, socket_read_size):
|
||||
super().__init__(socket_read_size)
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
self.push_handler_func = self.handle_push_response
|
||||
|
||||
def handle_pubsub_push_response(self, response):
|
||||
def handle_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
logger.info("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
def read_response(self, disable_decoding=False, push_request=False):
|
||||
@@ -91,16 +85,19 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
||||
# set response
|
||||
elif byte == b"~":
|
||||
# redis can return unhashable types (like dict) in a set,
|
||||
# so we return sets as list, all the time, for predictability
|
||||
# so we need to first convert to a list, and then try to convert it to a set
|
||||
response = [
|
||||
self._read_response(disable_decoding=disable_decoding)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
try:
|
||||
response = set(response)
|
||||
except TypeError:
|
||||
pass
|
||||
# map response
|
||||
elif byte == b"%":
|
||||
# We cannot use a dict-comprehension to parse stream.
|
||||
# Evaluation order of key:val expression in dict comprehension only
|
||||
# became defined to be left-right in version 3.8
|
||||
# we use this approach and not dict comprehension here
|
||||
# because this dict comprehension fails in python 3.7
|
||||
resp_dict = {}
|
||||
for _ in range(int(response)):
|
||||
key = self._read_response(disable_decoding=disable_decoding)
|
||||
@@ -116,13 +113,13 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
||||
)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
response = self.handle_push_response(response)
|
||||
res = self.push_handler_func(response)
|
||||
if not push_request:
|
||||
return self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
return res
|
||||
else:
|
||||
raise InvalidResponse(f"Protocol Error: {raw!r}")
|
||||
|
||||
@@ -130,16 +127,18 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
|
||||
response = self.encoder.decode(response)
|
||||
return response
|
||||
|
||||
def set_push_handler(self, push_handler_func):
|
||||
self.push_handler_func = push_handler_func
|
||||
|
||||
class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
|
||||
|
||||
class _AsyncRESP3Parser(_AsyncRESPBase):
|
||||
def __init__(self, socket_read_size):
|
||||
super().__init__(socket_read_size)
|
||||
self.pubsub_push_handler_func = self.handle_pubsub_push_response
|
||||
self.invalidation_push_handler_func = None
|
||||
self.push_handler_func = self.handle_push_response
|
||||
|
||||
async def handle_pubsub_push_response(self, response):
|
||||
def handle_push_response(self, response):
|
||||
logger = getLogger("push_response")
|
||||
logger.debug("Push response: " + str(response))
|
||||
logger.info("Push response: " + str(response))
|
||||
return response
|
||||
|
||||
async def read_response(
|
||||
@@ -215,23 +214,23 @@ class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
|
||||
# set response
|
||||
elif byte == b"~":
|
||||
# redis can return unhashable types (like dict) in a set,
|
||||
# so we always convert to a list, to have predictable return types
|
||||
# so we need to first convert to a list, and then try to convert it to a set
|
||||
response = [
|
||||
(await self._read_response(disable_decoding=disable_decoding))
|
||||
for _ in range(int(response))
|
||||
]
|
||||
try:
|
||||
response = set(response)
|
||||
except TypeError:
|
||||
pass
|
||||
# map response
|
||||
elif byte == b"%":
|
||||
# We cannot use a dict-comprehension to parse stream.
|
||||
# Evaluation order of key:val expression in dict comprehension only
|
||||
# became defined to be left-right in version 3.8
|
||||
resp_dict = {}
|
||||
for _ in range(int(response)):
|
||||
key = await self._read_response(disable_decoding=disable_decoding)
|
||||
resp_dict[key] = await self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
response = {
|
||||
(await self._read_response(disable_decoding=disable_decoding)): (
|
||||
await self._read_response(disable_decoding=disable_decoding)
|
||||
)
|
||||
response = resp_dict
|
||||
for _ in range(int(response))
|
||||
}
|
||||
# push response
|
||||
elif byte == b">":
|
||||
response = [
|
||||
@@ -242,16 +241,19 @@ class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser):
|
||||
)
|
||||
for _ in range(int(response))
|
||||
]
|
||||
response = await self.handle_push_response(response)
|
||||
res = self.push_handler_func(response)
|
||||
if not push_request:
|
||||
return await self._read_response(
|
||||
disable_decoding=disable_decoding, push_request=push_request
|
||||
)
|
||||
else:
|
||||
return response
|
||||
return res
|
||||
else:
|
||||
raise InvalidResponse(f"Protocol Error: {raw!r}")
|
||||
|
||||
if isinstance(response, bytes) and disable_decoding is False:
|
||||
response = self.encoder.decode(response)
|
||||
return response
|
||||
|
||||
def set_push_handler(self, push_handler_func):
|
||||
self.push_handler_func = push_handler_func
|
||||
|
||||
Reference in New Issue
Block a user