diff --git a/samples/basic/recursion.py b/samples/basic/recursion.py new file mode 100644 index 00000000000..b9057a97d5f --- /dev/null +++ b/samples/basic/recursion.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import kfp.dsl as dsl + +class FlipCoinOp(dsl.ContainerOp): + """Flip a coin and output heads or tails randomly.""" + + def __init__(self): + super(FlipCoinOp, self).__init__( + name='Flip', + image='python:alpine3.6', + command=['sh', '-c'], + arguments=['python -c "import random; result = \'heads\' if random.randint(0,1) == 0 ' + 'else \'tails\'; print(result)" | tee /tmp/output'], + file_outputs={'output': '/tmp/output'}) + +class PrintOp(dsl.ContainerOp): + """Print a message.""" + + def __init__(self, msg): + super(PrintOp, self).__init__( + name='Print', + image='alpine:3.6', + command=['echo', msg], + ) + +# Use the dsl.graph_component to decorate functions that are +# recursively called. +@dsl.graph_component +def flip_component(flip_result): + print_flip = PrintOp(flip_result) + flipA = FlipCoinOp().after(print_flip) + with dsl.Condition(flipA.output == 'heads'): + # When the flip_component is called recursively, the flipA.output + # from inside the graph component will be passed to the next flip_component + # as the input whereas the flip_result in the current graph component + # comes from the flipA.output in the flipcoin function. + flip_component(flipA.output) + # Return a dictionary of string to arguments + # such that the downstream components that depend + # on this graph component can access the output. + return {'flip_result': flipA.output} + +@dsl.pipeline( + name='pipeline flip coin', + description='shows how to use dsl.Condition.' +) +def flipcoin(): + flipA = FlipCoinOp() + flip_loop = flip_component(flipA.output) + # flip_loop is a graph_component with the outputs field + # filled with the returned dictionary. + PrintOp('cool, it is over. %s' % flip_loop.outputs['flip_result']) + +if __name__ == '__main__': + import kfp.compiler as compiler + compiler.Compiler().compile(flipcoin, __file__ + '.tar.gz') diff --git a/test/e2e_test_gke_v2.yaml b/test/e2e_test_gke_v2.yaml index 48edc315981..47ba67329a6 100644 --- a/test/e2e_test_gke_v2.yaml +++ b/test/e2e_test_gke_v2.yaml @@ -155,6 +155,18 @@ spec: value: "{{inputs.parameters.namespace}}" - name: test-name value: "paralleljoin" + - name: run-recursive-tests + template: run-basic-e2e-tests + arguments: + parameters: + - name: test-results-gcs-dir + value: "{{inputs.parameters.test-results-gcs-dir}}" + - name: sample-tests-image + value: "{{inputs.parameters.target-image-prefix}}{{inputs.parameters.basic-e2e-tests-image-suffix}}" + - name: namespace + value: "{{inputs.parameters.namespace}}" + - name: test-name + value: "recursion" # Build and push image - name: build-image diff --git a/test/sample-test/run_test.sh b/test/sample-test/run_test.sh index 82a23e35516..a1e84778cc3 100755 --- a/test/sample-test/run_test.sh +++ b/test/sample-test/run_test.sh @@ -248,6 +248,19 @@ elif [ "$TEST_NAME" == "paralleljoin" ]; then echo "Copy the test results to GCS ${RESULTS_GCS_DIR}/" gsutil cp ${SAMPLE_PARALLEL_JOIN_TEST_RESULT} ${RESULTS_GCS_DIR}/${SAMPLE_PARALLEL_JOIN_TEST_RESULT} +elif [ "$TEST_NAME" == "recursion" ]; then + SAMPLE_RECURSION_TEST_RESULT=junit_SampleRecursionOutput.xml + SAMPLE_RECURSION_TEST_OUTPUT=${RESULTS_GCS_DIR} + + # Compile samples + cd ${BASE_DIR}/samples/basic + dsl-compile --py recursion.py --output recursion.tar.gz + + cd "${TEST_DIR}" + python3 run_basic_test.py --input ${BASE_DIR}/samples/basic/recursion.tar.gz --result $SAMPLE_RECURSION_TEST_RESULT --output $SAMPLE_RECURSION_TEST_OUTPUT --testname recursion --namespace ${NAMESPACE} + + echo "Copy the test results to GCS ${RESULTS_GCS_DIR}/" + gsutil cp ${SAMPLE_RECURSION_TEST_RESULT} ${RESULTS_GCS_DIR}/${SAMPLE_RECURSION_TEST_RESULT} elif [ "$TEST_NAME" == "xgboost" ]; then SAMPLE_XGBOOST_TEST_RESULT=junit_SampleXGBoostOutput.xml SAMPLE_XGBOOST_TEST_OUTPUT=${RESULTS_GCS_DIR}