Skip to content

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 304066804
  • Loading branch information
tfx-copybara authored and tensorflow-extended-team committed Mar 31, 2020
1 parent f581530 commit bfbc958
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 0 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
delete and create new ones when pipelines are updated for KFP. (Requires
kfp >= 0.3.0)
* Added ability to enable quantization in tflite rewriter.
* Added k8s pod labels when the pipeline is executed via KubeflowDagRunner for
better usage telemetry.

### Deprecations

Expand Down
3 changes: 3 additions & 0 deletions tfx/experimental/templates/taxi/kubeflow_dag_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def run():
tfx_image=tfx_image
)

# Set the SDK type label environment.
os.environ[kubeflow_dag_runner.SDK_ENV_LABEL] = 'tfx-template'

kubeflow_dag_runner.KubeflowDagRunner(config=runner_config).run(
pipeline.create_pipeline(
pipeline_name=configs.PIPELINE_NAME,
Expand Down
19 changes: 19 additions & 0 deletions tfx/orchestration/kubeflow/kubeflow_dag_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import re
from typing import Callable, List, Optional, Text, Type
import uuid

from kfp import compiler
from kfp import dsl
Expand Down Expand Up @@ -55,6 +56,14 @@
# Default TFX container image to use in KubeflowDagRunner.
_KUBEFLOW_TFX_IMAGE = 'tensorflow/tfx:%s' % (version.__version__)

# The pod label indicating the SDK environment.
# LINT.IfChange
SDK_ENV_LABEL = 'pipelines.kubeflow.org/pipeline-sdk-type'
# LINT.ThenChange(../../tools/cli/handler/base_handler.py)

# The pod label of pipeline unique ID.
PIPELINE_UUID_LABEL = 'pipelines.kubeflow.org/pipeline-uuid'


def _mount_config_map_op(config_map_name: Text) -> OpFunc:
"""Mounts all key-value pairs found in the named Kubernetes ConfigMap.
Expand Down Expand Up @@ -235,6 +244,9 @@ def __init__(
self._compiler = compiler.Compiler()
self._params = [] # List of dsl.PipelineParam used in this pipeline.
self._deduped_parameter_names = set() # Set of unique param names used.
# Set the SDK environment label. This is hold off from user interface
# intentionally. Default to TFX.
self._sdk_env = os.getenv(SDK_ENV_LABEL) or 'tfx'

def _parse_parameter_from_component(
self, component: base_component.BaseComponent) -> None:
Expand Down Expand Up @@ -308,6 +320,11 @@ def _construct_pipeline_graph(self, pipeline: tfx_pipeline.Pipeline,
for operator in self._config.pipeline_operator_funcs:
kfp_component.container_op.apply(operator)

kfp_component.container_op.add_pod_label(SDK_ENV_LABEL, self._sdk_env)
assert self._pipeline_id, 'Failed to generate pipeline ID.'
kfp_component.container_op.add_pod_label(PIPELINE_UUID_LABEL,
self._pipeline_id)

component_to_kfp_op[component] = kfp_component.container_op

def run(self, pipeline: tfx_pipeline.Pipeline):
Expand All @@ -322,6 +339,8 @@ def run(self, pipeline: tfx_pipeline.Pipeline):
dsl_pipeline_root = dsl.PipelineParam(
name=pipeline_root.name, value=pipeline.pipeline_info.pipeline_root)
self._params.append(dsl_pipeline_root)
# Randomly generates pipeline id.
self._pipeline_id = str(uuid.uuid4())

def _construct_pipeline():
"""Constructs a Kubeflow pipeline.
Expand Down
9 changes: 9 additions & 0 deletions tfx/orchestration/kubeflow/kubeflow_dag_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def testTwoStepPipeline(self):
]
self.assertEqual(1, len(statistics_gen_container))

# Ensure the pod labels are correctly appended.
metadata = [
c['metadata'] for c in pipeline['spec']['templates'] if 'dag' not in c
]
for m in metadata:
self.assertEqual('tfx', m['labels'][kubeflow_dag_runner.SDK_ENV_LABEL])
self.assertIsNotNone(
m['labels'][kubeflow_dag_runner.PIPELINE_UUID_LABEL])

# Ensure dependencies between components are captured.
dag = [c for c in pipeline['spec']['templates'] if 'dag' in c]
self.assertEqual(1, len(dag))
Expand Down
4 changes: 4 additions & 0 deletions tfx/tools/cli/handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def _extract_pipeline_args(self) -> Dict[Text, Any]:
# Store temp_file path in temp_env.
temp_env[labels.TFX_JSON_EXPORT_PIPELINE_ARGS_PATH] = temp_file

# Mark the SDK environment if not in a template.
if 'pipelines.kubeflow.org/pipeline-sdk-type' not in temp_env:
temp_env['pipelines.kubeflow.org/pipeline-sdk-type'] = 'tfx-cli'

# Run dsl with mock environment to store pipeline args in temp_file.
self._subprocess_call([sys.executable, pipeline_dsl_path], env=temp_env)
if os.stat(temp_file).st_size != 0:
Expand Down

0 comments on commit bfbc958

Please sign in to comment.