Skip to content

Commit

Permalink
Factor out logic to recurse through dags to get parameter value, add …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
boarder7395 committed Oct 3, 2024
1 parent 87aff35 commit e40c24d
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 10 deletions.
61 changes: 51 additions & 10 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1165,19 +1165,31 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,

// This is the case where the input comes from the output of an upstream task.
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter:
cfg := resolveUpstreamParametersConfig{
ctx: ctx,
paramSpec: paramSpec,
dag: dag,
pipeline: pipeline,
mlmd: mlmd,
inputs: inputs,
name: name,
paramError: paramError,
// cfg := resolveUpstreamParametersConfig{
// ctx: ctx,
// paramSpec: paramSpec,
// dag: dag,
// pipeline: pipeline,
// mlmd: mlmd,
// inputs: inputs,
// name: name,
// paramError: paramError,
// }
// if err := resolveUpstreamParameters(cfg); err != nil {
// return nil, err
// }
outputParameters := paramSpec.GetTaskOutputParameter()
producerTask := outputParameters.GetProducerTask()
outputParameterKey := outputParameters.GetOutputParameterKey()
tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil)
if err != nil {
return nil, err
}
if err := resolveUpstreamParameters(cfg); err != nil {
parameterValue, err := recurseParameters(tasks, producerTask, outputParameterKey)
if err != nil {
return nil, err
}
inputs.ParameterValues[name] = parameterValue

case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue:
runtimeValue := paramSpec.GetRuntimeValue()
Expand Down Expand Up @@ -1236,6 +1248,35 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
return inputs, nil
}

func recurseParameters(tasks map[string]*metadata.Execution, taskName string, outputParameterKey string) (parameterValue *structpb.Value, err error) {
upstreamTask := tasks[taskName]
_, outputParametersCustomProperty, err := upstreamTask.GetParameters()
if err != nil {
return nil, err
}
if *upstreamTask.GetExecution().Type == "system.DAGExecution" {
// recurse
var outputParametersMap map[string]string
b, err := outputParametersCustomProperty[outputParameterKey].GetStructValue().MarshalJSON()
if err != nil {
return nil, err
}
json.Unmarshal(b, &outputParametersMap)
glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap)
nextTask := outputParametersMap["producer_subtask"]
outputParameterKey = outputParametersMap["output_parameter_key"]
downstreamParameterMapping, err := recurseParameters(tasks, nextTask, outputParameterKey)
if err != nil {
return nil, err
}
return downstreamParameterMapping, nil
} else {
// base case
return outputParametersCustomProperty[outputParameterKey], nil
}

}

// resolveUpstreamParametersConfig is just a config struct used to store the
// input parameters of the resolveUpstreamParameters function.
type resolveUpstreamParametersConfig struct {
Expand Down
99 changes: 99 additions & 0 deletions backend/src/v2/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ import (
"encoding/json"
"testing"

"google.golang.org/protobuf/types/known/structpb"
k8sres "k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/v2/metadata"
"github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform"
"github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
k8score "k8s.io/api/core/v1"
Expand Down Expand Up @@ -1621,3 +1623,100 @@ func Test_extendPodSpecPatch_GenericEphemeralVolume(t *testing.T) {
})
}
}

func TestRecurseParametersBase(t *testing.T) {
Id := int64(10)
Name := "some-name"
Type := "system.Execution"
CustomProperties := map[string]*ml_metadata.Value{
"inputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"name": structpb.NewStringValue("task1")}}}},
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"output_key": structpb.NewStringValue("dataset")}}}},
}
tasks := map[string]*metadata.Execution{"test": metadata.NewExecution(&ml_metadata.Execution{
Id: &Id,
Name: &Name,
Type: &Type,
CustomProperties: CustomProperties,
})}
parameterValueMapping, err := recurseParameters(tasks, "test", "output_key")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, structpb.NewStringValue("dataset"))
assert.Equal(t, *tasks["test"].GetExecution().Type, "system.Execution")
}

func TestRecurseParameters(t *testing.T) {
task1ID := int64(10)
task1Name := "some-name"
task2ID := int64(10)
task2Name := "some-name"
dagType := "system.DAGExecution"
executionType := "system.Execution"
subtaskValue, _ := structpb.NewValue("task2")
outputParameterKey, _ := structpb.NewValue("dataset2")
nestedOutputParameterKey, _ := structpb.NewValue("dataset2")
output := structpb.Value{Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"producer_subtask": subtaskValue, "output_parameter_key": outputParameterKey}}}}
CustomProperties := map[string]*ml_metadata.Value{
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset": &output}}}},
}
nestedCustomProperties := map[string]*ml_metadata.Value{
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset2": nestedOutputParameterKey}}}},
}
tasks := map[string]*metadata.Execution{
"task1": metadata.NewExecution(&ml_metadata.Execution{
Id: &task1ID,
Name: &task1Name,
Type: &dagType,
CustomProperties: CustomProperties}),
"task2": metadata.NewExecution(&ml_metadata.Execution{
Id: &task2ID,
Name: &task2Name,
Type: &executionType,
CustomProperties: nestedCustomProperties}),
}
parameterValueMapping, err := recurseParameters(tasks, "task1", "dataset")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, nestedOutputParameterKey)
}

func TestRecurseParametersMulti(t *testing.T) {
taskIDs := []int64{10, 11, 12}
taskNames := []string{"task1-name", "task2-name", "task3-name"}
dagType, executionType := "system.DAGExecution", "system.Execution"
subtaskValues := []*structpb.Value{structpb.NewStringValue("task2"), structpb.NewStringValue("task2")}
outputParameterKeys := []*structpb.Value{structpb.NewStringValue("dataset"), structpb.NewStringValue("dataset"), structpb.NewStringValue("dataset2")}
output := structpb.Value{Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"producer_subtask": subtaskValues[0], "output_parameter_key": outputParameterKeys[0]}}}}
output2 := structpb.Value{Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"producer_subtask": subtaskValues[0], "output_parameter_key": outputParameterKeys[1]}}}}
CustomProperties := map[string]*ml_metadata.Value{
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset": &output, "dataset2": &output2}}}},
}
nestedCustomProperties := map[string]*ml_metadata.Value{
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset": outputParameterKeys[2]}}}},
}
tasks := map[string]*metadata.Execution{
"task1": metadata.NewExecution(&ml_metadata.Execution{
Id: &taskIDs[0],
Name: &taskNames[0],
Type: &dagType,
CustomProperties: CustomProperties}),
"task2": metadata.NewExecution(&ml_metadata.Execution{
Id: &taskIDs[1],
Name: &taskNames[1],
Type: &executionType,
CustomProperties: nestedCustomProperties}),
"task3": metadata.NewExecution(&ml_metadata.Execution{
Id: &taskIDs[2],
Name: &taskNames[2],
Type: &dagType,
CustomProperties: CustomProperties,
}),
}
parameterValueMapping, err := recurseParameters(tasks, "task1", "dataset")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, outputParameterKeys[2])
parameterValueMapping, err = recurseParameters(tasks, "task1", "dataset2")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, outputParameterKeys[2])
parameterValueMapping, err = recurseParameters(tasks, "task3", "dataset")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, outputParameterKeys[2])
}

0 comments on commit e40c24d

Please sign in to comment.