Skip to content

Commit

Permalink
Merge pull request #50 from AryaXAI/sdk-v2
Browse files Browse the repository at this point in the history
fixes
  • Loading branch information
chintanarya authored Jul 25, 2024
2 parents c3de826 + a671e8f commit 97e0b98
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions aryaxai/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,8 @@ def create_monitoring_trigger(self, payload: dict) -> str:
"baseline_date": { "start_date": "", "end_date": ""},
"current_date": { "start_date": "", "end_date": ""},
"base_line_tag": "",
"current_tag": ""
"current_tag": "",
"instance_type": "" #Instance type to used for running trigger
} OR Target Drift Trigger Payload
{
"trigger_type": "" #["Data Drift", "Target Drift", "Model Performance"]
Expand All @@ -1489,7 +1490,8 @@ def create_monitoring_trigger(self, payload: dict) -> str:
"base_line_tag": "",
"current_tag": "",
"baseline_true_label": "",
"current_true_label": ""
"current_true_label": "",
"instance_type": "" #Instance type to used for running trigger
} OR Model Performance Trigger Payload
{
"trigger_type": "" #["Data Drift", "Target Drift", "Model Performance"]
Expand All @@ -1504,7 +1506,8 @@ def create_monitoring_trigger(self, payload: dict) -> str:
"current_date": { "start_date": "", "end_date": ""},
"base_line_tag": "",
"baseline_true_label": "",
"baseline_pred_label": ""
"baseline_pred_label": "",
"instance_type": "" #Instance type to used for running trigger
}
:return: response
"""
Expand Down Expand Up @@ -1636,6 +1639,17 @@ def create_monitoring_trigger(self, payload: dict) -> str:

Validate.validate_date_feature_val(payload, tags_info["alldatetimefeatures"])

if payload.get("instance_type"):
custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI)
Validate.value_against_list(
"instance_type",
payload["instance_type"],
[
server["instance_name"]
for server in custom_batch_servers.get("details", [])
],
)

payload = {
"project_name": self.project_name,
"modify_req": {
Expand Down Expand Up @@ -2006,7 +2020,7 @@ def model_inference(
self,
tag: str,
model_name: Optional[str] = None,
instance_tye: Optional[str] = "shared",
instance_type: Optional[str] = None,
) -> pd.DataFrame:
"""Run model inference on data
Expand All @@ -2033,11 +2047,22 @@ def model_inference(
or models.loc[models["status"] == "active"]["model_name"].values[0]
)

if instance_type:
custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI)
Validate.value_against_list(
"instance_type",
instance_type,
[
server["instance_name"]
for server in custom_batch_servers.get("details", [])
],
)

run_model_payload = {
"project_name": self.project_name,
"model_name": model,
"tags": tag,
"instance_type": instance_tye,
"instance_type": instance_type,
}

run_model_res = self.api_client.post(RUN_MODEL_ON_DATA_URI, run_model_payload)
Expand Down

0 comments on commit 97e0b98

Please sign in to comment.