diff --git a/sdk/python/kfp/components/_python_op.py b/sdk/python/kfp/components/_python_op.py index ab675051c0f..a68b480437e 100644 --- a/sdk/python/kfp/components/_python_op.py +++ b/sdk/python/kfp/components/_python_op.py @@ -20,6 +20,9 @@ 'InputPath', 'InputTextFile', 'InputBinaryFile', + 'OutputPath', + 'OutputTextFile', + 'OutputBinaryFile', ] from ._yaml_utils import dump_yaml @@ -58,8 +61,36 @@ def __init__(self, type=None): #OutputFile[GcsPath[Gzipped[Text]]] -class OutputFile(Generic[T], str): - pass +class OutputPath: + '''When creating component from function, OutputPath should be used as function parameter annotation to tell the system that the function wants to output data by writing it into a file with the given path instead of returning the data from the function.''' + def __init__(self, type=None): + self.type = type + + +class OutputTextFile: + '''When creating component from function, OutputTextFile should be used as function parameter annotation to tell the system that the function wants to output data by writing it into a given text file stream (`io.TextIOWrapper`) instead of returning the data from the function.''' + def __init__(self, type=None): + self.type = type + + +class OutputBinaryFile: + '''When creating component from function, OutputBinaryFile should be used as function parameter annotation to tell the system that the function wants to output data by writing it into a given binary file stream (`io.BytesIO`) instead of returning the data from the function.''' + def __init__(self, type=None): + self.type = type + + +def _make_parent_dirs_and_return_path(file_path: str): + import os + os.makedirs(os.path.dirname(file_path), exist_ok=True) + return file_path + + +def _parent_dirs_maker_that_returns_open_file(mode: str, encoding: str = None): + def make_parent_dirs_and_return_path(file_path: str): + import os + os.makedirs(os.path.dirname(file_path), exist_ok=True) + return open(file_path, mode=mode, encoding=encoding) + return make_parent_dirs_and_return_path #TODO: Replace this image name with another name once people decide what to replace it with. @@ -204,23 +235,33 @@ def annotation_to_type_struct(annotation): for parameter in parameters: parameter_annotation = parameter.annotation passing_style = None - if isinstance(parameter_annotation, (InputPath, InputTextFile, InputBinaryFile)): + if isinstance(parameter_annotation, (InputPath, InputTextFile, InputBinaryFile, OutputPath, OutputTextFile, OutputBinaryFile)): passing_style = type(parameter_annotation) parameter_annotation = parameter_annotation.type + if parameter.default is not inspect.Parameter.empty: + raise ValueError('Default values for file inputs/outputs are not supported. If you need them for some reason, please create an issue and write about your usage scenario.') # TODO: Fix the input names: "number_file_path" parameter should be exposed as "number" input type_struct = annotation_to_type_struct(parameter_annotation) #TODO: Humanize the input/output names - input_spec = InputSpec( - name=parameter.name, - type=type_struct, - ) - if parameter.default is not inspect.Parameter.empty: - input_spec.optional = True - if parameter.default is not None: - input_spec.default = serialize_value(parameter.default, type_struct) - input_spec._passing_style = passing_style - inputs.append(input_spec) + if isinstance(parameter.annotation, (OutputPath, OutputTextFile, OutputBinaryFile)): + output_spec = OutputSpec( + name=parameter.name, + type=type_struct, + ) + output_spec._passing_style = passing_style + outputs.append(output_spec) + else: + input_spec = InputSpec( + name=parameter.name, + type=type_struct, + ) + if parameter.default is not inspect.Parameter.empty: + input_spec.optional = True + if parameter.default is not None: + input_spec.default = serialize_value(parameter.default, type_struct) + input_spec._passing_style = passing_style + inputs.append(input_spec) #Analyzing the return type annotations. return_ann = signature.return_annotation @@ -234,6 +275,7 @@ def annotation_to_type_struct(annotation): name=field_name, type=type_struct, ) + output_spec._passing_style = None outputs.append(output_spec) elif signature.return_annotation is not None and signature.return_annotation != inspect.Parameter.empty: type_struct = annotation_to_type_struct(signature.return_annotation) @@ -241,6 +283,7 @@ def annotation_to_type_struct(annotation): name=single_output_name_const, type=type_struct, ) + output_spec._passing_style = None outputs.append(output_spec) #Component name and description are derived from the function's name and docstribng, but can be overridden by @python_component function decorator @@ -304,16 +347,30 @@ def get_deserializer_and_register_definitions(type_name): pre_func_definitions = set() def get_argparse_type_for_input_file(passing_style): + if passing_style is None: + return None + pre_func_definitions.add(inspect.getsource(passing_style)) + if passing_style is InputPath: - pre_func_definitions.add(inspect.getsource(InputPath)) return 'str' elif passing_style is InputTextFile: - pre_func_definitions.add(inspect.getsource(InputTextFile)) return "argparse.FileType('rt')" elif passing_style is InputBinaryFile: - pre_func_definitions.add(inspect.getsource(InputBinaryFile)) return "argparse.FileType('rb')" - return None + # For Output* we cannot use the build-in argparse.FileType objects since they do not create parent directories. + elif passing_style is OutputPath: + # ~= return 'str' + pre_func_definitions.add(inspect.getsource(_make_parent_dirs_and_return_path)) + return _make_parent_dirs_and_return_path.__name__ + elif passing_style is OutputTextFile: + # ~= return "argparse.FileType('wt')" + pre_func_definitions.add(inspect.getsource(_parent_dirs_maker_that_returns_open_file)) + return _parent_dirs_maker_that_returns_open_file.__name__ + "('wt')" + elif passing_style is OutputBinaryFile: + # ~= return "argparse.FileType('wb')" + pre_func_definitions.add(inspect.getsource(_parent_dirs_maker_that_returns_open_file)) + return _parent_dirs_maker_that_returns_open_file.__name__ + "('wb')" + raise NotImplementedError('Unexpected data passing style: "{}".'.format(str(passing_style))) def get_serializer_and_register_definitions(type_name) -> str: if type_name in type_name_to_serializer: @@ -333,10 +390,12 @@ def get_serializer_and_register_definitions(type_name) -> str: description_repr=repr(component_spec.description or ''), ), ] + outputs_passed_through_func_return_tuple = [output for output in (component_spec.outputs or []) if output._passing_style is None] + file_outputs_passed_using_func_parameters = [output for output in (component_spec.outputs or []) if output._passing_style is not None] arguments = [] - for input in component_spec.inputs: + for input in component_spec.inputs + file_outputs_passed_using_func_parameters: param_flag = "--" + input.name.replace("_", "-") - is_required = not input.optional + is_required = isinstance(input, OutputSpec) or not input.optional line = '_parser.add_argument("{param_flag}", dest="{param_var}", type={param_type}, required={is_required}, default=argparse.SUPPRESS)'.format( param_flag=param_flag, param_var=input.name, @@ -347,6 +406,8 @@ def get_serializer_and_register_definitions(type_name) -> str: if input._passing_style in [InputPath, InputTextFile, InputBinaryFile]: arguments_for_input = [param_flag, InputPathPlaceholder(input.name)] + elif input._passing_style in [OutputPath, OutputTextFile, OutputBinaryFile]: + arguments_for_input = [param_flag, OutputPathPlaceholder(input.name)] else: arguments_for_input = [param_flag, InputValuePlaceholder(input.name)] @@ -362,7 +423,7 @@ def get_serializer_and_register_definitions(type_name) -> str: ) ) - if component_spec.outputs: + if outputs_passed_through_func_return_tuple: param_flag="----output-paths" output_param_var="_output_paths" line = '_parser.add_argument("{param_flag}", dest="{param_var}", type=str, nargs={nargs})'.format( @@ -375,11 +436,9 @@ def get_serializer_and_register_definitions(type_name) -> str: arguments.extend(OutputPathPlaceholder(output.name) for output in component_spec.outputs) output_serialization_expression_strings = [] - if component_spec.outputs: - outputs_produced_by_func_return_value = component_spec.outputs - for output in outputs_produced_by_func_return_value: - serializer_call_str = get_serializer_and_register_definitions(output.type) - output_serialization_expression_strings.append(serializer_call_str) + for output in outputs_passed_through_func_return_tuple: + serializer_call_str = get_serializer_and_register_definitions(output.type) + output_serialization_expression_strings.append(serializer_call_str) pre_func_code = '\n'.join(list(pre_func_definitions)) diff --git a/sdk/python/tests/components/test_python_op.py b/sdk/python/tests/components/test_python_op.py index 7424efeeb31..93d9b321692 100644 --- a/sdk/python/tests/components/test_python_op.py +++ b/sdk/python/tests/components/test_python_op.py @@ -562,6 +562,52 @@ def consume_file_path(number_file: InputBinaryFile(int)) -> int: self.helper_test_component_using_local_call(task_factory, arguments={'number_file': "42"}, expected_output_values={'output': '42'}) + def test_output_path(self): + from kfp.components import OutputPath + def write_to_file_path(number_file_path: OutputPath(int)): + with open(number_file_path, 'w') as f: + f.write(str(42)) + + task_factory = comp.func_to_container_op(write_to_file_path) + + self.assertFalse(task_factory.component_spec.inputs) + self.assertEqual(len(task_factory.component_spec.outputs), 1) + self.assertEqual(task_factory.component_spec.outputs[0].type, 'Integer') + + # TODO: Fix the output names: "number_file_path" should be exposed as "number" output + self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'number_file_path': '42'}) + + + def test_output_text_file(self): + from kfp.components import OutputTextFile + def write_to_file_path(number_file: OutputTextFile(int)): + number_file.write(str(42)) + + task_factory = comp.func_to_container_op(write_to_file_path) + + self.assertFalse(task_factory.component_spec.inputs) + self.assertEqual(len(task_factory.component_spec.outputs), 1) + self.assertEqual(task_factory.component_spec.outputs[0].type, 'Integer') + + # TODO: Fix the output names: "number_file" should be exposed as "number" output + self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'number_file': '42'}) + + + def test_output_binary_file(self): + from kfp.components import OutputBinaryFile + def write_to_file_path(number_file: OutputBinaryFile(int)): + number_file.write(b'42') + + task_factory = comp.func_to_container_op(write_to_file_path) + + self.assertFalse(task_factory.component_spec.inputs) + self.assertEqual(len(task_factory.component_spec.outputs), 1) + self.assertEqual(task_factory.component_spec.outputs[0].type, 'Integer') + + # TODO: Fix the output names: "number_file" should be exposed as "number" output + self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'number_file': '42'}) + + def test_end_to_end_python_component_pipeline_compilation(self): import kfp.components as comp