Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WithParams #2044

Merged
merged 20 commits into from
Sep 17, 2019
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 91 additions & 67 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
from collections import defaultdict
import inspect
import tarfile
import zipfile
from typing import Any, Callable, Set, List, Text, Dict
from typing import Callable, Set, List, Text, Dict, Tuple, Any, Union, Optional

import yaml
from kfp.dsl import _container_op, _for_loop
from kfp.dsl import _for_loop

from .. import dsl
from ._k8s_helper import K8sHelper
Expand Down Expand Up @@ -339,17 +338,16 @@ def _get_dependencies(self, pipeline, root_group, op_groups, opsgroups_groups, o
upstream_op_names.add(param.op_name)
upstream_op_names |= set(op.dependent_names)

for op_name in upstream_op_names:
for upstream_op_name in upstream_op_names:
# the dependent op could be either a BaseOp or an opsgroup
if op_name in pipeline.ops:
upstream_op = pipeline.ops[op_name]
elif op_name in opsgroups:
upstream_op = opsgroups[op_name]
if upstream_op_name in pipeline.ops:
upstream_op = pipeline.ops[upstream_op_name]
elif upstream_op_name in opsgroups:
upstream_op = opsgroups[upstream_op_name]
else:
raise ValueError('compiler cannot find the ' + op_name)
raise ValueError('compiler cannot find the ' + upstream_op_name)

upstream_groups, downstream_groups = \
self._get_uncommon_ancestors(op_groups, opsgroups_groups, upstream_op, op)
upstream_groups, downstream_groups = self._get_uncommon_ancestors(op_groups, opsgroups_groups, upstream_op, op)
dependencies[downstream_groups[0]].add(upstream_groups[0])

# Generate dependencies based on the recursive opsgroups
Expand Down Expand Up @@ -463,66 +461,88 @@ def _group_to_dag_template(self, group, inputs, outputs, dependencies):

# Generate arguments section for this task.
if inputs.get(sub_group.name, None):
arguments = []
for param_name, dependent_name in inputs[sub_group.name]:
if dependent_name:
# The value comes from an upstream sibling.
# Special handling for recursive subgroup: argument name comes from the existing opsgroup
if is_recursive_subgroup:
for index, input in enumerate(sub_group.inputs):
if param_name == self._pipelineparam_full_name(input):
break
referenced_input = sub_group.recursive_ref.inputs[index]
full_name = self._pipelineparam_full_name(referenced_input)
arguments.append({
'name': full_name,
'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
})
else:
arguments.append({
'name': param_name,
'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
})
task['arguments'] = {'parameters': self.get_arguments_for_sub_group(sub_group, is_recursive_subgroup, inputs)}

# additional task modifications for withItems and withParam
if isinstance(sub_group, dsl.ParallelFor):
if sub_group.items_is_pipeline_param:
# these loop args are a 'withParam' rather than 'withItems'.
# i.e., rather than a static list, they are either the output of another task or were input
# as global pipeline parameters

pipeline_param = sub_group.loop_args
if pipeline_param.op_name is None:
withparam_value = '{{workflow.parameters.%s}}' % pipeline_param.name
else:
# The value comes from its parent.
# Special handling for recursive subgroup: argument name comes from the existing opsgroup
if is_recursive_subgroup:
for index, input in enumerate(sub_group.inputs):
if param_name == self._pipelineparam_full_name(input):
break
referenced_input = sub_group.recursive_ref.inputs[index]
full_name = self._pipelineparam_full_name(referenced_input)
arguments.append({
'name': full_name,
'value': '{{inputs.parameters.%s}}' % param_name
})
else:
if isinstance(sub_group, dsl.ParallelFor):
if sub_group.loop_args.name in param_name:
if _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable(param_name):
subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name(param_name)
value = '{{item.%s}}' % subvar_name
elif _for_loop.LoopArguments.name_is_loop_arguments(param_name):
value = '{{item}}'
else:
raise ValueError("Failed to match loop args with param. param_name: {}, ".format(param_name) +
"sub_group.loop_args.name: {}.".format(sub_group.loop_args.name))
else:
value = '{{inputs.parameters.%s}}' % param_name
task['withItems'] = sub_group.loop_args.to_list_for_task_yaml()
else:
value = '{{inputs.parameters.%s}}' % param_name
arguments.append({
'name': param_name,
'value': value,
})
arguments.sort(key=lambda x: x['name'])
task['arguments'] = {'parameters': arguments}
param_name = '%s-%s' % (pipeline_param.op_name, pipeline_param.name)
withparam_value = '{{tasks.%s.outputs.parameters.%s}}' % (pipeline_param.op_name, param_name)

# these loop args are the output of another task
if 'dependencies' not in task or task['dependencies'] is None:
task['dependencies'] = []
if pipeline_param.op_name not in task['dependencies']:
task['dependencies'].append(pipeline_param.op_name)

task['withParam'] = withparam_value
else:
task['withItems'] = sub_group.loop_args.to_list_for_task_yaml()

if isinstance(sub_group, dsl.ContainerOp) and sub_group.artifact_arguments:
artifact_argument_structs = []
for input_name, argument in sub_group.artifact_arguments.items():
artifact_argument_dict = {'name': input_name}
if isinstance(argument, str):
artifact_argument_dict['raw'] = {'data': str(argument)}
else:
raise TypeError('Argument "{}" was passed to the artifact input "{}", but only constant strings are supported at this moment.'.format(str(argument), input_name))
artifact_argument_structs.append(artifact_argument_dict)
task.setdefault('arguments', {})['artifacts'] = artifact_argument_structs

tasks.append(task)
tasks.sort(key=lambda x: x['name'])
template['dag'] = {'tasks': tasks}
return template

def get_arguments_for_sub_group(
self,
sub_group: Union[OpsGroup, dsl._container_op.BaseOp],
is_recursive_subgroup: Optional[bool],
inputs: Dict[Text, Tuple[Text, Text]],
):
arguments = []
for param_name, dependent_name in inputs[sub_group.name]:
if is_recursive_subgroup:
for index, input in enumerate(sub_group.inputs):
if param_name == self._pipelineparam_full_name(input):
break
referenced_input = sub_group.recursive_ref.inputs[index]
argument_name = self._pipelineparam_full_name(referenced_input)
else:
argument_name = param_name

# default argument_value + special cases
argument_value = '{{inputs.parameters.%s}}' % param_name
if isinstance(sub_group, dsl.ParallelFor):
if sub_group.loop_args.name in param_name:
if _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable(param_name):
subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name(param_name)
argument_value = '{{item.%s}}' % subvar_name
elif _for_loop.LoopArguments.name_is_loop_arguments(param_name) or sub_group.items_is_pipeline_param:
argument_value = '{{item}}'
else:
raise ValueError("Failed to match loop args with parameter. param_name: {}, ".format(param_name))
elif dependent_name:
argument_value = '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)

arguments.append({
'name': argument_name,
'value': argument_value,
})

arguments.sort(key=lambda x: x['name'])

return arguments

def _create_dag_templates(self, pipeline, op_transformers=None, op_to_templates_handler=None):
"""Create all groups and ops templates in the pipeline.

Expand Down Expand Up @@ -580,6 +600,7 @@ def _create_dag_templates(self, pipeline, op_transformers=None, op_to_templates_

for op in pipeline.ops.values():
templates.extend(op_to_templates_handler(op))

return templates

def _create_volumes(self, pipeline):
Expand All @@ -605,7 +626,10 @@ def _create_pipeline_workflow(self, args, pipeline, op_transformers=None):
for arg in args:
param = {'name': arg.name}
if arg.value is not None:
param['value'] = str(arg.value)
if isinstance(arg.value, (list, tuple)):
param['value'] = json.dumps(arg.value)
else:
param['value'] = str(arg.value)
input_params.append(param)

# Templates
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/_container_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ def __init__(
container_kwargs: Dict = None,
artifact_argument_paths: List[InputArgumentPath] = None,
file_outputs: Dict[str, str] = None,
output_artifact_paths : Dict[str, str]=None,
output_artifact_paths: Dict[str, str]=None,
artifact_location: V1alpha1ArtifactLocation=None,
is_exit_handler=False,
pvolumes: Dict[str, V1Volume] = None,
Expand Down
51 changes: 37 additions & 14 deletions sdk/python/kfp/dsl/_for_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import List, Union, Dict, Text, Any, Tuple
from typing import List, Union, Dict, Text, Any, Tuple, Optional

from kfp import dsl

Expand All @@ -19,7 +19,7 @@ class LoopArguments(dsl.PipelineParam):
def _subvar_name_is_legal(cls, proposed_variable_name: Text):
return re.match(cls.LEGAL_SUBVAR_NAME_REGEX, proposed_variable_name) is not None

def __init__(self, items: ItemList, code: Text):
def __init__(self, items: Union[ItemList, dsl.PipelineParam], code: Text, name_override: Optional[Text]=None, op_name: Optional[Text]=None, *args, **kwargs):
"""_LoopArguments represent the set of items to loop over in a ParallelFor loop. This class shoudn't be
instantiated by the user but rather is created by _ops_group.ParallelFor.

Expand All @@ -29,12 +29,18 @@ def __init__(self, items: ItemList, code: Text):
code: A unique code used to identify these loop arguments. Should match the code for the ParallelFor
ops_group which created these _LoopArguments. This prevents parameter name collissions.
"""
super().__init__(name=self._make_name(code))
if name_override is None:
super().__init__(name=self._make_name(code), *args, **kwargs)
else:
super().__init__(name=name_override, op_name=op_name, *args, **kwargs)

if not isinstance(items, (list, tuple, dsl.PipelineParam)):
raise TypeError("Expected list, tuple, or PipelineParam, got {}.".format(type(items)))

if not isinstance(items, (list, tuple)):
raise TypeError("Expected list or tuple, got {}.".format(type(items)))
if isinstance(items, tuple):
items = list(items)

if isinstance(items[0], dict):
if isinstance(items, list) and isinstance(items[0], dict):
subvar_names = set(items[0].keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does Argo resolve {{item.a}} when the key a is missing?

for item in items:
if not set(item.keys()) == subvar_names:
Expand All @@ -48,10 +54,31 @@ def __init__(self, items: ItemList, code: Text):
"name.".format(subvar_name))
setattr(self, subvar_name, LoopArgumentVariable(self.name, subvar_name))

self.items = items
self.items_or_pipeline_param = items
self.referenced_subvar_names = []

@classmethod
def from_pipeline_param(cls, param: dsl.PipelineParam) -> 'LoopArguments':
return LoopArguments(
items=param,
code=None,
name_override=param.name,
op_name=param.op_name,
value=param.value,
)

def __getattr__(self, item):
# this is being overridden so that we can access subvariables of the LoopArguments (i.e.: item.a) without
# knowing the subvariable names ahead of time
self.referenced_subvar_names.append(item)
return LoopArgumentVariable(self.name, item)

def to_list_for_task_yaml(self):
return self.items
if isinstance(self.items_or_pipeline_param, (list, tuple)):
return self.items_or_pipeline_param
else:
raise ValueError("You should only call this method on loop args which have list items, "
"not pipeline param items.")
kevinbache marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _make_name(cls, code: Text):
Expand All @@ -63,7 +90,7 @@ def name_is_loop_arguments(cls, param_name: Text) -> bool:
"""Return True if the given parameter name looks like it came from a loop arguments parameter."""
return re.match(
'%s-[0-9a-f]{%s}' % (cls.LOOP_ITEM_PARAM_NAME_BASE, cls.NUM_CODE_CHARS),
param_name
param_name,
) is not None


Expand Down Expand Up @@ -102,11 +129,7 @@ def get_name(cls, loop_args_name: Text, this_variable_name: Text) -> Text:
def name_is_loop_arguments_variable(cls, param_name: Text) -> bool:
"""Return True if the given parameter name looks like it came from a LoopArgumentsVariable."""
return re.match(
'%s-[0-9a-f]{%s}%s.*' % (
LoopArguments.LOOP_ITEM_PARAM_NAME_BASE,
LoopArguments.NUM_CODE_CHARS,
cls.SUBVAR_NAME_DELIMITER
),
'.+%s.+' % cls.SUBVAR_NAME_DELIMITER,
param_name
) is not None

Expand Down
16 changes: 11 additions & 5 deletions sdk/python/kfp/dsl/_ops_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from typing import Union
import uuid

from kfp.dsl import _for_loop
from kfp.dsl import _for_loop, _pipeline_param

from . import _container_op
from . import _pipeline
from ._pipeline_param import ConditionOperator


class OpsGroup(object):
"""Represents a logical group of ops and group of OpsGroups.
Expand Down Expand Up @@ -93,6 +93,7 @@ def after(self, dependency):
self.dependencies.append(dependency)
return self


class ExitHandler(OpsGroup):
"""Represents an exit handler that is invoked upon exiting a group of ops.

Expand Down Expand Up @@ -168,13 +169,18 @@ class ParallelFor(OpsGroup):
def _get_unique_id_code():
return uuid.uuid4().hex[:_for_loop.LoopArguments.NUM_CODE_CHARS]

def __init__(self, loop_args: _for_loop.ItemList):
# random code to id this loop
def __init__(self, loop_args: Union[_for_loop.ItemList, _pipeline_param.PipelineParam]):
self.items_is_pipeline_param = isinstance(loop_args, _pipeline_param.PipelineParam)

# use a random code to uniquely identify this loop
code = self._get_unique_id_code()
group_name = 'for-loop-{}'.format(code)
super().__init__(self.TYPE_NAME, name=group_name)

if not isinstance(loop_args, _for_loop.LoopArguments):
if self.items_is_pipeline_param:
loop_args = _for_loop.LoopArguments.from_pipeline_param(loop_args)
elif not self.items_is_pipeline_param and not isinstance(loop_args, _for_loop.LoopArguments):
# we were passed a raw list, wrap it in loop args
loop_args = _for_loop.LoopArguments(loop_args, code)

self.loop_args = loop_args
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/_pipeline_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __str__(self):
return '{{pipelineparam:op=%s;name=%s}}' % (op_name, self.name)

def __repr__(self):
return str({self.__class__.__name__: self.__dict__})
return str({self.__class__.__name__: self.__dict__})

def __eq__(self, other):
return ConditionOperator('==', self, other)
Expand Down
14 changes: 13 additions & 1 deletion sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,6 @@ def init_container_pipeline():
init_container = init_containers[0]
self.assertEqual(init_container, {'image':'alpine:latest', 'command': ['echo', 'bye'], 'name': 'echo'})


def test_delete_resource_op(self):
"""Test a pipeline with a delete resource operation."""
from kubernetes import client as k8s
Expand Down Expand Up @@ -714,6 +713,19 @@ def some_pipeline():
self.assertIsNone(delete_op_template.get("failureCondition"))
self.assertDictEqual(delete_op_template.get("outputs", {}), {})

def test_withparam_global(self):
self._test_py_compile_yaml('withparam_global')

def test_withparam_global_dict(self):
self._test_py_compile_yaml('withparam_global_dict')

def test_withparam_output(self):
self._test_py_compile_yaml('withparam_output')

def test_withparam_output_dict(self):
self._test_py_compile_yaml('withparam_output_dict')

def test_py_input_artifact_raw_value(self):
"""Test pipeline input_artifact_raw_value."""
self._test_py_compile_yaml('input_artifact_raw_value')

Loading