Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursion bug fix #1061

Merged
merged 11 commits into from
Apr 2, 2019
6 changes: 1 addition & 5 deletions samples/basic/recursion.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ def flip_component(flip_result):
# as the input whereas the flip_result in the current graph component
# comes from the flipA.output in the flipcoin function.
flip_component(flipA.output)
# Return a dictionary of string to arguments
# such that the downstream components that depend
# on this graph component can access the output.
return {'flip_result': flipA.output}

@dsl.pipeline(
name='pipeline flip coin',
Expand All @@ -63,7 +59,7 @@ def flipcoin():
flip_loop = flip_component(flipA.output)
# flip_loop is a graph_component with the outputs field
# filled with the returned dictionary.
PrintOp('cool, it is over. %s' % flip_loop.outputs['flip_result'])
PrintOp('cool, it is over. %s' % flipA.output).after(flip_loop)

if __name__ == '__main__':
import kfp.compiler as compiler
Expand Down
221 changes: 150 additions & 71 deletions sdk/python/kfp/compiler/compiler.py

Large diffs are not rendered by default.

7 changes: 1 addition & 6 deletions sdk/python/kfp/dsl/_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,7 @@ def _graph_component(*args, **kargs):
with graph_ops_group:
# Call the function
if not graph_ops_group.recursive_ref:
graph_ops_group.outputs = func(*args, **kargs)
if not isinstance(graph_ops_group.outputs, dict):
raise ValueError(func.__name__ + ' needs to return a dictionary of string to PipelineParam.')
for output in graph_ops_group.outputs:
if not (isinstance(output, str) and isinstance(graph_ops_group.outputs[output], PipelineParam)):
raise ValueError(func.__name__ + ' needs to return a dictionary of string to PipelineParam.')
func(*args, **kargs)

return graph_ops_group
return _graph_component
4 changes: 2 additions & 2 deletions sdk/python/kfp/dsl/_container_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def _decorated(*args, **kwargs):
# attributes specific to `ContainerOp`
self._inputs = []
self.file_outputs = file_outputs
self.dependent_op_names = []
self.dependent_names = []
self.is_exit_handler = is_exit_handler
self._metadata = None

Expand Down Expand Up @@ -851,7 +851,7 @@ def apply(self, mod_func):

def after(self, op):
"""Specify explicit dependency on another op."""
self.dependent_op_names.append(op.name)
self.dependent_names.append(op.name)
return self

def add_volume(self, volume):
Expand Down
15 changes: 8 additions & 7 deletions sdk/python/kfp/dsl/_ops_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, group_type: str, name: str=None):
self.ops = list()
self.groups = list()
self.name = name
self.dependencies = []
# recursive_ref points to the opsgroups with the same name if exists.
self.recursive_ref = None

Expand Down Expand Up @@ -80,6 +81,11 @@ def __enter__(self):
def __exit__(self, *args):
_pipeline.Pipeline.get_default_pipeline().pop_ops_group()

def after(self, dependency):
"""Specify explicit dependency on another op."""
self.dependencies.append(dependency)
return self

class ExitHandler(OpsGroup):
"""Represents an exit handler that is invoked upon exiting a group of ops.

Expand All @@ -101,7 +107,7 @@ def __init__(self, exit_op: _container_op.ContainerOp):
ValueError is the exit_op is invalid.
"""
super(ExitHandler, self).__init__('exit_handler')
if exit_op.dependent_op_names:
if exit_op.dependent_names:
raise ValueError('exit_op cannot depend on any other ops.')

self.exit_op = exit_op
Expand Down Expand Up @@ -137,9 +143,4 @@ def __init__(self, name):
super(Graph, self).__init__(group_type='graph', name=name)
self.inputs = []
self.outputs = {}
self.dependencies = []

def after(self, dependency):
"""Specify explicit dependency on another op."""
self.dependencies.append(dependency)
return self
self.dependencies = []
8 changes: 6 additions & 2 deletions sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,13 @@ def test_py_image_pull_secret(self):
"""Test pipeline imagepullsecret."""
self._test_py_compile_yaml('imagepullsecret')

def test_py_recursive(self):
def test_py_recursive_do_while(self):
"""Test pipeline recursive."""
self._test_py_compile_yaml('recursive')
self._test_py_compile_yaml('recursive_do_while')

def test_py_recursive_while(self):
"""Test pipeline recursive."""
self._test_py_compile_yaml('recursive_while')

def test_type_checking_with_consistent_types(self):
"""Test type check pipeline parameters against component metadata."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ def flip_component(flip_result):
flipA = FlipCoinOp().after(print_flip)
with dsl.Condition(flipA.output == 'heads'):
flip_component(flipA.output)
return {'flip_result': flipA.output}

@dsl.pipeline(
name='pipeline flip coin',
description='shows how to use graph_component.'
)
def recursive():
flipA = FlipCoinOp()
flipB = FlipCoinOp()
flip_loop = flip_component(flipA.output)
PrintOp('cool, it is over. %s' % flip_loop.outputs['flip_result'])
flip_loop.after(flipB)
PrintOp('cool, it is over. %s' % flipA.output).after(flip_loop)

if __name__ == '__main__':
import kfp.compiler as compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ spec:
- arguments:
parameters:
- name: flip-output
value: '{{inputs.parameters.flip-2-output}}'
value: '{{inputs.parameters.flip-3-output}}'
name: graph-flip-component-1
template: graph-flip-component-1
inputs:
parameters:
- name: flip-2-output
- name: flip-3-output
name: condition-2
- container:
args:
Expand Down Expand Up @@ -102,21 +102,62 @@ spec:
- name: flip-2-output
valueFrom:
path: /tmp/output
- container:
args:
- python -c "import random; result = 'heads' if random.randint(0,1) == 0 else
'tails'; print(result)" | tee /tmp/output
command:
- sh
- -c
image: python:alpine3.6
name: flip-3
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact
parameters:
- name: flip-3-output
valueFrom:
path: /tmp/output
- dag:
tasks:
- arguments:
parameters:
- name: flip-2-output
value: '{{tasks.flip-2.outputs.parameters.flip-2-output}}'
- name: flip-3-output
value: '{{tasks.flip-3.outputs.parameters.flip-3-output}}'
dependencies:
- flip-2
- flip-3
name: condition-2
template: condition-2
when: '{{tasks.flip-2.outputs.parameters.flip-2-output}} == heads'
when: '{{tasks.flip-3.outputs.parameters.flip-3-output}} == heads'
- dependencies:
- print
name: flip-2
template: flip-2
name: flip-3
template: flip-3
- arguments:
parameters:
- name: flip-output
Expand All @@ -127,28 +168,27 @@ spec:
parameters:
- name: flip-output
name: graph-flip-component-1
outputs:
parameters:
- name: flip-2-output
valueFrom:
parameter: '{{tasks.flip-2.outputs.parameters.flip-2-output}}'
- dag:
tasks:
- name: flip
template: flip
- name: flip-2
template: flip-2
- arguments:
parameters:
- name: flip-output
value: '{{tasks.flip.outputs.parameters.flip-output}}'
dependencies:
- flip
- flip-2
name: graph-flip-component-1
template: graph-flip-component-1
- arguments:
parameters:
- name: flip-2-output
value: '{{tasks.graph-flip-component-1.outputs.parameters.flip-2-output}}'
- name: flip-output
value: '{{tasks.flip.outputs.parameters.flip-output}}'
dependencies:
- flip
- graph-flip-component-1
name: print-2
template: print-2
Expand Down Expand Up @@ -193,11 +233,11 @@ spec:
- container:
command:
- echo
- cool, it is over. {{inputs.parameters.flip-2-output}}
- cool, it is over. {{inputs.parameters.flip-output}}
image: alpine:3.6
inputs:
parameters:
- name: flip-2-output
- name: flip-output
name: print-2
outputs:
artifacts:
Expand Down
59 changes: 59 additions & 0 deletions sdk/python/tests/compiler/testdata/recursive_while.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import kfp.dsl as dsl

class FlipCoinOp(dsl.ContainerOp):
"""Flip a coin and output heads or tails randomly."""

def __init__(self):
super(FlipCoinOp, self).__init__(
name='Flip',
image='python:alpine3.6',
command=['sh', '-c'],
arguments=['python -c "import random; result = \'heads\' if random.randint(0,1) == 0 '
'else \'tails\'; print(result)" | tee /tmp/output'],
file_outputs={'output': '/tmp/output'})

class PrintOp(dsl.ContainerOp):
"""Print a message."""

def __init__(self, msg):
super(PrintOp, self).__init__(
name='Print',
image='alpine:3.6',
command=['echo', msg],
)

@dsl._component.graph_component
def flip_component(flip_result):
with dsl.Condition(flip_result == 'heads'):
print_flip = PrintOp(flip_result)
flipA = FlipCoinOp().after(print_flip)
flip_component(flipA.output)

@dsl.pipeline(
name='pipeline flip coin',
description='shows how to use dsl.Condition.'
)
def flipcoin():
flipA = FlipCoinOp()
flipB = FlipCoinOp()
flip_loop = flip_component(flipA.output)
flip_loop.after(flipB)
PrintOp('cool, it is over. %s' % flipA.output).after(flip_loop)

if __name__ == '__main__':
import kfp.compiler as compiler
compiler.Compiler().compile(flipcoin, __file__ + '.tar.gz')
Loading