Skip to content

Commit

Permalink
Added artifact bucket mock to the unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ark-kun committed Jul 2, 2022
1 parent 9c1e1ac commit 9aecbaf
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import pipeline_jobs
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
from google.cloud.aiplatform.utils import gcs_utils
from google.cloud import storage
from google.protobuf import json_format

Expand Down Expand Up @@ -204,6 +205,31 @@ def mock_pipeline_service_create():
yield mock_create_pipeline_job


@pytest.fixture
def mock_pipeline_bucket_exists():
def mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
output_artifacts_gcs_dir,
service_account,
project,
location,
credentials,
):
output_artifacts_gcs_dir = (
output_artifacts_gcs_dir
or gcs_utils.generate_gcs_directory_for_pipeline_artifacts(
project=project,
location=location,
)
)
return output_artifacts_gcs_dir

with mock.patch(
"google.cloud.aiplatform.utils.gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist",
new=mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist,
) as mock_context:
yield mock_context


def make_pipeline_job(state):
return gca_pipeline_job.PipelineJob(
name=_TEST_PIPELINE_JOB_NAME,
Expand Down Expand Up @@ -322,6 +348,7 @@ def test_run_call_pipeline_service_create(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
sync,
Expand Down Expand Up @@ -399,6 +426,7 @@ def test_run_call_pipeline_service_create_artifact_registry(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
mock_request_urlopen,
job_spec,
mock_load_yaml_and_json,
Expand Down Expand Up @@ -482,6 +510,7 @@ def test_run_call_pipeline_service_create_with_timeout(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
sync,
Expand Down Expand Up @@ -563,6 +592,7 @@ def test_run_call_pipeline_service_create_with_timeout_not_explicitly_set(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
sync,
Expand Down Expand Up @@ -644,6 +674,7 @@ def test_run_call_pipeline_service_create_with_failure_policy(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
failure_policy,
Expand Down Expand Up @@ -728,6 +759,7 @@ def test_run_call_pipeline_service_create_legacy(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
sync,
Expand Down Expand Up @@ -809,6 +841,7 @@ def test_run_call_pipeline_service_create_tfx(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
sync,
Expand Down Expand Up @@ -886,6 +919,7 @@ def test_submit_call_pipeline_service_pipeline_job_create(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
Expand Down Expand Up @@ -961,6 +995,7 @@ def test_done_method_pipeline_service(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
Expand Down Expand Up @@ -999,6 +1034,7 @@ def test_submit_call_pipeline_service_pipeline_job_create_legacy(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
Expand Down Expand Up @@ -1079,6 +1115,7 @@ def test_get_pipeline_job(self, mock_pipeline_service_get):
@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_pipeline_bucket_exists",
)
@pytest.mark.parametrize(
"job_spec",
Expand Down Expand Up @@ -1116,6 +1153,7 @@ def test_cancel_pipeline_job(
@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get",
"mock_pipeline_bucket_exists",
)
@pytest.mark.parametrize(
"job_spec",
Expand Down Expand Up @@ -1190,6 +1228,7 @@ def test_cancel_pipeline_job_without_running(
@pytest.mark.usefixtures(
"mock_pipeline_service_create",
"mock_pipeline_service_get_with_fail",
"mock_pipeline_bucket_exists",
)
@pytest.mark.parametrize(
"job_spec",
Expand Down Expand Up @@ -1230,6 +1269,7 @@ def test_clone_pipeline_job(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
Expand Down Expand Up @@ -1307,6 +1347,7 @@ def test_clone_pipeline_job_with_all_args(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
Expand Down

0 comments on commit 9aecbaf

Please sign in to comment.