Skip to content

Commit

Permalink
Additional test case added for each search api: task and model
Browse files Browse the repository at this point in the history
Signed-off-by: Nurlan <nabzalbekov0@gmail.com>
  • Loading branch information
Nurlanprog committed Jun 16, 2023
1 parent 33be035 commit 867e708
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
4 changes: 4 additions & 0 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ def search_task(self, input_json) -> object:
if isinstance(input_json, str):
try:
json_obj = json.loads(input_json)
if not isinstance(json_obj, dict):
return "Invalid JSON object passed as argument."
API_BODY = json.dumps(json_obj)
except json.JSONDecodeError:
return "Invalid JSON string passed as argument."
Expand Down Expand Up @@ -419,6 +421,8 @@ def search_model(self, input_json) -> object:
if isinstance(input_json, str):
try:
json_obj = json.loads(input_json)
if not isinstance(json_obj, dict):
return "Invalid JSON object passed as argument."
API_BODY = json.dumps(json_obj)
except json.JSONDecodeError:
return "Invalid JSON string passed as argument."
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
)
_pd_ecommerce.insert(2, "customer_birth_date", None)
_pd_ecommerce.index = _pd_ecommerce.index.map(str) # make index 'object' not int
_pd_ecommerce["customer_birth_date"].astype("datetime64[ns]")
_pd_ecommerce["customer_birth_date"].astype("datetime64")
_oml_ecommerce = oml.DataFrame(OPENSEARCH_TEST_CLIENT, ECOMMERCE_INDEX_NAME)


Expand Down
36 changes: 36 additions & 0 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@ def test_search():
raised = True
assert raised == False, "Raised Exception in searching task"

raised = False
try:
search_task_obj = ml_client.search_task(input_json="15")
assert search_task_obj == "Invalid JSON object passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching task"

raised = False
try:
search_task_obj = ml_client.search_task(
input_json='{"query": {"match_all": {}},size: 1}'
)
assert search_task_obj == "Invalid JSON string passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching task"

# Search model cases
raised = False
try:
Expand Down Expand Up @@ -159,6 +177,24 @@ def test_search():
raised = True
assert raised == False, "Raised Exception in searching model"

raised = False
try:
search_model_obj = ml_client.search_model(input_json="15")
assert search_model_obj == "Invalid JSON object passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"

raised = False
try:
search_model_obj = ml_client.search_model(
input_json='{"query": {"match_all": {}},size: 1}'
)
assert search_model_obj == "Invalid JSON string passed as argument."
except: # noqa: E722
raised = True
assert raised == False, "Raised Exception in searching model"


def test_DEPRECATED_integration_pretrained_model_upload_unload_delete():
raised = False
Expand Down

0 comments on commit 867e708

Please sign in to comment.