From e40c24d69ca15b7dad740c744cdaa2c01dbb6083 Mon Sep 17 00:00:00 2001 From: Tyler Kalbach Date: Mon, 30 Sep 2024 19:47:19 -0400 Subject: [PATCH] Factor out logic to recurse through dags to get parameter value, add tests --- backend/src/v2/driver/driver.go | 61 ++++++++++++++--- backend/src/v2/driver/driver_test.go | 99 ++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 10 deletions(-) diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 60d04d7f53a..6cb25f0507b 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -1165,19 +1165,31 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, // This is the case where the input comes from the output of an upstream task. case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter: - cfg := resolveUpstreamParametersConfig{ - ctx: ctx, - paramSpec: paramSpec, - dag: dag, - pipeline: pipeline, - mlmd: mlmd, - inputs: inputs, - name: name, - paramError: paramError, + // cfg := resolveUpstreamParametersConfig{ + // ctx: ctx, + // paramSpec: paramSpec, + // dag: dag, + // pipeline: pipeline, + // mlmd: mlmd, + // inputs: inputs, + // name: name, + // paramError: paramError, + // } + // if err := resolveUpstreamParameters(cfg); err != nil { + // return nil, err + // } + outputParameters := paramSpec.GetTaskOutputParameter() + producerTask := outputParameters.GetProducerTask() + outputParameterKey := outputParameters.GetOutputParameterKey() + tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil) + if err != nil { + return nil, err } - if err := resolveUpstreamParameters(cfg); err != nil { + parameterValue, err := recurseParameters(tasks, producerTask, outputParameterKey) + if err != nil { return nil, err } + inputs.ParameterValues[name] = parameterValue case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue: runtimeValue := paramSpec.GetRuntimeValue() @@ -1236,6 +1248,35 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, return inputs, nil } +func recurseParameters(tasks map[string]*metadata.Execution, taskName string, outputParameterKey string) (parameterValue *structpb.Value, err error) { + upstreamTask := tasks[taskName] + _, outputParametersCustomProperty, err := upstreamTask.GetParameters() + if err != nil { + return nil, err + } + if *upstreamTask.GetExecution().Type == "system.DAGExecution" { + // recurse + var outputParametersMap map[string]string + b, err := outputParametersCustomProperty[outputParameterKey].GetStructValue().MarshalJSON() + if err != nil { + return nil, err + } + json.Unmarshal(b, &outputParametersMap) + glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap) + nextTask := outputParametersMap["producer_subtask"] + outputParameterKey = outputParametersMap["output_parameter_key"] + downstreamParameterMapping, err := recurseParameters(tasks, nextTask, outputParameterKey) + if err != nil { + return nil, err + } + return downstreamParameterMapping, nil + } else { + // base case + return outputParametersCustomProperty[outputParameterKey], nil + } + +} + // resolveUpstreamParametersConfig is just a config struct used to store the // input parameters of the resolveUpstreamParameters function. type resolveUpstreamParametersConfig struct { diff --git a/backend/src/v2/driver/driver_test.go b/backend/src/v2/driver/driver_test.go index be64723ccfb..22c1ccbcbff 100644 --- a/backend/src/v2/driver/driver_test.go +++ b/backend/src/v2/driver/driver_test.go @@ -17,12 +17,14 @@ import ( "encoding/json" "testing" + "google.golang.org/protobuf/types/known/structpb" k8sres "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" "github.com/kubeflow/pipelines/backend/src/v2/metadata" "github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform" + "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata" "github.com/spf13/viper" "github.com/stretchr/testify/assert" k8score "k8s.io/api/core/v1" @@ -1621,3 +1623,100 @@ func Test_extendPodSpecPatch_GenericEphemeralVolume(t *testing.T) { }) } } + +func TestRecurseParametersBase(t *testing.T) { + Id := int64(10) + Name := "some-name" + Type := "system.Execution" + CustomProperties := map[string]*ml_metadata.Value{ + "inputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"name": structpb.NewStringValue("task1")}}}}, + "outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"output_key": structpb.NewStringValue("dataset")}}}}, + } + tasks := map[string]*metadata.Execution{"test": metadata.NewExecution(&ml_metadata.Execution{ + Id: &Id, + Name: &Name, + Type: &Type, + CustomProperties: CustomProperties, + })} + parameterValueMapping, err := recurseParameters(tasks, "test", "output_key") + assert.Nil(t, err) + assert.Equal(t, parameterValueMapping, structpb.NewStringValue("dataset")) + assert.Equal(t, *tasks["test"].GetExecution().Type, "system.Execution") +} + +func TestRecurseParameters(t *testing.T) { + task1ID := int64(10) + task1Name := "some-name" + task2ID := int64(10) + task2Name := "some-name" + dagType := "system.DAGExecution" + executionType := "system.Execution" + subtaskValue, _ := structpb.NewValue("task2") + outputParameterKey, _ := structpb.NewValue("dataset2") + nestedOutputParameterKey, _ := structpb.NewValue("dataset2") + output := structpb.Value{Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"producer_subtask": subtaskValue, "output_parameter_key": outputParameterKey}}}} + CustomProperties := map[string]*ml_metadata.Value{ + "outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset": &output}}}}, + } + nestedCustomProperties := map[string]*ml_metadata.Value{ + "outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset2": nestedOutputParameterKey}}}}, + } + tasks := map[string]*metadata.Execution{ + "task1": metadata.NewExecution(&ml_metadata.Execution{ + Id: &task1ID, + Name: &task1Name, + Type: &dagType, + CustomProperties: CustomProperties}), + "task2": metadata.NewExecution(&ml_metadata.Execution{ + Id: &task2ID, + Name: &task2Name, + Type: &executionType, + CustomProperties: nestedCustomProperties}), + } + parameterValueMapping, err := recurseParameters(tasks, "task1", "dataset") + assert.Nil(t, err) + assert.Equal(t, parameterValueMapping, nestedOutputParameterKey) +} + +func TestRecurseParametersMulti(t *testing.T) { + taskIDs := []int64{10, 11, 12} + taskNames := []string{"task1-name", "task2-name", "task3-name"} + dagType, executionType := "system.DAGExecution", "system.Execution" + subtaskValues := []*structpb.Value{structpb.NewStringValue("task2"), structpb.NewStringValue("task2")} + outputParameterKeys := []*structpb.Value{structpb.NewStringValue("dataset"), structpb.NewStringValue("dataset"), structpb.NewStringValue("dataset2")} + output := structpb.Value{Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"producer_subtask": subtaskValues[0], "output_parameter_key": outputParameterKeys[0]}}}} + output2 := structpb.Value{Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"producer_subtask": subtaskValues[0], "output_parameter_key": outputParameterKeys[1]}}}} + CustomProperties := map[string]*ml_metadata.Value{ + "outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset": &output, "dataset2": &output2}}}}, + } + nestedCustomProperties := map[string]*ml_metadata.Value{ + "outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset": outputParameterKeys[2]}}}}, + } + tasks := map[string]*metadata.Execution{ + "task1": metadata.NewExecution(&ml_metadata.Execution{ + Id: &taskIDs[0], + Name: &taskNames[0], + Type: &dagType, + CustomProperties: CustomProperties}), + "task2": metadata.NewExecution(&ml_metadata.Execution{ + Id: &taskIDs[1], + Name: &taskNames[1], + Type: &executionType, + CustomProperties: nestedCustomProperties}), + "task3": metadata.NewExecution(&ml_metadata.Execution{ + Id: &taskIDs[2], + Name: &taskNames[2], + Type: &dagType, + CustomProperties: CustomProperties, + }), + } + parameterValueMapping, err := recurseParameters(tasks, "task1", "dataset") + assert.Nil(t, err) + assert.Equal(t, parameterValueMapping, outputParameterKeys[2]) + parameterValueMapping, err = recurseParameters(tasks, "task1", "dataset2") + assert.Nil(t, err) + assert.Equal(t, parameterValueMapping, outputParameterKeys[2]) + parameterValueMapping, err = recurseParameters(tasks, "task3", "dataset") + assert.Nil(t, err) + assert.Equal(t, parameterValueMapping, outputParameterKeys[2]) +}