diff --git a/README.md b/README.md index 12ed30b..5f6aedd 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ A python library to handle communication with Sinequa REST API. ``` ## Example Usage -``` +```python import pynequa from pynequa.models import QueryParams diff --git a/pynequa/api/api.py b/pynequa/api/api.py index 30d0bf0..da5a115 100644 --- a/pynequa/api/api.py +++ b/pynequa/api/api.py @@ -5,13 +5,13 @@ class API: - ''' - API Class handles all HTTP Requests + """ + API Class handles all HTTP Requests - Attributes: - base_url(string): REST API base URL for Sinequa instance - access_token(string): token for Sinequa authentication - ''' + Attributes: + base_url(string): REST API base URL for Sinequa instance + access_token(string): token for Sinequa authentication + """ def __init__(self, access_token: str, base_url: str) -> None: if not access_token or not base_url: @@ -21,9 +21,7 @@ def __init__(self, access_token: str, base_url: str) -> None: self.base_url = base_url def _get_headers(self) -> Dict: - headers = { - "Authorization": f"Bearer {self.access_token}" - } + headers = {"Authorization": f"Bearer {self.access_token}"} return headers def _get_url(self, endpoint) -> str: @@ -31,19 +29,22 @@ def _get_url(self, endpoint) -> str: def get(self, endpoint) -> Dict: """ - This method handles GET method. + This method handles GET method. """ - session = requests.Session() - resp = session.get(self._get_url(endpoint=endpoint), - headers=self._get_headers()) - session.close - return resp.json() + with requests.Session() as session: + resp = session.get( + self._get_url(endpoint=endpoint), headers=self._get_headers() + ) + return resp.json() def post(self, endpoint, payload) -> Dict: """ - This method handles POST method. + This method handles POST method. """ - session = requests.Session() - resp = session.post(self._get_url(endpoint=endpoint), - headers=self._get_headers(), json=payload) - return resp.json() + with requests.Session() as session: + resp = session.post( + self._get_url(endpoint=endpoint), + headers=self._get_headers(), + json=payload, + ) + return resp.json() diff --git a/pynequa/models.py b/pynequa/models.py index 553cf13..5065021 100644 --- a/pynequa/models.py +++ b/pynequa/models.py @@ -31,6 +31,7 @@ class TreeParams(AbstractParams): Possible values: '=', '!=', '<', '<=', '>=', '>', 'between', 'not between'. value (str): The filter value (required). """ + box: str = "" column: str = "" op: str = "" @@ -83,9 +84,37 @@ def generate_payload(self, **kwargs) -> Dict: @dataclass class AdvancedParams(AbstractParams): - col_name: str = "" - col_value: str = None - value: str or int = None + """ + AdvancedParams represents the elemental advanced params. + Remember following things: + + 1. col_name is required. + 2. col_value has to be either "str" or "List[str]". + 3. if col_value is not present, the value could be a dict of + "value" and "operator". + + + Example: + "advanced": { + "docformat": [ + "ppt", + "pdf" + ], + "modified": [ + { + "value": "2019-01-01", + "operator": ">=" + }, + { + "value": "2019-12-31", + "operator": "<=" + } + ] + } + """ + col_name: str + col_value: str or List[str] = None + value: str or int = None operator: str = None debug: bool = False @@ -94,11 +123,15 @@ def generate_payload(self, **kwargs) -> Dict: This method generates payload for AdvancedParams. """ - payload = { - self.col_name: self.col_value, - "value": self.value, - "operator": self.operator - } + payload = {} + # To prevent payloads with empty values + if self.col_name and self.col_value: + payload[self.col_name] = self.col_value + if self.value and self.operator: + payload[self.col_name] = { + "value": self.value, + "operator": self.operator + } if self.debug: logger.debug(payload) @@ -135,7 +168,8 @@ class QueryParams(AbstractParams): aggregations: Optional[List[str]] = field(default_factory=lambda: []) order_by: Optional[str] = None group_by: Optional[str] = None - advanced: Optional[AdvancedParams] = None + advanced: Optional[List[AdvancedParams]] = field( + default_factory=lambda: []) debug: bool = False def _prepare_query_args(self, query_name: str) -> Dict: @@ -222,8 +256,21 @@ def _prepare_query_args(self, query_name: str) -> Dict: if self.group_by is not None: params["groupBy"] = self.group_by - if self.advanced is not None: - params["advanced"] = self.advanced.generate_payload() + if len(self.advanced) > 0: + advanced_param_payload = {} + for advanced_param in self.advanced: + column_name = advanced_param.col_name + payload_value = advanced_param.generate_payload()[column_name] + if column_name in advanced_param_payload: + advanced_param_payload[column_name].append( + payload_value + ) + elif isinstance(payload_value, dict): + advanced_param_payload[column_name] = [payload_value] + else: + advanced_param_payload[column_name] = payload_value + + params["advanced"] = advanced_param_payload return params diff --git a/tests/test_models.py b/tests/test_models.py index d3d166f..1f1ee41 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,7 @@ -from pynequa.models import QueryParams +from pynequa.models import QueryParams, AdvancedParams import unittest import logging +import json class TestQueryParams(unittest.TestCase): @@ -12,10 +13,12 @@ def test_query_params_payload(self): """ qp = QueryParams( name="query", - search_text="What was Landsat-9 launched?" + search_text="What was Landsat-9 launched?", + page_size=20, ) payload = qp.generate_payload() + print(payload) logging.debug(payload) keys_which_must_be_in_payload = [ @@ -30,6 +33,58 @@ def test_query_params_payload(self): if key not in payload: self.assertEqual(key, "test", f"{key} is mising in payload") + def test_query_params_with_advanced_params(self): + """ + Test if advanced params are correctly + generated in query param payload or not. + """ + + ap1 = AdvancedParams( + col_name="collection", + col_value="accounting" + ) + + ap2 = AdvancedParams( + col_name="docformat", + col_value=["pdf", "docx"] + ) + + ap3 = AdvancedParams( + col_name="modified", + value="2019-01-01", + operator=">=" + ) + + ap4 = AdvancedParams( + col_name="modified", + value="2019-12-31", + operator="<=" + ) + + qp = QueryParams( + name="query", + search_text="What was Landsat-9 launched?", + advanced=[ + ap1, + ap2, + ap3, + ap4 + ] + ) + + payload = qp.generate_payload() + + expected_payload = { + "collection": "accounting", + "docformat": ["pdf", "docx"], + "modified": [ + {"value": "2019-01-01", "operator": ">="}, + {"value": "2019-12-31", "operator": "<="} + ] + } + + assert payload["advanced"] == expected_payload + if __name__ == '__main__': unittest.main()