API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -5,7 +5,7 @@ from .commands import * # noqa
from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo
class AbstractBloom(object):
class AbstractBloom:
"""
The client allows to interact with RedisBloom and use all of
it's functionality.

View File

@@ -1,6 +1,5 @@
from redis.client import NEVER_DECODE
from redis.exceptions import ModuleError
from redis.utils import HIREDIS_AVAILABLE, deprecated_function
from redis.utils import deprecated_function
BF_RESERVE = "BF.RESERVE"
BF_ADD = "BF.ADD"
@@ -139,9 +138,6 @@ class BFCommands:
This command will return successive (iter, data) pairs until (0, NULL) to indicate completion.
For more information see `BF.SCANDUMP <https://redis.io/commands/bf.scandump>`_.
""" # noqa
if HIREDIS_AVAILABLE:
raise ModuleError("This command cannot be used when hiredis is available.")
params = [key, iter]
options = {}
options[NEVER_DECODE] = []

View File

@@ -1,7 +1,7 @@
from ..helpers import nativestr
class BFInfo(object):
class BFInfo:
capacity = None
size = None
filterNum = None
@@ -26,7 +26,7 @@ class BFInfo(object):
return getattr(self, item)
class CFInfo(object):
class CFInfo:
size = None
bucketNum = None
filterNum = None
@@ -57,7 +57,7 @@ class CFInfo(object):
return getattr(self, item)
class CMSInfo(object):
class CMSInfo:
width = None
depth = None
count = None
@@ -72,7 +72,7 @@ class CMSInfo(object):
return getattr(self, item)
class TopKInfo(object):
class TopKInfo:
k = None
width = None
depth = None
@@ -89,7 +89,7 @@ class TopKInfo(object):
return getattr(self, item)
class TDigestInfo(object):
class TDigestInfo:
compression = None
capacity = None
merged_nodes = None

View File

@@ -7,13 +7,13 @@ from typing import (
Iterable,
Iterator,
List,
Literal,
Mapping,
NoReturn,
Optional,
Union,
)
from redis.compat import Literal
from redis.crc import key_slot
from redis.exceptions import RedisClusterException, RedisError
from redis.typing import (
@@ -23,6 +23,7 @@ from redis.typing import (
KeysT,
KeyT,
PatternT,
ResponseT,
)
from .core import (
@@ -30,21 +31,18 @@ from .core import (
AsyncACLCommands,
AsyncDataAccessCommands,
AsyncFunctionCommands,
AsyncGearsCommands,
AsyncManagementCommands,
AsyncModuleCommands,
AsyncScriptCommands,
DataAccessCommands,
FunctionCommands,
GearsCommands,
ManagementCommands,
ModuleCommands,
PubSubCommands,
ResponseT,
ScriptCommands,
)
from .helpers import list_or_args
from .redismodules import RedisModuleCommands
from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands
if TYPE_CHECKING:
from redis.asyncio.cluster import TargetNodesT
@@ -225,7 +223,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
The keys are first split up into slots
and then an DEL command is sent for every slot
Non-existant keys are ignored.
Non-existent keys are ignored.
Returns the number of keys that were deleted.
For more information see https://redis.io/commands/del
@@ -240,7 +238,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
The keys are first split up into slots
and then an TOUCH command is sent for every slot
Non-existant keys are ignored.
Non-existent keys are ignored.
Returns the number of keys that were touched.
For more information see https://redis.io/commands/touch
@@ -254,7 +252,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol):
The keys are first split up into slots
and then an TOUCH command is sent for every slot
Non-existant keys are ignored.
Non-existent keys are ignored.
Returns the number of keys that were unlinked.
For more information see https://redis.io/commands/unlink
@@ -595,7 +593,7 @@ class ClusterManagementCommands(ManagementCommands):
"CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node
)
elif state.upper() == "STABLE":
raise RedisError('For "stable" state please use ' "cluster_setslot_stable")
raise RedisError('For "stable" state please use cluster_setslot_stable')
else:
raise RedisError(f"Invalid slot state: {state}")
@@ -693,12 +691,6 @@ class ClusterManagementCommands(ManagementCommands):
self.read_from_replicas = False
return self.execute_command("READWRITE", target_nodes=target_nodes)
def gears_refresh_cluster(self, **kwargs) -> ResponseT:
"""
On an OSS cluster, before executing any gears function, you must call this command. # noqa
"""
return self.execute_command("REDISGEARS_2.REFRESHCLUSTER", **kwargs)
class AsyncClusterManagementCommands(
ClusterManagementCommands, AsyncManagementCommands
@@ -874,7 +866,6 @@ class RedisClusterCommands(
ClusterDataAccessCommands,
ScriptCommands,
FunctionCommands,
GearsCommands,
ModuleCommands,
RedisModuleCommands,
):
@@ -905,8 +896,8 @@ class AsyncRedisClusterCommands(
AsyncClusterDataAccessCommands,
AsyncScriptCommands,
AsyncFunctionCommands,
AsyncGearsCommands,
AsyncModuleCommands,
AsyncRedisModuleCommands,
):
"""
A class for all Redis Cluster commands

File diff suppressed because it is too large Load Diff

View File

@@ -1,263 +0,0 @@
import warnings
from ..helpers import quote_string, random_string, stringify_param_value
from .commands import AsyncGraphCommands, GraphCommands
from .edge import Edge # noqa
from .node import Node # noqa
from .path import Path # noqa
DB_LABELS = "DB.LABELS"
DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES"
DB_PROPERTYKEYS = "DB.PROPERTYKEYS"
class Graph(GraphCommands):
"""
Graph, collection of nodes and edges.
"""
def __init__(self, client, name=random_string()):
"""
Create a new graph.
"""
warnings.warn(
DeprecationWarning(
"RedisGraph support is deprecated as of Redis Stack 7.2 \
(https://redis.com/blog/redisgraph-eol/)"
)
)
self.NAME = name # Graph key
self.client = client
self.execute_command = client.execute_command
self.nodes = {}
self.edges = []
self._labels = [] # List of node labels.
self._properties = [] # List of properties.
self._relationship_types = [] # List of relation types.
self.version = 0 # Graph version
@property
def name(self):
return self.NAME
def _clear_schema(self):
self._labels = []
self._properties = []
self._relationship_types = []
def _refresh_schema(self):
self._clear_schema()
self._refresh_labels()
self._refresh_relations()
self._refresh_attributes()
def _refresh_labels(self):
lbls = self.labels()
# Unpack data.
self._labels = [l[0] for _, l in enumerate(lbls)]
def _refresh_relations(self):
rels = self.relationship_types()
# Unpack data.
self._relationship_types = [r[0] for _, r in enumerate(rels)]
def _refresh_attributes(self):
props = self.property_keys()
# Unpack data.
self._properties = [p[0] for _, p in enumerate(props)]
def get_label(self, idx):
"""
Returns a label by it's index
Args:
idx:
The index of the label
"""
try:
label = self._labels[idx]
except IndexError:
# Refresh labels.
self._refresh_labels()
label = self._labels[idx]
return label
def get_relation(self, idx):
"""
Returns a relationship type by it's index
Args:
idx:
The index of the relation
"""
try:
relationship_type = self._relationship_types[idx]
except IndexError:
# Refresh relationship types.
self._refresh_relations()
relationship_type = self._relationship_types[idx]
return relationship_type
def get_property(self, idx):
"""
Returns a property by it's index
Args:
idx:
The index of the property
"""
try:
p = self._properties[idx]
except IndexError:
# Refresh properties.
self._refresh_attributes()
p = self._properties[idx]
return p
def add_node(self, node):
"""
Adds a node to the graph.
"""
if node.alias is None:
node.alias = random_string()
self.nodes[node.alias] = node
def add_edge(self, edge):
"""
Adds an edge to the graph.
"""
if not (self.nodes[edge.src_node.alias] and self.nodes[edge.dest_node.alias]):
raise AssertionError("Both edge's end must be in the graph")
self.edges.append(edge)
def _build_params_header(self, params):
if params is None:
return ""
if not isinstance(params, dict):
raise TypeError("'params' must be a dict")
# Header starts with "CYPHER"
params_header = "CYPHER "
for key, value in params.items():
params_header += str(key) + "=" + stringify_param_value(value) + " "
return params_header
# Procedures.
def call_procedure(self, procedure, *args, read_only=False, **kwagrs):
args = [quote_string(arg) for arg in args]
q = f"CALL {procedure}({','.join(args)})"
y = kwagrs.get("y", None)
if y is not None:
q += f"YIELD {','.join(y)}"
return self.query(q, read_only=read_only)
def labels(self):
return self.call_procedure(DB_LABELS, read_only=True).result_set
def relationship_types(self):
return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set
def property_keys(self):
return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set
class AsyncGraph(Graph, AsyncGraphCommands):
"""Async version for Graph"""
async def _refresh_labels(self):
lbls = await self.labels()
# Unpack data.
self._labels = [l[0] for _, l in enumerate(lbls)]
async def _refresh_attributes(self):
props = await self.property_keys()
# Unpack data.
self._properties = [p[0] for _, p in enumerate(props)]
async def _refresh_relations(self):
rels = await self.relationship_types()
# Unpack data.
self._relationship_types = [r[0] for _, r in enumerate(rels)]
async def get_label(self, idx):
"""
Returns a label by it's index
Args:
idx:
The index of the label
"""
try:
label = self._labels[idx]
except IndexError:
# Refresh labels.
await self._refresh_labels()
label = self._labels[idx]
return label
async def get_property(self, idx):
"""
Returns a property by it's index
Args:
idx:
The index of the property
"""
try:
p = self._properties[idx]
except IndexError:
# Refresh properties.
await self._refresh_attributes()
p = self._properties[idx]
return p
async def get_relation(self, idx):
"""
Returns a relationship type by it's index
Args:
idx:
The index of the relation
"""
try:
relationship_type = self._relationship_types[idx]
except IndexError:
# Refresh relationship types.
await self._refresh_relations()
relationship_type = self._relationship_types[idx]
return relationship_type
async def call_procedure(self, procedure, *args, read_only=False, **kwagrs):
args = [quote_string(arg) for arg in args]
q = f"CALL {procedure}({','.join(args)})"
y = kwagrs.get("y", None)
if y is not None:
f"YIELD {','.join(y)}"
return await self.query(q, read_only=read_only)
async def labels(self):
return ((await self.call_procedure(DB_LABELS, read_only=True))).result_set
async def property_keys(self):
return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set
async def relationship_types(self):
return (
await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True)
).result_set

View File

@@ -1,313 +0,0 @@
from redis import DataError
from redis.exceptions import ResponseError
from .exceptions import VersionMismatchException
from .execution_plan import ExecutionPlan
from .query_result import AsyncQueryResult, QueryResult
PROFILE_CMD = "GRAPH.PROFILE"
RO_QUERY_CMD = "GRAPH.RO_QUERY"
QUERY_CMD = "GRAPH.QUERY"
DELETE_CMD = "GRAPH.DELETE"
SLOWLOG_CMD = "GRAPH.SLOWLOG"
CONFIG_CMD = "GRAPH.CONFIG"
LIST_CMD = "GRAPH.LIST"
EXPLAIN_CMD = "GRAPH.EXPLAIN"
class GraphCommands:
"""RedisGraph Commands"""
def commit(self):
"""
Create entire graph.
"""
if len(self.nodes) == 0 and len(self.edges) == 0:
return None
query = "CREATE "
for _, node in self.nodes.items():
query += str(node) + ","
query += ",".join([str(edge) for edge in self.edges])
# Discard leading comma.
if query[-1] == ",":
query = query[:-1]
return self.query(query)
def query(self, q, params=None, timeout=None, read_only=False, profile=False):
"""
Executes a query against the graph.
For more information see `GRAPH.QUERY <https://redis.io/commands/graph.query>`_. # noqa
Args:
q : str
The query.
params : dict
Query parameters.
timeout : int
Maximum runtime for read queries in milliseconds.
read_only : bool
Executes a readonly query if set to True.
profile : bool
Return details on results produced by and time
spent in each operation.
"""
# maintain original 'q'
query = q
# handle query parameters
query = self._build_params_header(params) + query
# construct query command
# ask for compact result-set format
# specify known graph version
if profile:
cmd = PROFILE_CMD
else:
cmd = RO_QUERY_CMD if read_only else QUERY_CMD
command = [cmd, self.name, query, "--compact"]
# include timeout is specified
if isinstance(timeout, int):
command.extend(["timeout", timeout])
elif timeout is not None:
raise Exception("Timeout argument must be a positive integer")
# issue query
try:
response = self.execute_command(*command)
return QueryResult(self, response, profile)
except ResponseError as e:
if "unknown command" in str(e) and read_only:
# `GRAPH.RO_QUERY` is unavailable in older versions.
return self.query(q, params, timeout, read_only=False)
raise e
except VersionMismatchException as e:
# client view over the graph schema is out of sync
# set client version and refresh local schema
self.version = e.version
self._refresh_schema()
# re-issue query
return self.query(q, params, timeout, read_only)
def merge(self, pattern):
"""
Merge pattern.
"""
query = "MERGE "
query += str(pattern)
return self.query(query)
def delete(self):
"""
Deletes graph.
For more information see `DELETE <https://redis.io/commands/graph.delete>`_. # noqa
"""
self._clear_schema()
return self.execute_command(DELETE_CMD, self.name)
# declared here, to override the built in redis.db.flush()
def flush(self):
"""
Commit the graph and reset the edges and the nodes to zero length.
"""
self.commit()
self.nodes = {}
self.edges = []
def bulk(self, **kwargs):
"""Internal only. Not supported."""
raise NotImplementedError(
"GRAPH.BULK is internal only. "
"Use https://github.com/redisgraph/redisgraph-bulk-loader."
)
def profile(self, query):
"""
Execute a query and produce an execution plan augmented with metrics
for each operation's execution. Return a string representation of a
query execution plan, with details on results produced by and time
spent in each operation.
For more information see `GRAPH.PROFILE <https://redis.io/commands/graph.profile>`_. # noqa
"""
return self.query(query, profile=True)
def slowlog(self):
"""
Get a list containing up to 10 of the slowest queries issued
against the given graph ID.
For more information see `GRAPH.SLOWLOG <https://redis.io/commands/graph.slowlog>`_. # noqa
Each item in the list has the following structure:
1. A unix timestamp at which the log entry was processed.
2. The issued command.
3. The issued query.
4. The amount of time needed for its execution, in milliseconds.
"""
return self.execute_command(SLOWLOG_CMD, self.name)
def config(self, name, value=None, set=False):
"""
Retrieve or update a RedisGraph configuration.
For more information see `https://redis.io/commands/graph.config-get/>`_. # noqa
Args:
name : str
The name of the configuration
value :
The value we want to set (can be used only when `set` is on)
set : bool
Turn on to set a configuration. Default behavior is get.
"""
params = ["SET" if set else "GET", name]
if value is not None:
if set:
params.append(value)
else:
raise DataError(
"``value`` can be provided only when ``set`` is True"
) # noqa
return self.execute_command(CONFIG_CMD, *params)
def list_keys(self):
"""
Lists all graph keys in the keyspace.
For more information see `GRAPH.LIST <https://redis.io/commands/graph.list>`_. # noqa
"""
return self.execute_command(LIST_CMD)
def execution_plan(self, query, params=None):
"""
Get the execution plan for given query,
GRAPH.EXPLAIN returns an array of operations.
Args:
query: the query that will be executed
params: query parameters
"""
query = self._build_params_header(params) + query
plan = self.execute_command(EXPLAIN_CMD, self.name, query)
if isinstance(plan[0], bytes):
plan = [b.decode() for b in plan]
return "\n".join(plan)
def explain(self, query, params=None):
"""
Get the execution plan for given query,
GRAPH.EXPLAIN returns ExecutionPlan object.
For more information see `GRAPH.EXPLAIN <https://redis.io/commands/graph.explain>`_. # noqa
Args:
query: the query that will be executed
params: query parameters
"""
query = self._build_params_header(params) + query
plan = self.execute_command(EXPLAIN_CMD, self.name, query)
return ExecutionPlan(plan)
class AsyncGraphCommands(GraphCommands):
async def query(self, q, params=None, timeout=None, read_only=False, profile=False):
"""
Executes a query against the graph.
For more information see `GRAPH.QUERY <https://oss.redis.com/redisgraph/master/commands/#graphquery>`_. # noqa
Args:
q : str
The query.
params : dict
Query parameters.
timeout : int
Maximum runtime for read queries in milliseconds.
read_only : bool
Executes a readonly query if set to True.
profile : bool
Return details on results produced by and time
spent in each operation.
"""
# maintain original 'q'
query = q
# handle query parameters
query = self._build_params_header(params) + query
# construct query command
# ask for compact result-set format
# specify known graph version
if profile:
cmd = PROFILE_CMD
else:
cmd = RO_QUERY_CMD if read_only else QUERY_CMD
command = [cmd, self.name, query, "--compact"]
# include timeout is specified
if isinstance(timeout, int):
command.extend(["timeout", timeout])
elif timeout is not None:
raise Exception("Timeout argument must be a positive integer")
# issue query
try:
response = await self.execute_command(*command)
return await AsyncQueryResult().initialize(self, response, profile)
except ResponseError as e:
if "unknown command" in str(e) and read_only:
# `GRAPH.RO_QUERY` is unavailable in older versions.
return await self.query(q, params, timeout, read_only=False)
raise e
except VersionMismatchException as e:
# client view over the graph schema is out of sync
# set client version and refresh local schema
self.version = e.version
self._refresh_schema()
# re-issue query
return await self.query(q, params, timeout, read_only)
async def execution_plan(self, query, params=None):
"""
Get the execution plan for given query,
GRAPH.EXPLAIN returns an array of operations.
Args:
query: the query that will be executed
params: query parameters
"""
query = self._build_params_header(params) + query
plan = await self.execute_command(EXPLAIN_CMD, self.name, query)
if isinstance(plan[0], bytes):
plan = [b.decode() for b in plan]
return "\n".join(plan)
async def explain(self, query, params=None):
"""
Get the execution plan for given query,
GRAPH.EXPLAIN returns ExecutionPlan object.
Args:
query: the query that will be executed
params: query parameters
"""
query = self._build_params_header(params) + query
plan = await self.execute_command(EXPLAIN_CMD, self.name, query)
return ExecutionPlan(plan)
async def flush(self):
"""
Commit the graph and reset the edges and the nodes to zero length.
"""
await self.commit()
self.nodes = {}
self.edges = []

View File

@@ -1,91 +0,0 @@
from ..helpers import quote_string
from .node import Node
class Edge:
"""
An edge connecting two nodes.
"""
def __init__(self, src_node, relation, dest_node, edge_id=None, properties=None):
"""
Create a new edge.
"""
if src_node is None or dest_node is None:
# NOTE(bors-42): It makes sense to change AssertionError to
# ValueError here
raise AssertionError("Both src_node & dest_node must be provided")
self.id = edge_id
self.relation = relation or ""
self.properties = properties or {}
self.src_node = src_node
self.dest_node = dest_node
def to_string(self):
res = ""
if self.properties:
props = ",".join(
key + ":" + str(quote_string(val))
for key, val in sorted(self.properties.items())
)
res += "{" + props + "}"
return res
def __str__(self):
# Source node.
if isinstance(self.src_node, Node):
res = str(self.src_node)
else:
res = "()"
# Edge
res += "-["
if self.relation:
res += ":" + self.relation
if self.properties:
props = ",".join(
key + ":" + str(quote_string(val))
for key, val in sorted(self.properties.items())
)
res += "{" + props + "}"
res += "]->"
# Dest node.
if isinstance(self.dest_node, Node):
res += str(self.dest_node)
else:
res += "()"
return res
def __eq__(self, rhs):
# Type checking
if not isinstance(rhs, Edge):
return False
# Quick positive check, if both IDs are set.
if self.id is not None and rhs.id is not None and self.id == rhs.id:
return True
# Source and destination nodes should match.
if self.src_node != rhs.src_node:
return False
if self.dest_node != rhs.dest_node:
return False
# Relation should match.
if self.relation != rhs.relation:
return False
# Quick check for number of properties.
if len(self.properties) != len(rhs.properties):
return False
# Compare properties.
if self.properties != rhs.properties:
return False
return True

View File

@@ -1,3 +0,0 @@
class VersionMismatchException(Exception):
def __init__(self, version):
self.version = version

View File

@@ -1,211 +0,0 @@
import re
class ProfileStats:
"""
ProfileStats, runtime execution statistics of operation.
"""
def __init__(self, records_produced, execution_time):
self.records_produced = records_produced
self.execution_time = execution_time
class Operation:
"""
Operation, single operation within execution plan.
"""
def __init__(self, name, args=None, profile_stats=None):
"""
Create a new operation.
Args:
name: string that represents the name of the operation
args: operation arguments
profile_stats: profile statistics
"""
self.name = name
self.args = args
self.profile_stats = profile_stats
self.children = []
def append_child(self, child):
if not isinstance(child, Operation) or self is child:
raise Exception("child must be Operation")
self.children.append(child)
return self
def child_count(self):
return len(self.children)
def __eq__(self, o: object) -> bool:
if not isinstance(o, Operation):
return False
return self.name == o.name and self.args == o.args
def __str__(self) -> str:
args_str = "" if self.args is None else " | " + self.args
return f"{self.name}{args_str}"
class ExecutionPlan:
"""
ExecutionPlan, collection of operations.
"""
def __init__(self, plan):
"""
Create a new execution plan.
Args:
plan: array of strings that represents the collection operations
the output from GRAPH.EXPLAIN
"""
if not isinstance(plan, list):
raise Exception("plan must be an array")
if isinstance(plan[0], bytes):
plan = [b.decode() for b in plan]
self.plan = plan
self.structured_plan = self._operation_tree()
def _compare_operations(self, root_a, root_b):
"""
Compare execution plan operation tree
Return: True if operation trees are equal, False otherwise
"""
# compare current root
if root_a != root_b:
return False
# make sure root have the same number of children
if root_a.child_count() != root_b.child_count():
return False
# recursively compare children
for i in range(root_a.child_count()):
if not self._compare_operations(root_a.children[i], root_b.children[i]):
return False
return True
def __str__(self) -> str:
def aggraget_str(str_children):
return "\n".join(
[
" " + line
for str_child in str_children
for line in str_child.splitlines()
]
)
def combine_str(x, y):
return f"{x}\n{y}"
return self._operation_traverse(
self.structured_plan, str, aggraget_str, combine_str
)
def __eq__(self, o: object) -> bool:
"""Compares two execution plans
Return: True if the two plans are equal False otherwise
"""
# make sure 'o' is an execution-plan
if not isinstance(o, ExecutionPlan):
return False
# get root for both plans
root_a = self.structured_plan
root_b = o.structured_plan
# compare execution trees
return self._compare_operations(root_a, root_b)
def _operation_traverse(self, op, op_f, aggregate_f, combine_f):
"""
Traverse operation tree recursively applying functions
Args:
op: operation to traverse
op_f: function applied for each operation
aggregate_f: aggregation function applied for all children of a single operation
combine_f: combine function applied for the operation result and the children result
""" # noqa
# apply op_f for each operation
op_res = op_f(op)
if len(op.children) == 0:
return op_res # no children return
else:
# apply _operation_traverse recursively
children = [
self._operation_traverse(child, op_f, aggregate_f, combine_f)
for child in op.children
]
# combine the operation result with the children aggregated result
return combine_f(op_res, aggregate_f(children))
def _operation_tree(self):
"""Build the operation tree from the string representation"""
# initial state
i = 0
level = 0
stack = []
current = None
def _create_operation(args):
profile_stats = None
name = args[0].strip()
args.pop(0)
if len(args) > 0 and "Records produced" in args[-1]:
records_produced = int(
re.search("Records produced: (\\d+)", args[-1]).group(1)
)
execution_time = float(
re.search("Execution time: (\\d+.\\d+) ms", args[-1]).group(1)
)
profile_stats = ProfileStats(records_produced, execution_time)
args.pop(-1)
return Operation(
name, None if len(args) == 0 else args[0].strip(), profile_stats
)
# iterate plan operations
while i < len(self.plan):
current_op = self.plan[i]
op_level = current_op.count(" ")
if op_level == level:
# if the operation level equal to the current level
# set the current operation and move next
child = _create_operation(current_op.split("|"))
if current:
current = stack.pop()
current.append_child(child)
current = child
i += 1
elif op_level == level + 1:
# if the operation is child of the current operation
# add it as child and set as current operation
child = _create_operation(current_op.split("|"))
current.append_child(child)
stack.append(current)
current = child
level += 1
i += 1
elif op_level < level:
# if the operation is not child of current operation
# go back to it's parent operation
levels_back = level - op_level + 1
for _ in range(levels_back):
current = stack.pop()
level -= levels_back
else:
raise Exception("corrupted plan")
return stack[0]

View File

@@ -1,88 +0,0 @@
from ..helpers import quote_string
class Node:
"""
A node within the graph.
"""
def __init__(self, node_id=None, alias=None, label=None, properties=None):
"""
Create a new node.
"""
self.id = node_id
self.alias = alias
if isinstance(label, list):
label = [inner_label for inner_label in label if inner_label != ""]
if (
label is None
or label == ""
or (isinstance(label, list) and len(label) == 0)
):
self.label = None
self.labels = None
elif isinstance(label, str):
self.label = label
self.labels = [label]
elif isinstance(label, list) and all(
[isinstance(inner_label, str) for inner_label in label]
):
self.label = label[0]
self.labels = label
else:
raise AssertionError(
"label should be either None, string or a list of strings"
)
self.properties = properties or {}
def to_string(self):
res = ""
if self.properties:
props = ",".join(
key + ":" + str(quote_string(val))
for key, val in sorted(self.properties.items())
)
res += "{" + props + "}"
return res
def __str__(self):
res = "("
if self.alias:
res += self.alias
if self.labels:
res += ":" + ":".join(self.labels)
if self.properties:
props = ",".join(
key + ":" + str(quote_string(val))
for key, val in sorted(self.properties.items())
)
res += "{" + props + "}"
res += ")"
return res
def __eq__(self, rhs):
# Type checking
if not isinstance(rhs, Node):
return False
# Quick positive check, if both IDs are set.
if self.id is not None and rhs.id is not None and self.id != rhs.id:
return False
# Label should match.
if self.label != rhs.label:
return False
# Quick check for number of properties.
if len(self.properties) != len(rhs.properties):
return False
# Compare properties.
if self.properties != rhs.properties:
return False
return True

View File

@@ -1,78 +0,0 @@
from .edge import Edge
from .node import Node
class Path:
def __init__(self, nodes, edges):
if not (isinstance(nodes, list) and isinstance(edges, list)):
raise TypeError("nodes and edges must be list")
self._nodes = nodes
self._edges = edges
self.append_type = Node
@classmethod
def new_empty_path(cls):
return cls([], [])
def nodes(self):
return self._nodes
def edges(self):
return self._edges
def get_node(self, index):
return self._nodes[index]
def get_relationship(self, index):
return self._edges[index]
def first_node(self):
return self._nodes[0]
def last_node(self):
return self._nodes[-1]
def edge_count(self):
return len(self._edges)
def nodes_count(self):
return len(self._nodes)
def add_node(self, node):
if not isinstance(node, self.append_type):
raise AssertionError("Add Edge before adding Node")
self._nodes.append(node)
self.append_type = Edge
return self
def add_edge(self, edge):
if not isinstance(edge, self.append_type):
raise AssertionError("Add Node before adding Edge")
self._edges.append(edge)
self.append_type = Node
return self
def __eq__(self, other):
# Type checking
if not isinstance(other, Path):
return False
return self.nodes() == other.nodes() and self.edges() == other.edges()
def __str__(self):
res = "<"
edge_count = self.edge_count()
for i in range(0, edge_count):
node_id = self.get_node(i).id
res += "(" + str(node_id) + ")"
edge = self.get_relationship(i)
res += (
"-[" + str(int(edge.id)) + "]->"
if edge.src_node == node_id
else "<-[" + str(int(edge.id)) + "]-"
)
node_id = self.get_node(edge_count).id
res += "(" + str(node_id) + ")"
res += ">"
return res

View File

@@ -1,573 +0,0 @@
import sys
from collections import OrderedDict
from distutils.util import strtobool
# from prettytable import PrettyTable
from redis import ResponseError
from .edge import Edge
from .exceptions import VersionMismatchException
from .node import Node
from .path import Path
LABELS_ADDED = "Labels added"
LABELS_REMOVED = "Labels removed"
NODES_CREATED = "Nodes created"
NODES_DELETED = "Nodes deleted"
RELATIONSHIPS_DELETED = "Relationships deleted"
PROPERTIES_SET = "Properties set"
PROPERTIES_REMOVED = "Properties removed"
RELATIONSHIPS_CREATED = "Relationships created"
INDICES_CREATED = "Indices created"
INDICES_DELETED = "Indices deleted"
CACHED_EXECUTION = "Cached execution"
INTERNAL_EXECUTION_TIME = "internal execution time"
STATS = [
LABELS_ADDED,
LABELS_REMOVED,
NODES_CREATED,
PROPERTIES_SET,
PROPERTIES_REMOVED,
RELATIONSHIPS_CREATED,
NODES_DELETED,
RELATIONSHIPS_DELETED,
INDICES_CREATED,
INDICES_DELETED,
CACHED_EXECUTION,
INTERNAL_EXECUTION_TIME,
]
class ResultSetColumnTypes:
COLUMN_UNKNOWN = 0
COLUMN_SCALAR = 1
COLUMN_NODE = 2 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
COLUMN_RELATION = 3 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa
class ResultSetScalarTypes:
VALUE_UNKNOWN = 0
VALUE_NULL = 1
VALUE_STRING = 2
VALUE_INTEGER = 3
VALUE_BOOLEAN = 4
VALUE_DOUBLE = 5
VALUE_ARRAY = 6
VALUE_EDGE = 7
VALUE_NODE = 8
VALUE_PATH = 9
VALUE_MAP = 10
VALUE_POINT = 11
class QueryResult:
def __init__(self, graph, response, profile=False):
"""
A class that represents a result of the query operation.
Args:
graph:
The graph on which the query was executed.
response:
The response from the server.
profile:
A boolean indicating if the query command was "GRAPH.PROFILE"
"""
self.graph = graph
self.header = []
self.result_set = []
# in case of an error an exception will be raised
self._check_for_errors(response)
if len(response) == 1:
self.parse_statistics(response[0])
elif profile:
self.parse_profile(response)
else:
# start by parsing statistics, matches the one we have
self.parse_statistics(response[-1]) # Last element.
self.parse_results(response)
def _check_for_errors(self, response):
"""
Check if the response contains an error.
"""
if isinstance(response[0], ResponseError):
error = response[0]
if str(error) == "version mismatch":
version = response[1]
error = VersionMismatchException(version)
raise error
# If we encountered a run-time error, the last response
# element will be an exception
if isinstance(response[-1], ResponseError):
raise response[-1]
def parse_results(self, raw_result_set):
"""
Parse the query execution result returned from the server.
"""
self.header = self.parse_header(raw_result_set)
# Empty header.
if len(self.header) == 0:
return
self.result_set = self.parse_records(raw_result_set)
def parse_statistics(self, raw_statistics):
"""
Parse the statistics returned in the response.
"""
self.statistics = {}
# decode statistics
for idx, stat in enumerate(raw_statistics):
if isinstance(stat, bytes):
raw_statistics[idx] = stat.decode()
for s in STATS:
v = self._get_value(s, raw_statistics)
if v is not None:
self.statistics[s] = v
def parse_header(self, raw_result_set):
"""
Parse the header of the result.
"""
# An array of column name/column type pairs.
header = raw_result_set[0]
return header
def parse_records(self, raw_result_set):
"""
Parses the result set and returns a list of records.
"""
records = [
[
self.parse_record_types[self.header[idx][0]](cell)
for idx, cell in enumerate(row)
]
for row in raw_result_set[1]
]
return records
def parse_entity_properties(self, props):
"""
Parse node / edge properties.
"""
# [[name, value type, value] X N]
properties = {}
for prop in props:
prop_name = self.graph.get_property(prop[0])
prop_value = self.parse_scalar(prop[1:])
properties[prop_name] = prop_value
return properties
def parse_string(self, cell):
"""
Parse the cell as a string.
"""
if isinstance(cell, bytes):
return cell.decode()
elif not isinstance(cell, str):
return str(cell)
else:
return cell
def parse_node(self, cell):
"""
Parse the cell to a node.
"""
# Node ID (integer),
# [label string offset (integer)],
# [[name, value type, value] X N]
node_id = int(cell[0])
labels = None
if len(cell[1]) > 0:
labels = []
for inner_label in cell[1]:
labels.append(self.graph.get_label(inner_label))
properties = self.parse_entity_properties(cell[2])
return Node(node_id=node_id, label=labels, properties=properties)
def parse_edge(self, cell):
"""
Parse the cell to an edge.
"""
# Edge ID (integer),
# reltype string offset (integer),
# src node ID offset (integer),
# dest node ID offset (integer),
# [[name, value, value type] X N]
edge_id = int(cell[0])
relation = self.graph.get_relation(cell[1])
src_node_id = int(cell[2])
dest_node_id = int(cell[3])
properties = self.parse_entity_properties(cell[4])
return Edge(
src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties
)
def parse_path(self, cell):
"""
Parse the cell to a path.
"""
nodes = self.parse_scalar(cell[0])
edges = self.parse_scalar(cell[1])
return Path(nodes, edges)
def parse_map(self, cell):
"""
Parse the cell as a map.
"""
m = OrderedDict()
n_entries = len(cell)
# A map is an array of key value pairs.
# 1. key (string)
# 2. array: (value type, value)
for i in range(0, n_entries, 2):
key = self.parse_string(cell[i])
m[key] = self.parse_scalar(cell[i + 1])
return m
def parse_point(self, cell):
"""
Parse the cell to point.
"""
p = {}
# A point is received an array of the form: [latitude, longitude]
# It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa
p["latitude"] = float(cell[0])
p["longitude"] = float(cell[1])
return p
def parse_null(self, cell):
"""
Parse a null value.
"""
return None
def parse_integer(self, cell):
"""
Parse the integer value from the cell.
"""
return int(cell)
def parse_boolean(self, value):
"""
Parse the cell value as a boolean.
"""
value = value.decode() if isinstance(value, bytes) else value
try:
scalar = True if strtobool(value) else False
except ValueError:
sys.stderr.write("unknown boolean type\n")
scalar = None
return scalar
def parse_double(self, cell):
"""
Parse the cell as a double.
"""
return float(cell)
def parse_array(self, value):
"""
Parse an array of values.
"""
scalar = [self.parse_scalar(value[i]) for i in range(len(value))]
return scalar
def parse_unknown(self, cell):
"""
Parse a cell of unknown type.
"""
sys.stderr.write("Unknown type\n")
return None
def parse_scalar(self, cell):
"""
Parse a scalar value from a cell in the result set.
"""
scalar_type = int(cell[0])
value = cell[1]
scalar = self.parse_scalar_types[scalar_type](value)
return scalar
def parse_profile(self, response):
self.result_set = [x[0 : x.index(",")].strip() for x in response]
def is_empty(self):
return len(self.result_set) == 0
@staticmethod
def _get_value(prop, statistics):
for stat in statistics:
if prop in stat:
return float(stat.split(": ")[1].split(" ")[0])
return None
def _get_stat(self, stat):
return self.statistics[stat] if stat in self.statistics else 0
@property
def labels_added(self):
"""Returns the number of labels added in the query"""
return self._get_stat(LABELS_ADDED)
@property
def labels_removed(self):
"""Returns the number of labels removed in the query"""
return self._get_stat(LABELS_REMOVED)
@property
def nodes_created(self):
"""Returns the number of nodes created in the query"""
return self._get_stat(NODES_CREATED)
@property
def nodes_deleted(self):
"""Returns the number of nodes deleted in the query"""
return self._get_stat(NODES_DELETED)
@property
def properties_set(self):
"""Returns the number of properties set in the query"""
return self._get_stat(PROPERTIES_SET)
@property
def properties_removed(self):
"""Returns the number of properties removed in the query"""
return self._get_stat(PROPERTIES_REMOVED)
@property
def relationships_created(self):
"""Returns the number of relationships created in the query"""
return self._get_stat(RELATIONSHIPS_CREATED)
@property
def relationships_deleted(self):
"""Returns the number of relationships deleted in the query"""
return self._get_stat(RELATIONSHIPS_DELETED)
@property
def indices_created(self):
"""Returns the number of indices created in the query"""
return self._get_stat(INDICES_CREATED)
@property
def indices_deleted(self):
"""Returns the number of indices deleted in the query"""
return self._get_stat(INDICES_DELETED)
@property
def cached_execution(self):
"""Returns whether or not the query execution plan was cached"""
return self._get_stat(CACHED_EXECUTION) == 1
@property
def run_time_ms(self):
"""Returns the server execution time of the query"""
return self._get_stat(INTERNAL_EXECUTION_TIME)
@property
def parse_scalar_types(self):
return {
ResultSetScalarTypes.VALUE_NULL: self.parse_null,
ResultSetScalarTypes.VALUE_STRING: self.parse_string,
ResultSetScalarTypes.VALUE_INTEGER: self.parse_integer,
ResultSetScalarTypes.VALUE_BOOLEAN: self.parse_boolean,
ResultSetScalarTypes.VALUE_DOUBLE: self.parse_double,
ResultSetScalarTypes.VALUE_ARRAY: self.parse_array,
ResultSetScalarTypes.VALUE_NODE: self.parse_node,
ResultSetScalarTypes.VALUE_EDGE: self.parse_edge,
ResultSetScalarTypes.VALUE_PATH: self.parse_path,
ResultSetScalarTypes.VALUE_MAP: self.parse_map,
ResultSetScalarTypes.VALUE_POINT: self.parse_point,
ResultSetScalarTypes.VALUE_UNKNOWN: self.parse_unknown,
}
@property
def parse_record_types(self):
return {
ResultSetColumnTypes.COLUMN_SCALAR: self.parse_scalar,
ResultSetColumnTypes.COLUMN_NODE: self.parse_node,
ResultSetColumnTypes.COLUMN_RELATION: self.parse_edge,
ResultSetColumnTypes.COLUMN_UNKNOWN: self.parse_unknown,
}
class AsyncQueryResult(QueryResult):
"""
Async version for the QueryResult class - a class that
represents a result of the query operation.
"""
def __init__(self):
"""
To init the class you must call self.initialize()
"""
pass
async def initialize(self, graph, response, profile=False):
"""
Initializes the class.
Args:
graph:
The graph on which the query was executed.
response:
The response from the server.
profile:
A boolean indicating if the query command was "GRAPH.PROFILE"
"""
self.graph = graph
self.header = []
self.result_set = []
# in case of an error an exception will be raised
self._check_for_errors(response)
if len(response) == 1:
self.parse_statistics(response[0])
elif profile:
self.parse_profile(response)
else:
# start by parsing statistics, matches the one we have
self.parse_statistics(response[-1]) # Last element.
await self.parse_results(response)
return self
async def parse_node(self, cell):
"""
Parses a node from the cell.
"""
# Node ID (integer),
# [label string offset (integer)],
# [[name, value type, value] X N]
labels = None
if len(cell[1]) > 0:
labels = []
for inner_label in cell[1]:
labels.append(await self.graph.get_label(inner_label))
properties = await self.parse_entity_properties(cell[2])
node_id = int(cell[0])
return Node(node_id=node_id, label=labels, properties=properties)
async def parse_scalar(self, cell):
"""
Parses a scalar value from the server response.
"""
scalar_type = int(cell[0])
value = cell[1]
try:
scalar = await self.parse_scalar_types[scalar_type](value)
except TypeError:
# Not all of the functions are async
scalar = self.parse_scalar_types[scalar_type](value)
return scalar
async def parse_records(self, raw_result_set):
"""
Parses the result set and returns a list of records.
"""
records = []
for row in raw_result_set[1]:
record = [
await self.parse_record_types[self.header[idx][0]](cell)
for idx, cell in enumerate(row)
]
records.append(record)
return records
async def parse_results(self, raw_result_set):
"""
Parse the query execution result returned from the server.
"""
self.header = self.parse_header(raw_result_set)
# Empty header.
if len(self.header) == 0:
return
self.result_set = await self.parse_records(raw_result_set)
async def parse_entity_properties(self, props):
"""
Parse node / edge properties.
"""
# [[name, value type, value] X N]
properties = {}
for prop in props:
prop_name = await self.graph.get_property(prop[0])
prop_value = await self.parse_scalar(prop[1:])
properties[prop_name] = prop_value
return properties
async def parse_edge(self, cell):
"""
Parse the cell to an edge.
"""
# Edge ID (integer),
# reltype string offset (integer),
# src node ID offset (integer),
# dest node ID offset (integer),
# [[name, value, value type] X N]
edge_id = int(cell[0])
relation = await self.graph.get_relation(cell[1])
src_node_id = int(cell[2])
dest_node_id = int(cell[3])
properties = await self.parse_entity_properties(cell[4])
return Edge(
src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties
)
async def parse_path(self, cell):
"""
Parse the cell to a path.
"""
nodes = await self.parse_scalar(cell[0])
edges = await self.parse_scalar(cell[1])
return Path(nodes, edges)
async def parse_map(self, cell):
"""
Parse the cell to a map.
"""
m = OrderedDict()
n_entries = len(cell)
# A map is an array of key value pairs.
# 1. key (string)
# 2. array: (value type, value)
for i in range(0, n_entries, 2):
key = self.parse_string(cell[i])
m[key] = await self.parse_scalar(cell[i + 1])
return m
async def parse_array(self, value):
"""
Parse array value.
"""
scalar = [await self.parse_scalar(value[i]) for i in range(len(value))]
return scalar

View File

@@ -43,19 +43,32 @@ def parse_to_list(response):
"""Optimistically parse the response to a list."""
res = []
special_values = {"infinity", "nan", "-infinity"}
if response is None:
return res
for item in response:
if item is None:
res.append(None)
continue
try:
res.append(int(item))
except ValueError:
try:
res.append(float(item))
except ValueError:
res.append(nativestr(item))
item_str = nativestr(item)
except TypeError:
res.append(None)
continue
if isinstance(item_str, str) and item_str.lower() in special_values:
res.append(item_str) # Keep as string
else:
try:
res.append(int(item))
except ValueError:
try:
res.append(float(item))
except ValueError:
res.append(item_str)
return res
@@ -64,6 +77,11 @@ def parse_list_to_dict(response):
for i in range(0, len(response), 2):
if isinstance(response[i], list):
res["Child iterators"].append(parse_list_to_dict(response[i]))
try:
if isinstance(response[i + 1], list):
res["Child iterators"].append(parse_list_to_dict(response[i + 1]))
except IndexError:
pass
elif isinstance(response[i + 1], list):
res["Child iterators"] = [parse_list_to_dict(response[i + 1])]
else:
@@ -74,25 +92,6 @@ def parse_list_to_dict(response):
return res
def parse_to_dict(response):
if response is None:
return {}
res = {}
for det in response:
if isinstance(det[1], list):
res[det[0]] = parse_list_to_dict(det[1])
else:
try: # try to set the attribute. may be provided without value
try: # try to convert the value to float
res[det[0]] = float(det[1])
except (TypeError, ValueError):
res[det[0]] = det[1]
except IndexError:
pass
return res
def random_string(length=10):
"""
Returns a random N character long string.
@@ -102,26 +101,6 @@ def random_string(length=10):
)
def quote_string(v):
"""
RedisGraph strings must be quoted,
quote_string wraps given v with quotes incase
v is a string.
"""
if isinstance(v, bytes):
v = v.decode()
elif not isinstance(v, str):
return v
if len(v) == 0:
return '""'
v = v.replace("\\", "\\\\")
v = v.replace('"', '\\"')
return f'"{v}"'
def decode_dict_keys(obj):
"""Decode the keys of the given dictionary with utf-8."""
newobj = copy.copy(obj)
@@ -132,33 +111,6 @@ def decode_dict_keys(obj):
return newobj
def stringify_param_value(value):
"""
Turn a parameter value into a string suitable for the params header of
a Cypher command.
You may pass any value that would be accepted by `json.dumps()`.
Ways in which output differs from that of `str()`:
* Strings are quoted.
* None --> "null".
* In dictionaries, keys are _not_ quoted.
:param value: The parameter value to be turned into a string.
:return: string
"""
if isinstance(value, str):
return quote_string(value)
elif value is None:
return "null"
elif isinstance(value, (list, tuple)):
return f'[{",".join(map(stringify_param_value, value))}]'
elif isinstance(value, dict):
return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' # noqa
else:
return str(value)
def get_protocol_version(client):
if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis):
return client.connection_pool.connection_kwargs.get("protocol")

View File

@@ -120,7 +120,7 @@ class JSON(JSONCommands):
startup_nodes=self.client.nodes_manager.startup_nodes,
result_callbacks=self.client.result_callbacks,
cluster_response_callbacks=self.client.cluster_response_callbacks,
cluster_error_retry_attempts=self.client.cluster_error_retry_attempts,
cluster_error_retry_attempts=self.client.retry.get_retries(),
read_from_replicas=self.client.read_from_replicas,
reinitialize_steps=self.client.reinitialize_steps,
lock=self.client._lock,

View File

@@ -1,3 +1,5 @@
from typing import Any, Dict, List, Union
from typing import List, Mapping, Union
JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
JsonType = Union[
str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"]
]

View File

@@ -15,7 +15,7 @@ class JSONCommands:
def arrappend(
self, name: str, path: Optional[str] = Path.root_path(), *args: List[JsonType]
) -> List[Union[int, None]]:
) -> List[Optional[int]]:
"""Append the objects ``args`` to the array under the
``path` in key ``name``.
@@ -33,7 +33,7 @@ class JSONCommands:
scalar: int,
start: Optional[int] = None,
stop: Optional[int] = None,
) -> List[Union[int, None]]:
) -> List[Optional[int]]:
"""
Return the index of ``scalar`` in the JSON array under ``path`` at key
``name``.
@@ -49,11 +49,11 @@ class JSONCommands:
if stop is not None:
pieces.append(stop)
return self.execute_command("JSON.ARRINDEX", *pieces)
return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name])
def arrinsert(
self, name: str, path: str, index: int, *args: List[JsonType]
) -> List[Union[int, None]]:
) -> List[Optional[int]]:
"""Insert the objects ``args`` to the array at index ``index``
under the ``path` in key ``name``.
@@ -66,20 +66,20 @@ class JSONCommands:
def arrlen(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Union[int, None]]:
) -> List[Optional[int]]:
"""Return the length of the array JSON value under ``path``
at key``name``.
For more information see `JSON.ARRLEN <https://redis.io/commands/json.arrlen>`_.
""" # noqa
return self.execute_command("JSON.ARRLEN", name, str(path))
return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name])
def arrpop(
self,
name: str,
path: Optional[str] = Path.root_path(),
index: Optional[int] = -1,
) -> List[Union[str, None]]:
) -> List[Optional[str]]:
"""Pop the element at ``index`` in the array JSON value under
``path`` at key ``name``.
@@ -89,7 +89,7 @@ class JSONCommands:
def arrtrim(
self, name: str, path: str, start: int, stop: int
) -> List[Union[int, None]]:
) -> List[Optional[int]]:
"""Trim the array JSON value under ``path`` at key ``name`` to the
inclusive range given by ``start`` and ``stop``.
@@ -102,32 +102,34 @@ class JSONCommands:
For more information see `JSON.TYPE <https://redis.io/commands/json.type>`_.
""" # noqa
return self.execute_command("JSON.TYPE", name, str(path))
return self.execute_command("JSON.TYPE", name, str(path), keys=[name])
def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List:
"""Return the JSON value under ``path`` at key ``name``.
For more information see `JSON.RESP <https://redis.io/commands/json.resp>`_.
""" # noqa
return self.execute_command("JSON.RESP", name, str(path))
return self.execute_command("JSON.RESP", name, str(path), keys=[name])
def objkeys(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Union[List[str], None]]:
) -> List[Optional[List[str]]]:
"""Return the key names in the dictionary JSON value under ``path`` at
key ``name``.
For more information see `JSON.OBJKEYS <https://redis.io/commands/json.objkeys>`_.
""" # noqa
return self.execute_command("JSON.OBJKEYS", name, str(path))
return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name])
def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int:
def objlen(
self, name: str, path: Optional[str] = Path.root_path()
) -> List[Optional[int]]:
"""Return the length of the dictionary JSON value under ``path`` at key
``name``.
For more information see `JSON.OBJLEN <https://redis.io/commands/json.objlen>`_.
""" # noqa
return self.execute_command("JSON.OBJLEN", name, str(path))
return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name])
def numincrby(self, name: str, path: str, number: int) -> str:
"""Increment the numeric (integer or floating point) JSON value under
@@ -173,7 +175,7 @@ class JSONCommands:
def get(
self, name: str, *args, no_escape: Optional[bool] = False
) -> List[JsonType]:
) -> Optional[List[JsonType]]:
"""
Get the object stored as a JSON value at key ``name``.
@@ -197,7 +199,7 @@ class JSONCommands:
# Handle case where key doesn't exist. The JSONDecoder would raise a
# TypeError exception since it can't decode None
try:
return self.execute_command("JSON.GET", *pieces)
return self.execute_command("JSON.GET", *pieces, keys=[name])
except TypeError:
return None
@@ -211,7 +213,7 @@ class JSONCommands:
pieces = []
pieces += keys
pieces.append(str(path))
return self.execute_command("JSON.MGET", *pieces)
return self.execute_command("JSON.MGET", *pieces, keys=keys)
def set(
self,
@@ -312,7 +314,7 @@ class JSONCommands:
"""
with open(file_name, "r") as fp:
with open(file_name) as fp:
file_content = loads(fp.read())
return self.set(name, path, file_content, nx=nx, xx=xx, decode_keys=decode_keys)
@@ -324,7 +326,7 @@ class JSONCommands:
nx: Optional[bool] = False,
xx: Optional[bool] = False,
decode_keys: Optional[bool] = False,
) -> List[Dict[str, bool]]:
) -> Dict[str, bool]:
"""
Iterate over ``root_folder`` and set each JSON file to a value
under ``json_path`` with the file name as the key.
@@ -355,7 +357,7 @@ class JSONCommands:
return set_files_result
def strlen(self, name: str, path: Optional[str] = None) -> List[Union[int, None]]:
def strlen(self, name: str, path: Optional[str] = None) -> List[Optional[int]]:
"""Return the length of the string JSON value under ``path`` at key
``name``.
@@ -364,7 +366,7 @@ class JSONCommands:
pieces = [name]
if path is not None:
pieces.append(str(path))
return self.execute_command("JSON.STRLEN", *pieces)
return self.execute_command("JSON.STRLEN", *pieces, keys=[name])
def toggle(
self, name: str, path: Optional[str] = Path.root_path()
@@ -377,7 +379,7 @@ class JSONCommands:
return self.execute_command("JSON.TOGGLE", name, str(path))
def strappend(
self, name: str, value: str, path: Optional[int] = Path.root_path()
self, name: str, value: str, path: Optional[str] = Path.root_path()
) -> Union[int, List[Optional[int]]]:
"""Append to the string JSON value. If two options are specified after
the key name, the path is determined to be the first. If a single

View File

@@ -1,4 +1,14 @@
from __future__ import annotations
from json import JSONDecoder, JSONEncoder
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .bf import BFBloom, CFBloom, CMSBloom, TDigestBloom, TOPKBloom
from .json import JSON
from .search import AsyncSearch, Search
from .timeseries import TimeSeries
from .vectorset import VectorSet
class RedisModuleCommands:
@@ -6,7 +16,7 @@ class RedisModuleCommands:
modules into the command namespace.
"""
def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()):
def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()) -> JSON:
"""Access the json namespace, providing support for redis json."""
from .json import JSON
@@ -14,7 +24,7 @@ class RedisModuleCommands:
jj = JSON(client=self, encoder=encoder, decoder=decoder)
return jj
def ft(self, index_name="idx"):
def ft(self, index_name="idx") -> Search:
"""Access the search namespace, providing support for redis search."""
from .search import Search
@@ -22,7 +32,7 @@ class RedisModuleCommands:
s = Search(client=self, index_name=index_name)
return s
def ts(self):
def ts(self) -> TimeSeries:
"""Access the timeseries namespace, providing support for
redis timeseries data.
"""
@@ -32,7 +42,7 @@ class RedisModuleCommands:
s = TimeSeries(client=self)
return s
def bf(self):
def bf(self) -> BFBloom:
"""Access the bloom namespace."""
from .bf import BFBloom
@@ -40,7 +50,7 @@ class RedisModuleCommands:
bf = BFBloom(client=self)
return bf
def cf(self):
def cf(self) -> CFBloom:
"""Access the bloom namespace."""
from .bf import CFBloom
@@ -48,7 +58,7 @@ class RedisModuleCommands:
cf = CFBloom(client=self)
return cf
def cms(self):
def cms(self) -> CMSBloom:
"""Access the bloom namespace."""
from .bf import CMSBloom
@@ -56,7 +66,7 @@ class RedisModuleCommands:
cms = CMSBloom(client=self)
return cms
def topk(self):
def topk(self) -> TOPKBloom:
"""Access the bloom namespace."""
from .bf import TOPKBloom
@@ -64,7 +74,7 @@ class RedisModuleCommands:
topk = TOPKBloom(client=self)
return topk
def tdigest(self):
def tdigest(self) -> TDigestBloom:
"""Access the bloom namespace."""
from .bf import TDigestBloom
@@ -72,32 +82,20 @@ class RedisModuleCommands:
tdigest = TDigestBloom(client=self)
return tdigest
def graph(self, index_name="idx"):
"""Access the graph namespace, providing support for
redis graph data.
"""
def vset(self) -> VectorSet:
"""Access the VectorSet commands namespace."""
from .graph import Graph
from .vectorset import VectorSet
g = Graph(client=self, name=index_name)
return g
vset = VectorSet(client=self)
return vset
class AsyncRedisModuleCommands(RedisModuleCommands):
def ft(self, index_name="idx"):
def ft(self, index_name="idx") -> AsyncSearch:
"""Access the search namespace, providing support for redis search."""
from .search import AsyncSearch
s = AsyncSearch(client=self, index_name=index_name)
return s
def graph(self, index_name="idx"):
"""Access the graph namespace, providing support for
redis graph data.
"""
from .graph import AsyncGraph
g = AsyncGraph(client=self, name=index_name)
return g

View File

@@ -1,7 +1,7 @@
def to_string(s):
def to_string(s, encoding: str = "utf-8"):
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode("utf-8", "ignore")
return s.decode(encoding, "ignore")
else:
return s # Not a string we care about

View File

@@ -1,5 +1,7 @@
from typing import List, Union
from redis.commands.search.dialect import DEFAULT_DIALECT
FIELDNAME = object()
@@ -24,7 +26,7 @@ class Reducer:
NAME = None
def __init__(self, *args: List[str]) -> None:
def __init__(self, *args: str) -> None:
self._args = args
self._field = None
self._alias = None
@@ -110,9 +112,11 @@ class AggregateRequest:
self._with_schema = False
self._verbatim = False
self._cursor = []
self._dialect = None
self._dialect = DEFAULT_DIALECT
self._add_scores = False
self._scorer = "TFIDF"
def load(self, *fields: List[str]) -> "AggregateRequest":
def load(self, *fields: str) -> "AggregateRequest":
"""
Indicate the fields to be returned in the response. These fields are
returned in addition to any others implicitly specified.
@@ -219,7 +223,7 @@ class AggregateRequest:
self._aggregateplan.extend(_limit.build_args())
return self
def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
"""
Indicate how the results should be sorted. This can also be used for
*top-N* style queries
@@ -292,6 +296,24 @@ class AggregateRequest:
self._with_schema = True
return self
def add_scores(self) -> "AggregateRequest":
"""
If set, includes the score as an ordinary field of the row.
"""
self._add_scores = True
return self
def scorer(self, scorer: str) -> "AggregateRequest":
"""
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""
self._scorer = scorer
return self
def verbatim(self) -> "AggregateRequest":
self._verbatim = True
return self
@@ -315,12 +337,19 @@ class AggregateRequest:
if self._verbatim:
ret.append("VERBATIM")
if self._scorer:
ret.extend(["SCORER", self._scorer])
if self._add_scores:
ret.append("ADDSCORES")
if self._cursor:
ret += self._cursor
if self._loadall:
ret.append("LOAD")
ret.append("*")
elif self._loadfields:
ret.append("LOAD")
ret.append(str(len(self._loadfields)))

View File

@@ -2,13 +2,16 @@ import itertools
import time
from typing import Dict, List, Optional, Union
from redis.client import Pipeline
from redis.client import NEVER_DECODE, Pipeline
from redis.utils import deprecated_function
from ..helpers import get_protocol_version, parse_to_dict
from ..helpers import get_protocol_version
from ._util import to_string
from .aggregation import AggregateRequest, AggregateResult, Cursor
from .document import Document
from .field import Field
from .index_definition import IndexDefinition
from .profile_information import ProfileInformation
from .query import Query
from .result import Result
from .suggestion import SuggestionParser
@@ -20,7 +23,6 @@ ALTER_CMD = "FT.ALTER"
SEARCH_CMD = "FT.SEARCH"
ADD_CMD = "FT.ADD"
ADDHASH_CMD = "FT.ADDHASH"
DROP_CMD = "FT.DROP"
DROPINDEX_CMD = "FT.DROPINDEX"
EXPLAIN_CMD = "FT.EXPLAIN"
EXPLAINCLI_CMD = "FT.EXPLAINCLI"
@@ -32,7 +34,6 @@ SPELLCHECK_CMD = "FT.SPELLCHECK"
DICT_ADD_CMD = "FT.DICTADD"
DICT_DEL_CMD = "FT.DICTDEL"
DICT_DUMP_CMD = "FT.DICTDUMP"
GET_CMD = "FT.GET"
MGET_CMD = "FT.MGET"
CONFIG_CMD = "FT.CONFIG"
TAGVALS_CMD = "FT.TAGVALS"
@@ -65,7 +66,7 @@ class SearchCommands:
def _parse_results(self, cmd, res, **kwargs):
if get_protocol_version(self.client) in ["3", 3]:
return res
return ProfileInformation(res) if cmd == "FT.PROFILE" else res
else:
return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs)
@@ -80,6 +81,7 @@ class SearchCommands:
duration=kwargs["duration"],
has_payload=kwargs["query"]._with_payloads,
with_scores=kwargs["query"]._with_scores,
field_encodings=kwargs["query"]._return_fields_decode_as,
)
def _parse_aggregate(self, res, **kwargs):
@@ -98,7 +100,7 @@ class SearchCommands:
with_scores=query._with_scores,
)
return result, parse_to_dict(res[1])
return result, ProfileInformation(res[1])
def _parse_spellcheck(self, res, **kwargs):
corrections = {}
@@ -151,44 +153,43 @@ class SearchCommands:
def create_index(
self,
fields,
no_term_offsets=False,
no_field_flags=False,
stopwords=None,
definition=None,
fields: List[Field],
no_term_offsets: bool = False,
no_field_flags: bool = False,
stopwords: Optional[List[str]] = None,
definition: Optional[IndexDefinition] = None,
max_text_fields=False,
temporary=None,
no_highlight=False,
no_term_frequencies=False,
skip_initial_scan=False,
no_highlight: bool = False,
no_term_frequencies: bool = False,
skip_initial_scan: bool = False,
):
"""
Create the search index. The index must not already exist.
Creates the search index. The index must not already exist.
### Parameters:
For more information, see https://redis.io/commands/ft.create/
- **fields**: a list of TextField or NumericField objects
- **no_term_offsets**: If true, we will not save term offsets in
the index
- **no_field_flags**: If true, we will not save field flags that
allow searching in specific fields
- **stopwords**: If not None, we create the index with this custom
stopword list. The list can be empty
- **max_text_fields**: If true, we will encode indexes as if there
were more than 32 text fields which allows you to add additional
fields (beyond 32).
- **temporary**: Create a lightweight temporary index which will
expire after the specified period of inactivity (in seconds). The
internal idle timer is reset whenever the index is searched or added to.
- **no_highlight**: If true, disabling highlighting support.
Also implied by no_term_offsets.
- **no_term_frequencies**: If true, we avoid saving the term frequencies
in the index.
- **skip_initial_scan**: If true, we do not scan and index.
For more information see `FT.CREATE <https://redis.io/commands/ft.create>`_.
""" # noqa
Args:
fields: A list of Field objects.
no_term_offsets: If `true`, term offsets will not be saved in the index.
no_field_flags: If true, field flags that allow searching in specific fields
will not be saved.
stopwords: If provided, the index will be created with this custom stopword
list. The list can be empty.
definition: If provided, the index will be created with this custom index
definition.
max_text_fields: If true, indexes will be encoded as if there were more than
32 text fields, allowing for additional fields beyond 32.
temporary: Creates a lightweight temporary index which will expire after the
specified period of inactivity. The internal idle timer is reset
whenever the index is searched or added to.
no_highlight: If true, disables highlighting support. Also implied by
`no_term_offsets`.
no_term_frequencies: If true, term frequencies will not be saved in the
index.
skip_initial_scan: If true, the initial scan and indexing will be skipped.
"""
args = [CREATE_CMD, self.index_name]
if definition is not None:
args += definition.args
@@ -252,8 +253,18 @@ class SearchCommands:
For more information see `FT.DROPINDEX <https://redis.io/commands/ft.dropindex>`_.
""" # noqa
delete_str = "DD" if delete_documents else ""
return self.execute_command(DROPINDEX_CMD, self.index_name, delete_str)
args = [DROPINDEX_CMD, self.index_name]
delete_str = (
"DD"
if isinstance(delete_documents, bool) and delete_documents is True
else ""
)
if delete_str:
args.append(delete_str)
return self.execute_command(*args)
def _add_document(
self,
@@ -335,30 +346,30 @@ class SearchCommands:
"""
Add a single document to the index.
### Parameters
Args:
- **doc_id**: the id of the saved document.
- **nosave**: if set to true, we just index the document, and don't
doc_id: the id of the saved document.
nosave: if set to true, we just index the document, and don't
save a copy of it. This means that searches will just
return ids.
- **score**: the document ranking, between 0.0 and 1.0
- **payload**: optional inner-index payload we can save for fast
i access in scoring functions
- **replace**: if True, and the document already is in the index,
we perform an update and reindex the document
- **partial**: if True, the fields specified will be added to the
score: the document ranking, between 0.0 and 1.0
payload: optional inner-index payload we can save for fast
access in scoring functions
replace: if True, and the document already is in the index,
we perform an update and reindex the document
partial: if True, the fields specified will be added to the
existing document.
This has the added benefit that any fields specified
with `no_index`
will not be reindexed again. Implies `replace`
- **language**: Specify the language used for document tokenization.
- **no_create**: if True, the document is only updated and reindexed
language: Specify the language used for document tokenization.
no_create: if True, the document is only updated and reindexed
if it already exists.
If the document does not exist, an error will be
returned. Implies `replace`
- **fields** kwargs dictionary of the document fields to be saved
and/or indexed.
NOTE: Geo points shoule be encoded as strings of "lon,lat"
fields: kwargs dictionary of the document fields to be saved
and/or indexed.
NOTE: Geo points shoule be encoded as strings of "lon,lat"
""" # noqa
return self._add_document(
doc_id,
@@ -393,6 +404,7 @@ class SearchCommands:
doc_id, conn=None, score=score, language=language, replace=replace
)
@deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0")
def delete_document(self, doc_id, conn=None, delete_actual_document=False):
"""
Delete a document from index
@@ -427,6 +439,7 @@ class SearchCommands:
return Document(id=id, **fields)
@deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0")
def get(self, *ids):
"""
Returns the full contents of multiple documents.
@@ -497,14 +510,19 @@ class SearchCommands:
For more information see `FT.SEARCH <https://redis.io/commands/ft.search>`_.
""" # noqa
args, query = self._mk_query_args(query, query_params=query_params)
st = time.time()
res = self.execute_command(SEARCH_CMD, *args)
st = time.monotonic()
options = {}
if get_protocol_version(self.client) not in ["3", 3]:
options[NEVER_DECODE] = True
res = self.execute_command(SEARCH_CMD, *args, **options)
if isinstance(res, Pipeline):
return res
return self._parse_results(
SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
)
def explain(
@@ -524,7 +542,7 @@ class SearchCommands:
def aggregate(
self,
query: Union[str, Query],
query: Union[AggregateRequest, Cursor],
query_params: Dict[str, Union[str, int, float]] = None,
):
"""
@@ -555,7 +573,7 @@ class SearchCommands:
)
def _get_aggregate_result(
self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool
self, raw: List, query: Union[AggregateRequest, Cursor], has_cursor: bool
):
if has_cursor:
if isinstance(query, Cursor):
@@ -578,7 +596,7 @@ class SearchCommands:
def profile(
self,
query: Union[str, Query, AggregateRequest],
query: Union[Query, AggregateRequest],
limited: bool = False,
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
):
@@ -588,13 +606,13 @@ class SearchCommands:
### Parameters
**query**: This can be either an `AggregateRequest`, `Query` or string.
**query**: This can be either an `AggregateRequest` or `Query`.
**limited**: If set to True, removes details of reader iterator.
**query_params**: Define one or more value parameters.
Each parameter has a name and a value.
"""
st = time.time()
st = time.monotonic()
cmd = [PROFILE_CMD, self.index_name, ""]
if limited:
cmd.append("LIMITED")
@@ -613,20 +631,20 @@ class SearchCommands:
res = self.execute_command(*cmd)
return self._parse_results(
PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0
PROFILE_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
)
def spellcheck(self, query, distance=None, include=None, exclude=None):
"""
Issue a spellcheck query
### Parameters
Args:
**query**: search query.
**distance***: the maximal Levenshtein distance for spelling
query: search query.
distance: the maximal Levenshtein distance for spelling
suggestions (default: 1, max: 4).
**include**: specifies an inclusion custom dictionary.
**exclude**: specifies an exclusion custom dictionary.
include: specifies an inclusion custom dictionary.
exclude: specifies an exclusion custom dictionary.
For more information see `FT.SPELLCHECK <https://redis.io/commands/ft.spellcheck>`_.
""" # noqa
@@ -684,6 +702,10 @@ class SearchCommands:
cmd = [DICT_DUMP_CMD, name]
return self.execute_command(*cmd)
@deprecated_function(
version="8.0.0",
reason="deprecated since Redis 8.0, call config_set from core module instead",
)
def config_set(self, option: str, value: str) -> bool:
"""Set runtime configuration option.
@@ -698,6 +720,10 @@ class SearchCommands:
raw = self.execute_command(*cmd)
return raw == "OK"
@deprecated_function(
version="8.0.0",
reason="deprecated since Redis 8.0, call config_get from core module instead",
)
def config_get(self, option: str) -> str:
"""Get runtime configuration option value.
@@ -924,19 +950,24 @@ class AsyncSearchCommands(SearchCommands):
For more information see `FT.SEARCH <https://redis.io/commands/ft.search>`_.
""" # noqa
args, query = self._mk_query_args(query, query_params=query_params)
st = time.time()
res = await self.execute_command(SEARCH_CMD, *args)
st = time.monotonic()
options = {}
if get_protocol_version(self.client) not in ["3", 3]:
options[NEVER_DECODE] = True
res = await self.execute_command(SEARCH_CMD, *args, **options)
if isinstance(res, Pipeline):
return res
return self._parse_results(
SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0
SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0
)
async def aggregate(
self,
query: Union[str, Query],
query: Union[AggregateResult, Cursor],
query_params: Dict[str, Union[str, int, float]] = None,
):
"""
@@ -994,6 +1025,10 @@ class AsyncSearchCommands(SearchCommands):
return self._parse_results(SPELLCHECK_CMD, res)
@deprecated_function(
version="8.0.0",
reason="deprecated since Redis 8.0, call config_set from core module instead",
)
async def config_set(self, option: str, value: str) -> bool:
"""Set runtime configuration option.
@@ -1008,6 +1043,10 @@ class AsyncSearchCommands(SearchCommands):
raw = await self.execute_command(*cmd)
return raw == "OK"
@deprecated_function(
version="8.0.0",
reason="deprecated since Redis 8.0, call config_get from core module instead",
)
async def config_get(self, option: str) -> str:
"""Get runtime configuration option value.

View File

@@ -0,0 +1,3 @@
# Value for the default dialect to be used as a part of
# Search or Aggregate query.
DEFAULT_DIALECT = 2

View File

@@ -4,6 +4,10 @@ from redis import DataError
class Field:
"""
A class representing a field in a document.
"""
NUMERIC = "NUMERIC"
TEXT = "TEXT"
WEIGHT = "WEIGHT"
@@ -13,6 +17,9 @@ class Field:
SORTABLE = "SORTABLE"
NOINDEX = "NOINDEX"
AS = "AS"
GEOSHAPE = "GEOSHAPE"
INDEX_MISSING = "INDEXMISSING"
INDEX_EMPTY = "INDEXEMPTY"
def __init__(
self,
@@ -20,8 +27,24 @@ class Field:
args: List[str] = None,
sortable: bool = False,
no_index: bool = False,
index_missing: bool = False,
index_empty: bool = False,
as_name: str = None,
):
"""
Create a new field object.
Args:
name: The name of the field.
args:
sortable: If `True`, the field will be sortable.
no_index: If `True`, the field will not be indexed.
index_missing: If `True`, it will be possible to search for documents that
have this field missing.
index_empty: If `True`, it will be possible to search for documents that
have this field empty.
as_name: If provided, this alias will be used for the field.
"""
if args is None:
args = []
self.name = name
@@ -33,6 +56,10 @@ class Field:
self.args_suffix.append(Field.SORTABLE)
if no_index:
self.args_suffix.append(Field.NOINDEX)
if index_missing:
self.args_suffix.append(Field.INDEX_MISSING)
if index_empty:
self.args_suffix.append(Field.INDEX_EMPTY)
if no_index and not sortable:
raise ValueError("Non-Sortable non-Indexable fields are ignored")
@@ -91,6 +118,21 @@ class NumericField(Field):
Field.__init__(self, name, args=[Field.NUMERIC], **kwargs)
class GeoShapeField(Field):
"""
GeoShapeField is used to enable within/contain indexing/searching
"""
SPHERICAL = "SPHERICAL"
FLAT = "FLAT"
def __init__(self, name: str, coord_system=None, **kwargs):
args = [Field.GEOSHAPE]
if coord_system:
args.append(coord_system)
Field.__init__(self, name, args=args, **kwargs)
class GeoField(Field):
"""
GeoField is used to define a geo-indexing field in a schema definition
@@ -139,7 +181,7 @@ class VectorField(Field):
``name`` is the name of the field.
``algorithm`` can be "FLAT" or "HNSW".
``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA".
``attributes`` each algorithm can have specific attributes. Some of them
are mandatory and some of them are optional. See
@@ -152,10 +194,10 @@ class VectorField(Field):
if sort or noindex:
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
if algorithm.upper() not in ["FLAT", "HNSW"]:
if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]:
raise DataError(
"Realtime vector indexing supporting 2 Indexing Methods:"
"'FLAT' and 'HNSW'."
"Realtime vector indexing supporting 3 Indexing Methods:"
"'FLAT', 'HNSW', and 'SVS-VAMANA'."
)
attr_li = []

View File

@@ -0,0 +1,14 @@
from typing import Any
class ProfileInformation:
"""
Wrapper around FT.PROFILE response
"""
def __init__(self, info: Any) -> None:
self._info: Any = info
@property
def info(self) -> Any:
return self._info

View File

@@ -1,5 +1,7 @@
from typing import List, Optional, Union
from redis.commands.search.dialect import DEFAULT_DIALECT
class Query:
"""
@@ -35,11 +37,12 @@ class Query:
self._in_order: bool = False
self._sortby: Optional[SortbyField] = None
self._return_fields: List = []
self._return_fields_decode_as: dict = {}
self._summarize_fields: List = []
self._highlight_fields: List = []
self._language: Optional[str] = None
self._expander: Optional[str] = None
self._dialect: Optional[int] = None
self._dialect: int = DEFAULT_DIALECT
def query_string(self) -> str:
"""Return the query string of this query only."""
@@ -53,13 +56,27 @@ class Query:
def return_fields(self, *fields) -> "Query":
"""Add fields to return fields."""
self._return_fields += fields
for field in fields:
self.return_field(field)
return self
def return_field(self, field: str, as_field: Optional[str] = None) -> "Query":
"""Add field to return fields (Optional: add 'AS' name
to the field)."""
def return_field(
self,
field: str,
as_field: Optional[str] = None,
decode_field: Optional[bool] = True,
encoding: Optional[str] = "utf8",
) -> "Query":
"""
Add a field to the list of fields to return.
- **field**: The field to include in query results
- **as_field**: The alias for the field
- **decode_field**: Whether to decode the field from bytes to string
- **encoding**: The encoding to use when decoding the field
"""
self._return_fields.append(field)
self._return_fields_decode_as[field] = encoding if decode_field else None
if as_field is not None:
self._return_fields += ("AS", as_field)
return self
@@ -162,6 +179,8 @@ class Query:
Use a different scoring function to evaluate document relevance.
Default is `TFIDF`.
Since Redis 8.0 default was changed to BM25STD.
:param scorer: The scoring function to use
(e.g. `TFIDF.DOCNORM` or `BM25`)
"""

View File

@@ -1,3 +1,5 @@
from typing import Optional
from ._util import to_string
from .document import Document
@@ -9,11 +11,19 @@ class Result:
"""
def __init__(
self, res, hascontent, duration=0, has_payload=False, with_scores=False
self,
res,
hascontent,
duration=0,
has_payload=False,
with_scores=False,
field_encodings: Optional[dict] = None,
):
"""
- **snippets**: An optional dictionary of the form
{field: snippet_size} for snippet formatting
- duration: the execution time of the query
- has_payload: whether the query has payloads
- with_scores: whether the query has scores
- field_encodings: a dictionary of field encodings if any is provided
"""
self.total = res[0]
@@ -39,18 +49,22 @@ class Result:
fields = {}
if hascontent and res[i + fields_offset] is not None:
fields = (
dict(
dict(
zip(
map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]),
)
)
)
if hascontent
else {}
)
keys = map(to_string, res[i + fields_offset][::2])
values = res[i + fields_offset][1::2]
for key, value in zip(keys, values):
if field_encodings is None or key not in field_encodings:
fields[key] = to_string(value)
continue
encoding = field_encodings[key]
# If the encoding is None, we don't need to decode the value
if encoding is None:
fields[key] = value
else:
fields[key] = to_string(value, encoding=encoding)
try:
del fields["id"]
except KeyError:

View File

@@ -11,16 +11,35 @@ class SentinelCommands:
"""Redis Sentinel's SENTINEL command."""
warnings.warn(DeprecationWarning("Use the individual sentinel_* methods"))
def sentinel_get_master_addr_by_name(self, service_name):
"""Returns a (host, port) pair for the given ``service_name``"""
return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name)
def sentinel_get_master_addr_by_name(self, service_name, return_responses=False):
"""
Returns a (host, port) pair for the given ``service_name`` when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL GET-MASTER-ADDR-BY-NAME",
service_name,
once=True,
return_responses=return_responses,
)
def sentinel_master(self, service_name):
"""Returns a dictionary containing the specified masters state."""
return self.execute_command("SENTINEL MASTER", service_name)
def sentinel_master(self, service_name, return_responses=False):
"""
Returns a dictionary containing the specified masters state, when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL MASTER", service_name, return_responses=return_responses
)
def sentinel_masters(self):
"""Returns a list of dictionaries containing each master's state."""
"""
Returns a list of dictionaries containing each master's state.
Important: This function is called by the Sentinel implementation and is
called directly on the Redis standalone client for sentinels,
so it doesn't support the "once" and "return_responses" options.
"""
return self.execute_command("SENTINEL MASTERS")
def sentinel_monitor(self, name, ip, port, quorum):
@@ -31,16 +50,27 @@ class SentinelCommands:
"""Remove a master from Sentinel's monitoring"""
return self.execute_command("SENTINEL REMOVE", name)
def sentinel_sentinels(self, service_name):
"""Returns a list of sentinels for ``service_name``"""
return self.execute_command("SENTINEL SENTINELS", service_name)
def sentinel_sentinels(self, service_name, return_responses=False):
"""
Returns a list of sentinels for ``service_name``, when return_responses is True,
otherwise returns a boolean value that indicates if the command was successful.
"""
return self.execute_command(
"SENTINEL SENTINELS", service_name, return_responses=return_responses
)
def sentinel_set(self, name, option, value):
"""Set Sentinel monitoring parameters for a given master"""
return self.execute_command("SENTINEL SET", name, option, value)
def sentinel_slaves(self, service_name):
"""Returns a list of slaves for ``service_name``"""
"""
Returns a list of slaves for ``service_name``
Important: This function is called by the Sentinel implementation and is
called directly on the Redis standalone client for sentinels,
so it doesn't support the "once" and "return_responses" options.
"""
return self.execute_command("SENTINEL SLAVES", service_name)
def sentinel_reset(self, pattern):

View File

@@ -84,7 +84,7 @@ class TimeSeries(TimeSeriesCommands):
startup_nodes=self.client.nodes_manager.startup_nodes,
result_callbacks=self.client.result_callbacks,
cluster_response_callbacks=self.client.cluster_response_callbacks,
cluster_error_retry_attempts=self.client.cluster_error_retry_attempts,
cluster_error_retry_attempts=self.client.retry.get_retries(),
read_from_replicas=self.client.read_from_replicas,
reinitialize_steps=self.client.reinitialize_steps,
lock=self.client._lock,

View File

@@ -6,7 +6,7 @@ class TSInfo:
"""
Hold information and statistics on the time-series.
Can be created using ``tsinfo`` command
https://oss.redis.com/redistimeseries/commands/#tsinfo.
https://redis.io/docs/latest/commands/ts.info/
"""
rules = []
@@ -57,7 +57,7 @@ class TSInfo:
Policy that will define handling of duplicate samples.
Can read more about on
https://oss.redis.com/redistimeseries/configuration/#duplicate_policy
https://redis.io/docs/latest/develop/data-types/timeseries/configuration/#duplicate_policy
"""
response = dict(zip(map(nativestr, args[::2]), args[1::2]))
self.rules = response.get("rules")
@@ -78,7 +78,7 @@ class TSInfo:
self.chunk_size = response["chunkSize"]
if "duplicatePolicy" in response:
self.duplicate_policy = response["duplicatePolicy"]
if type(self.duplicate_policy) == bytes:
if isinstance(self.duplicate_policy, bytes):
self.duplicate_policy = self.duplicate_policy.decode()
def get(self, item):

View File

@@ -5,7 +5,7 @@ def list_to_dict(aList):
return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))}
def parse_range(response):
def parse_range(response, **kwargs):
"""Parse range response. Used by TS.RANGE and TS.REVRANGE."""
return [tuple((r[0], float(r[1]))) for r in response]

View File

@@ -0,0 +1,46 @@
import json
from redis._parsers.helpers import pairs_to_dict
from redis.commands.vectorset.utils import (
parse_vemb_result,
parse_vlinks_result,
parse_vsim_result,
)
from ..helpers import get_protocol_version
from .commands import (
VEMB_CMD,
VGETATTR_CMD,
VINFO_CMD,
VLINKS_CMD,
VSIM_CMD,
VectorSetCommands,
)
class VectorSet(VectorSetCommands):
def __init__(self, client, **kwargs):
"""Create a new VectorSet client."""
# Set the module commands' callbacks
self._MODULE_CALLBACKS = {
VEMB_CMD: parse_vemb_result,
VGETATTR_CMD: lambda r: r and json.loads(r) or None,
}
self._RESP2_MODULE_CALLBACKS = {
VINFO_CMD: lambda r: r and pairs_to_dict(r) or None,
VSIM_CMD: parse_vsim_result,
VLINKS_CMD: parse_vlinks_result,
}
self._RESP3_MODULE_CALLBACKS = {}
self.client = client
self.execute_command = client.execute_command
if get_protocol_version(self.client) in ["3", 3]:
self._MODULE_CALLBACKS.update(self._RESP3_MODULE_CALLBACKS)
else:
self._MODULE_CALLBACKS.update(self._RESP2_MODULE_CALLBACKS)
for k, v in self._MODULE_CALLBACKS.items():
self.client.set_response_callback(k, v)

View File

@@ -0,0 +1,374 @@
import json
from enum import Enum
from typing import Awaitable, Dict, List, Optional, Union
from redis.client import NEVER_DECODE
from redis.commands.helpers import get_protocol_version
from redis.exceptions import DataError
from redis.typing import CommandsProtocol, EncodableT, KeyT, Number
VADD_CMD = "VADD"
VSIM_CMD = "VSIM"
VREM_CMD = "VREM"
VDIM_CMD = "VDIM"
VCARD_CMD = "VCARD"
VEMB_CMD = "VEMB"
VLINKS_CMD = "VLINKS"
VINFO_CMD = "VINFO"
VSETATTR_CMD = "VSETATTR"
VGETATTR_CMD = "VGETATTR"
VRANDMEMBER_CMD = "VRANDMEMBER"
class QuantizationOptions(Enum):
"""Quantization options for the VADD command."""
NOQUANT = "NOQUANT"
BIN = "BIN"
Q8 = "Q8"
class CallbacksOptions(Enum):
"""Options that can be set for the commands callbacks"""
RAW = "RAW"
WITHSCORES = "WITHSCORES"
ALLOW_DECODING = "ALLOW_DECODING"
RESP3 = "RESP3"
class VectorSetCommands(CommandsProtocol):
"""Redis VectorSet commands"""
def vadd(
self,
key: KeyT,
vector: Union[List[float], bytes],
element: str,
reduce_dim: Optional[int] = None,
cas: Optional[bool] = False,
quantization: Optional[QuantizationOptions] = None,
ef: Optional[Number] = None,
attributes: Optional[Union[dict, str]] = None,
numlinks: Optional[int] = None,
) -> Union[Awaitable[int], int]:
"""
Add vector ``vector`` for element ``element`` to a vector set ``key``.
``reduce_dim`` sets the dimensions to reduce the vector to.
If not provided, the vector is not reduced.
``cas`` is a boolean flag that indicates whether to use CAS (check-and-set style)
when adding the vector. If not provided, CAS is not used.
``quantization`` sets the quantization type to use.
If not provided, int8 quantization is used.
The options are:
- NOQUANT: No quantization
- BIN: Binary quantization
- Q8: Signed 8-bit quantization
``ef`` sets the exploration factor to use.
If not provided, the default exploration factor is used.
``attributes`` is a dictionary or json string that contains the attributes to set for the vector.
If not provided, no attributes are set.
``numlinks`` sets the number of links to create for the vector.
If not provided, the default number of links is used.
For more information see https://redis.io/commands/vadd
"""
if not vector or not element:
raise DataError("Both vector and element must be provided")
pieces = []
if reduce_dim:
pieces.extend(["REDUCE", reduce_dim])
values_pieces = []
if isinstance(vector, bytes):
values_pieces.extend(["FP32", vector])
else:
values_pieces.extend(["VALUES", len(vector)])
values_pieces.extend(vector)
pieces.extend(values_pieces)
pieces.append(element)
if cas:
pieces.append("CAS")
if quantization:
pieces.append(quantization.value)
if ef:
pieces.extend(["EF", ef])
if attributes:
if isinstance(attributes, dict):
# transform attributes to json string
attributes_json = json.dumps(attributes)
else:
attributes_json = attributes
pieces.extend(["SETATTR", attributes_json])
if numlinks:
pieces.extend(["M", numlinks])
return self.execute_command(VADD_CMD, key, *pieces)
def vsim(
self,
key: KeyT,
input: Union[List[float], bytes, str],
with_scores: Optional[bool] = False,
count: Optional[int] = None,
ef: Optional[Number] = None,
filter: Optional[str] = None,
filter_ef: Optional[str] = None,
truth: Optional[bool] = False,
no_thread: Optional[bool] = False,
epsilon: Optional[Number] = None,
) -> Union[
Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]],
Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]],
]:
"""
Compare a vector or element ``input`` with the other vectors in a vector set ``key``.
``with_scores`` sets if the results should be returned with the
similarity scores of the elements in the result.
``count`` sets the number of results to return.
``ef`` sets the exploration factor.
``filter`` sets filter that should be applied for the search.
``filter_ef`` sets the max filtering effort.
``truth`` when enabled forces the command to perform linear scan.
``no_thread`` when enabled forces the command to execute the search
on the data structure in the main thread.
``epsilon`` floating point between 0 and 1, if specified will return
only elements with distance no further than the specified one.
For more information see https://redis.io/commands/vsim
"""
if not input:
raise DataError("'input' should be provided")
pieces = []
options = {}
if isinstance(input, bytes):
pieces.extend(["FP32", input])
elif isinstance(input, list):
pieces.extend(["VALUES", len(input)])
pieces.extend(input)
else:
pieces.extend(["ELE", input])
if with_scores:
pieces.append("WITHSCORES")
options[CallbacksOptions.WITHSCORES.value] = True
if count:
pieces.extend(["COUNT", count])
if epsilon:
pieces.extend(["EPSILON", epsilon])
if ef:
pieces.extend(["EF", ef])
if filter:
pieces.extend(["FILTER", filter])
if filter_ef:
pieces.extend(["FILTER-EF", filter_ef])
if truth:
pieces.append("TRUTH")
if no_thread:
pieces.append("NOTHREAD")
return self.execute_command(VSIM_CMD, key, *pieces, **options)
def vdim(self, key: KeyT) -> Union[Awaitable[int], int]:
"""
Get the dimension of a vector set.
In the case of vectors that were populated using the `REDUCE`
option, for random projection, the vector set will report the size of
the projected (reduced) dimension.
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
For more information see https://redis.io/commands/vdim
"""
return self.execute_command(VDIM_CMD, key)
def vcard(self, key: KeyT) -> Union[Awaitable[int], int]:
"""
Get the cardinality(the number of elements) of a vector set with key ``key``.
Raises `redis.exceptions.ResponseError` if the vector set doesn't exist.
For more information see https://redis.io/commands/vcard
"""
return self.execute_command(VCARD_CMD, key)
def vrem(self, key: KeyT, element: str) -> Union[Awaitable[int], int]:
"""
Remove an element from a vector set.
For more information see https://redis.io/commands/vrem
"""
return self.execute_command(VREM_CMD, key, element)
def vemb(
self, key: KeyT, element: str, raw: Optional[bool] = False
) -> Union[
Awaitable[Optional[Union[List[EncodableT], Dict[str, EncodableT]]]],
Optional[Union[List[EncodableT], Dict[str, EncodableT]]],
]:
"""
Get the approximated vector of an element ``element`` from vector set ``key``.
``raw`` is a boolean flag that indicates whether to return the
interal representation used by the vector.
For more information see https://redis.io/commands/vembed
"""
options = {}
pieces = []
pieces.extend([key, element])
if get_protocol_version(self.client) in ["3", 3]:
options[CallbacksOptions.RESP3.value] = True
if raw:
pieces.append("RAW")
options[NEVER_DECODE] = True
if (
hasattr(self.client, "connection_pool")
and self.client.connection_pool.connection_kwargs["decode_responses"]
) or (
hasattr(self.client, "nodes_manager")
and self.client.nodes_manager.connection_kwargs["decode_responses"]
):
# allow decoding in the postprocessing callback
# if the user set decode_responses=True
# in the connection pool
options[CallbacksOptions.ALLOW_DECODING.value] = True
options[CallbacksOptions.RAW.value] = True
return self.execute_command(VEMB_CMD, *pieces, **options)
def vlinks(
self, key: KeyT, element: str, with_scores: Optional[bool] = False
) -> Union[
Awaitable[
Optional[
List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]
]
],
Optional[List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]],
]:
"""
Returns the neighbors for each level the element ``element`` exists in the vector set ``key``.
The result is a list of lists, where each list contains the neighbors for one level.
If the element does not exist, or if the vector set does not exist, None is returned.
If the ``WITHSCORES`` option is provided, the result is a list of dicts,
where each dict contains the neighbors for one level, with the scores as values.
For more information see https://redis.io/commands/vlinks
"""
options = {}
pieces = []
pieces.extend([key, element])
if with_scores:
pieces.append("WITHSCORES")
options[CallbacksOptions.WITHSCORES.value] = True
return self.execute_command(VLINKS_CMD, *pieces, **options)
def vinfo(self, key: KeyT) -> Union[Awaitable[dict], dict]:
"""
Get information about a vector set.
For more information see https://redis.io/commands/vinfo
"""
return self.execute_command(VINFO_CMD, key)
def vsetattr(
self, key: KeyT, element: str, attributes: Optional[Union[dict, str]] = None
) -> Union[Awaitable[int], int]:
"""
Associate or remove JSON attributes ``attributes`` of element ``element``
for vector set ``key``.
For more information see https://redis.io/commands/vsetattr
"""
if attributes is None:
attributes_json = "{}"
elif isinstance(attributes, dict):
# transform attributes to json string
attributes_json = json.dumps(attributes)
else:
attributes_json = attributes
return self.execute_command(VSETATTR_CMD, key, element, attributes_json)
def vgetattr(
self, key: KeyT, element: str
) -> Union[Optional[Awaitable[dict]], Optional[dict]]:
"""
Retrieve the JSON attributes of an element ``elemet`` for vector set ``key``.
If the element does not exist, or if the vector set does not exist, None is
returned.
For more information see https://redis.io/commands/vgetattr
"""
return self.execute_command(VGETATTR_CMD, key, element)
def vrandmember(
self, key: KeyT, count: Optional[int] = None
) -> Union[
Awaitable[Optional[Union[List[str], str]]], Optional[Union[List[str], str]]
]:
"""
Returns random elements from a vector set ``key``.
``count`` is the number of elements to return.
If ``count`` is not provided, a single element is returned as a single string.
If ``count`` is positive(smaller than the number of elements
in the vector set), the command returns a list with up to ``count``
distinct elements from the vector set
If ``count`` is negative, the command returns a list with ``count`` random elements,
potentially with duplicates.
If ``count`` is greater than the number of elements in the vector set,
only the entire set is returned as a list.
If the vector set does not exist, ``None`` is returned.
For more information see https://redis.io/commands/vrandmember
"""
pieces = []
pieces.append(key)
if count is not None:
pieces.append(count)
return self.execute_command(VRANDMEMBER_CMD, *pieces)

View File

@@ -0,0 +1,94 @@
from redis._parsers.helpers import pairs_to_dict
from redis.commands.vectorset.commands import CallbacksOptions
def parse_vemb_result(response, **options):
"""
Handle VEMB result since the command can returning different result
structures depending on input options and on quantization type of the vector set.
Parsing VEMB result into:
- List[Union[bytes, Union[int, float]]]
- Dict[str, Union[bytes, str, float]]
"""
if response is None:
return response
if options.get(CallbacksOptions.RAW.value):
result = {}
result["quantization"] = (
response[0].decode("utf-8")
if options.get(CallbacksOptions.ALLOW_DECODING.value)
else response[0]
)
result["raw"] = response[1]
result["l2"] = float(response[2])
if len(response) > 3:
result["range"] = float(response[3])
return result
else:
if options.get(CallbacksOptions.RESP3.value):
return response
result = []
for i in range(len(response)):
try:
result.append(int(response[i]))
except ValueError:
# if the value is not an integer, it should be a float
result.append(float(response[i]))
return result
def parse_vlinks_result(response, **options):
"""
Handle VLINKS result since the command can be returning different result
structures depending on input options.
Parsing VLINKS result into:
- List[List[str]]
- List[Dict[str, Number]]
"""
if response is None:
return response
if options.get(CallbacksOptions.WITHSCORES.value):
result = []
# Redis will return a list of list of strings.
# This list have to be transformed to list of dicts
for level_item in response:
level_data_dict = {}
for key, value in pairs_to_dict(level_item).items():
value = float(value)
level_data_dict[key] = value
result.append(level_data_dict)
return result
else:
# return the list of elements for each level
# list of lists
return response
def parse_vsim_result(response, **options):
"""
Handle VSIM result since the command can be returning different result
structures depending on input options.
Parsing VSIM result into:
- List[List[str]]
- List[Dict[str, Number]]
"""
if response is None:
return response
if options.get(CallbacksOptions.WITHSCORES.value):
# Redis will return a list of list of pairs.
# This list have to be transformed to dict
result_dict = {}
for key, value in pairs_to_dict(response).items():
value = float(value)
result_dict[key] = value
return result_dict
else:
# return the list of elements for each level
# list of lists
return response