diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index a1b267db278..0a8b2df1509 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -453,7 +453,7 @@ def _group_to_template(self, group, inputs, outputs, dependencies): def _create_templates(self, pipeline, op_transformers=None, op_to_templates_handler=None): """Create all groups and ops templates in the pipeline. - + Args: pipeline: Pipeline context object to get all the pipeline data from. op_transformers: A list of functions that are applied to all ContainerOp instances that are being processed. @@ -463,6 +463,13 @@ def _create_templates(self, pipeline, op_transformers=None, op_to_templates_hand op_to_templates_handler = op_to_templates_handler or (lambda op : [_op_to_template(op)]) new_root_group = pipeline.groups[0] + # Call the transformation functions before determining the inputs/outputs, otherwise + # the user would not be able to use pipeline parameters in the container definition + # (for example as pod labels) - the generated template is invalid. + for op in pipeline.ops.values(): + for transformer in op_transformers or []: + transformer(op) + # Generate core data structures to prepare for argo yaml generation # op_groups: op name -> list of ancestor groups including the current op # opsgroups: a dictionary of ospgroup.name -> opsgroup @@ -486,8 +493,6 @@ def _create_templates(self, pipeline, op_transformers=None, op_to_templates_hand templates.append(template) for op in pipeline.ops.values(): - for transformer in op_transformers or []: - op = transformer(op) or op templates.extend(op_to_templates_handler(op)) return templates diff --git a/sdk/python/tests/compiler/compiler_tests.py b/sdk/python/tests/compiler/compiler_tests.py index 00e6ec2f667..bf590fd9e61 100644 --- a/sdk/python/tests/compiler/compiler_tests.py +++ b/sdk/python/tests/compiler/compiler_tests.py @@ -366,6 +366,10 @@ def test_py_param_substitutions(self): """Test pipeline param_substitutions.""" self._test_py_compile_yaml('param_substitutions') + def test_py_param_op_transform(self): + """Test pipeline param_op_transform.""" + self._test_py_compile_yaml('param_op_transform') + def test_type_checking_with_consistent_types(self): """Test type check pipeline parameters against component metadata.""" @component @@ -471,7 +475,7 @@ def op(): def pipeline(): task1 = op() task2 = op().after(task1) - + compiler.Compiler()._compile(pipeline) def _test_op_to_template_yaml(self, ops, file_base_name): diff --git a/sdk/python/tests/compiler/testdata/param_op_transform.py b/sdk/python/tests/compiler/testdata/param_op_transform.py new file mode 100644 index 00000000000..bc7ba5193f0 --- /dev/null +++ b/sdk/python/tests/compiler/testdata/param_op_transform.py @@ -0,0 +1,28 @@ +from typing import Callable + +import kfp.dsl as dsl + +def add_common_labels(param): + + def _add_common_labels(op: dsl.ContainerOp) -> dsl.ContainerOp: + return op.add_pod_label('param', param) + + return _add_common_labels + +@dsl.pipeline( + name="Parameters in Op transformation functions", + description="Test that parameters used in Op transformation functions as pod labels " + "would be correcly identified and set as arguments in he generated yaml" +) +def param_substitutions(param = dsl.PipelineParam(name='param')): + dsl.get_pipeline_conf().op_transformers.append(add_common_labels(param)) + + op = dsl.ContainerOp( + name="cop", + image="image", + ) + + +if __name__ == '__main__': + import kfp.compiler as compiler + compiler.Compiler().compile(param_substitutions, __file__ + '.yaml') diff --git a/sdk/python/tests/compiler/testdata/param_op_transform.yaml b/sdk/python/tests/compiler/testdata/param_op_transform.yaml new file mode 100644 index 00000000000..a9f800c2acd --- /dev/null +++ b/sdk/python/tests/compiler/testdata/param_op_transform.yaml @@ -0,0 +1,40 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: parameters-in-op-transformation-functions- +spec: + arguments: + parameters: + - name: param + entrypoint: parameters-in-op-transformation-functions + serviceAccountName: pipeline-runner + templates: + - container: + image: image + inputs: + parameters: + - name: param + metadata: + labels: + param: '{{inputs.parameters.param}}' + name: cop + outputs: + artifacts: + - name: mlpipeline-ui-metadata + optional: true + path: /mlpipeline-ui-metadata.json + - name: mlpipeline-metrics + optional: true + path: /mlpipeline-metrics.json + - dag: + tasks: + - arguments: + parameters: + - name: param + value: '{{inputs.parameters.param}}' + name: cop + template: cop + inputs: + parameters: + - name: param + name: parameters-in-op-transformation-functions