Skip to content

Commit

Permalink
Bugfix advanced params (#28)
Browse files Browse the repository at this point in the history
* Remove operator and value in payload if they are None in AdvancedParams

Sample:

```python
{
    'app': 'vanilla-search',
    'query': {
        'name': 'query',
        'text': 'himawari',
        'isFirstpage': False,
        'strictRefine': False,
        'removeDuplicates': False,
        'action': 'search',
        'page': 1,
        'pageSize': 10,
        'advanced': {
            'collection': '/user_needs_database/snwg-assessments-2020/',
            # 'value': None, 'operator': None,
        }
    }
}
```

* Bugfix payload generation for AdvancedParams

Now col_name and values are checked to see if empty.

* added support for advanced params according to documentation

---------

Co-authored-by: anisbhsl <bhusal.anish12@gmail.com>
  • Loading branch information
NISH1001 and anisbhsl authored Feb 27, 2024
1 parent 8878504 commit 4fabf47
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 34 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ A python library to handle communication with Sinequa REST API.
```

## Example Usage
```
```python
import pynequa
from pynequa.models import QueryParams

Expand Down
41 changes: 21 additions & 20 deletions pynequa/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@


class API:
'''
API Class handles all HTTP Requests
"""
API Class handles all HTTP Requests
Attributes:
base_url(string): REST API base URL for Sinequa instance
access_token(string): token for Sinequa authentication
'''
Attributes:
base_url(string): REST API base URL for Sinequa instance
access_token(string): token for Sinequa authentication
"""

def __init__(self, access_token: str, base_url: str) -> None:
if not access_token or not base_url:
Expand All @@ -21,29 +21,30 @@ def __init__(self, access_token: str, base_url: str) -> None:
self.base_url = base_url

def _get_headers(self) -> Dict:
headers = {
"Authorization": f"Bearer {self.access_token}"
}
headers = {"Authorization": f"Bearer {self.access_token}"}
return headers

def _get_url(self, endpoint) -> str:
return os.path.join(self.base_url, endpoint)

def get(self, endpoint) -> Dict:
"""
This method handles GET method.
This method handles GET method.
"""
session = requests.Session()
resp = session.get(self._get_url(endpoint=endpoint),
headers=self._get_headers())
session.close
return resp.json()
with requests.Session() as session:
resp = session.get(
self._get_url(endpoint=endpoint), headers=self._get_headers()
)
return resp.json()

def post(self, endpoint, payload) -> Dict:
"""
This method handles POST method.
This method handles POST method.
"""
session = requests.Session()
resp = session.post(self._get_url(endpoint=endpoint),
headers=self._get_headers(), json=payload)
return resp.json()
with requests.Session() as session:
resp = session.post(
self._get_url(endpoint=endpoint),
headers=self._get_headers(),
json=payload,
)
return resp.json()
69 changes: 58 additions & 11 deletions pynequa/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TreeParams(AbstractParams):
Possible values: '=', '!=', '<', '<=', '>=', '>', 'between', 'not between'.
value (str): The filter value (required).
"""

box: str = ""
column: str = ""
op: str = ""
Expand Down Expand Up @@ -83,9 +84,37 @@ def generate_payload(self, **kwargs) -> Dict:

@dataclass
class AdvancedParams(AbstractParams):
col_name: str = ""
col_value: str = None
value: str or int = None
"""
AdvancedParams represents the elemental advanced params.
Remember following things:
1. col_name is required.
2. col_value has to be either "str" or "List[str]".
3. if col_value is not present, the value could be a dict of
"value" and "operator".
Example:
"advanced": {
"docformat": [
"ppt",
"pdf"
],
"modified": [
{
"value": "2019-01-01",
"operator": ">="
},
{
"value": "2019-12-31",
"operator": "<="
}
]
}
"""
col_name: str
col_value: str or List[str] = None
value: str or int = None
operator: str = None
debug: bool = False

Expand All @@ -94,11 +123,15 @@ def generate_payload(self, **kwargs) -> Dict:
This method generates payload for
AdvancedParams.
"""
payload = {
self.col_name: self.col_value,
"value": self.value,
"operator": self.operator
}
payload = {}
# To prevent payloads with empty values
if self.col_name and self.col_value:
payload[self.col_name] = self.col_value
if self.value and self.operator:
payload[self.col_name] = {
"value": self.value,
"operator": self.operator
}

if self.debug:
logger.debug(payload)
Expand Down Expand Up @@ -135,7 +168,8 @@ class QueryParams(AbstractParams):
aggregations: Optional[List[str]] = field(default_factory=lambda: [])
order_by: Optional[str] = None
group_by: Optional[str] = None
advanced: Optional[AdvancedParams] = None
advanced: Optional[List[AdvancedParams]] = field(
default_factory=lambda: [])
debug: bool = False

def _prepare_query_args(self, query_name: str) -> Dict:
Expand Down Expand Up @@ -222,8 +256,21 @@ def _prepare_query_args(self, query_name: str) -> Dict:
if self.group_by is not None:
params["groupBy"] = self.group_by

if self.advanced is not None:
params["advanced"] = self.advanced.generate_payload()
if len(self.advanced) > 0:
advanced_param_payload = {}
for advanced_param in self.advanced:
column_name = advanced_param.col_name
payload_value = advanced_param.generate_payload()[column_name]
if column_name in advanced_param_payload:
advanced_param_payload[column_name].append(
payload_value
)
elif isinstance(payload_value, dict):
advanced_param_payload[column_name] = [payload_value]
else:
advanced_param_payload[column_name] = payload_value

params["advanced"] = advanced_param_payload

return params

Expand Down
59 changes: 57 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pynequa.models import QueryParams
from pynequa.models import QueryParams, AdvancedParams
import unittest
import logging
import json


class TestQueryParams(unittest.TestCase):
Expand All @@ -12,10 +13,12 @@ def test_query_params_payload(self):
"""
qp = QueryParams(
name="query",
search_text="What was Landsat-9 launched?"
search_text="What was Landsat-9 launched?",
page_size=20,
)

payload = qp.generate_payload()
print(payload)
logging.debug(payload)

keys_which_must_be_in_payload = [
Expand All @@ -30,6 +33,58 @@ def test_query_params_payload(self):
if key not in payload:
self.assertEqual(key, "test", f"{key} is mising in payload")

def test_query_params_with_advanced_params(self):
"""
Test if advanced params are correctly
generated in query param payload or not.
"""

ap1 = AdvancedParams(
col_name="collection",
col_value="accounting"
)

ap2 = AdvancedParams(
col_name="docformat",
col_value=["pdf", "docx"]
)

ap3 = AdvancedParams(
col_name="modified",
value="2019-01-01",
operator=">="
)

ap4 = AdvancedParams(
col_name="modified",
value="2019-12-31",
operator="<="
)

qp = QueryParams(
name="query",
search_text="What was Landsat-9 launched?",
advanced=[
ap1,
ap2,
ap3,
ap4
]
)

payload = qp.generate_payload()

expected_payload = {
"collection": "accounting",
"docformat": ["pdf", "docx"],
"modified": [
{"value": "2019-01-01", "operator": ">="},
{"value": "2019-12-31", "operator": "<="}
]
}

assert payload["advanced"] == expected_payload


if __name__ == '__main__':
unittest.main()

0 comments on commit 4fabf47

Please sign in to comment.