Skip to content

Commit

Permalink
feat(sdk): Support dsl.ParallelFor over list of Artifacts (kubeflow#1…
Browse files Browse the repository at this point in the history
…0441)

* samples message

* update pr number

* Split LoopArgument into LoopParameterArgument and LoopArtifactArgument

* formatting

* address some comments

* resolve release notes conflict

* flatten loops in pipeline_spec_builder

* update artifact type checking logic

* simplify artifact checking logic

* re-add issubtype_of_artifact

* move name_is_loop_argument to for_loop_test.py

* simplify LoopArtifactArgument is_artifact_list=False logic

* update typeerror

* typo

* typo

* small fix

* formatting

* formatting

* remove issubtype_of_artifact()

* small changes

* assert LoopArtifactArgument channel.is_artifact_list is True

* remove whitespace

* remove newline

* Update single artifact check and error message

* formatting

* add unit test for is_artifact_list==False

* formatting

* update valueerror test.

* typo

* regex formatting

* formatting
  • Loading branch information
KevinGrantLee authored and petethegreat committed Mar 27, 2024
1 parent e672edb commit c6c6661
Show file tree
Hide file tree
Showing 11 changed files with 755 additions and 76 deletions.
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* Support local execution of sequential pipelines [\#10423](https://github.com/kubeflow/pipelines/pull/10423)
* Support local execution of `dsl.importer` components [\#10431](https://github.com/kubeflow/pipelines/pull/10431)
* Support local execution of pipelines in pipelines [\#10440](https://github.com/kubeflow/pipelines/pull/10440)
* Support dsl.ParallelFor over list of Artifacts [\#10441](https://github.com/kubeflow/pipelines/pull/10441)

## Breaking changes

Expand Down
12 changes: 7 additions & 5 deletions sdk/python/kfp/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def get_inputs_for_all_groups(
channel_to_add = channel

while isinstance(channel_to_add, (
for_loop.LoopArgument,
for_loop.LoopParameterArgument,
for_loop.LoopArtifactArgument,
for_loop.LoopArgumentVariable,
)):
channels_to_add.append(channel_to_add)
Expand Down Expand Up @@ -309,10 +310,11 @@ def get_inputs_for_all_groups(
# loop items, we have to go from bottom-up because the
# PipelineChannel can be originated from the middle a DAG,
# which is not needed and visible to its parent DAG.
if isinstance(
channel,
(for_loop.LoopArgument, for_loop.LoopArgumentVariable
)) and channel.is_with_items_loop_argument:
if isinstance(channel, (
for_loop.LoopParameterArgument,
for_loop.LoopArtifactArgument,
for_loop.LoopArgumentVariable,
)) and channel.is_with_items_loop_argument:
for group_name in task_name_to_parent_groups[
task.name][::-1]:

Expand Down
86 changes: 64 additions & 22 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.protobuf import json_format
from google.protobuf import struct_pb2
import kfp
from kfp import dsl
from kfp.compiler import compiler_utils
from kfp.dsl import component_factory
from kfp.dsl import for_loop
Expand Down Expand Up @@ -128,8 +129,10 @@ def build_task_spec_for_task(
task._task_spec.retry_policy.to_proto())

for input_name, input_value in task.inputs.items():
# since LoopArgument and LoopArgumentVariable are narrower types than PipelineParameterChannel, start with it
if isinstance(input_value, for_loop.LoopArgument):
# Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower
# types than PipelineParameterChannel, start with them.

if isinstance(input_value, for_loop.LoopParameterArgument):

component_input_parameter = (
compiler_utils.additional_input_name_for_pipeline_channel(
Expand All @@ -140,6 +143,17 @@ def build_task_spec_for_task(
input_name].component_input_parameter = (
component_input_parameter)

elif isinstance(input_value, for_loop.LoopArtifactArgument):

component_input_artifact = (
compiler_utils.additional_input_name_for_pipeline_channel(
input_value))
assert component_input_artifact in parent_component_inputs.artifacts, \
f'component_input_artifact: {component_input_artifact} not found. All inputs: {parent_component_inputs}'
pipeline_task_spec.inputs.artifacts[
input_name].component_input_artifact = (
component_input_artifact)

elif isinstance(input_value, for_loop.LoopArgumentVariable):

component_input_parameter = (
Expand All @@ -155,7 +169,7 @@ def build_task_spec_for_task(
f'parseJson(string_value)["{input_value.subvar_name}"]')
elif isinstance(input_value,
pipeline_channel.PipelineArtifactChannel) or (
isinstance(input_value, for_loop.Collected) and
isinstance(input_value, dsl.Collected) and
input_value.is_artifact_channel):

if input_value.task_name:
Expand Down Expand Up @@ -190,7 +204,7 @@ def build_task_spec_for_task(

elif isinstance(input_value,
pipeline_channel.PipelineParameterChannel) or (
isinstance(input_value, for_loop.Collected) and
isinstance(input_value, dsl.Collected) and
not input_value.is_artifact_channel):
if input_value.task_name:

Expand Down Expand Up @@ -683,19 +697,25 @@ def build_component_spec_for_group(
input_name = compiler_utils.additional_input_name_for_pipeline_channel(
channel)

if isinstance(channel, pipeline_channel.PipelineArtifactChannel):
if isinstance(channel, (pipeline_channel.PipelineArtifactChannel,
for_loop.LoopArtifactArgument)):
component_spec.input_definitions.artifacts[
input_name].artifact_type.CopyFrom(
type_utils.bundled_artifact_to_artifact_proto(
channel.channel_type))
component_spec.input_definitions.artifacts[
input_name].is_artifact_list = channel.is_artifact_list
else:
# channel is one of PipelineParameterChannel, LoopArgument, or
# LoopArgumentVariable.
elif isinstance(channel,
(pipeline_channel.PipelineParameterChannel,
for_loop.LoopParameterArgument,
for_loop.LoopArgumentVariable, dsl.Collected)):
component_spec.input_definitions.parameters[
input_name].parameter_type = type_utils.get_parameter_type(
channel.channel_type)
else:
raise TypeError(
f'Expected PipelineParameterChannel, PipelineArtifactChannel, LoopParameterArgument, LoopArtifactArgument, LoopArgumentVariable, or Collected, got {type(channel)}.'
)

for output_name, output in output_pipeline_channels.items():
if isinstance(output, pipeline_channel.PipelineArtifactChannel):
Expand Down Expand Up @@ -747,13 +767,34 @@ def _update_task_spec_for_loop_group(
loop_argument_item_name = compiler_utils.additional_input_name_for_pipeline_channel(
group.loop_argument.full_name)

loop_arguments_item = f'{input_parameter_name}-{for_loop.LoopArgument.LOOP_ITEM_NAME_BASE}'
loop_arguments_item = f'{input_parameter_name}-{for_loop.LOOP_ITEM_NAME_BASE}'
assert loop_arguments_item == loop_argument_item_name

pipeline_task_spec.parameter_iterator.items.input_parameter = (
input_parameter_name)
pipeline_task_spec.parameter_iterator.item_input = (
loop_argument_item_name)
if isinstance(group.loop_argument, for_loop.LoopParameterArgument):
pipeline_task_spec.parameter_iterator.items.input_parameter = (
input_parameter_name)
pipeline_task_spec.parameter_iterator.item_input = (
loop_argument_item_name)

_pop_input_from_task_spec(
task_spec=pipeline_task_spec,
input_name=pipeline_task_spec.parameter_iterator.item_input)

elif isinstance(group.loop_argument, for_loop.LoopArtifactArgument):
input_artifact_name = compiler_utils.additional_input_name_for_pipeline_channel(
loop_items_channel)

pipeline_task_spec.artifact_iterator.items.input_artifact = input_artifact_name
pipeline_task_spec.artifact_iterator.item_input = (
loop_argument_item_name)

_pop_input_from_task_spec(
task_spec=pipeline_task_spec,
input_name=pipeline_task_spec.artifact_iterator.item_input)
else:
raise TypeError(
f'Expected LoopParameterArgument or LoopArtifactArgument, got {type(group.loop_argument)}.'
)

# If the loop items itself is a loop arguments variable, handle the
# subvar name.
Expand All @@ -777,14 +818,14 @@ def _update_task_spec_for_loop_group(
pipeline_task_spec.parameter_iterator.item_input = (
input_parameter_name)

_pop_input_from_task_spec(
task_spec=pipeline_task_spec,
input_name=pipeline_task_spec.parameter_iterator.item_input)

if (group.parallelism_limit > 0):
pipeline_task_spec.iterator_policy.parallelism_limit = (
group.parallelism_limit)

_pop_input_from_task_spec(
task_spec=pipeline_task_spec,
input_name=pipeline_task_spec.parameter_iterator.item_input)


def _binary_operations_to_cel_conjunctive(
operations: List[pipeline_channel.ConditionOperation]) -> str:
Expand Down Expand Up @@ -1290,10 +1331,11 @@ def build_spec_by_group(

for channel in subgroup_input_channels:
# Skip 'withItems' loop arguments if it's from an inner loop.
if isinstance(
channel,
(for_loop.LoopArgument, for_loop.LoopArgumentVariable
)) and channel.is_with_items_loop_argument:
if isinstance(channel, (
for_loop.LoopParameterArgument,
for_loop.LoopArtifactArgument,
for_loop.LoopArgumentVariable,
)) and channel.is_with_items_loop_argument:
withitems_loop_arg_found_in_self_or_upstream = False
for group_name in group_name_to_parent_groups[
subgroup.name][::-1]:
Expand Down Expand Up @@ -1782,7 +1824,7 @@ def _rename_component_refs(
def validate_pipeline_outputs_dict(
pipeline_outputs_dict: Dict[str, pipeline_channel.PipelineChannel]):
for channel in pipeline_outputs_dict.values():
if isinstance(channel, for_loop.Collected):
if isinstance(channel, dsl.Collected):
# this validation doesn't apply to Collected
continue

Expand Down
Loading

0 comments on commit c6c6661

Please sign in to comment.