From 7333d51363c11a12d3108e88fa06ce928fccdfe2 Mon Sep 17 00:00:00 2001 From: Siddhant Mishra Date: Thu, 21 Nov 2024 13:37:51 +0530 Subject: [PATCH] Upload Data - Feature Encoding & Model Inference fix --- aryaxai/common/types.py | 2 +- aryaxai/core/project.py | 29 +++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/aryaxai/common/types.py b/aryaxai/common/types.py index 0db178c..e253e78 100644 --- a/aryaxai/common/types.py +++ b/aryaxai/common/types.py @@ -10,7 +10,7 @@ class ProjectConfig(TypedDict): feature_exclude: Optional[List[str]] drop_duplicate_uid: Optional[bool] handle_errors: Optional[bool] - + feature_encodings: Optional[dict] class DataConfig(TypedDict): tags: List[str] diff --git a/aryaxai/core/project.py b/aryaxai/core/project.py index c2c2e2f..1c262d3 100755 --- a/aryaxai/core/project.py +++ b/aryaxai/core/project.py @@ -534,7 +534,8 @@ def upload_data( "pred_label": "", "feature_exclude": [], "drop_duplicate_uid: "", - "handle_errors": False + "handle_errors": False, + "feature_encodings": Dict[str, str] # {"feature_name":"labelencode | countencode | onehotencode"} }, defaults to None :return: response @@ -625,6 +626,19 @@ def upload_file_and_return_path() -> str: feature for feature in column_names if feature not in feature_exclude ] + feature_encodings = config.get("feature_encodings", None) + if feature_encodings: + Validate.value_against_list( + "feature_encodings_feature", + list(feature_encodings.keys()), + column_names, + ) + Validate.value_against_list( + "feature_encodings_feature", + list(feature_encodings.values()), + ["labelencode", "countencode", "onehotencode"], + ) + payload = { "project_name": self.project_name, "project_type": config["project_type"], @@ -639,7 +653,7 @@ def upload_file_and_return_path() -> str: "handle_errors": config.get("handle_errors", False), "feature_exclude": feature_exclude, "feature_include": feature_include, - "feature_encodings": {}, + "feature_encodings": feature_encodings, "feature_actual_used": [], }, } @@ -2097,13 +2111,12 @@ def model_inference( event_id=run_model_res["event_id"], ) - download_tag_payload = { - "project_name": self.project_name, - "tag": f"{tag}_{model}_Inference", - } + auth_token = self.api_client.get_auth_token() + + uri = f"{DOWNLOAD_TAG_DATA_URI}?project_name={self.project_name}&tag={tag}_{model}_Inference&token={auth_token}" - tag_data = self.api_client.request( - "POST", DOWNLOAD_TAG_DATA_URI, download_tag_payload + tag_data = self.api_client.base_request( + "GET", uri ) tag_data_df = pd.read_csv(io.StringIO(tag_data.text))