Skip to content

Commit

Permalink
Add a recursion sample (#1016)
Browse files Browse the repository at this point in the history
* add a While in the ops group

* deepcopy the while conditions when entering and exiting

* add while condition resolution in the compiler

* define graph component decorator

* remove while loop related codes

* fixes

* remove while loop related code

* fix bugs

* generate a unique ops group name and being able to retrieve by name

* resolve the opsgroups inputs and dependencies based on the pipelineparam in the condition

* add a recursive ops_groups

* fix bugs of the recursive opsgroup template name

* resolve the recursive template name and arguments

* add validity checks

* add more comments

* add usage comment in graph_component

* add a sample

* add unit test for the graph opsgraph

* refactor the opsgroup

* add unit test for the graph_component decorator

* exposing graph_component decorator

* add recursive compiler unit tests

* add the sample test

* fix the bug of opsgroup name
adjust the graph_component usage example
fix index bugs
use with statement in the graph_component instead of directly calling
the enter/exit functions

* add a todo to combine the graph_component and component decorators

* fix some merging bug

* fix typo

* add more comments in the sample

* update comments
  • Loading branch information
gaoning777 authored and k8s-ci-robot committed Mar 28, 2019
1 parent f59c25b commit 1d617b5
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
70 changes: 70 additions & 0 deletions samples/basic/recursion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import kfp.dsl as dsl

class FlipCoinOp(dsl.ContainerOp):
"""Flip a coin and output heads or tails randomly."""

def __init__(self):
super(FlipCoinOp, self).__init__(
name='Flip',
image='python:alpine3.6',
command=['sh', '-c'],
arguments=['python -c "import random; result = \'heads\' if random.randint(0,1) == 0 '
'else \'tails\'; print(result)" | tee /tmp/output'],
file_outputs={'output': '/tmp/output'})

class PrintOp(dsl.ContainerOp):
"""Print a message."""

def __init__(self, msg):
super(PrintOp, self).__init__(
name='Print',
image='alpine:3.6',
command=['echo', msg],
)

# Use the dsl.graph_component to decorate functions that are
# recursively called.
@dsl.graph_component
def flip_component(flip_result):
print_flip = PrintOp(flip_result)
flipA = FlipCoinOp().after(print_flip)
with dsl.Condition(flipA.output == 'heads'):
# When the flip_component is called recursively, the flipA.output
# from inside the graph component will be passed to the next flip_component
# as the input whereas the flip_result in the current graph component
# comes from the flipA.output in the flipcoin function.
flip_component(flipA.output)
# Return a dictionary of string to arguments
# such that the downstream components that depend
# on this graph component can access the output.
return {'flip_result': flipA.output}

@dsl.pipeline(
name='pipeline flip coin',
description='shows how to use dsl.Condition.'
)
def flipcoin():
flipA = FlipCoinOp()
flip_loop = flip_component(flipA.output)
# flip_loop is a graph_component with the outputs field
# filled with the returned dictionary.
PrintOp('cool, it is over. %s' % flip_loop.outputs['flip_result'])

if __name__ == '__main__':
import kfp.compiler as compiler
compiler.Compiler().compile(flipcoin, __file__ + '.tar.gz')
12 changes: 12 additions & 0 deletions test/e2e_test_gke_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ spec:
value: "{{inputs.parameters.namespace}}"
- name: test-name
value: "paralleljoin"
- name: run-recursive-tests
template: run-basic-e2e-tests
arguments:
parameters:
- name: test-results-gcs-dir
value: "{{inputs.parameters.test-results-gcs-dir}}"
- name: sample-tests-image
value: "{{inputs.parameters.target-image-prefix}}{{inputs.parameters.basic-e2e-tests-image-suffix}}"
- name: namespace
value: "{{inputs.parameters.namespace}}"
- name: test-name
value: "recursion"

# Build and push image
- name: build-image
Expand Down
13 changes: 13 additions & 0 deletions test/sample-test/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,19 @@ elif [ "$TEST_NAME" == "paralleljoin" ]; then

echo "Copy the test results to GCS ${RESULTS_GCS_DIR}/"
gsutil cp ${SAMPLE_PARALLEL_JOIN_TEST_RESULT} ${RESULTS_GCS_DIR}/${SAMPLE_PARALLEL_JOIN_TEST_RESULT}
elif [ "$TEST_NAME" == "recursion" ]; then
SAMPLE_RECURSION_TEST_RESULT=junit_SampleRecursionOutput.xml
SAMPLE_RECURSION_TEST_OUTPUT=${RESULTS_GCS_DIR}

# Compile samples
cd ${BASE_DIR}/samples/basic
dsl-compile --py recursion.py --output recursion.tar.gz

cd "${TEST_DIR}"
python3 run_basic_test.py --input ${BASE_DIR}/samples/basic/recursion.tar.gz --result $SAMPLE_RECURSION_TEST_RESULT --output $SAMPLE_RECURSION_TEST_OUTPUT --testname recursion --namespace ${NAMESPACE}

echo "Copy the test results to GCS ${RESULTS_GCS_DIR}/"
gsutil cp ${SAMPLE_RECURSION_TEST_RESULT} ${RESULTS_GCS_DIR}/${SAMPLE_RECURSION_TEST_RESULT}
elif [ "$TEST_NAME" == "xgboost" ]; then
SAMPLE_XGBOOST_TEST_RESULT=junit_SampleXGBoostOutput.xml
SAMPLE_XGBOOST_TEST_OUTPUT=${RESULTS_GCS_DIR}
Expand Down

0 comments on commit 1d617b5

Please sign in to comment.