Skip to content

Commit

Permalink
Factor out logic to recurse through dags to get parameter value, add …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
boarder7395 committed Oct 3, 2024
1 parent 87aff35 commit 73bb608
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 79 deletions.
115 changes: 36 additions & 79 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1165,19 +1165,24 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,

// This is the case where the input comes from the output of an upstream task.
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter:
cfg := resolveUpstreamParametersConfig{
ctx: ctx,
paramSpec: paramSpec,
dag: dag,
pipeline: pipeline,
mlmd: mlmd,
inputs: inputs,
name: name,
paramError: paramError,
outputParameters := paramSpec.GetTaskOutputParameter()
producerTask := outputParameters.GetProducerTask()
outputParameterKey := outputParameters.GetOutputParameterKey()
if producerTask == "" {
return nil, paramError(fmt.Errorf("producer task is empty"))
}
if err := resolveUpstreamParameters(cfg); err != nil {
if outputParameterKey == "" {
return nil, paramError(fmt.Errorf("output parameter key is empty"))
}
tasks, err := getDAGTasks(ctx, dag, pipeline, mlmd, nil)
if err != nil {
return nil, err
}
parameterValue, err := resolveUpstreamParameters(tasks, producerTask, outputParameterKey)
if err != nil {
return nil, err
}
inputs.ParameterValues[name] = parameterValue

case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue:
runtimeValue := paramSpec.GetRuntimeValue()
Expand Down Expand Up @@ -1236,82 +1241,34 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
return inputs, nil
}

// resolveUpstreamParametersConfig is just a config struct used to store the
// input parameters of the resolveUpstreamParameters function.
type resolveUpstreamParametersConfig struct {
ctx context.Context
paramSpec *pipelinespec.TaskInputsSpec_InputParameterSpec
dag *metadata.DAG
pipeline *metadata.Pipeline
mlmd *metadata.Client
inputs *pipelinespec.ExecutorInput_Inputs
name string
paramError func(error) error
}

// resolveUpstreamParameters resolves input parameters that come from upstream
// tasks. These tasks can be components/containers, which is relatively
// straightforward, or DAGs, in which case, we need to traverse the graph until
// we arrive at a component/container (since there can be n nested DAGs).
func resolveUpstreamParameters(cfg resolveUpstreamParametersConfig) error {
taskOutput := cfg.paramSpec.GetTaskOutputParameter()
glog.V(4).Info("taskOutput: ", taskOutput)
if taskOutput.GetProducerTask() == "" {
return cfg.paramError(fmt.Errorf("producer task is empty"))
}
if taskOutput.GetOutputParameterKey() == "" {
return cfg.paramError(fmt.Errorf("output parameter key is empty"))
}
tasks, err := getDAGTasks(cfg.ctx, cfg.dag, cfg.pipeline, cfg.mlmd, nil)
// resolveUpstreamParameters is a recursive function that traverses the DAG to get a downstream parameter.
func resolveUpstreamParameters(tasks map[string]*metadata.Execution, taskName string, outputParameterKey string) (parameterValue *structpb.Value, err error) {
upstreamTask := tasks[taskName]
_, outputParametersCustomProperty, err := upstreamTask.GetParameters()
if err != nil {
return cfg.paramError(err)
return nil, err
}

// The producer is the task that produces the output that we need to
// consume.
producer := tasks[taskOutput.GetProducerTask()]
outputParameterKey := taskOutput.GetOutputParameterKey()
glog.V(4).Info("producer: ", producer)
currentTask := producer
currentSubTaskMaybeDAG := true
// Continue looping until we reach a sub-task that is NOT a DAG.
for currentSubTaskMaybeDAG {
glog.V(4).Info("currentTask: ", currentTask.TaskName())
_, outputParametersCustomProperty, err := currentTask.GetParameters()
if *upstreamTask.GetExecution().Type == "system.DAGExecution" {
// recurse
var outputParametersMap map[string]string
b, err := outputParametersCustomProperty[outputParameterKey].GetStructValue().MarshalJSON()
if err != nil {
return err
return nil, err
}
// If the current task is a DAG:
if *currentTask.GetExecution().Type == "system.DAGExecution" {
// Since currentTask is a DAG, we need to deserialize its
// output parameter map so that we can look its
// corresponding producer sub-task, reassign currentTask,
// and iterate through this loop again.
var outputParametersMap map[string]string
b, err := outputParametersCustomProperty[outputParameterKey].GetStructValue().MarshalJSON()
if err != nil {
return err
}
json.Unmarshal(b, &outputParametersMap)
glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap)
subTaskName := outputParametersMap["producer_subtask"]
outputParameterKey = outputParametersMap["output_parameter_key"]
glog.V(4).Infof(
"Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.",
currentTask.TaskName(),
subTaskName,
)

// Reassign sub-task before running through the loop again.
currentTask = tasks[subTaskName]
} else {
cfg.inputs.ParameterValues[cfg.name] = outputParametersCustomProperty[outputParameterKey]
// Exit the loop.
currentSubTaskMaybeDAG = false
json.Unmarshal(b, &outputParametersMap)
glog.V(4).Info("Deserialized outputParametersMap: ", outputParametersMap)
nextTask := outputParametersMap["producer_subtask"]
outputParameterKey = outputParametersMap["output_parameter_key"]
downstreamParameterMapping, err := resolveUpstreamParameters(tasks, nextTask, outputParameterKey)
if err != nil {
return nil, err
}
return downstreamParameterMapping, nil
} else {
// base case
return outputParametersCustomProperty[outputParameterKey], nil
}

return nil
}

// resolveUpstreamArtifactsConfig is just a config struct used to store the
Expand Down
99 changes: 99 additions & 0 deletions backend/src/v2/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ import (
"encoding/json"
"testing"

"google.golang.org/protobuf/types/known/structpb"
k8sres "k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/v2/metadata"
"github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform"
"github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
k8score "k8s.io/api/core/v1"
Expand Down Expand Up @@ -1621,3 +1623,100 @@ func Test_extendPodSpecPatch_GenericEphemeralVolume(t *testing.T) {
})
}
}

func TestResolveUpstreamParametersBase(t *testing.T) {
Id := int64(10)
Name := "some-name"
Type := "system.Execution"
CustomProperties := map[string]*ml_metadata.Value{
"inputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"name": structpb.NewStringValue("task1")}}}},
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"output_key": structpb.NewStringValue("dataset")}}}},
}
tasks := map[string]*metadata.Execution{"test": metadata.NewExecution(&ml_metadata.Execution{
Id: &Id,
Name: &Name,
Type: &Type,
CustomProperties: CustomProperties,
})}
parameterValueMapping, err := resolveUpstreamParameters(tasks, "test", "output_key")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, structpb.NewStringValue("dataset"))
assert.Equal(t, *tasks["test"].GetExecution().Type, "system.Execution")
}

func TestResolveUpstreamParametersParameters(t *testing.T) {
task1ID := int64(10)
task1Name := "some-name"
task2ID := int64(10)
task2Name := "some-name"
dagType := "system.DAGExecution"
executionType := "system.Execution"
subtaskValue, _ := structpb.NewValue("task2")
outputParameterKey, _ := structpb.NewValue("dataset2")
nestedOutputParameterKey, _ := structpb.NewValue("dataset2")
output := structpb.Value{Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"producer_subtask": subtaskValue, "output_parameter_key": outputParameterKey}}}}
CustomProperties := map[string]*ml_metadata.Value{
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset": &output}}}},
}
nestedCustomProperties := map[string]*ml_metadata.Value{
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset2": nestedOutputParameterKey}}}},
}
tasks := map[string]*metadata.Execution{
"task1": metadata.NewExecution(&ml_metadata.Execution{
Id: &task1ID,
Name: &task1Name,
Type: &dagType,
CustomProperties: CustomProperties}),
"task2": metadata.NewExecution(&ml_metadata.Execution{
Id: &task2ID,
Name: &task2Name,
Type: &executionType,
CustomProperties: nestedCustomProperties}),
}
parameterValueMapping, err := resolveUpstreamParameters(tasks, "task1", "dataset")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, nestedOutputParameterKey)
}

func TestResolveUpstreamParametersMulti(t *testing.T) {
taskIDs := []int64{10, 11, 12}
taskNames := []string{"task1-name", "task2-name", "task3-name"}
dagType, executionType := "system.DAGExecution", "system.Execution"
subtaskValues := []*structpb.Value{structpb.NewStringValue("task2"), structpb.NewStringValue("task2")}
outputParameterKeys := []*structpb.Value{structpb.NewStringValue("dataset"), structpb.NewStringValue("dataset"), structpb.NewStringValue("dataset2")}
output := structpb.Value{Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"producer_subtask": subtaskValues[0], "output_parameter_key": outputParameterKeys[0]}}}}
output2 := structpb.Value{Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"producer_subtask": subtaskValues[0], "output_parameter_key": outputParameterKeys[1]}}}}
CustomProperties := map[string]*ml_metadata.Value{
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset": &output, "dataset2": &output2}}}},
}
nestedCustomProperties := map[string]*ml_metadata.Value{
"outputs": {Value: &ml_metadata.Value_StructValue{StructValue: &structpb.Struct{Fields: map[string]*structpb.Value{"dataset": outputParameterKeys[2]}}}},
}
tasks := map[string]*metadata.Execution{
"task1": metadata.NewExecution(&ml_metadata.Execution{
Id: &taskIDs[0],
Name: &taskNames[0],
Type: &dagType,
CustomProperties: CustomProperties}),
"task2": metadata.NewExecution(&ml_metadata.Execution{
Id: &taskIDs[1],
Name: &taskNames[1],
Type: &executionType,
CustomProperties: nestedCustomProperties}),
"task3": metadata.NewExecution(&ml_metadata.Execution{
Id: &taskIDs[2],
Name: &taskNames[2],
Type: &dagType,
CustomProperties: CustomProperties,
}),
}
parameterValueMapping, err := resolveUpstreamParameters(tasks, "task1", "dataset")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, outputParameterKeys[2])
parameterValueMapping, err = resolveUpstreamParameters(tasks, "task1", "dataset2")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, outputParameterKeys[2])
parameterValueMapping, err = resolveUpstreamParameters(tasks, "task3", "dataset")
assert.Nil(t, err)
assert.Equal(t, parameterValueMapping, outputParameterKeys[2])
}

0 comments on commit 73bb608

Please sign in to comment.