Skip to content

Commit

Permalink
fix: resolve Artifact Registry tags when creating PipelineJob
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570460499
  • Loading branch information
sararob authored and copybara-github committed Oct 3, 2023
1 parent efe88f9 commit f04ca35
Show file tree
Hide file tree
Showing 9 changed files with 11 additions and 78 deletions.
27 changes: 1 addition & 26 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import datetime
import logging
import re
import requests
import tempfile
import time
from typing import Any, Callable, Dict, List, Optional, Union
Expand Down Expand Up @@ -200,10 +199,7 @@ def __init__(
scheduled tasks will continue to completion.
Raises:
ValueError:
If job_id or labels have incorrect format.
If an invalid Artifact Registry URI is passed to template_path or there is a
credentials error retrieving the AR template path.
ValueError: If job_id or labels have incorrect format.
"""
if not display_name:
display_name = self.__class__._generate_display_name()
Expand All @@ -218,27 +214,6 @@ def __init__(
project=project, location=location
)

# TODO(b/293637096): remove this when AR IAM is updated
# If it's an Artifact Registry URI with a tag, we need to replace the tag with the version due to how the pipeline service handles auth
# See https://github.com/googleapis/python-aiplatform/issues/2398
if re.match(_VALID_AR_URL, template_path):
if "sha256" not in template_path.split("/")[-1]:
template_uri_prefix = template_path.split("kfp.pkg.dev/")[0]
ar_region = template_uri_prefix.split("//")[1][:-1]
ar_project, ar_repo, ar_package, ar_tag = template_path.split(
"kfp.pkg.dev/"
)[1].split("/")
tag_name = f"projects/{ar_project}/locations/{ar_region}/repositories/{ar_repo}/packages/{ar_package}/tags/{ar_tag}"

response = requests.get(
f"https://artifactregistry.googleapis.com/v1/{tag_name}",
auth=self.credentials.token,
)
response.raise_for_status()

version = response.json()["version"].split("/")[-1]
template_path = f"{template_uri_prefix}kfp.pkg.dev/{ar_project}/{ar_repo}/{ar_package}/{version}"

# this loads both .yaml and .json files because YAML is a superset of JSON
pipeline_json = yaml_utils.load_yaml(
template_path, self.project, self.credentials
Expand Down
4 changes: 4 additions & 0 deletions tests/system/aiplatform/test_model_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import pytest

from google import auth
from google.cloud import storage

from google.cloud import aiplatform
Expand Down Expand Up @@ -86,6 +87,9 @@ def test_model_evaluate_custom_tabular_model(self, staging_bucket, shared_state)
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
credentials=auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
),
)

custom_model = aiplatform.Model(
Expand Down
20 changes: 0 additions & 20 deletions tests/unit/aiplatform/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
# limitations under the License.
#

import json
import pytest
import requests

from google import auth
from google.api_core import operation
Expand Down Expand Up @@ -515,21 +513,3 @@ def add_context_children_mock():
metadata_service_client_v1.MetadataServiceClient, "add_context_children"
) as add_context_children_mock:
yield add_context_children_mock


@pytest.fixture
def mock_artifact_registry_request():
with mock.patch.object(
requests,
"get",
) as mock_artifact_registry:
ar_response = requests.models.Response()
ar_response.status_code = 200
ar_response._content = json.dumps(
{
"name": "projects/proj/locations/us-central1/repositories/repo/packages/pack",
"version": "projects/proj/locations/us-central1/repositories/repo/packages/pack/versions/sha256:5d3a03",
}
).encode("utf-8")
mock_artifact_registry.return_value = ar_response
yield mock_artifact_registry
3 changes: 0 additions & 3 deletions tests/unit/aiplatform/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,6 @@ class PipelineJobConstants:
_TEST_PIPELINE_JOB_ID = "sample-test-pipeline-202111111"
_TEST_PIPELINE_JOB_NAME = f"projects/{ProjectConstants._TEST_PROJECT}/locations/{ProjectConstants._TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}"
_TEST_PIPELINE_CREATE_TIME = datetime.now()
_TEST_AR_TEMPLATE_VERSION = (
"https://us-central1-kfp.pkg.dev/proj/repo/pack/sha256:5d3a03"
)


@dataclasses.dataclass(frozen=True)
Expand Down
8 changes: 0 additions & 8 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,6 @@ def test_text_generation_response_repr(self):
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_tune_text_generation_model(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1544,7 +1543,6 @@ def test_tune_text_generation_model(
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_tune_text_generation_model_ga(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1625,7 +1623,6 @@ def test_tune_text_generation_model_ga(
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_tune_chat_model(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1679,7 +1676,6 @@ def test_tune_chat_model(
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_tune_code_generation_model(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1726,7 +1722,6 @@ def test_tune_code_generation_model(
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_tune_code_chat_model(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -2835,7 +2830,6 @@ class TestLanguageModelEvaluation:
["https://us-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_model_evaluation_text_generation_task_with_gcs_input(
self,
job_spec,
Expand Down Expand Up @@ -3059,7 +3053,6 @@ def test_evaluate_raises_on_ga_language_model(
["https://us-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_model_evaluation_text_generation_task_on_base_model(
self,
job_spec,
Expand Down Expand Up @@ -3108,7 +3101,6 @@ def test_model_evaluation_text_generation_task_on_base_model(
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_model_evaluation_text_classification_base_model_only_summary_metrics(
self,
job_spec,
Expand Down
16 changes: 4 additions & 12 deletions tests/unit/aiplatform/test_model_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
_TEST_PIPELINE_CREATE_TIME = datetime.datetime.now()

_TEST_KFP_TEMPLATE_URI = "https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation-automl-tabular-classification-pipeline/1.0.0"
_TEST_KFP_TEMPLATE_VERSION = "https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation-automl-tabular-classification-pipeline/sha256:5d3a03"

_TEST_TEMPLATE_REF = {
"base_uri": "https://us-kfp.pkg.dev/vertex-evaluation/pipeline-templates/evaluation",
"tag": "20230713_1737",
Expand Down Expand Up @@ -848,7 +848,6 @@ def test_get_model_evaluation_pipeline_job(
"job_spec",
[_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_get_model_evaluation_bp_job(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -903,7 +902,6 @@ def test_get_model_evaluation_bp_job(
"job_spec",
[_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_get_model_evaluation_mlmd_resource(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1023,9 +1021,7 @@ def test_init_model_evaluation_job_with_invalid_pipeline_job_name_raises(
"job_spec",
[_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON],
)
@pytest.mark.usefixtures(
"mock_pipeline_service_create", "mock_artifact_registry_request"
)
@pytest.mark.usefixtures("mock_pipeline_service_create")
def test_model_evaluation_job_submit(
self,
job_spec,
Expand Down Expand Up @@ -1104,7 +1100,7 @@ def test_model_evaluation_job_submit(
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
template_uri=_TEST_KFP_TEMPLATE_VERSION,
template_uri=_TEST_KFP_TEMPLATE_URI,
)

mock_model_eval_job_create.assert_called_with(
Expand All @@ -1124,7 +1120,6 @@ def test_model_evaluation_job_submit(
"job_spec",
[_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_model_evaluation_job_submit_with_experiment(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1211,7 +1206,7 @@ def test_model_evaluation_job_submit_with_experiment(
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
template_uri=_TEST_KFP_TEMPLATE_VERSION,
template_uri=_TEST_KFP_TEMPLATE_URI,
)

mock_model_eval_job_create.assert_called_with(
Expand All @@ -1230,7 +1225,6 @@ def test_model_evaluation_job_submit_with_experiment(
"job_spec",
[_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_get_model_evaluation_with_successful_pipeline_run_returns_resource(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1299,7 +1293,6 @@ def test_get_model_evaluation_with_successful_pipeline_run_returns_resource(
"job_spec",
[_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_model_evaluation_job_get_model_evaluation_with_failed_pipeline_run_raises(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -1349,7 +1342,6 @@ def test_model_evaluation_job_get_model_evaluation_with_failed_pipeline_run_rais
"job_spec",
[_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_model_evaluation_job_get_model_evaluation_with_pending_pipeline_run_returns_none(
self,
mock_pipeline_service_create,
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3469,7 +3469,6 @@ def test_raw_predict(self, raw_predict_mock):
"job_spec_json",
[_TEST_MODEL_EVAL_PIPELINE_JOB],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_model_evaluate_with_gcs_input_uris(
self,
get_model_mock,
Expand Down Expand Up @@ -3516,7 +3515,6 @@ def test_model_evaluate_with_gcs_input_uris(
"job_spec_json",
[_TEST_MODEL_EVAL_PIPELINE_JOB_WITH_BQ_INPUT],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_model_evaluate_with_bigquery_input(
self,
get_model_mock,
Expand Down Expand Up @@ -3550,7 +3548,6 @@ def test_model_evaluate_with_bigquery_input(
"job_spec_json",
[_TEST_MODEL_EVAL_PIPELINE_JOB],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_model_evaluate_using_initialized_staging_bucket(
self,
get_model_mock,
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/aiplatform/test_pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import pipeline_jobs
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
import constants as test_constants
from google.cloud.aiplatform.compat.services import (
pipeline_service_client,
schedule_service_client,
Expand Down Expand Up @@ -771,7 +770,6 @@ def test_call_schedule_service_create_with_different_timezone(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_call_schedule_service_create_artifact_registry(
self,
mock_schedule_service_create,
Expand Down Expand Up @@ -836,7 +834,7 @@ def test_call_schedule_service_create_artifact_registry(
"pipeline_spec": dict_to_struct(pipeline_spec),
"service_account": _TEST_SERVICE_ACCOUNT,
"network": _TEST_NETWORK,
"template_uri": test_constants.PipelineJobConstants._TEST_AR_TEMPLATE_VERSION,
"template_uri": _TEST_AR_TEMPLATE_PATH,
},
},
)
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from google.cloud.aiplatform import pipeline_jobs
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
from google.cloud.aiplatform.utils import gcs_utils
import constants as test_constants
from google.cloud import storage
from google.protobuf import json_format
from google.protobuf import field_mask_pb2 as field_mask
Expand Down Expand Up @@ -554,7 +553,6 @@ def test_run_call_pipeline_service_create(
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("mock_artifact_registry_request")
def test_run_call_pipeline_service_create_artifact_registry(
self,
mock_pipeline_service_create,
Expand Down Expand Up @@ -614,7 +612,7 @@ def test_run_call_pipeline_service_create_artifact_registry(
runtime_config=runtime_config,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
template_uri=test_constants.PipelineJobConstants._TEST_AR_TEMPLATE_VERSION,
template_uri=_TEST_AR_TEMPLATE_PATH,
)

mock_pipeline_service_create.assert_called_once_with(
Expand Down

0 comments on commit f04ca35

Please sign in to comment.