Skip to content

Commit

Permalink
fix: Fix the sync option for Model Monitor job creation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653408498
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jul 18, 2024
1 parent 217faf8 commit 22151e2
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 116 deletions.
104 changes: 82 additions & 22 deletions tests/unit/vertexai/test_model_monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
model_monitoring_stats_v1beta1 as gca_model_monitoring_stats,
schedule_service_v1beta1 as gca_schedule_service,
schedule_v1beta1 as gca_schedule,
job_state_v1beta1 as gca_job_state,
explanation_v1beta1 as explanation,
)
from vertexai.resources.preview import (
Expand All @@ -51,7 +52,10 @@

# -*- coding: utf-8 -*-

_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
_TEST_CREDENTIALS = mock.Mock(
spec=auth_credentials.AnonymousCredentials(),
universe_domain="googleapis.com",
)
_TEST_DESCRIPTION = "test description"
_TEST_JSON_CONTENT_TYPE = "application/json"
_TEST_LOCATION = "us-central1"
Expand Down Expand Up @@ -178,6 +182,9 @@
user_emails=[_TEST_NOTIFICATION_EMAIL]
),
),
explanation_spec=explanation.ExplanationSpec(
parameters=explanation.ExplanationParameters(top_k=10)
),
)
_TEST_UPDATED_MODEL_MONITOR_OBJ = gca_model_monitor.ModelMonitor(
name=_TEST_MODEL_MONITOR_RESOURCE_NAME,
Expand Down Expand Up @@ -222,6 +229,9 @@
user_emails=[_TEST_NOTIFICATION_EMAIL, "456@test.com"]
),
),
explanation_spec=explanation.ExplanationSpec(
parameters=explanation.ExplanationParameters(top_k=10)
),
)
_TEST_CREATE_MODEL_MONITORING_JOB_OBJ = gca_model_monitoring_job.ModelMonitoringJob(
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
Expand Down Expand Up @@ -249,7 +259,9 @@
vertex_dataset=_TEST_TARGET_RESOURCE
)
),
explanation_spec=explanation.ExplanationSpec(),
explanation_spec=explanation.ExplanationSpec(
parameters=explanation.ExplanationParameters(top_k=10)
),
),
output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec(
gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH)
Expand Down Expand Up @@ -288,7 +300,9 @@
vertex_dataset=_TEST_TARGET_RESOURCE
)
),
explanation_spec=explanation.ExplanationSpec(),
explanation_spec=explanation.ExplanationSpec(
parameters=explanation.ExplanationParameters(top_k=10)
),
),
output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec(
gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH)
Expand All @@ -299,6 +313,7 @@
)
),
),
state=gca_job_state.JobState.JOB_STATE_SUCCEEDED,
)
_TEST_CRON = r"America/New_York 1 \* \* \* \*"
_TEST_SCHEDULE_OBJ = gca_schedule.Schedule(
Expand Down Expand Up @@ -336,7 +351,9 @@
vertex_dataset=_TEST_TARGET_RESOURCE
)
),
explanation_spec=explanation.ExplanationSpec(),
explanation_spec=explanation.ExplanationSpec(
parameters=explanation.ExplanationParameters(top_k=10)
),
),
output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec(
gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH)
Expand Down Expand Up @@ -564,7 +581,12 @@ def get_model_monitoring_job_mock():
model_monitoring_service_client.ModelMonitoringServiceClient,
"get_model_monitoring_job",
) as get_model_monitoring_job_mock:
get_model_monitor_mock.return_value = _TEST_MODEL_MONITORING_JOB_OBJ
model_monitoring_job_mock = mock.Mock(
spec=gca_model_monitoring_job.ModelMonitoringJob
)
model_monitoring_job_mock.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED
model_monitoring_job_mock.name = _TEST_MODEL_MONITORING_JOB_RESOURCE_NAME
get_model_monitoring_job_mock.return_value = model_monitoring_job_mock
yield get_model_monitoring_job_mock


Expand Down Expand Up @@ -762,6 +784,9 @@ def test_create_schedule(self, create_schedule_mock):
notification_spec=ml_monitoring.spec.NotificationSpec(
user_emails=[_TEST_NOTIFICATION_EMAIL]
),
explanation_spec=explanation.ExplanationSpec(
parameters=explanation.ExplanationParameters(top_k=10)
),
)
test_model_monitor.create_schedule(
display_name=_TEST_SCHEDULE_NAME,
Expand Down Expand Up @@ -851,9 +876,12 @@ def test_update_schedule(self, update_schedule_mock, get_schedule_mock):
assert get_schedule_mock.call_count == 1

@pytest.mark.usefixtures(
"create_model_monitoring_job_mock", "create_model_monitor_mock"
"create_model_monitoring_job_mock",
"create_model_monitor_mock",
"get_model_monitoring_job_mock",
)
def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
@pytest.mark.parametrize("sync", [True, False])
def test_run_model_monitoring_job(self, create_model_monitoring_job_mock, sync):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
Expand All @@ -866,6 +894,15 @@ def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
model_name=_TEST_MODEL_NAME,
display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME,
model_version_id=_TEST_MODEL_VERSION_ID,
)
test_model_monitoring_job = test_model_monitor.run(
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
baseline_dataset=ml_monitoring.spec.MonitoringInput(
vertex_dataset=_TEST_BASELINE_RESOURCE
),
target_dataset=ml_monitoring.spec.MonitoringInput(
vertex_dataset=_TEST_TARGET_RESOURCE
),
tabular_objective_spec=ml_monitoring.spec.TabularObjective(
feature_drift_spec=ml_monitoring.spec.DataDriftSpec(
default_categorical_alert_threshold=0.1,
Expand All @@ -876,13 +913,15 @@ def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
notification_spec=ml_monitoring.spec.NotificationSpec(
user_emails=[_TEST_NOTIFICATION_EMAIL]
),
)
test_model_monitor.run(
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
target_dataset=ml_monitoring.spec.MonitoringInput(
vertex_dataset=_TEST_TARGET_RESOURCE
explanation_spec=explanation.ExplanationSpec(
parameters=explanation.ExplanationParameters(top_k=10)
),
sync=sync,
)

if not sync:
test_model_monitoring_job.wait()

create_model_monitoring_job_mock.assert_called_once_with(
request=gca_model_monitoring_service.CreateModelMonitoringJobRequest(
parent=_TEST_MODEL_MONITOR_RESOURCE_NAME,
Expand All @@ -891,7 +930,9 @@ def test_run_model_monitoring_job(self, create_model_monitoring_job_mock):
)

@pytest.mark.usefixtures(
"create_model_monitoring_job_mock", "create_model_monitor_mock"
"create_model_monitoring_job_mock",
"create_model_monitor_mock",
"get_model_monitoring_job_mock",
)
def test_run_model_monitoring_job_with_user_id(
self, create_model_monitoring_job_mock
Expand All @@ -908,6 +949,15 @@ def test_run_model_monitoring_job_with_user_id(
model_name=_TEST_MODEL_NAME,
display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME,
model_version_id=_TEST_MODEL_VERSION_ID,
)
test_model_monitor.run(
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
baseline_dataset=ml_monitoring.spec.MonitoringInput(
vertex_dataset=_TEST_BASELINE_RESOURCE
),
target_dataset=ml_monitoring.spec.MonitoringInput(
vertex_dataset=_TEST_TARGET_RESOURCE
),
tabular_objective_spec=ml_monitoring.spec.TabularObjective(
feature_drift_spec=ml_monitoring.spec.DataDriftSpec(
default_categorical_alert_threshold=0.1,
Expand All @@ -918,11 +968,8 @@ def test_run_model_monitoring_job_with_user_id(
notification_spec=ml_monitoring.spec.NotificationSpec(
user_emails=[_TEST_NOTIFICATION_EMAIL]
),
)
test_model_monitor.run(
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
target_dataset=ml_monitoring.spec.MonitoringInput(
vertex_dataset=_TEST_TARGET_RESOURCE
explanation_spec=explanation.ExplanationSpec(
parameters=explanation.ExplanationParameters(top_k=10)
),
model_monitoring_job_id=_TEST_MODEL_MONITORING_JOB_USER_ID,
)
Expand All @@ -938,6 +985,7 @@ def test_run_model_monitoring_job_with_user_id(
"create_model_monitoring_job_mock",
"create_model_monitor_mock",
"search_metrics_mock",
"get_model_monitoring_job_mock",
)
def test_search_metrics(self, search_metrics_mock):
aiplatform.init(
Expand Down Expand Up @@ -978,6 +1026,7 @@ def test_search_metrics(self, search_metrics_mock):
"create_model_monitoring_job_mock",
"create_model_monitor_mock",
"search_alerts_mock",
"get_model_monitoring_job_mock",
)
def test_search_alerts(self, search_alerts_mock):
aiplatform.init(
Expand Down Expand Up @@ -1047,14 +1096,17 @@ def test_delete_model_monitor(self, delete_model_monitor_mock, force):
)
)

@pytest.mark.usefixtures("create_model_monitoring_job_mock")
def test_create_model_monitoring_job(self, create_model_monitoring_job_mock):
@pytest.mark.usefixtures(
"create_model_monitoring_job_mock", "get_model_monitoring_job_mock"
)
@pytest.mark.parametrize("sync", [True, False])
def test_create_model_monitoring_job(self, create_model_monitoring_job_mock, sync):
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)
ModelMonitoringJob.create(
test_model_monitoring_job = ModelMonitoringJob.create(
display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME,
model_monitor_name=_TEST_MODEL_MONITOR_RESOURCE_NAME,
tabular_objective_spec=ml_monitoring.spec.TabularObjective(
Expand All @@ -1073,8 +1125,15 @@ def test_create_model_monitoring_job(self, create_model_monitoring_job_mock):
notification_spec=ml_monitoring.spec.NotificationSpec(
user_emails=[_TEST_NOTIFICATION_EMAIL]
),
explanation_spec=explanation.ExplanationSpec(),
explanation_spec=explanation.ExplanationSpec(
parameters=explanation.ExplanationParameters(top_k=10)
),
sync=sync,
)

if not sync:
test_model_monitoring_job.wait()

create_model_monitoring_job_mock.assert_called_once_with(
request=gca_model_monitoring_service.CreateModelMonitoringJobRequest(
parent=_TEST_MODEL_MONITOR_RESOURCE_NAME,
Expand All @@ -1086,6 +1145,7 @@ def test_create_model_monitoring_job(self, create_model_monitoring_job_mock):
"create_model_monitor_mock",
"create_model_monitoring_job_mock",
"delete_model_monitoring_job_mock",
"get_model_monitoring_job_mock",
)
def test_delete_model_monitoring_job(self, delete_model_monitoring_job_mock):
aiplatform.init(
Expand Down
Loading

0 comments on commit 22151e2

Please sign in to comment.