From f245488138952f2f958119cd4ecd1469b839cf9c Mon Sep 17 00:00:00 2001 From: Steve Lorello <42971704+slorello89@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:10:02 -0400 Subject: [PATCH 1/2] supporting literals as tag type (#635) * supporting literals as tag type * fixing key-prefix issue --- aredis_om/model/model.py | 47 ++++++++++++++++++++++++++-------------- tests/test_hash_model.py | 22 +++++++++++++++++++ tests/test_json_model.py | 24 ++++++++++++++++++++ 3 files changed, 77 insertions(+), 16 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index ecacda3..7d7ebb5 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -14,6 +14,7 @@ ClassVar, Dict, List, + Literal, Mapping, Optional, Sequence, @@ -141,10 +142,10 @@ def embedded(cls): def is_supported_container_type(typ: Optional[type]) -> bool: # TODO: Wait, why don't we support indexing sets? - if typ == list or typ == tuple: + if typ == list or typ == tuple or typ == Literal: return True unwrapped = get_origin(typ) - return unwrapped == list or unwrapped == tuple + return unwrapped == list or unwrapped == tuple or unwrapped == Literal def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]): @@ -1414,6 +1415,8 @@ def outer_type_or_annotation(field): if not isinstance(field.annotation, type): raise AttributeError(f"could not extract outer type from field {field}") return field.annotation + elif get_origin(field.annotation) == Literal: + return str else: return field.annotation.__args__[0] @@ -2057,21 +2060,33 @@ def schema_for_type( # find any values marked as indexed. if is_container_type and not is_vector: field_type = get_origin(typ) - embedded_cls = get_args(typ) - if not embedded_cls: - log.warning( - "Model %s defined an empty list or tuple field: %s", cls, name + if field_type == Literal: + path = f"{json_path}.{name}" + return cls.schema_for_type( + path, + name, + name_prefix, + str, + field_info, + parent_type=field_type, + ) + else: + embedded_cls = get_args(typ) + if not embedded_cls: + log.warning( + "Model %s defined an empty list or tuple field: %s", cls, name + ) + return "" + path = f"{json_path}.{name}[*]" + embedded_cls = embedded_cls[0] + return cls.schema_for_type( + path, + name, + name_prefix, + embedded_cls, + field_info, + parent_type=field_type, ) - return "" - embedded_cls = embedded_cls[0] - return cls.schema_for_type( - f"{json_path}.{name}[*]", - name, - name_prefix, - embedded_cls, - field_info, - parent_type=field_type, - ) elif field_is_model: name_prefix = f"{name_prefix}_{name}" if name_prefix else name sub_fields = [] diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index de4bdb8..9533972 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -917,3 +917,25 @@ class TestUpdate(HashModel): rematerialized = await TestUpdate.find(TestUpdate.pk == t.pk).first() assert rematerialized.age == 34 + + +@py_test_mark_asyncio +async def test_literals(): + from typing import Literal + + class TestLiterals(HashModel): + flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") + + schema = TestLiterals.redisearch_schema() + + key_prefix = TestLiterals.make_key( + TestLiterals._meta.primary_key_pattern.format(pk="") + ) + assert schema == ( + f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | flavor TAG SEPARATOR |" + ) + await Migrator().run() + item = TestLiterals(flavor="pumpkin") + await item.save() + rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first() + assert rematerialized.pk == item.pk diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 47cfd2f..24f0f62 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1098,6 +1098,7 @@ class ModelWithIntPk(JsonModel): m = await ModelWithIntPk.find(ModelWithIntPk.my_id == 42).first() assert m.my_id == 42 + @py_test_mark_asyncio async def test_pagination(): class Test(JsonModel): @@ -1121,3 +1122,26 @@ async def get_page(cls, offset, limit): res = await Test.get_page(10, 30) assert len(res) == 30 assert res[0].num == 10 + + +@py_test_mark_asyncio +async def test_literals(): + from typing import Literal + + class TestLiterals(JsonModel): + flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") + + schema = TestLiterals.redisearch_schema() + + key_prefix = TestLiterals.make_key( + TestLiterals._meta.primary_key_pattern.format(pk="") + ) + assert schema == ( + f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | " + "$.flavor AS flavor TAG SEPARATOR |" + ) + await Migrator().run() + item = TestLiterals(flavor="pumpkin") + await item.save() + rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first() + assert rematerialized.pk == item.pk From 424b8424842f125708f86885614a1069e66701f4 Mon Sep 17 00:00:00 2001 From: Savannah Norem Date: Mon, 5 Aug 2024 12:19:40 -0400 Subject: [PATCH 2/2] Test the FindQuery class and how it turns expressions into Redis commands (#642) * added return_fields function, attempting to optionally limit fields returned by find * added call to get query, adding tests * cleaned up test file, added endswith test * cleaned up one more return fields line, added fuzzy matching and cleaned up some tests * remove changes - return_fields doesn't exist here, that'll be it's own branch * removing whitespace from blank lines for linter * fixing linter issues * ignorning erroneous error * making xfix tests more specific * linting fixes * removing validate return fields function as return_fields don't exist here --------- Co-authored-by: slorello89 --- aredis_om/model/encoders.py | 2 +- aredis_om/model/model.py | 11 +- tests/test_find_query.py | 449 ++++++++++++++++++++++++++++++++++++ tests/test_hash_model.py | 24 +- tests/test_json_model.py | 24 +- 5 files changed, 496 insertions(+), 14 deletions(-) create mode 100644 tests/test_find_query.py diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index f097a35..93142f4 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -90,7 +90,7 @@ def jsonable_encoder( sqlalchemy_safe=sqlalchemy_safe, ) if dataclasses.is_dataclass(obj): - return dataclasses.asdict(obj) + return dataclasses.asdict(obj) # type: ignore[call-overload] if isinstance(obj, Enum): return obj.value if isinstance(obj, PurePath): diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 7d7ebb5..27ebcc5 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -873,7 +873,9 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: return result - async def execute(self, exhaust_results=True, return_raw_result=False): + async def execute( + self, exhaust_results=True, return_raw_result=False, return_query_args=False + ): args: List[Union[str, bytes]] = [ "FT.SEARCH", self.model.Meta.index_name, @@ -898,6 +900,9 @@ async def execute(self, exhaust_results=True, return_raw_result=False): if self.nocontent: args.append("NOCONTENT") + if return_query_args: + return self.model.Meta.index_name, args + # Reset the cache if we're executing from offset 0. if self.offset == 0: self._model_cache.clear() @@ -931,6 +936,10 @@ async def execute(self, exhaust_results=True, return_raw_result=False): self._model_cache += _results return self._model_cache + async def get_query(self): + query = self.copy() + return await query.execute(return_query_args=True) + async def first(self): query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields) results = await query.execute(exhaust_results=False) diff --git a/tests/test_find_query.py b/tests/test_find_query.py new file mode 100644 index 0000000..ecd14e4 --- /dev/null +++ b/tests/test_find_query.py @@ -0,0 +1,449 @@ +# type: ignore + +import abc +import dataclasses +import datetime +import decimal +import uuid +from collections import namedtuple +from typing import Dict, List, Optional, Set, Union +from unittest import mock + +import pytest +import pytest_asyncio + +from aredis_om import ( + EmbeddedJsonModel, + Field, + FindQuery, + HashModel, + JsonModel, + Migrator, + NotFoundError, + QueryNotSupportedError, + RedisModelError, +) + +# We need to run this check as sync code (during tests) even in async mode +# because we call it in the top-level module scope. +from redis_om import has_redis_json +from tests._compat import EmailStr, PositiveInt, ValidationError + +from .conftest import py_test_mark_asyncio + + +if not has_redis_json(): + pytestmark = pytest.mark.skip + +today = datetime.date.today() + + +@pytest_asyncio.fixture +async def m(key_prefix, redis): + class BaseJsonModel(JsonModel, abc.ABC): + class Meta: + global_key_prefix = key_prefix + + class Note(EmbeddedJsonModel): + # TODO: This was going to be a full-text search example, but + # we can't index embedded documents for full-text search in + # the preview release. + description: str = Field(index=True) + created_on: datetime.datetime + + class Address(EmbeddedJsonModel): + address_line_1: str + address_line_2: Optional[str] = None + city: str = Field(index=True) + state: str + country: str + postal_code: str = Field(index=True) + note: Optional[Note] = None + + class Item(EmbeddedJsonModel): + price: decimal.Decimal + name: str = Field(index=True) + + class Order(EmbeddedJsonModel): + items: List[Item] + created_on: datetime.datetime + + class Member(BaseJsonModel): + first_name: str = Field(index=True, case_sensitive=True) + last_name: str = Field(index=True) + email: Optional[EmailStr] = Field(index=True, default=None) + join_date: datetime.date + age: Optional[PositiveInt] = Field(index=True, default=None) + bio: Optional[str] = Field(index=True, full_text_search=True, default="") + + # Creates an embedded model. + address: Address + + # Creates an embedded list of models. + orders: Optional[List[Order]] = None + + await Migrator().run() + + return namedtuple( + "Models", ["BaseJsonModel", "Note", "Address", "Item", "Order", "Member"] + )(BaseJsonModel, Note, Address, Item, Order, Member) + + +@pytest.fixture() +def address(m): + try: + yield m.Address( + address_line_1="1 Main St.", + city="Portland", + state="OR", + country="USA", + postal_code="11111", + ) + except Exception as e: + raise e + + +@pytest_asyncio.fixture() +async def members(address, m): + member1 = m.Member( + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + age=38, + join_date=today, + address=address, + bio="Andrew is a software engineer", + ) + + member2 = m.Member( + first_name="Kim", + last_name="Brookins", + email="k@example.com", + age=34, + join_date=today, + address=address, + bio="Kim is a newer hire", + ) + + member3 = m.Member( + first_name="Andrew", + last_name="Smith", + email="as@example.com", + age=100, + join_date=today, + address=address, + bio="Andrew is old", + ) + + await member1.save() + await member2.save() + await member3.save() + + yield member1, member2, member3 + + +@py_test_mark_asyncio +async def test_find_query_in(members, m): + # << means "in" + member1, member2, member3 = members + model_name, fq = await FindQuery( + expressions=[m.Member.pk << [member1.pk, member2.pk, member3.pk]], + model=m.Member, + ).get_query() + in_str = ( + "(@pk:{" + + str(member1.pk) + + "|" + + str(member2.pk) + + "|" + + str(member3.pk) + + "})" + ) + assert fq == ["FT.SEARCH", model_name, in_str, "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_not_in(members, m): + # >> means "not in" + member1, member2, member3 = members + model_name, fq = await FindQuery( + expressions=[m.Member.pk >> [member2.pk, member3.pk]], model=m.Member + ).get_query() + not_in_str = "-(@pk:{" + str(member2.pk) + "|" + str(member3.pk) + "})" + assert fq == ["FT.SEARCH", model_name, not_in_str, "LIMIT", 0, 1000] + + +# experssion testing; (==, !=, <, <=, >, >=, |, &, ~) +@py_test_mark_asyncio +async def test_find_query_eq(m): + model_name, fq = await FindQuery( + expressions=[m.Member.first_name == "Andrew"], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@first_name:{Andrew}", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_ne(m): + model_name, fq = await FindQuery( + expressions=[m.Member.first_name != "Andrew"], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "-(@first_name:{Andrew})", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_lt(m): + model_name, fq = await FindQuery( + expressions=[m.Member.age < 40], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@age:[-inf (40]", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_le(m): + model_name, fq = await FindQuery( + expressions=[m.Member.age <= 38], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@age:[-inf 38]", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_gt(m): + model_name, fq = await FindQuery( + expressions=[m.Member.age > 38], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@age:[(38 +inf]", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_ge(m): + model_name, fq = await FindQuery( + expressions=[m.Member.age >= 38], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@age:[38 +inf]", "LIMIT", 0, 1000] + + +# tests for sorting and text search with and, or, not +@py_test_mark_asyncio +async def test_find_query_sort(m): + model_name, fq = await FindQuery( + expressions=[m.Member.age > 0], model=m.Member, sort_fields=["age"] + ).get_query() + assert fq == [ + "FT.SEARCH", + model_name, + "@age:[(0 +inf]", + "LIMIT", + 0, + 1000, + "SORTBY", + "age", + "asc", + ] + + +@py_test_mark_asyncio +async def test_find_query_sort_desc(m): + model_name, fq = await FindQuery( + expressions=[m.Member.age > 0], model=m.Member, sort_fields=["-age"] + ).get_query() + assert fq == [ + "FT.SEARCH", + model_name, + "@age:[(0 +inf]", + "LIMIT", + 0, + 1000, + "SORTBY", + "age", + "desc", + ] + + +@py_test_mark_asyncio +async def test_find_query_text_search(m): + model_name, fq = await FindQuery( + expressions=[m.Member.bio == "test"], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@bio:{test}", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_text_search_and(m, members): + model_name, fq = await FindQuery( + expressions=[m.Member.age < 40, m.Member.first_name == "Andrew"], model=m.Member + ).get_query() + assert fq == [ + "FT.SEARCH", + model_name, + "(@age:[-inf (40]) (@first_name:{Andrew})", + "LIMIT", + 0, + 1000, + ] + + +@py_test_mark_asyncio +async def test_find_query_text_search_or(m, members): + model_name, fq = await FindQuery( + expressions=[(m.Member.age < 40) | (m.Member.first_name == "Andrew")], + model=m.Member, + ).get_query() + assert fq == [ + "FT.SEARCH", + model_name, + "(@age:[-inf (40])| (@first_name:{Andrew})", + "LIMIT", + 0, + 1000, + ] + + +@py_test_mark_asyncio +async def test_find_query_text_search_not(m): + model_name, fq = await FindQuery( + expressions=[~(m.Member.first_name == "Andrew")], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "-(@first_name:{Andrew})", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_text_search_not_and(m, members): + model_name, fq = await FindQuery( + expressions=[~((m.Member.first_name == "Andrew") & (m.Member.age < 40))], + model=m.Member, + ).get_query() + assert fq == [ + "FT.SEARCH", + model_name, + "-((@first_name:{Andrew}) (@age:[-inf (40]))", + "LIMIT", + 0, + 1000, + ] + + +@py_test_mark_asyncio +async def test_find_query_text_search_not_or(m, members): + model_name, fq = await FindQuery( + expressions=[~((m.Member.first_name == "Andrew") | (m.Member.age < 40))], + model=m.Member, + ).get_query() + assert fq == [ + "FT.SEARCH", + model_name, + "-((@first_name:{Andrew})| (@age:[-inf (40]))", + "LIMIT", + 0, + 1000, + ] + + +@py_test_mark_asyncio +async def test_find_query_text_search_not_or_and(m, members): + model_name, fq = await FindQuery( + expressions=[ + ~( + ((m.Member.first_name == "Andrew") | (m.Member.age < 40)) + & (m.Member.last_name == "Brookins") + ) + ], + model=m.Member, + ).get_query() + assert fq == [ + "FT.SEARCH", + model_name, + "-(((@first_name:{Andrew})| (@age:[-inf (40])) (@last_name:{Brookins}))", + "LIMIT", + 0, + 1000, + ] + + +# text search operators; contains, startswith, endswith, fuzzy +@py_test_mark_asyncio +async def test_find_query_text_contains(m): + model_name, fq = await FindQuery( + expressions=[m.Member.first_name.contains("drew")], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "(@first_name:{*drew*})", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_text_startswith(m): + model_name, fq = await FindQuery( + expressions=[m.Member.first_name.startswith("An")], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "(@first_name:{An*})", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_text_endswith(m): + model_name, fq = await FindQuery( + expressions=[m.Member.first_name.endswith("ew")], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "(@first_name:{*ew})", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_test_fuzzy(m): + model_name, fq = await FindQuery( + expressions=[m.Member.bio % "%newb%"], model=m.Member + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@bio_fts:%newb%", "LIMIT", 0, 1000] + + +# limit, offset, page_size +@py_test_mark_asyncio +async def test_find_query_limit_one(m): + model_name, fq = await FindQuery( + expressions=[m.Member.first_name == "Andrew"], model=m.Member, limit=1 + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@first_name:{Andrew}", "LIMIT", 0, 1] + + +@py_test_mark_asyncio +async def test_find_query_limit_offset(m): + model_name, fq = await FindQuery( + expressions=[m.Member.first_name == "Andrew"], model=m.Member, limit=1, offset=1 + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@first_name:{Andrew}", "LIMIT", 1, 1] + + +@py_test_mark_asyncio +async def test_find_query_page_size(m): + # note that this test in unintuitive. + # page_size gets resolved in a while True loop that makes copies of the intial query and adds the limit and offset each time + model_name, fq = await FindQuery( + expressions=[m.Member.first_name == "Andrew"], model=m.Member, page_size=1 + ).get_query() + assert fq == ["FT.SEARCH", model_name, "@first_name:{Andrew}", "LIMIT", 0, 1000] + + +@py_test_mark_asyncio +async def test_find_query_monster(m): + # test monster query with everything everywhere all at once + # including ors, nots, ands, less thans, greater thans, text search + model_name, fq = await FindQuery( + expressions=[ + ~( + ((m.Member.first_name == "Andrew") | (m.Member.age < 40)) + & ( + ( + m.Member.last_name.contains("oo") + | ~(m.Member.email.startswith("z")) + ) + ) + ) + ], + model=m.Member, + limit=1, + offset=1, + ).get_query() + assert fq == [ + "FT.SEARCH", + model_name, + "-(((@first_name:{Andrew})| (@age:[-inf (40])) (((@last_name:{*oo*}))| -((@email:{z*}))))", + "LIMIT", + 1, + 1, + ] diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 9533972..9737367 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -860,22 +860,34 @@ class TypeWithUuid(HashModel): async def test_xfix_queries(members, m): member1, member2, member3 = members - result = await m.Member.find(m.Member.first_name.startswith("And")).first() + result = await m.Member.find( + m.Member.first_name.startswith("And") and m.Member.last_name == "Brookins" + ).first() assert result.last_name == "Brookins" - result = await m.Member.find(m.Member.last_name.endswith("ins")).first() + result = await m.Member.find( + m.Member.last_name.endswith("ins") and m.Member.last_name == "Brookins" + ).first() assert result.last_name == "Brookins" - result = await m.Member.find(m.Member.last_name.contains("ook")).first() + result = await m.Member.find( + m.Member.last_name.contains("ook") and m.Member.last_name == "Brookins" + ).first() assert result.last_name == "Brookins" - result = await m.Member.find(m.Member.bio % "great*").first() + result = await m.Member.find( + m.Member.bio % "great*" and m.Member.first_name == "Andrew" + ).first() assert result.first_name == "Andrew" - result = await m.Member.find(m.Member.bio % "*rty").first() + result = await m.Member.find( + m.Member.bio % "*rty" and m.Member.first_name == "Andrew" + ).first() assert result.first_name == "Andrew" - result = await m.Member.find(m.Member.bio % "*eat*").first() + result = await m.Member.find( + m.Member.bio % "*eat*" and m.Member.first_name == "Andrew" + ).first() assert result.first_name == "Andrew" diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 24f0f62..ea27555 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -969,22 +969,34 @@ async def test_xfix_queries(m): age=34, ).save() - result = await m.Member.find(m.Member.first_name.startswith("Ste")).first() + result = await m.Member.find( + m.Member.first_name.startswith("Ste") and m.Member.first_name == "Steve" + ).first() assert result.first_name == "Steve" - result = await m.Member.find(m.Member.last_name.endswith("llo")).first() + result = await m.Member.find( + m.Member.last_name.endswith("llo") and m.Member.first_name == "Steve" + ).first() assert result.first_name == "Steve" - result = await m.Member.find(m.Member.address.city.contains("llite")).first() + result = await m.Member.find( + m.Member.address.city.contains("llite") and m.Member.first_name == "Steve" + ).first() assert result.first_name == "Steve" - result = await m.Member.find(m.Member.bio % "tw*").first() + result = await m.Member.find( + m.Member.bio % "tw*" and m.Member.first_name == "Steve" + ).first() assert result.first_name == "Steve" - result = await m.Member.find(m.Member.bio % "*cker").first() + result = await m.Member.find( + m.Member.bio % "*cker" and m.Member.first_name == "Steve" + ).first() assert result.first_name == "Steve" - result = await m.Member.find(m.Member.bio % "*ack*").first() + result = await m.Member.find( + m.Member.bio % "*ack*" and m.Member.first_name == "Steve" + ).first() assert result.first_name == "Steve"