diff --git a/sdk/python/kfp/compiler/_default_transformers.py b/sdk/python/kfp/compiler/_default_transformers.py index bb33d212668..4c59555cfcf 100644 --- a/sdk/python/kfp/compiler/_default_transformers.py +++ b/sdk/python/kfp/compiler/_default_transformers.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Callable, Dict, Optional, Text + from ..dsl._container_op import BaseOp, ContainerOp # Pod label indicating the SDK type from which the pipeline is @@ -20,6 +22,16 @@ _SDK_ENV_LABEL = 'pipelines.kubeflow.org/pipeline-sdk-type' _SDK_ENV_DEFAULT = 'kfp' +# Common prefix of KFP OOB components url paths. +_OOB_COMPONENT_PATH_PREFIX = 'https://raw.githubusercontent.com/kubeflow/'\ + 'pipelines' + +# Key for component origin path pod label. +COMPONENT_PATH_LABEL_KEY = 'pipelines.kubeflow.org/component_origin_path' + +# Key for component spec digest pod label. +COMPONENT_DIGEST_LABEL_KEY = 'pipelines.kubeflow.org/component_digest' + def get_default_telemetry_labels() -> Dict[Text, Text]: """Returns the default pod labels for telemetry purpose.""" @@ -68,3 +80,40 @@ def _add_pod_labels(task): return task return _add_pod_labels + + +def _remove_suffix(string: Text, suffix: Text) -> Text: + """Removes the suffix from a string.""" + if suffix and string.endswith(suffix): + return string[:-len(suffix)] + else: + return string + + +def add_name_for_oob_components() -> Callable: + """Adds the OOB component name if applicable.""" + + def _add_name_for_oob_components(task): + # Detect the component origin uri in component_ref if exists, and + # attach the OOB component name as a pod label. + component_ref = getattr(task, '_component_ref', None) + if component_ref: + if component_ref.url: + origin_path = _remove_suffix( + component_ref.url, 'component.yaml').rstrip('/') + # Only include KFP OOB components. + if origin_path.startswith(_OOB_COMPONENT_PATH_PREFIX): + origin_path = origin_path.split('/', 7)[-1] + else: + return task + # Clean the label to comply with the k8s label convention. + origin_path = re.sub('[^-a-z0-9A-Z_.]', '.', origin_path) + origin_path_label = origin_path[-63:].strip('-_.') + task.add_pod_label(COMPONENT_PATH_LABEL_KEY, origin_path_label) + if component_ref.digest: + task.add_pod_label( + COMPONENT_DIGEST_LABEL_KEY, component_ref.digest) + + return task + + return _add_name_for_oob_components \ No newline at end of file diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index d6cbddb203f..bb70b45d488 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -27,7 +27,7 @@ from .. import dsl from ._k8s_helper import convert_k8s_obj_to_json, sanitize_k8s_name from ._op_to_template import _op_to_template -from ._default_transformers import add_pod_env, add_pod_labels, get_default_telemetry_labels +from ._default_transformers import add_pod_env, add_pod_labels, add_name_for_oob_components, get_default_telemetry_labels from ..components.structures import InputSpec from ..components._yaml_utils import dump_yaml @@ -836,6 +836,7 @@ def _create_workflow(self, if allow_telemetry: pod_labels = get_default_telemetry_labels() op_transformers.append(add_pod_labels(pod_labels)) + op_transformers.append(add_name_for_oob_components()) op_transformers.extend(pipeline_conf.op_transformers) diff --git a/sdk/python/tests/compiler/compiler_tests.py b/sdk/python/tests/compiler/compiler_tests.py index 593e5053622..3a175ed15f6 100644 --- a/sdk/python/tests/compiler/compiler_tests.py +++ b/sdk/python/tests/compiler/compiler_tests.py @@ -26,6 +26,8 @@ import unittest import yaml +from kfp import components +from kfp.compiler._default_transformers import COMPONENT_DIGEST_LABEL_KEY, COMPONENT_PATH_LABEL_KEY from kfp.dsl._component import component from kfp.dsl import ContainerOp, pipeline from kfp.dsl.types import Integer, InconsistentTypeException @@ -40,6 +42,11 @@ def some_op(): command=['sleep 1'], ) +_TEST_GCS_DOWNLOAD_COMPONENT_URL = 'https://raw.githubusercontent.com/kubeflow/'\ + 'pipelines/2dac60c400ad8767b452649d08f328df'\ + 'af230f96/components/google-cloud/storage/'\ + 'download/component.yaml' + class TestCompiler(unittest.TestCase): # Define the places of samples covered by unit tests. @@ -711,6 +718,27 @@ def some_pipeline(): container = template.get('container', None) if container: self.assertEqual(template['retryStrategy']['limit'], 5) + + def test_oob_component_label(self): + gcs_download_op = components.load_component_from_url( + _TEST_GCS_DOWNLOAD_COMPONENT_URL) + + @dsl.pipeline(name='some_pipeline') + def some_pipeline(): + _download_task = gcs_download_op('gs://some_bucket/some_dir/some_file') + + workflow_dict = compiler.Compiler()._compile(some_pipeline) + + found_download_task = False + for template in workflow_dict['spec']['templates']: + if template.get('container', None): + found_download_task = True + self.assertEqual( + template['metadata']['labels'][COMPONENT_PATH_LABEL_KEY], + 'google-cloud.storage.download') + self.assertIsNotNone( + template['metadata']['labels'].get(COMPONENT_DIGEST_LABEL_KEY)) + self.assertTrue(found_download_task, 'download task not found in workflow.') def test_image_pull_policy(self): def some_op():