Skip to content

Commit

Permalink
SDK - Components refactoring (kubeflow#2865)
Browse files Browse the repository at this point in the history
* SDK - Components refactoring

This change is a pure refactoring of the implementation of component task creation.
For pipelines compiled using the DSL compiler (the compile() function or the command-line program) nothing should change.

The main goal of the refactoring is to change the way the component instantiation can be customized.
Previously, the flow was like this:

`ComponentSpec` + arguments --> `TaskSpec` --resolving+transform--> `ContainerOp`

This PR changes it to more direct path:

`ComponentSpec` + arguments --constructor--> `ContainerOp`
or
`ComponentSpec` + arguments --constructor--> `TaskSpec`
or
`ComponentSpec` + arguments --constructor--> `SomeCustomTask`

The original approach where the flow always passes through `TaskSpec` had some issues since TaskSpec only accepts string arguments (and two
other reference classes). This made it harder to handle custom types of arguments like PipelineParam or Channel.

Low-level refactoring changes:

Resolving of command-line argument placeholders has been extracted into a function usable by different task constructors.

Changed `_components._created_task_transformation_handler` to `_components._container_task_constructor`. Previously, the handler was receiving a `TaskSpec` instance. Now it receives `ComponentSpec` + arguments [+ `ComponentReference`].
Moved the `ContainerOp` construction handler setup to the `kfp.dsl.Pipeline` context class as planned.
Extracted `TaskSpec` creation to `_components._create_task_spec_from_component_and_arguments`.
Refactored `_dsl_bridge.create_container_op_from_task` to `_components._resolve_command_line_and_paths` which returns `_ResolvedCommandLineAndPaths`.
Renamed `_dsl_bridge._create_container_op_from_resolved_task` to `_dsl_bridge._create_container_op_from_component_and_arguments`.
The signature of `_components._resolve_graph_task` was changed and it now returns `_ResolvedGraphTask` instead of modified `TaskSpec`.

Some of the component tests still expect ContainerOp and its attributes.
These tests will be changed later.

* Adapted the _python_op tests

* Fixed linter failure

I do not want to add any top-level kfp imports in this file to prevent circular references.

* Added docstrings

* FIxed the return type forward reference
  • Loading branch information
Ark-kun authored and Jeffwan committed Dec 9, 2020
1 parent f81cac3 commit e34b4b0
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 213 deletions.
296 changes: 247 additions & 49 deletions sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
'load_component_from_file',
]

import copy
import sys
from collections import OrderedDict
from typing import Any, List, Mapping, NamedTuple, Sequence, Union
from ._naming import _sanitize_file_name, _sanitize_python_function_name, generate_unique_name_conversion_table
from ._yaml_utils import load_yaml
from ._structures import ComponentSpec
from ._structures import *
from ._data_passing import serialize_value, type_name_to_type
from kfp.dsl import PipelineParam
from kfp.dsl.types import verify_type_compatibility


_default_component_name = 'Component'

Expand Down Expand Up @@ -170,13 +171,71 @@ def _generate_output_file_name(port_name):
return _outputs_dir + '/' + _sanitize_file_name(port_name) + '/' + _single_io_file_name


#Holds the transformation functions that are called each time TaskSpec instance is created from a component. If there are multiple handlers, the last one is used.
_created_task_transformation_handler = []
def _react_to_incompatible_reference_type(
input_type,
argument_type,
input_name: str,
):
"""Raises error for the case when the argument type is incompatible with the input type."""
message = 'Argument with type "{}" was passed to the input "{}" that has type "{}".'.format(argument_type, input_name, input_type)
raise TypeError(message)


def _create_task_spec_from_component_and_arguments(
component_spec: ComponentSpec,
arguments: Mapping[str, Any],
component_ref: ComponentReference = None,
) -> TaskSpec:
"""Constructs a TaskSpec object from component reference and arguments.
The function also checks the arguments types and serializes them."""
if component_ref is None:
component_ref = ComponentReference(spec=component_spec)
else:
component_ref = copy.copy(component_ref)
component_ref.spec = component_spec

# Not checking for missing or extra arguments since the dynamic factory function checks that
task_arguments = {}
for input_name, argument_value in arguments.items():
input_type = component_spec._inputs_dict[input_name].type

if isinstance(argument_value, (GraphInputArgument, TaskOutputArgument)):
# argument_value is a reference
if isinstance(argument_value, GraphInputArgument):
reference_type = argument_value.graph_input.type
elif isinstance(argument_value, TaskOutputArgument):
reference_type = argument_value.task_output.type
else:
reference_type = None

if reference_type and input_type and reference_type != input_type:
_react_to_incompatible_reference_type(input_type, reference_type, input_name)

task_arguments[input_name] = argument_value
else:
# argument_value is a constant value
serialized_argument_value = serialize_value(argument_value, input_type)
task_arguments[input_name] = serialized_argument_value

task = TaskSpec(
component_ref=component_ref,
arguments=task_arguments,
)
task._init_outputs()

return task


_default_container_task_constructor = _create_task_spec_from_component_and_arguments

#TODO: Move to the dsl.Pipeline context class
from . import _dsl_bridge
_created_task_transformation_handler.append(_dsl_bridge.create_container_op_from_task)
# Holds the function that constructs a task object based on ComponentSpec, arguments and ComponentReference.
# Framework authors can override this constructor function to construct different framework-specific task-like objects.
# The task object should have the task.outputs dictionary with keys corresponding to the ComponentSpec outputs.
# The default constructor creates and instance of the TaskSpec class.
_container_task_constructor = _default_container_task_constructor


_always_expand_graph_components = False


class _DefaultValue:
Expand Down Expand Up @@ -210,44 +269,33 @@ def _create_task_factory_from_component_spec(component_spec:ComponentSpec, compo
component_ref.spec = component_spec

def create_task_from_component_and_arguments(pythonic_arguments):
arguments = {}
# Not checking for missing or extra arguments since the dynamic factory function checks that
for argument_name, argument_value in pythonic_arguments.items():
if isinstance(argument_value, _DefaultValue): # Skipping passing arguments for optional values that have not been overridden.
continue
input_name = pythonic_name_to_input_name[argument_name]
input_type = component_spec._inputs_dict[input_name].type

if isinstance(argument_value, (GraphInputArgument, TaskOutputArgument, PipelineParam)):
# argument_value is a reference

if isinstance(argument_value, PipelineParam):
reference_type = argument_value.param_type
argument_value = str(argument_value)
elif isinstance(argument_value, TaskOutputArgument):
reference_type = argument_value.task_output.type
else:
reference_type = None

verify_type_compatibility(reference_type, input_type, 'Incompatible argument passed to the input "{}" of component "{}": '.format(input_name, component_spec.name))

arguments[input_name] = argument_value
else:
# argument_value is a constant value
serialized_argument_value = serialize_value(argument_value, input_type)
arguments[input_name] = serialized_argument_value

task = TaskSpec(
component_ref=component_ref,
arguments = {
pythonic_name_to_input_name[argument_name]: argument_value
for argument_name, argument_value in pythonic_arguments.items()
if not isinstance(argument_value, _DefaultValue) # Skipping passing arguments for optional values that have not been overridden.
}

if (
isinstance(component_spec.implementation, GraphImplementation)
and (
# When the container task constructor is not overriden, we just construct TaskSpec for both container and graph tasks.
# If the container task constructor is overriden, we should expand the graph components so that the override is called for all sub-tasks.
_container_task_constructor != _default_container_task_constructor
or _always_expand_graph_components
)
):
return _resolve_graph_task(
component_spec=component_spec,
arguments=arguments,
component_ref=component_ref,
)

task = _container_task_constructor(
component_spec=component_spec,
arguments=arguments,
component_ref=component_ref,
)
task._init_outputs()

if isinstance(component_spec.implementation, GraphImplementation):
return _resolve_graph_task(task, component_spec)

if _created_task_transformation_handler:
task = _created_task_transformation_handler[-1](task)
return task

import inspect
Expand Down Expand Up @@ -284,14 +332,161 @@ def component_default_to_func_default(component_default: str, is_optional: bool)
return task_factory


def _resolve_graph_task(graph_task: TaskSpec, graph_component_spec: ComponentSpec) -> TaskSpec:
_ResolvedCommandLineAndPaths = NamedTuple(
'_ResolvedCommandLineAndPaths',
[
('command', Sequence[str]),
('args', Sequence[str]),
('input_paths', Mapping[str, str]),
('output_paths', Mapping[str, str]),
('inputs_consumed_by_value', Mapping[str, str]),
],
)


def _resolve_command_line_and_paths(
component_spec: ComponentSpec,
arguments: Mapping[str, str],
input_path_generator=_generate_input_file_name,
output_path_generator=_generate_output_file_name,
argument_serializer=serialize_value,
) -> _ResolvedCommandLineAndPaths:
"""Resolves the command line argument placeholders. Also produces the maps of the generated inpuit/output paths."""
argument_values = arguments

if not isinstance(component_spec.implementation, ContainerImplementation):
raise TypeError('Only container components have command line to resolve')

inputs_dict = {input_spec.name: input_spec for input_spec in component_spec.inputs or []}
container_spec = component_spec.implementation.container

output_paths = OrderedDict() #Preserving the order to make the kubernetes output names deterministic
unconfigurable_output_paths = container_spec.file_outputs or {}
for output in component_spec.outputs or []:
if output.name in unconfigurable_output_paths:
output_paths[output.name] = unconfigurable_output_paths[output.name]

input_paths = OrderedDict()
inputs_consumed_by_value = {}

def expand_command_part(arg) -> Union[str, List[str], None]:
if arg is None:
return None
if isinstance(arg, (str, int, float, bool)):
return str(arg)

if isinstance(arg, InputValuePlaceholder):
input_name = arg.input_name
input_spec = inputs_dict[input_name]
input_value = argument_values.get(input_name, None)
if input_value is not None:
serialized_argument = argument_serializer(input_value, input_spec.type)
inputs_consumed_by_value[input_name] = serialized_argument
return serialized_argument
else:
if input_spec.optional:
return None
else:
raise ValueError('No value provided for input {}'.format(input_name))

if isinstance(arg, InputPathPlaceholder):
input_name = arg.input_name
input_value = argument_values.get(input_name, None)
if input_value is not None:
input_path = input_path_generator(input_name)
input_paths[input_name] = input_path
return input_path
else:
input_spec = inputs_dict[input_name]
if input_spec.optional:
#Even when we support default values there is no need to check for a default here.
#In current execution flow (called by python task factory), the missing argument would be replaced with the default value by python itself.
return None
else:
raise ValueError('No value provided for input {}'.format(input_name))

elif isinstance(arg, OutputPathPlaceholder):
output_name = arg.output_name
output_filename = output_path_generator(output_name)
if arg.output_name in output_paths:
if output_paths[output_name] != output_filename:
raise ValueError('Conflicting output files specified for port {}: {} and {}'.format(output_name, output_paths[output_name], output_filename))
else:
output_paths[output_name] = output_filename

return output_filename

elif isinstance(arg, ConcatPlaceholder):
expanded_argument_strings = expand_argument_list(arg.items)
return ''.join(expanded_argument_strings)

elif isinstance(arg, IfPlaceholder):
arg = arg.if_structure
condition_result = expand_command_part(arg.condition)
from distutils.util import strtobool
condition_result_bool = condition_result and strtobool(condition_result) #Python gotcha: bool('False') == True; Need to use strtobool; Also need to handle None and []
result_node = arg.then_value if condition_result_bool else arg.else_value
if result_node is None:
return []
if isinstance(result_node, list):
expanded_result = expand_argument_list(result_node)
else:
expanded_result = expand_command_part(result_node)
return expanded_result

elif isinstance(arg, IsPresentPlaceholder):
argument_is_present = argument_values.get(arg.input_name, None) is not None
return str(argument_is_present)
else:
raise TypeError('Unrecognized argument type: {}'.format(arg))

def expand_argument_list(argument_list):
expanded_list = []
if argument_list is not None:
for part in argument_list:
expanded_part = expand_command_part(part)
if expanded_part is not None:
if isinstance(expanded_part, list):
expanded_list.extend(expanded_part)
else:
expanded_list.append(str(expanded_part))
return expanded_list

expanded_command = expand_argument_list(container_spec.command)
expanded_args = expand_argument_list(container_spec.args)

return _ResolvedCommandLineAndPaths(
command=expanded_command,
args=expanded_args,
input_paths=input_paths,
output_paths=output_paths,
inputs_consumed_by_value=inputs_consumed_by_value,
)


_ResolvedGraphTask = NamedTuple(
'_ResolvedGraphTask',
[
('component_spec', ComponentSpec),
('component_ref', ComponentReference),
('outputs', Mapping[str, Any]),
('task_arguments', Mapping[str, Any]),
],
)


def _resolve_graph_task(
component_spec: ComponentSpec,
arguments: Mapping[str, Any],
component_ref: ComponentReference = None,
) -> TaskSpec:
from ..components import ComponentStore
component_store = ComponentStore.default_store

graph = graph_component_spec.implementation.graph
graph = component_spec.implementation.graph

graph_input_arguments = {input.name: input.default for input in graph_component_spec.inputs if input.default is not None}
graph_input_arguments.update(graph_task.arguments)
graph_input_arguments = {input.name: input.default for input in component_spec.inputs if input.default is not None}
graph_input_arguments.update(arguments)

outputs_of_tasks = {}
def resolve_argument(argument):
Expand Down Expand Up @@ -326,7 +521,10 @@ def resolve_argument(argument):
resolved_graph_outputs = OrderedDict([(output_name, resolve_argument(argument)) for output_name, argument in graph.output_values.items()])

# For resolved graph component tasks task.outputs point to the actual tasks that originally produced the output that is later returned from the graph
graph_task.output_references = graph_task.outputs
graph_task.outputs = resolved_graph_outputs

graph_task = _ResolvedGraphTask(
component_ref=component_ref,
component_spec=component_spec,
outputs = resolved_graph_outputs,
task_arguments=arguments,
)
return graph_task
Loading

0 comments on commit e34b4b0

Please sign in to comment.