diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index a2bb23b76d..3b74b86bd0 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -1,4 +1,4 @@ -import redis +from redis.client import Pipeline as RedisPipeline from ...asyncio.client import Pipeline as AsyncioPipeline from .commands import ( @@ -181,9 +181,17 @@ def pipeline(self, transaction=True, shard_hint=None): return p -class Pipeline(SearchCommands, redis.client.Pipeline): +class Pipeline(SearchCommands, RedisPipeline): """Pipeline for the module.""" + def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): + super().__init__(connection_pool, response_callbacks, transaction, shard_hint) + self.index_name: str = "" -class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline, Pipeline): + +class AsyncPipeline(AsyncSearchCommands, AsyncioPipeline): """AsyncPipeline for the module.""" + + def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): + super().__init__(connection_pool, response_callbacks, transaction, shard_hint) + self.index_name: str = "" diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 00435f626b..b3f5895816 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional, Tuple, Union from redis.commands.search.dialect import DEFAULT_DIALECT @@ -27,9 +27,9 @@ class Reducer: NAME = None def __init__(self, *args: str) -> None: - self._args = args - self._field = None - self._alias = None + self._args: Tuple[str, ...] = args + self._field: Optional[str] = None + self._alias: Optional[str] = None def alias(self, alias: str) -> "Reducer": """ @@ -49,13 +49,14 @@ def alias(self, alias: str) -> "Reducer": if alias is FIELDNAME: if not self._field: raise ValueError("Cannot use FIELDNAME alias with no field") - # Chop off initial '@' - alias = self._field[1:] + else: + # Chop off initial '@' + alias = self._field[1:] self._alias = alias return self @property - def args(self) -> List[str]: + def args(self) -> Tuple[str, ...]: return self._args @@ -64,7 +65,7 @@ class SortDirection: This special class is used to indicate sort direction. """ - DIRSTRING = None + DIRSTRING: Optional[str] = None def __init__(self, field: str) -> None: self.field = field @@ -104,17 +105,17 @@ def __init__(self, query: str = "*") -> None: All member methods (except `build_args()`) return the object itself, making them useful for chaining. """ - self._query = query - self._aggregateplan = [] - self._loadfields = [] - self._loadall = False - self._max = 0 - self._with_schema = False - self._verbatim = False - self._cursor = [] - self._dialect = DEFAULT_DIALECT - self._add_scores = False - self._scorer = "TFIDF" + self._query: str = query + self._aggregateplan: List[str] = [] + self._loadfields: List[str] = [] + self._loadall: bool = False + self._max: int = 0 + self._with_schema: bool = False + self._verbatim: bool = False + self._cursor: List[str] = [] + self._dialect: int = DEFAULT_DIALECT + self._add_scores: bool = False + self._scorer: str = "TFIDF" def load(self, *fields: str) -> "AggregateRequest": """ @@ -133,7 +134,7 @@ def load(self, *fields: str) -> "AggregateRequest": return self def group_by( - self, fields: List[str], *reducers: Union[Reducer, List[Reducer]] + self, fields: Union[str, List[str]], *reducers: Reducer ) -> "AggregateRequest": """ Specify by which fields to group the aggregation. @@ -147,7 +148,6 @@ def group_by( `aggregation` module. """ fields = [fields] if isinstance(fields, str) else fields - reducers = [reducers] if isinstance(reducers, Reducer) else reducers ret = ["GROUPBY", str(len(fields)), *fields] for reducer in reducers: @@ -251,12 +251,10 @@ def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest": .sort_by(Desc("@paid"), max=10) ``` """ - if isinstance(fields, (str, SortDirection)): - fields = [fields] fields_args = [] for f in fields: - if isinstance(f, SortDirection): + if isinstance(f, (Asc, Desc)): fields_args += [f.field, f.DIRSTRING] else: fields_args += [f] @@ -356,7 +354,7 @@ def build_args(self) -> List[str]: ret.extend(self._loadfields) if self._dialect: - ret.extend(["DIALECT", self._dialect]) + ret.extend(["DIALECT", str(self._dialect)]) ret.extend(self._aggregateplan) @@ -393,7 +391,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None: self.cursor = cursor self.schema = schema - def __repr__(self) -> (str, str): + def __repr__(self) -> str: cid = self.cursor.cid if self.cursor else -1 return ( f"<{self.__class__.__name__} at 0x{id(self):x} " diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 80d9b35728..679218c636 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -64,6 +64,44 @@ class SearchCommands: """Search commands.""" + @property + def index_name(self) -> str: + """The name of the search index. Must be implemented by inheriting classes.""" + if not hasattr(self, "_index_name"): + raise AttributeError("index_name must be set by the inheriting class") + return self._index_name + + @index_name.setter + def index_name(self, value: str) -> None: + """Set the name of the search index.""" + self._index_name = value + + @property + def client(self): + """The Redis client. Must be provided by inheriting classes.""" + if not hasattr(self, "_client"): + raise AttributeError("client must be set by the inheriting class") + return self._client + + @client.setter + def client(self, value) -> None: + """Set the Redis client.""" + self._client = value + + @property + def _RESP2_MODULE_CALLBACKS(self): + """Response callbacks for RESP2. Must be provided by inheriting classes.""" + if not hasattr(self, "_resp2_module_callbacks"): + raise AttributeError( + "_RESP2_MODULE_CALLBACKS must be set by the inheriting class" + ) + return self._resp2_module_callbacks + + @_RESP2_MODULE_CALLBACKS.setter + def _RESP2_MODULE_CALLBACKS(self, value) -> None: + """Set the RESP2 module callbacks.""" + self._resp2_module_callbacks = value + def _parse_results(self, cmd, res, **kwargs): if get_protocol_version(self.client) in ["3", 3]: return ProfileInformation(res) if cmd == "FT.PROFILE" else res @@ -221,7 +259,7 @@ def create_index( return self.execute_command(*args) - def alter_schema_add(self, fields: List[str]): + def alter_schema_add(self, fields: Union[Field, List[Field]]): """ Alter the existing search index by adding new fields. The index must already exist. @@ -336,11 +374,11 @@ def add_document( doc_id: str, nosave: bool = False, score: float = 1.0, - payload: bool = None, + payload: Optional[bool] = None, replace: bool = False, partial: bool = False, language: Optional[str] = None, - no_create: str = False, + no_create: bool = False, **fields: List[str], ): """ @@ -464,7 +502,7 @@ def info(self): return self._parse_results(INFO_CMD, res) def get_params_args( - self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None] + self, query_params: Optional[Dict[str, Union[str, int, float, bytes]]] ): if query_params is None: return [] @@ -478,7 +516,7 @@ def get_params_args( return args def _mk_query_args( - self, query, query_params: Union[Dict[str, Union[str, int, float, bytes]], None] + self, query, query_params: Optional[Dict[str, Union[str, int, float, bytes]]] ): args = [self.index_name] @@ -528,7 +566,7 @@ def search( def explain( self, query: Union[str, Query], - query_params: Dict[str, Union[str, int, float]] = None, + query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None, ): """Returns the execution plan for a complex query. @@ -543,7 +581,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa def aggregate( self, query: Union[AggregateRequest, Cursor], - query_params: Dict[str, Union[str, int, float]] = None, + query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None, ): """ Issue an aggregation query. @@ -598,7 +636,7 @@ def profile( self, query: Union[Query, AggregateRequest], limited: bool = False, - query_params: Optional[Dict[str, Union[str, int, float]]] = None, + query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None, ): """ Performs a search or aggregate command and collects performance @@ -936,7 +974,7 @@ async def info(self): async def search( self, query: Union[str, Query], - query_params: Dict[str, Union[str, int, float]] = None, + query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None, ): """ Search the index for a given query, and return a result of documents @@ -968,7 +1006,7 @@ async def search( async def aggregate( self, query: Union[AggregateResult, Cursor], - query_params: Dict[str, Union[str, int, float]] = None, + query_params: Optional[Dict[str, Union[str, int, float, bytes]]] = None, ): """ Issue an aggregation query. diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index ee281fafd0..86d93b8f8a 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from redis.commands.search.dialect import DEFAULT_DIALECT @@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None: self._with_scores: bool = False self._scorer: Optional[str] = None self._filters: List = list() - self._ids: Optional[List[str]] = None + self._ids: Optional[Tuple[str, ...]] = None self._slop: int = -1 self._timeout: Optional[float] = None self._in_order: bool = False @@ -81,7 +81,7 @@ def return_field( self._return_fields += ("AS", as_field) return self - def _mk_field_list(self, fields: List[str]) -> List: + def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List: if not fields: return [] return [fields] if isinstance(fields, str) else list(fields) @@ -126,7 +126,7 @@ def summarize( def highlight( self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None - ) -> None: + ) -> "Query": """ Apply specified markup to matched term(s) within the returned field(s). @@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query": self._scorer = scorer return self - def get_args(self) -> List[str]: + def get_args(self) -> List[Union[str, int, float]]: """Format the redis arguments for this query and return them.""" - args = [self._query_string] + args: List[Union[str, int, float]] = [self._query_string] args += self._get_args_tags() args += self._summarize_fields + self._highlight_fields args += ["LIMIT", self._offset, self._num] return args - def _get_args_tags(self) -> List[str]: - args = [] + def _get_args_tags(self) -> List[Union[str, int, float]]: + args: List[Union[str, int, float]] = [] if self._no_content: args.append("NOCONTENT") if self._fields: @@ -288,14 +288,14 @@ def with_scores(self) -> "Query": self._with_scores = True return self - def limit_fields(self, *fields: List[str]) -> "Query": + def limit_fields(self, *fields: str) -> "Query": """ Limit the search to specific TEXT fields only. - - **fields**: A list of strings; case-sensitive field names + - **fields**: Each element should be a string, case sensitive field name from the defined schema. """ - self._fields = fields + self._fields = list(fields) return self def add_filter(self, flt: "Filter") -> "Query": @@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query": class Filter: - def __init__(self, keyword: str, field: str, *args: List[str]) -> None: + def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None: self.args = [keyword, field] + list(args)