Skip to content

Commit

Permalink
Merge pull request #65 from AryaXAI/bugfix_test
Browse files Browse the repository at this point in the history
Upload feature mapping, data description and model
  • Loading branch information
chintanarya authored Dec 2, 2024
2 parents 254babb + 8b55d9e commit fd5fac6
Showing 1 changed file with 256 additions and 0 deletions.
256 changes: 256 additions & 0 deletions aryaxai/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def files(self) -> pd.DataFrame:
.rename(columns={"filepath": "file_name"})
)

files_df = files_df.loc[files_df['status'] == "active"]
files_df["file_name"] = files_df["file_name"].apply(
lambda file_path: file_path.split("/")[-1]
)
Expand Down Expand Up @@ -801,6 +802,138 @@ def upload_file_and_return_path() -> str:

return res.get("details", "Data description upload successful")

def upload_feature_mapping_dataconnectors(
self,
data_connector_name: str,
bucket_name: Optional[str] = None,
file_path: Optional[str] = None
) -> str:
"""uploads feature mapping for the project
:param data_connector_name: name of the data connector
:param bucket_name: if data connector has buckets # Example: s3/gcs buckets
:param file_path: filepath from the bucket for the data to read
:return: response
"""

def get_connector() -> str | pd.DataFrame:
url = f"{LIST_DATA_CONNECTORS}?project_name={self.project_name}"
res = self.api_client.post(url)

if res["success"]:
df = pd.DataFrame(res["details"])
filtered_df = df.loc[df['link_service_name'] == data_connector_name]
if filtered_df.empty:
return "No data connector found"
return filtered_df

return res["details"]

connectors = get_connector()
if isinstance(connectors, pd.DataFrame):
value = connectors.loc[connectors['link_service_name'] == data_connector_name, 'link_service_type'].values[0]
ds_type = value

if ds_type == "s3" or ds_type == "gcs":
if not bucket_name:
return "Missing argument bucket_name"
if not file_path:
return "Missing argument file_path"
else:
return connectors

def upload_file_and_return_path() -> str:
res = self.api_client.post(
f"{UPLOAD_FILE_DATA_CONNECTORS}?project_name={self.project_name}&link_service_name={data_connector_name}&data_type=feature_mapping&bucket_name={bucket_name}&file_path={file_path}")

print(res)
if not res["success"]:
raise Exception(res.get("details"))
uploaded_path = res.get("metadata").get("filepath")

return uploaded_path

uploaded_path = upload_file_and_return_path()

payload = {
"path": uploaded_path,
"type": "feature_mapping",
"project_name": self.project_name,
}
res = self.api_client.post(UPLOAD_DATA_URI, payload)

if not res["success"]:
self.delete_file(uploaded_path)
raise Exception(res.get("details"))

return res.get("details", "Feature mapping upload successful")

def upload_data_description_dataconnectors(
self,
data_connector_name: str,
bucket_name: Optional[str] = None,
file_path: Optional[str] = None,
) -> str:
"""uploads data description for the project
:param data_connector_name: name of the data connector
:param bucket_name: if data connector has buckets # Example: s3/gcs buckets
:param file_path: filepath from the bucket for the data to read
:return: response
"""

def get_connector() -> str | pd.DataFrame:
url = f"{LIST_DATA_CONNECTORS}?project_name={self.project_name}"
res = self.api_client.post(url)

if res["success"]:
df = pd.DataFrame(res["details"])
filtered_df = df.loc[df['link_service_name'] == data_connector_name]
if filtered_df.empty:
return "No data connector found"
return filtered_df

return res["details"]

connectors = get_connector()
if isinstance(connectors, pd.DataFrame):
value = connectors.loc[connectors['link_service_name'] == data_connector_name, 'link_service_type'].values[0]
ds_type = value

if ds_type == "s3" or ds_type == "gcs":
if not bucket_name:
return "Missing argument bucket_name"
if not file_path:
return "Missing argument file_path"
else:
return connectors

def upload_file_and_return_path() -> str:
res = self.api_client.post(
f"{UPLOAD_FILE_DATA_CONNECTORS}?project_name={self.project_name}&link_service_name={data_connector_name}&data_type=data_description&bucket_name={bucket_name}&file_path={file_path}")

print(res)
if not res["success"]:
raise Exception(res.get("details"))
uploaded_path = res.get("metadata").get("filepath")

return uploaded_path

uploaded_path = upload_file_and_return_path()

payload = {
"path": uploaded_path,
"type": "data_description",
"project_name": self.project_name,
}
res = self.api_client.post(UPLOAD_DATA_URI, payload)

if not res["success"]:
self.delete_file(uploaded_path)
raise Exception(res.get("details"))

return res.get("details", "Data description upload successful")

def upload_model_types(self) -> dict:
"""Model types which can be uploaded using upload_model()
Expand Down Expand Up @@ -904,6 +1037,128 @@ def upload_file_and_return_path() -> str:
lambda: self.delete_file(uploaded_path),
)

def upload_model_dataconnectors(
self,
data_connector_name: str,
model_architecture: str,
model_type: str,
model_name: str,
model_data_tags: List[str],
model_test_tags: Optional[List[str]] = None,
instance_type: Optional[str] = None,
explainability_method: Optional[List[str]] = ["shap"],
bucket_name: Optional[str] = None,
file_path: Optional[str] = None,
):
"""Uploads your custom model on AryaXAI
:param data_connector_name: name of the data connector
:param model_architecture: architecture of model ["machine_learning", "deep_learning"]
:param model_type: type of the model based on the architecture ["Xgboost","Lgboost","CatBoost","Random_forest","Linear_Regression","Logistic_Regression","Gaussian_NaiveBayes","SGD"]
use upload_model_types() method to get all allowed model_types
:param model_name: name of the model
:param model_data_tags: data tags for model
:param model_test_tags: test tags for model (optional)
:param instance_type: instance to be used for uploading model (optional)
:param explainability_method: explainability method to be used while uploading model ["shap", "lime"] (optional)
:param bucket_name: if data connector has buckets # Example: s3/gcs buckets
:param file_path: filepath from the bucket for the data to read
"""

def get_connector() -> str | pd.DataFrame:
url = f"{LIST_DATA_CONNECTORS}?project_name={self.project_name}"
res = self.api_client.post(url)

if res["success"]:
df = pd.DataFrame(res["details"])
filtered_df = df.loc[df['link_service_name'] == data_connector_name]
if filtered_df.empty:
return "No data connector found"
return filtered_df

return res["details"]

connectors = get_connector()
if isinstance(connectors, pd.DataFrame):
value = connectors.loc[connectors['link_service_name'] == data_connector_name, 'link_service_type'].values[0]
ds_type = value

if ds_type == "s3" or ds_type == "gcs":
if not bucket_name:
return "Missing argument bucket_name"
if not file_path:
return "Missing argument file_path"
else:
return connectors

def upload_file_and_return_path() -> str:
res = self.api_client.post(
f"{UPLOAD_FILE_DATA_CONNECTORS}?project_name={self.project_name}&link_service_name={data_connector_name}&data_type=model&bucket_name={bucket_name}&file_path={file_path}")

print(res)
if not res["success"]:
raise Exception(res.get("details"))
uploaded_path = res.get("metadata").get("filepath")

return uploaded_path

model_types = self.api_client.get(GET_MODEL_TYPES_URI)
valid_model_architecture = model_types.get("model_architecture").keys()
Validate.value_against_list(
"model_achitecture", model_architecture, valid_model_architecture
)

valid_model_types = model_types.get("model_architecture")[model_architecture]
Validate.value_against_list("model_type", model_type, valid_model_types)

tags = self.tags()
Validate.value_against_list("model_data_tags", model_data_tags, tags)

if model_test_tags is not None:
Validate.value_against_list("model_test_tags", model_test_tags, tags)

uploaded_path = upload_file_and_return_path()

if instance_type:
custom_batch_servers = self.api_client.get(AVAILABLE_BATCH_SERVERS_URI)
Validate.value_against_list(
"instance_type",
instance_type,
[
server["instance_name"]
for server in custom_batch_servers.get("details", [])
],
)

if explainability_method:
Validate.value_against_list("explainability_method", explainability_method, ["shap", "lime"])

payload = {
"project_name": self.project_name,
"model_name": model_name,
"model_architecture": model_architecture,
"model_type": model_type,
"model_path": uploaded_path,
"model_data_tags": model_data_tags,
"model_test_tags": model_test_tags,
"explainability_method": explainability_method
}

if instance_type:
payload["instance_type"] = instance_type

res = self.api_client.post(UPLOAD_MODEL_URI, payload)

if not res.get("success"):
raise Exception(res.get("details"))

poll_events(
self.api_client,
self.project_name,
res["event_id"],
lambda: self.delete_file(uploaded_path),
)

def data_observations(self, tag: str) -> pd.DataFrame:
"""Data Observations for the project
Expand Down Expand Up @@ -2434,6 +2689,7 @@ def list_data_connectors(self) -> str | pd.DataFrame:

if res["success"]:
df = pd.DataFrame(res["details"])
df = df.drop(["_id", "region", "gcp_project_name", "gcp_project_id", "gdrive_file_name"], axis = 1, errors = "ignore")
return df

return res["details"]
Expand Down

0 comments on commit fd5fac6

Please sign in to comment.