From 867e708f3ee2e60a7fb9787042b0e2e823a5a551 Mon Sep 17 00:00:00 2001 From: Nurlan Date: Fri, 16 Jun 2023 15:26:25 +0600 Subject: [PATCH] Additional test case added for each search api: task and model Signed-off-by: Nurlan --- .../ml_commons/ml_commons_client.py | 4 +++ tests/common.py | 2 +- tests/ml_commons/test_ml_commons_client.py | 36 +++++++++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index accc330d..52e73316 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -391,6 +391,8 @@ def search_task(self, input_json) -> object: if isinstance(input_json, str): try: json_obj = json.loads(input_json) + if not isinstance(json_obj, dict): + return "Invalid JSON object passed as argument." API_BODY = json.dumps(json_obj) except json.JSONDecodeError: return "Invalid JSON string passed as argument." @@ -419,6 +421,8 @@ def search_model(self, input_json) -> object: if isinstance(input_json, str): try: json_obj = json.loads(input_json) + if not isinstance(json_obj, dict): + return "Invalid JSON object passed as argument." API_BODY = json.dumps(json_obj) except json.JSONDecodeError: return "Invalid JSON string passed as argument." diff --git a/tests/common.py b/tests/common.py index f332b0f6..df06ddde 100644 --- a/tests/common.py +++ b/tests/common.py @@ -57,7 +57,7 @@ ) _pd_ecommerce.insert(2, "customer_birth_date", None) _pd_ecommerce.index = _pd_ecommerce.index.map(str) # make index 'object' not int -_pd_ecommerce["customer_birth_date"].astype("datetime64[ns]") +_pd_ecommerce["customer_birth_date"].astype("datetime64") _oml_ecommerce = oml.DataFrame(OPENSEARCH_TEST_CLIENT, ECOMMERCE_INDEX_NAME) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 387820f3..29c114b4 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -130,6 +130,24 @@ def test_search(): raised = True assert raised == False, "Raised Exception in searching task" + raised = False + try: + search_task_obj = ml_client.search_task(input_json="15") + assert search_task_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching task" + + raised = False + try: + search_task_obj = ml_client.search_task( + input_json='{"query": {"match_all": {}},size: 1}' + ) + assert search_task_obj == "Invalid JSON string passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching task" + # Search model cases raised = False try: @@ -159,6 +177,24 @@ def test_search(): raised = True assert raised == False, "Raised Exception in searching model" + raised = False + try: + search_model_obj = ml_client.search_model(input_json="15") + assert search_model_obj == "Invalid JSON object passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching model" + + raised = False + try: + search_model_obj = ml_client.search_model( + input_json='{"query": {"match_all": {}},size: 1}' + ) + assert search_model_obj == "Invalid JSON string passed as argument." + except: # noqa: E722 + raised = True + assert raised == False, "Raised Exception in searching model" + def test_DEPRECATED_integration_pretrained_model_upload_unload_delete(): raised = False