Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): Support dsl.ParallelFor over list of Artifacts #10441

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7145184
samples message
KevinGrantLee Jan 29, 2024
8d4e7ed
update pr number
KevinGrantLee Jan 29, 2024
241cf75
Split LoopArgument into LoopParameterArgument and LoopArtifactArgument
KevinGrantLee Jan 30, 2024
c738cd4
formatting
KevinGrantLee Jan 30, 2024
3fe31e8
address some comments
KevinGrantLee Jan 30, 2024
3050e4c
resolve release notes conflict
KevinGrantLee Jan 30, 2024
fa9027b
Merge branch 'master' into parallelfor-list-artifacts
KevinGrantLee Jan 30, 2024
c8b65de
Merge branch 'master' of https://github.com/KevinGrantLee/pipelines i…
KevinGrantLee Jan 31, 2024
3462a15
Merge branch 'parallelfor-list-artifacts' of https://github.com/Kevin…
KevinGrantLee Jan 31, 2024
43ef672
flatten loops in pipeline_spec_builder
KevinGrantLee Jan 31, 2024
e5d5c9b
update artifact type checking logic
KevinGrantLee Jan 31, 2024
974be3a
simplify artifact checking logic
KevinGrantLee Jan 31, 2024
f649fe8
re-add issubtype_of_artifact
KevinGrantLee Jan 31, 2024
4e2f033
move name_is_loop_argument to for_loop_test.py
KevinGrantLee Jan 31, 2024
29bc8b1
simplify LoopArtifactArgument is_artifact_list=False logic
KevinGrantLee Jan 31, 2024
9a31af8
update typeerror
KevinGrantLee Jan 31, 2024
231aecb
typo
KevinGrantLee Jan 31, 2024
3090549
typo
KevinGrantLee Jan 31, 2024
753115a
small fix
KevinGrantLee Jan 31, 2024
b4a01dd
formatting
KevinGrantLee Jan 31, 2024
49edbd7
formatting
KevinGrantLee Jan 31, 2024
b3de65b
remove issubtype_of_artifact()
KevinGrantLee Feb 1, 2024
a75bff3
small changes
KevinGrantLee Feb 2, 2024
63cd40a
assert LoopArtifactArgument channel.is_artifact_list is True
KevinGrantLee Feb 2, 2024
0500050
Merge branch 'master' of https://github.com/KevinGrantLee/pipelines i…
KevinGrantLee Feb 3, 2024
163ecab
remove whitespace
KevinGrantLee Feb 3, 2024
7709b66
remove newline
KevinGrantLee Feb 3, 2024
4337f87
Update single artifact check and error message
KevinGrantLee Feb 5, 2024
106f8b6
formatting
KevinGrantLee Feb 5, 2024
d29d8cb
add unit test for is_artifact_list==False
KevinGrantLee Feb 5, 2024
d1141fa
formatting
KevinGrantLee Feb 5, 2024
19e693c
update valueerror test.
KevinGrantLee Feb 6, 2024
564f666
typo
KevinGrantLee Feb 6, 2024
11ea34f
regex formatting
KevinGrantLee Feb 6, 2024
d367372
formatting
KevinGrantLee Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* Support local execution of sequential pipelines [\#10423](https://github.com/kubeflow/pipelines/pull/10423)
* Support local execution of `dsl.importer` components [\#10431](https://github.com/kubeflow/pipelines/pull/10431)
* Support local execution of pipelines in pipelines [\#10440](https://github.com/kubeflow/pipelines/pull/10440)
* Support dsl.ParallelFor over list of Artifacts [\#10441](https://github.com/kubeflow/pipelines/pull/10441)

## Breaking changes

Expand Down
12 changes: 7 additions & 5 deletions sdk/python/kfp/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def get_inputs_for_all_groups(
channel_to_add = channel

while isinstance(channel_to_add, (
for_loop.LoopArgument,
for_loop.LoopParameterArgument,
for_loop.LoopArtifactArgument,
for_loop.LoopArgumentVariable,
)):
channels_to_add.append(channel_to_add)
Expand Down Expand Up @@ -309,10 +310,11 @@ def get_inputs_for_all_groups(
# loop items, we have to go from bottom-up because the
# PipelineChannel can be originated from the middle a DAG,
# which is not needed and visible to its parent DAG.
if isinstance(
channel,
(for_loop.LoopArgument, for_loop.LoopArgumentVariable
)) and channel.is_with_items_loop_argument:
if isinstance(channel, (
for_loop.LoopParameterArgument,
for_loop.LoopArtifactArgument,
for_loop.LoopArgumentVariable,
)) and channel.is_with_items_loop_argument:
for group_name in task_name_to_parent_groups[
task.name][::-1]:

Expand Down
86 changes: 64 additions & 22 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.protobuf import json_format
from google.protobuf import struct_pb2
import kfp
from kfp import dsl
from kfp.compiler import compiler_utils
from kfp.dsl import component_factory
from kfp.dsl import for_loop
Expand Down Expand Up @@ -128,8 +129,10 @@ def build_task_spec_for_task(
task._task_spec.retry_policy.to_proto())

for input_name, input_value in task.inputs.items():
# since LoopArgument and LoopArgumentVariable are narrower types than PipelineParameterChannel, start with it
if isinstance(input_value, for_loop.LoopArgument):
# Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower
# types than PipelineParameterChannel, start with them.

if isinstance(input_value, for_loop.LoopParameterArgument):

component_input_parameter = (
compiler_utils.additional_input_name_for_pipeline_channel(
Expand All @@ -140,6 +143,17 @@ def build_task_spec_for_task(
input_name].component_input_parameter = (
component_input_parameter)

elif isinstance(input_value, for_loop.LoopArtifactArgument):

component_input_artifact = (
compiler_utils.additional_input_name_for_pipeline_channel(
input_value))
assert component_input_artifact in parent_component_inputs.artifacts, \
f'component_input_artifact: {component_input_artifact} not found. All inputs: {parent_component_inputs}'
pipeline_task_spec.inputs.artifacts[
input_name].component_input_artifact = (
component_input_artifact)

elif isinstance(input_value, for_loop.LoopArgumentVariable):

component_input_parameter = (
Expand All @@ -155,7 +169,7 @@ def build_task_spec_for_task(
f'parseJson(string_value)["{input_value.subvar_name}"]')
elif isinstance(input_value,
pipeline_channel.PipelineArtifactChannel) or (
isinstance(input_value, for_loop.Collected) and
isinstance(input_value, dsl.Collected) and
input_value.is_artifact_channel):

if input_value.task_name:
Expand Down Expand Up @@ -190,7 +204,7 @@ def build_task_spec_for_task(

elif isinstance(input_value,
pipeline_channel.PipelineParameterChannel) or (
isinstance(input_value, for_loop.Collected) and
isinstance(input_value, dsl.Collected) and
not input_value.is_artifact_channel):
if input_value.task_name:

Expand Down Expand Up @@ -683,19 +697,25 @@ def build_component_spec_for_group(
input_name = compiler_utils.additional_input_name_for_pipeline_channel(
channel)

if isinstance(channel, pipeline_channel.PipelineArtifactChannel):
if isinstance(channel, (pipeline_channel.PipelineArtifactChannel,
for_loop.LoopArtifactArgument)):
component_spec.input_definitions.artifacts[
input_name].artifact_type.CopyFrom(
type_utils.bundled_artifact_to_artifact_proto(
channel.channel_type))
component_spec.input_definitions.artifacts[
input_name].is_artifact_list = channel.is_artifact_list
else:
# channel is one of PipelineParameterChannel, LoopArgument, or
# LoopArgumentVariable.
elif isinstance(channel,
(pipeline_channel.PipelineParameterChannel,
for_loop.LoopParameterArgument,
for_loop.LoopArgumentVariable, dsl.Collected)):
component_spec.input_definitions.parameters[
input_name].parameter_type = type_utils.get_parameter_type(
channel.channel_type)
else:
raise TypeError(
f'Expected PipelineParameterChannel, PipelineArtifactChannel, LoopParameterArgument, LoopArtifactArgument, LoopArgumentVariable, or Collected, got {type(channel)}.'
)

for output_name, output in output_pipeline_channels.items():
if isinstance(output, pipeline_channel.PipelineArtifactChannel):
Expand Down Expand Up @@ -747,13 +767,34 @@ def _update_task_spec_for_loop_group(
loop_argument_item_name = compiler_utils.additional_input_name_for_pipeline_channel(
group.loop_argument.full_name)

loop_arguments_item = f'{input_parameter_name}-{for_loop.LoopArgument.LOOP_ITEM_NAME_BASE}'
loop_arguments_item = f'{input_parameter_name}-{for_loop.LOOP_ITEM_NAME_BASE}'
KevinGrantLee marked this conversation as resolved.
Show resolved Hide resolved
assert loop_arguments_item == loop_argument_item_name

pipeline_task_spec.parameter_iterator.items.input_parameter = (
input_parameter_name)
pipeline_task_spec.parameter_iterator.item_input = (
loop_argument_item_name)
if isinstance(group.loop_argument, for_loop.LoopParameterArgument):
pipeline_task_spec.parameter_iterator.items.input_parameter = (
input_parameter_name)
pipeline_task_spec.parameter_iterator.item_input = (
loop_argument_item_name)

_pop_input_from_task_spec(
task_spec=pipeline_task_spec,
input_name=pipeline_task_spec.parameter_iterator.item_input)

elif isinstance(group.loop_argument, for_loop.LoopArtifactArgument):
input_artifact_name = compiler_utils.additional_input_name_for_pipeline_channel(
loop_items_channel)

pipeline_task_spec.artifact_iterator.items.input_artifact = input_artifact_name
pipeline_task_spec.artifact_iterator.item_input = (
loop_argument_item_name)

_pop_input_from_task_spec(
task_spec=pipeline_task_spec,
input_name=pipeline_task_spec.artifact_iterator.item_input)
else:
KevinGrantLee marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
f'Expected LoopParameterArgument or LoopArtifactArgument, got {type(group.loop_argument)}.'
)

# If the loop items itself is a loop arguments variable, handle the
# subvar name.
Expand All @@ -777,14 +818,14 @@ def _update_task_spec_for_loop_group(
pipeline_task_spec.parameter_iterator.item_input = (
input_parameter_name)

_pop_input_from_task_spec(
task_spec=pipeline_task_spec,
input_name=pipeline_task_spec.parameter_iterator.item_input)

if (group.parallelism_limit > 0):
pipeline_task_spec.iterator_policy.parallelism_limit = (
group.parallelism_limit)

_pop_input_from_task_spec(
task_spec=pipeline_task_spec,
input_name=pipeline_task_spec.parameter_iterator.item_input)


def _binary_operations_to_cel_conjunctive(
operations: List[pipeline_channel.ConditionOperation]) -> str:
Expand Down Expand Up @@ -1290,10 +1331,11 @@ def build_spec_by_group(

for channel in subgroup_input_channels:
# Skip 'withItems' loop arguments if it's from an inner loop.
if isinstance(
channel,
(for_loop.LoopArgument, for_loop.LoopArgumentVariable
)) and channel.is_with_items_loop_argument:
if isinstance(channel, (
for_loop.LoopParameterArgument,
for_loop.LoopArtifactArgument,
for_loop.LoopArgumentVariable,
)) and channel.is_with_items_loop_argument:
withitems_loop_arg_found_in_self_or_upstream = False
for group_name in group_name_to_parent_groups[
subgroup.name][::-1]:
Expand Down Expand Up @@ -1782,7 +1824,7 @@ def _rename_component_refs(
def validate_pipeline_outputs_dict(
pipeline_outputs_dict: Dict[str, pipeline_channel.PipelineChannel]):
for channel in pipeline_outputs_dict.values():
if isinstance(channel, for_loop.Collected):
if isinstance(channel, dsl.Collected):
# this validation doesn't apply to Collected
continue

Expand Down
Loading