Skip to content

Commit

Permalink
Merge pull request #74 from AryaXAI/bugfix_test
Browse files Browse the repository at this point in the history
Models list - Create Policy and Handle Imbalance Data Key
  • Loading branch information
amey-balekundri authored Dec 24, 2024
2 parents d58fcba + 9433f3e commit c313a6f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aryaxai/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ProjectConfig(TypedDict):
drop_duplicate_uid: Optional[bool]
handle_errors: Optional[bool]
feature_encodings: Optional[dict]
handle_data_imbalance: Optional[bool]

class DataConfig(TypedDict):
tags: List[str]
Expand All @@ -23,6 +24,7 @@ class DataConfig(TypedDict):
explainability_sample_percentage: float
lime_explainability_iterations: int
explainability_method: List[str]
handle_data_imbalance: Optional[bool]

class SyntheticDataConfig(TypedDict):
model_name: str
Expand Down
9 changes: 9 additions & 0 deletions aryaxai/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def upload_data(
"feature_exclude": [],
"drop_duplicate_uid: "",
"handle_errors": False,
"handle_data_imbalance": False, # SMOTE sampling
"feature_encodings": Dict[str, str] # {"feature_name":"labelencode | countencode | onehotencode"}
},
defaults to None
Expand Down Expand Up @@ -595,6 +596,7 @@ def upload_file_and_return_path() -> str:
"feature_exclude": [],
"drop_duplicate_uid": False,
"handle_errors": False,
"handle_data_imbalance": False
}
raise Exception(
f"Project Config is required, since no config is set for project \n {json.dumps(config,indent=1)}"
Expand Down Expand Up @@ -670,6 +672,7 @@ def upload_file_and_return_path() -> str:
"feature_include": feature_include,
"feature_encodings": feature_encodings,
"feature_actual_used": [],
"handle_data_imbalance": config.get("handle_data_imbalance", False)
},
}

Expand Down Expand Up @@ -2146,6 +2149,7 @@ def train_model(
"explainability_sample_percentage": float # Explainability sample percentage to be used
"lime_explainability_iterations": int # Lime Explainability iterations to be used
"explainability_method": str # List of explainability method ["shap", "lime"]
"handle_data_imbalance": bool # Handle data imbalance using SMOTE
},
defaults to None
:param model_config: config with hyper parameters for the model, defaults to None
Expand Down Expand Up @@ -2300,6 +2304,7 @@ def train_model(
tags = data_conf.get("tags") or project_config["metadata"]["tags"]
test_tags = data_conf.get("test_tags") or project_config["metadata"]["test_tags"] or []
use_optuna = data_conf.get("use_optuna") or project_config["metadata"]["use_optuna"] or False
handle_data_imbalance = data_conf.get("handle_data_imbalance") or project_config["metadata"]["handle_data_imbalance"] or False

payload = {
"project_name": self.project_name,
Expand All @@ -2317,6 +2322,7 @@ def train_model(
"test_tags": test_tags,
"use_optuna": use_optuna,
"explainability_method": explainability_method,
"handle_data_imbalance": handle_data_imbalance
},
"sample_percentage": data_conf.get("sample_percentage"),
"explainability_sample_percentage": data_conf.get(
Expand Down Expand Up @@ -3678,6 +3684,7 @@ def create_policy(
statement: str,
decision: str,
input: Optional[str] = None,
models: Optional[list] = []
) -> str:
"""Creates New Policy
Expand All @@ -3695,6 +3702,7 @@ def create_policy(
the content inside the curly brackets represents the feature name
:param decision: decision of policy
:param input: custom input for the decision if input selected for decision of policy
:param models: List of trained model names - The policy will only execute for the selected model. In case of empty list will execute for all models
:return: response
"""
configuration, expression = build_expression(expression)
Expand All @@ -3720,6 +3728,7 @@ def create_policy(
"metadata": {"expression": expression},
"statement": [statement],
"decision": input if decision == "input" else decision,
"models": models
}

res = self.api_client.post(CREATE_POLICY_URI, payload)
Expand Down

0 comments on commit c313a6f

Please sign in to comment.