Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/pip/setuptools-gte-69.2-and-lt-73.0
Browse files Browse the repository at this point in the history
  • Loading branch information
slorello89 committed Aug 6, 2024
2 parents 1daf1a1 + 424b842 commit 6724f1a
Show file tree
Hide file tree
Showing 5 changed files with 573 additions and 30 deletions.
2 changes: 1 addition & 1 deletion aredis_om/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def jsonable_encoder(
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
return dataclasses.asdict(obj) # type: ignore[call-overload]
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, PurePath):
Expand Down
58 changes: 41 additions & 17 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ClassVar,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -141,10 +142,10 @@ def embedded(cls):

def is_supported_container_type(typ: Optional[type]) -> bool:
# TODO: Wait, why don't we support indexing sets?
if typ == list or typ == tuple:
if typ == list or typ == tuple or typ == Literal:
return True
unwrapped = get_origin(typ)
return unwrapped == list or unwrapped == tuple
return unwrapped == list or unwrapped == tuple or unwrapped == Literal


def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
Expand Down Expand Up @@ -872,7 +873,9 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:

return result

async def execute(self, exhaust_results=True, return_raw_result=False):
async def execute(
self, exhaust_results=True, return_raw_result=False, return_query_args=False
):
args: List[Union[str, bytes]] = [
"FT.SEARCH",
self.model.Meta.index_name,
Expand All @@ -897,6 +900,9 @@ async def execute(self, exhaust_results=True, return_raw_result=False):
if self.nocontent:
args.append("NOCONTENT")

if return_query_args:
return self.model.Meta.index_name, args

# Reset the cache if we're executing from offset 0.
if self.offset == 0:
self._model_cache.clear()
Expand Down Expand Up @@ -930,6 +936,10 @@ async def execute(self, exhaust_results=True, return_raw_result=False):
self._model_cache += _results
return self._model_cache

async def get_query(self):
query = self.copy()
return await query.execute(return_query_args=True)

async def first(self):
query = self.copy(offset=0, limit=1, sort_fields=self.sort_fields)
results = await query.execute(exhaust_results=False)
Expand Down Expand Up @@ -1414,6 +1424,8 @@ def outer_type_or_annotation(field):
if not isinstance(field.annotation, type):
raise AttributeError(f"could not extract outer type from field {field}")
return field.annotation
elif get_origin(field.annotation) == Literal:
return str
else:
return field.annotation.__args__[0]

Expand Down Expand Up @@ -2057,21 +2069,33 @@ def schema_for_type(
# find any values marked as indexed.
if is_container_type and not is_vector:
field_type = get_origin(typ)
embedded_cls = get_args(typ)
if not embedded_cls:
log.warning(
"Model %s defined an empty list or tuple field: %s", cls, name
if field_type == Literal:
path = f"{json_path}.{name}"
return cls.schema_for_type(
path,
name,
name_prefix,
str,
field_info,
parent_type=field_type,
)
else:
embedded_cls = get_args(typ)
if not embedded_cls:
log.warning(
"Model %s defined an empty list or tuple field: %s", cls, name
)
return ""
path = f"{json_path}.{name}[*]"
embedded_cls = embedded_cls[0]
return cls.schema_for_type(
path,
name,
name_prefix,
embedded_cls,
field_info,
parent_type=field_type,
)
return ""
embedded_cls = embedded_cls[0]
return cls.schema_for_type(
f"{json_path}.{name}[*]",
name,
name_prefix,
embedded_cls,
field_info,
parent_type=field_type,
)
elif field_is_model:
name_prefix = f"{name_prefix}_{name}" if name_prefix else name
sub_fields = []
Expand Down
Loading

0 comments on commit 6724f1a

Please sign in to comment.