diff --git a/sdk/python/kfp/components/_python_op.py b/sdk/python/kfp/components/_python_op.py index e6b597976e3..0304a2daa00 100644 --- a/sdk/python/kfp/components/_python_op.py +++ b/sdk/python/kfp/components/_python_op.py @@ -463,11 +463,11 @@ def get_serializer_and_register_definitions(type_name) -> str: line = '_parser.add_argument("{param_flag}", dest="{param_var}", type=str, nargs={nargs})'.format( param_flag=param_flag, param_var=output_param_var, - nargs=len(component_spec.outputs), + nargs=len(outputs_passed_through_func_return_tuple), ) arg_parse_code_lines.append(line) arguments.append(param_flag) - arguments.extend(OutputPathPlaceholder(output.name) for output in component_spec.outputs) + arguments.extend(OutputPathPlaceholder(output.name) for output in outputs_passed_through_func_return_tuple) output_serialization_expression_strings = [] for output in outputs_passed_through_func_return_tuple: diff --git a/sdk/python/tests/components/test_python_op.py b/sdk/python/tests/components/test_python_op.py index e899a5a55a1..5c8a4044697 100644 --- a/sdk/python/tests/components/test_python_op.py +++ b/sdk/python/tests/components/test_python_op.py @@ -17,7 +17,7 @@ import unittest from contextlib import contextmanager from pathlib import Path -from typing import Callable, Sequence +from typing import Callable, NamedTuple, Sequence import kfp import kfp.components as comp @@ -631,6 +631,68 @@ def write_to_file_path(number_file: OutputBinaryFile(int)): self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'number': '42'}) + def test_output_path_plus_return_value(self): + from kfp.components import OutputPath + def write_to_file_path(number_file_path: OutputPath(int)) -> str: + with open(number_file_path, 'w') as f: + f.write(str(42)) + return 'Hello' + + task_factory = comp.func_to_container_op(write_to_file_path) + + self.assertFalse(task_factory.component_spec.inputs) + self.assertEqual(len(task_factory.component_spec.outputs), 2) + self.assertEqual(task_factory.component_spec.outputs[0].type, 'Integer') + self.assertEqual(task_factory.component_spec.outputs[1].type, 'String') + + self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'number': '42', 'output': 'Hello'}) + + + def test_all_data_passing_ways(self): + from kfp.components import InputTextFile, InputPath, OutputTextFile, OutputPath + def write_to_file_path( + file_input1_path: InputPath(str), + file_input2_file: InputTextFile(str), + file_output1_path: OutputPath(str), + file_output2_file: OutputTextFile(str), + value_input1: str = 'foo', + value_input2: str = 'foo', + ) -> NamedTuple( + 'Outputs', [ + ('return_output1', str), + ('return_output2', str), + ] + ): + with open(file_input1_path, 'r') as file_input1_file: + with open(file_output1_path, 'w') as file_output1_file: + file_output1_file.write(file_input1_file.read()) + + file_output2_file.write(file_input2_file.read()) + + return (value_input1, value_input2) + + task_factory = comp.func_to_container_op(write_to_file_path) + + self.assertEqual(set(input.name for input in task_factory.component_spec.inputs), {'file_input1', 'file_input2', 'value_input1', 'value_input2'}) + self.assertEqual(set(output.name for output in task_factory.component_spec.outputs), {'file_output1', 'file_output2', 'return_output1', 'return_output2'}) + + self.helper_test_component_using_local_call( + task_factory, + arguments={ + 'file_input1': 'file_input1_value', + 'file_input2': 'file_input2_value', + 'value_input1': 'value_input1_value', + 'value_input2': 'value_input2_value', + }, + expected_output_values={ + 'file_output1': 'file_input1_value', + 'file_output2': 'file_input2_value', + 'return_output1': 'value_input1_value', + 'return_output2': 'value_input2_value', + }, + ) + + def test_file_input_name_conversion(self): # Checking the input name conversion rules for file inputs: # For InputPath, the "_path" suffix is removed