From 95b107c8727245e1836f9cbddd3f2e331532dd62 Mon Sep 17 00:00:00 2001 From: sina chavoshi Date: Wed, 20 Jul 2022 17:48:14 -0700 Subject: [PATCH] feat: Change the Metadata SDK _Context class to an external class (#1519) * feat: Change the Metadata SDK _Context class to an external class * Add base schema class for context * Add additional context schema types * Add additional context schema types * Add create method to Context. * Fix unit test failure. * add unit tests * fix lint issue * Add Context to root __init__. * correct import path --- google/cloud/aiplatform/__init__.py | 1 + google/cloud/aiplatform/metadata/context.py | 154 ++++++++++++++- .../metadata/experiment_resources.py | 32 +-- .../metadata/experiment_run_resource.py | 22 +-- google/cloud/aiplatform/metadata/metadata.py | 2 +- .../metadata/schema/base_context.py | 111 +++++++++++ .../metadata/schema/system/context_schema.py | 185 ++++++++++++++++++ google/cloud/aiplatform/pipeline_jobs.py | 8 +- .../aiplatform/test_metadata_resources.py | 16 +- tests/unit/aiplatform/test_metadata_schema.py | 82 ++++++++ tests/unit/aiplatform/test_utils.py | 8 +- 11 files changed, 575 insertions(+), 46 deletions(-) create mode 100644 google/cloud/aiplatform/metadata/schema/base_context.py create mode 100644 google/cloud/aiplatform/metadata/schema/system/context_schema.py diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 88b450460f..73fad5a223 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -95,6 +95,7 @@ ExperimentRun = metadata.experiment_run_resource.ExperimentRun Artifact = metadata.artifact.Artifact Execution = metadata.execution.Execution +Context = metadata.context.Context __all__ = ( diff --git a/google/cloud/aiplatform/metadata/context.py b/google/cloud/aiplatform/metadata/context.py index d072a6e047..d1c7dea99c 100644 --- a/google/cloud/aiplatform/metadata/context.py +++ b/google/cloud/aiplatform/metadata/context.py @@ -19,6 +19,8 @@ import proto +from google.auth import credentials as auth_credentials + from google.cloud.aiplatform import base from google.cloud.aiplatform import utils from google.cloud.aiplatform.metadata import utils as metadata_utils @@ -31,10 +33,11 @@ ) from google.cloud.aiplatform.metadata import artifact from google.cloud.aiplatform.metadata import execution +from google.cloud.aiplatform.metadata import metadata_store from google.cloud.aiplatform.metadata import resource -class _Context(resource._Resource): +class Context(resource._Resource): """Metadata Context resource for Vertex AI""" _resource_noun = "contexts" @@ -81,6 +84,153 @@ def get_artifacts(self) -> List[artifact.Artifact]: credentials=self.credentials, ) + @classmethod + def create( + cls, + schema_title: str, + *, + resource_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "Context": + """Creates a new Metadata Context. + + Args: + schema_title (str): + Required. schema_title identifies the schema title used by the Context. + Please reference https://cloud.google.com/vertex-ai/docs/ml-metadata/system-schemas. + resource_id (str): + Optional. The portion of the Context name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//Contexts/. + display_name (str): + Optional. The user-defined name of the Context. + schema_version (str): + Optional. schema_version specifies the version used by the Context. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Context to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Context. + metadata_store_id (str): + Optional. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores//Contexts/ + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Optional. Project used to create this Context. Overrides project set in + aiplatform.init. + location (str): + Optional. Location used to create this Context. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials used to create this Context. Overrides + credentials set in aiplatform.init. + + Returns: + Context: Instantiated representation of the managed Metadata Context. + """ + return cls._create( + resource_id=resource_id, + schema_title=schema_title, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=metadata, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + + # TODO() refactor code to move _create to _Resource class. + @classmethod + def _create( + cls, + resource_id: str, + schema_title: str, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "Context": + """Creates a new Metadata resource. + + Args: + resource_id (str): + Required. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores///. + schema_title (str): + Required. schema_title identifies the schema title used by the resource. + display_name (str): + Optional. The user-defined name of the resource. + schema_version (str): + Optional. schema_version specifies the version used by the resource. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the resource to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the resource. + metadata_store_id (str): + The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores/// + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Project used to create this resource. Overrides project set in + aiplatform.init. + location (str): + Location used to create this resource. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Custom credentials used to create this resource. Overrides + credentials set in aiplatform.init. + + Returns: + resource (_Resource): + Instantiated representation of the managed Metadata resource. + + """ + api_client = cls._instantiate_client(location=location, credentials=credentials) + + parent = utils.full_resource_name( + resource_name=metadata_store_id, + resource_noun=metadata_store._MetadataStore._resource_noun, + parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name, + format_resource_name_method=metadata_store._MetadataStore._format_resource_name, + project=project, + location=location, + ) + + resource = cls._create_resource( + client=api_client, + parent=parent, + resource_id=resource_id, + schema_title=schema_title, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=metadata, + ) + + self = cls._empty_constructor( + project=project, location=location, credentials=credentials + ) + self._gca_resource = resource + + return self + @classmethod def _create_resource( cls, @@ -147,7 +297,7 @@ def _list_resources( ) return client.list_contexts(request=list_request) - def add_context_children(self, contexts: List["_Context"]): + def add_context_children(self, contexts: List["Context"]): """Adds the provided contexts as children of this context. Args: diff --git a/google/cloud/aiplatform/metadata/experiment_resources.py b/google/cloud/aiplatform/metadata/experiment_resources.py index e0c48b2f00..908e30561a 100644 --- a/google/cloud/aiplatform/metadata/experiment_resources.py +++ b/google/cloud/aiplatform/metadata/experiment_resources.py @@ -119,13 +119,13 @@ def __init__( ) with _SetLoggerLevel(resource): - experiment_context = context._Context(**metadata_args) + experiment_context = context.Context(**metadata_args) self._validate_experiment_context(experiment_context) self._metadata_context = experiment_context @staticmethod - def _validate_experiment_context(experiment_context: context._Context): + def _validate_experiment_context(experiment_context: context.Context): """Validates this context is an experiment context. Args: @@ -146,7 +146,7 @@ def _validate_experiment_context(experiment_context: context._Context): ) @staticmethod - def _is_tensorboard_experiment(context: context._Context) -> bool: + def _is_tensorboard_experiment(context: context.Context) -> bool: """Returns True if Experiment is a Tensorboard Experiment created by CustomJob.""" return constants.TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD in context.metadata @@ -192,7 +192,7 @@ def create( ) with _SetLoggerLevel(resource): - experiment_context = context._Context._create( + experiment_context = context.Context._create( resource_id=experiment_name, display_name=experiment_name, description=description, @@ -248,7 +248,7 @@ def get_or_create( ) with _SetLoggerLevel(resource): - experiment_context = context._Context.get_or_create( + experiment_context = context.Context.get_or_create( resource_id=experiment_name, display_name=experiment_name, description=description, @@ -303,7 +303,7 @@ def list( ) with _SetLoggerLevel(resource): - experiment_contexts = context._Context.list( + experiment_contexts = context.Context.list( filter=filter_str, project=project, location=location, @@ -341,7 +341,7 @@ def delete(self, *, delete_backing_tensorboard_runs: bool = False): runs under this experiment that we used to store time series metrics. """ - experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context._Context][ + experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context.Context][ constants.SYSTEM_EXPERIMENT_RUN ].list(experiment=self) for experiment_run in experiment_runs: @@ -380,11 +380,11 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821 filter_str = metadata_utils._make_filter_string( schema_title=sorted( - list(_SUPPORTED_LOGGABLE_RESOURCES[context._Context].keys()) + list(_SUPPORTED_LOGGABLE_RESOURCES[context.Context].keys()) ), parent_contexts=[self._metadata_context.resource_name], ) - contexts = context._Context.list(filter_str, **service_request_args) + contexts = context.Context.list(filter_str, **service_request_args) filter_str = metadata_utils._make_filter_string( schema_title=list( @@ -398,7 +398,7 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821 rows = [] for metadata_context in contexts: row_dict = ( - _SUPPORTED_LOGGABLE_RESOURCES[context._Context][ + _SUPPORTED_LOGGABLE_RESOURCES[context.Context][ metadata_context.schema_title ] ._query_experiment_row(metadata_context) @@ -568,7 +568,7 @@ class _VertexResourceWithMetadata(NamedTuple): """Represents a resource coupled with it's metadata representation""" resource: base.VertexAiResourceNoun - metadata: Union[artifact.Artifact, execution.Execution, context._Context] + metadata: Union[artifact.Artifact, execution.Execution, context.Context] class _ExperimentLoggableSchema(NamedTuple): @@ -581,7 +581,7 @@ class _ExperimentLoggableSchema(NamedTuple): """ title: str - type: Union[Type[context._Context], Type[execution.Execution]] = context._Context + type: Union[Type[context.Context], Type[execution.Execution]] = context.Context class _ExperimentLoggable(abc.ABC): @@ -618,7 +618,7 @@ class PipelineJob(..., experiment_loggable_schemas= _SUPPORTED_LOGGABLE_RESOURCES[schema.type][schema.title] = cls @abc.abstractmethod - def _get_context(self) -> context._Context: + def _get_context(self) -> context.Context: """Should return the metadata context that represents this resource. The subclass should enforce this context exists. @@ -631,7 +631,7 @@ def _get_context(self) -> context._Context: @classmethod @abc.abstractmethod def _query_experiment_row( - cls, node: Union[context._Context, execution.Execution] + cls, node: Union[context.Context, execution.Execution] ) -> _ExperimentRow: """Should return parameters and metrics for this resource as a run row. @@ -716,6 +716,6 @@ def _associate_to_experiment(self, experiment: Union[str, Experiment]): # Context -> 'system.ExperimentRun' -> aiplatform.ExperimentRun # Execution -> 'system.Run' -> aiplatform.ExperimentRun _SUPPORTED_LOGGABLE_RESOURCES: Dict[ - Union[Type[context._Context], Type[execution.Execution]], + Union[Type[context.Context], Type[execution.Execution]], Dict[str, _ExperimentLoggable], -] = {execution.Execution: dict(), context._Context: dict()} +] = {execution.Execution: dict(), context.Context: dict()} diff --git a/google/cloud/aiplatform/metadata/experiment_run_resource.py b/google/cloud/aiplatform/metadata/experiment_run_resource.py index 7d1c5f465e..055bd8d981 100644 --- a/google/cloud/aiplatform/metadata/experiment_run_resource.py +++ b/google/cloud/aiplatform/metadata/experiment_run_resource.py @@ -78,7 +78,7 @@ class ExperimentRun( experiment_resources._ExperimentLoggable, experiment_loggable_schemas=( experiment_resources._ExperimentLoggableSchema( - title=constants.SYSTEM_EXPERIMENT_RUN, type=context._Context + title=constants.SYSTEM_EXPERIMENT_RUN, type=context.Context ), # backwards compatibility with Preview Experiment runs experiment_resources._ExperimentLoggableSchema( @@ -136,9 +136,9 @@ def __init__( credentials=credentials, ) - def _get_context() -> context._Context: + def _get_context() -> context.Context: with experiment_resources._SetLoggerLevel(resource): - run_context = context._Context( + run_context = context.Context( **{**metadata_args, "resource_name": run_id} ) if run_context.schema_title != constants.SYSTEM_EXPERIMENT_RUN: @@ -212,7 +212,7 @@ def _v1_format_artifact_name(run_id: str) -> str: """Formats resource id of legacy metric artifact for this run.""" return f"{run_id}-metrics" - def _get_context(self) -> context._Context: + def _get_context(self) -> context.Context: """Returns this metadata context that represents this run. Returns: @@ -427,7 +427,7 @@ def list( parent_contexts=[experiment.resource_name], ) - run_contexts = context._Context.list(filter=filter_str, **metadata_args) + run_contexts = context.Context.list(filter=filter_str, **metadata_args) filter_str = metadata_utils._make_filter_string( schema_title=constants.SYSTEM_RUN, in_context=[experiment.resource_name] @@ -435,7 +435,7 @@ def list( run_executions = execution.Execution.list(filter=filter_str, **metadata_args) - def _initialize_experiment_run(context: context._Context) -> ExperimentRun: + def _initialize_experiment_run(context: context.Context) -> ExperimentRun: this_experiment_run = cls.__new__(cls) this_experiment_run._experiment = experiment this_experiment_run._run_name = context.display_name @@ -489,7 +489,7 @@ def _initialize_v1_experiment_run( @classmethod def _query_experiment_row( - cls, node: Union[context._Context, execution.Execution] + cls, node: Union[context.Context, execution.Execution] ) -> experiment_resources._ExperimentRow: """Retrieves the runs metric and parameters into an experiment run row. @@ -507,7 +507,7 @@ def _query_experiment_row( name=node.display_name, ) - if isinstance(node, context._Context): + if isinstance(node, context.Context): this_experiment_run._backing_tensorboard_run = ( this_experiment_run._lookup_tensorboard_run_artifact() ) @@ -526,7 +526,7 @@ def _query_experiment_row( row.state = node.state.name return row - def _get_logged_pipeline_runs(self) -> List[context._Context]: + def _get_logged_pipeline_runs(self) -> List[context.Context]: """Returns Pipeline Run contexts logged to this Experiment Run. Returns: @@ -544,7 +544,7 @@ def _get_logged_pipeline_runs(self) -> List[context._Context]: parent_contexts=[self._metadata_node.resource_name], ) - return context._Context.list(filter=filter_str, **service_request_args) + return context.Context.list(filter=filter_str, **service_request_args) def _get_latest_time_series_metric_columns(self) -> Dict[str, Union[float, int]]: """Determines the latest step for each time series metric. @@ -666,7 +666,7 @@ def create( def _create_context(): with experiment_resources._SetLoggerLevel(resource): - return context._Context._create( + return context.Context._create( resource_id=run_id, display_name=run_name, schema_title=constants.SYSTEM_EXPERIMENT_RUN, diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py index f321a622b3..6f67a6ddf6 100644 --- a/google/cloud/aiplatform/metadata/metadata.py +++ b/google/cloud/aiplatform/metadata/metadata.py @@ -95,7 +95,7 @@ def _get_experiment_or_pipeline_resource_name( NotFound exception if experiment or pipeline does not exist. """ - this_context = context._Context(resource_name=name) + this_context = context.Context(resource_name=name) if this_context.schema_title != expected_schema: raise ValueError( diff --git a/google/cloud/aiplatform/metadata/schema/base_context.py b/google/cloud/aiplatform/metadata/schema/base_context.py new file mode 100644 index 0000000000..a39835da00 --- /dev/null +++ b/google/cloud/aiplatform/metadata/schema/base_context.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc + +from typing import Optional, Dict + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform.metadata import constants +from google.cloud.aiplatform.metadata import context + + +class BaseContextSchema(metaclass=abc.ABCMeta): + """Base class for Metadata Context schema.""" + + @property + @classmethod + @abc.abstractmethod + def schema_title(cls) -> str: + """Identifies the Vertex Metadta schema title used by the resource.""" + pass + + def __init__( + self, + *, + context_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + ): + + """Initializes the Context with the given name, URI and metadata. + + Args: + context_id (str): + Optional. The portion of the Context name with + the following format, this is globally unique in a metadataStore. + projects/123/locations/us-central1/metadataStores//Contexts/. + display_name (str): + Optional. The user-defined name of the Context. + schema_version (str): + Optional. schema_version specifies the version used by the Context. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Context. + description (str): + Optional. Describes the purpose of the Context to be created. + """ + self.context_id = context_id + self.display_name = display_name + self.schema_version = schema_version or constants._DEFAULT_SCHEMA_VERSION + self.metadata = metadata + self.description = description + + def create( + self, + *, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "context.Context": + """Creates a new Metadata Context. + + Args: + metadata_store_id (str): + Optional. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores//Contexts/ + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Optional. Project used to create this Context. Overrides project set in + aiplatform.init. + location (str): + Optional. Location used to create this Context. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials used to create this Context. Overrides + credentials set in aiplatform.init. + Returns: + Context: Instantiated representation of the managed Metadata Context. + + """ + return context.Context.create( + resource_id=self.context_id, + schema_title=self.schema_title, + display_name=self.display_name, + schema_version=self.schema_version, + description=self.description, + metadata=self.metadata, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/metadata/schema/system/context_schema.py b/google/cloud/aiplatform/metadata/schema/system/context_schema.py new file mode 100644 index 0000000000..940c63ff26 --- /dev/null +++ b/google/cloud/aiplatform/metadata/schema/system/context_schema.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +from typing import Optional, Dict + +from google.cloud.aiplatform.metadata.schema import base_context + + +class Experiment(base_context.BaseContextSchema): + """Context schema for a Experiment context.""" + + schema_title = "system.Experiment" + + def __init__( + self, + *, + context_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + ): + """Args: + context_id (str): + Optional. The portion of the context name with + the following format, this is globally unique in a metadataStore. + projects/123/locations/us-central1/metadataStores//contexts/. + display_name (str): + Optional. The user-defined name of the context. + schema_version (str): + Optional. schema_version specifies the version used by the context. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the context. + description (str): + Optional. Describes the purpose of the context to be created. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + super(Experiment, self).__init__( + context_id=context_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + ) + + +class ExperimentRun(base_context.BaseContextSchema): + """Context schema for a ExperimentRun context.""" + + schema_title = "system.ExperimentRun" + + def __init__( + self, + *, + experiment_id: Optional[str] = None, + context_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + ): + """Args: + experiment_id (str): + Optional. The experiment_id that this experiment_run belongs to. + context_id (str): + Optional. The portion of the context name with + the following format, this is globally unique in a metadataStore. + projects/123/locations/us-central1/metadataStores//contexts/. + display_name (str): + Optional. The user-defined name of the context. + schema_version (str): + Optional. schema_version specifies the version used by the context. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the context. + description (str): + Optional. Describes the purpose of the context to be created. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + extended_metadata["experiment_id"] = experiment_id + super(ExperimentRun, self).__init__( + context_id=context_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + ) + + +class Pipeline(base_context.BaseContextSchema): + """Context schema for a Pipeline context.""" + + schema_title = "system.Pipeline" + + def __init__( + self, + *, + context_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + ): + """Args: + context_id (str): + Optional. The portion of the context name with + the following format, this is globally unique in a metadataStore. + projects/123/locations/us-central1/metadataStores//contexts/. + display_name (str): + Optional. The user-defined name of the context. + schema_version (str): + Optional. schema_version specifies the version used by the context. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the context. + description (str): + Optional. Describes the purpose of the context to be created. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + super(Pipeline, self).__init__( + context_id=context_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + ) + + +class PipelineRun(base_context.BaseContextSchema): + """Context schema for a PipelineRun context.""" + + schema_title = "system.PipelineRun" + + def __init__( + self, + *, + pipeline_id: Optional[str] = None, + context_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + ): + """Args: + pipeline_id (str): + Optional. PipelineJob resource name corresponding to this run. + context_id (str): + Optional. The portion of the context name with + the following format, this is globally unique in a metadataStore. + projects/123/locations/us-central1/metadataStores//contexts/. + display_name (str): + Optional. The user-defined name of the context. + schema_version (str): + Optional. schema_version specifies the version used by the context. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the context. + description (str): + Optional. Describes the purpose of the context to be created. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + extended_metadata["pipeline_id"] = pipeline_id + super(PipelineRun, self).__init__( + context_id=context_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + ) diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index e4d93d1e92..31adf372f3 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -554,7 +554,7 @@ def _has_failed(self) -> bool: return self.state in _PIPELINE_ERROR_STATES - def _get_context(self) -> context._Context: + def _get_context(self) -> context.Context: """Returns the PipelineRun Context for this PipelineJob in the MetadataStore. Returns: @@ -583,7 +583,7 @@ def _get_context(self) -> context._Context: "Cannot associate PipelineJob to Experiment because PipelineJob context could not be found." ) - return context._Context( + return context.Context( resource=pipeline_run_context, project=self.project, location=self.location, @@ -592,7 +592,7 @@ def _get_context(self) -> context._Context: @classmethod def _query_experiment_row( - cls, node: context._Context + cls, node: context.Context ) -> experiment_resources._ExperimentRow: """Queries the PipelineJob metadata as an experiment run parameter and metric row. @@ -924,7 +924,7 @@ def get_associated_experiment(self) -> Optional["aiplatform.Experiment"]: ) pipeline_experiment_resources = [ - context._Context(resource_name=c)._gca_resource + context.Context(resource_name=c)._gca_resource for c in pipeline_parent_contexts if c != self._gca_resource.job_detail.pipeline_context.name ] diff --git a/tests/unit/aiplatform/test_metadata_resources.py b/tests/unit/aiplatform/test_metadata_resources.py index 66d4799ee1..1844399e7f 100644 --- a/tests/unit/aiplatform/test_metadata_resources.py +++ b/tests/unit/aiplatform/test_metadata_resources.py @@ -371,14 +371,14 @@ def teardown_method(self): def test_init_context(self, get_context_mock): aiplatform.init(project=_TEST_PROJECT) - context._Context(resource_name=_TEST_CONTEXT_NAME) + context.Context(resource_name=_TEST_CONTEXT_NAME) get_context_mock.assert_called_once_with( name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY ) def test_init_context_with_id(self, get_context_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - context._Context( + context.Context( resource_name=_TEST_CONTEXT_ID, metadata_store_id=_TEST_METADATA_STORE ) get_context_mock.assert_called_once_with( @@ -390,7 +390,7 @@ def test_get_or_create_context( ): aiplatform.init(project=_TEST_PROJECT) - my_context = context._Context.get_or_create( + my_context = context.Context.get_or_create( resource_id=_TEST_CONTEXT_ID, schema_title=_TEST_SCHEMA_TITLE, display_name=_TEST_DISPLAY_NAME, @@ -424,7 +424,7 @@ def test_get_or_create_context( def test_update_context(self, update_context_mock): aiplatform.init(project=_TEST_PROJECT) - my_context = context._Context._create( + my_context = context.Context._create( resource_id=_TEST_CONTEXT_ID, schema_title=_TEST_SCHEMA_TITLE, display_name=_TEST_DISPLAY_NAME, @@ -452,7 +452,7 @@ def test_list_contexts(self, list_contexts_mock): aiplatform.init(project=_TEST_PROJECT) filter = "test-filter" - context_list = context._Context.list( + context_list = context.Context.list( filter=filter, metadata_store_id=_TEST_METADATA_STORE ) @@ -481,7 +481,7 @@ def test_add_artifacts_and_executions( ): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - my_context = context._Context.get_or_create( + my_context = context.Context.get_or_create( resource_id=_TEST_CONTEXT_ID, schema_title=_TEST_SCHEMA_TITLE, display_name=_TEST_DISPLAY_NAME, @@ -504,7 +504,7 @@ def test_add_artifacts_and_executions( @pytest.mark.usefixtures("get_context_mock") def test_add_artifacts_only(self, add_context_artifacts_and_executions_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - my_context = context._Context.get_or_create( + my_context = context.Context.get_or_create( resource_id=_TEST_CONTEXT_ID, schema_title=_TEST_SCHEMA_TITLE, display_name=_TEST_DISPLAY_NAME, @@ -526,7 +526,7 @@ def test_add_artifacts_only(self, add_context_artifacts_and_executions_mock): @pytest.mark.usefixtures("get_context_mock") def test_add_executions_only(self, add_context_artifacts_and_executions_mock): aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) - my_context = context._Context.get_or_create( + my_context = context.Context.get_or_create( resource_id=_TEST_CONTEXT_ID, schema_title=_TEST_SCHEMA_TITLE, display_name=_TEST_DISPLAY_NAME, diff --git a/tests/unit/aiplatform/test_metadata_schema.py b/tests/unit/aiplatform/test_metadata_schema.py index 0dd97fdcbc..b133d5eebe 100644 --- a/tests/unit/aiplatform/test_metadata_schema.py +++ b/tests/unit/aiplatform/test_metadata_schema.py @@ -29,12 +29,16 @@ from google.cloud.aiplatform.metadata import metadata from google.cloud.aiplatform.metadata.schema import base_artifact from google.cloud.aiplatform.metadata.schema import base_execution +from google.cloud.aiplatform.metadata.schema import base_context from google.cloud.aiplatform.metadata.schema.google import ( artifact_schema as google_artifact_schema, ) from google.cloud.aiplatform.metadata.schema.system import ( artifact_schema as system_artifact_schema, ) +from google.cloud.aiplatform.metadata.schema.system import ( + context_schema as system_context_schema, +) from google.cloud.aiplatform.metadata.schema.system import ( execution_schema as system_execution_schema, ) @@ -73,6 +77,10 @@ _TEST_EXECUTION_ID = "test-execution-id" _TEST_EXECUTION_NAME = f"{_TEST_PARENT}/executions/{_TEST_EXECUTION_ID}" +# context +_TEST_CONTEXT_ID = "test-context-id" +_TEST_CONTEXT_NAME = f"{_TEST_PARENT}/contexts/{_TEST_CONTEXT_ID}" + @pytest.fixture def create_artifact_mock(): @@ -106,6 +114,20 @@ def create_execution_mock(): yield create_execution_mock +@pytest.fixture +def create_context_mock(): + with patch.object(MetadataServiceClient, "create_context") as create_context_mock: + create_context_mock.return_value = GapicExecution( + name=_TEST_CONTEXT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + ) + yield create_context_mock + + @pytest.mark.usefixtures("google_auth_mock") class TestMetadataBaseArtifactSchema: def setup_method(self): @@ -521,6 +543,66 @@ def test_system_run_execution_schema_title_is_set_correctly(self): assert execution.schema_title == "system.Run" +@pytest.mark.usefixtures("google_auth_mock") +class TestMetadataSystemSchemaContext: + def setup_method(self): + reload(initializer) + reload(metadata) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + # Test system.Context Schemas + @pytest.mark.usefixtures("create_context_mock") + def test_create_is_called_with_default_parameters(self, create_context_mock): + aiplatform.init(project=_TEST_PROJECT) + + class TestContext(base_context.BaseContextSchema): + schema_title = _TEST_SCHEMA_TITLE + + context = TestContext( + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + context.create(metadata_store_id=_TEST_METADATA_STORE) + create_context_mock.assert_called_once_with( + parent=f"{_TEST_PARENT}/metadataStores/{_TEST_METADATA_STORE}", + context=mock.ANY, + context_id=None, + ) + _, _, kwargs = create_context_mock.mock_calls[0] + assert kwargs["context"].schema_title == _TEST_SCHEMA_TITLE + assert kwargs["context"].display_name == _TEST_DISPLAY_NAME + assert kwargs["context"].description == _TEST_DESCRIPTION + assert kwargs["context"].metadata == _TEST_UPDATED_METADATA + + def test_system_experiment_schema_title_is_set_correctly(self): + context = system_context_schema.Experiment() + assert context.schema_title == "system.Experiment" + + def test_system_experiment_run_schema_title_is_set_correctly(self): + context = system_context_schema.ExperimentRun() + assert context.schema_title == "system.ExperimentRun" + + def test_system_experiment_run_parameters_are_set_correctly(self): + context = system_context_schema.ExperimentRun(experiment_id=_TEST_CONTEXT_ID) + assert context.metadata["experiment_id"] == _TEST_CONTEXT_ID + + def test_system_pipeline_schema_title_is_set_correctly(self): + context = system_context_schema.Pipeline() + assert context.schema_title == "system.Pipeline" + + def test_system_pipeline_run_schema_title_is_set_correctly(self): + context = system_context_schema.PipelineRun() + assert context.schema_title == "system.PipelineRun" + + def test_system_pipeline_run_parameters_are_set_correctly(self): + context = system_context_schema.PipelineRun(pipeline_id=_TEST_CONTEXT_ID) + assert context.metadata["pipeline_id"] == _TEST_CONTEXT_ID + + @pytest.mark.usefixtures("google_auth_mock") class TestMetadataUtils: def setup_method(self): diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 6651990c5a..811940175c 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -74,8 +74,8 @@ def test_invalid_region_does_not_raise_with_valid_region(): ( "contexts", "123456", - aiplatform.metadata.context._Context._parse_resource_name, - aiplatform.metadata.context._Context._format_resource_name, + aiplatform.metadata.context.Context._parse_resource_name, + aiplatform.metadata.context.Context._format_resource_name, { aiplatform.metadata.metadata_store._MetadataStore._resource_noun: "default" }, @@ -147,8 +147,8 @@ def test_full_resource_name_with_full_name( ( "123", "contexts", - aiplatform.metadata.context._Context._parse_resource_name, - aiplatform.metadata.context._Context._format_resource_name, + aiplatform.metadata.context.Context._parse_resource_name, + aiplatform.metadata.context.Context._format_resource_name, { aiplatform.metadata.metadata_store._MetadataStore._resource_noun: "default" },