Skip to content

Commit

Permalink
Merge pull request #60 from AryaXAI/bugfix_test
Browse files Browse the repository at this point in the history
Upload Data - Feature Encoding & Model Inference fix
  • Loading branch information
chintanarya authored Nov 21, 2024
2 parents b90d888 + 7333d51 commit 5f3392e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
2 changes: 1 addition & 1 deletion aryaxai/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ProjectConfig(TypedDict):
feature_exclude: Optional[List[str]]
drop_duplicate_uid: Optional[bool]
handle_errors: Optional[bool]

feature_encodings: Optional[dict]

class DataConfig(TypedDict):
tags: List[str]
Expand Down
29 changes: 21 additions & 8 deletions aryaxai/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ def upload_data(
"pred_label": "",
"feature_exclude": [],
"drop_duplicate_uid: "",
"handle_errors": False
"handle_errors": False,
"feature_encodings": Dict[str, str] # {"feature_name":"labelencode | countencode | onehotencode"}
},
defaults to None
:return: response
Expand Down Expand Up @@ -625,6 +626,19 @@ def upload_file_and_return_path() -> str:
feature for feature in column_names if feature not in feature_exclude
]

feature_encodings = config.get("feature_encodings", None)
if feature_encodings:
Validate.value_against_list(
"feature_encodings_feature",
list(feature_encodings.keys()),
column_names,
)
Validate.value_against_list(
"feature_encodings_feature",
list(feature_encodings.values()),
["labelencode", "countencode", "onehotencode"],
)

payload = {
"project_name": self.project_name,
"project_type": config["project_type"],
Expand All @@ -639,7 +653,7 @@ def upload_file_and_return_path() -> str:
"handle_errors": config.get("handle_errors", False),
"feature_exclude": feature_exclude,
"feature_include": feature_include,
"feature_encodings": {},
"feature_encodings": feature_encodings,
"feature_actual_used": [],
},
}
Expand Down Expand Up @@ -2097,13 +2111,12 @@ def model_inference(
event_id=run_model_res["event_id"],
)

download_tag_payload = {
"project_name": self.project_name,
"tag": f"{tag}_{model}_Inference",
}
auth_token = self.api_client.get_auth_token()

uri = f"{DOWNLOAD_TAG_DATA_URI}?project_name={self.project_name}&tag={tag}_{model}_Inference&token={auth_token}"

tag_data = self.api_client.request(
"POST", DOWNLOAD_TAG_DATA_URI, download_tag_payload
tag_data = self.api_client.base_request(
"GET", uri
)

tag_data_df = pd.read_csv(io.StringIO(tag_data.text))
Expand Down

0 comments on commit 5f3392e

Please sign in to comment.