diff --git a/pynequa/models.py b/pynequa/models.py index 5065021..32d8170 100644 --- a/pynequa/models.py +++ b/pynequa/models.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from abc import abstractmethod, ABC from dataclasses import dataclass, field from loguru import logger @@ -168,7 +168,7 @@ class QueryParams(AbstractParams): aggregations: Optional[List[str]] = field(default_factory=lambda: []) order_by: Optional[str] = None group_by: Optional[str] = None - advanced: Optional[List[AdvancedParams]] = field( + advanced: Optional[Union[AdvancedParams, List[AdvancedParams]]] = field( default_factory=lambda: []) debug: bool = False @@ -256,9 +256,14 @@ def _prepare_query_args(self, query_name: str) -> Dict: if self.group_by is not None: params["groupBy"] = self.group_by - if len(self.advanced) > 0: + advanced = ( + [self.advanced] + if isinstance(self.advanced, AdvancedParams) + else self.advanced + ) + if len(advanced) > 0: advanced_param_payload = {} - for advanced_param in self.advanced: + for advanced_param in advanced: column_name = advanced_param.col_name payload_value = advanced_param.generate_payload()[column_name] if column_name in advanced_param_payload: