From 1d42e3f6e23890b61d6f829e6817b10a44faa189 Mon Sep 17 00:00:00 2001 From: Jiaxiao Zheng Date: Tue, 25 Feb 2020 17:49:56 -0800 Subject: [PATCH] [Backend] Fix parameter patching (#3145) * staging changes * fix unit test. Patch both workflow spec and api run pipeline spec. * fix the condition * update per comments --- .../apiserver/resource/resource_manager.go | 9 +++ .../resource/resource_manager_test.go | 71 ++++++++++++++++++- .../resource/resource_manager_util.go | 47 ++++++++++++ backend/src/apiserver/server/run_server.go | 17 ----- .../src/apiserver/server/run_server_test.go | 7 +- backend/src/apiserver/server/util.go | 16 ----- backend/src/common/util/workflow.go | 3 +- 7 files changed, 130 insertions(+), 40 deletions(-) diff --git a/backend/src/apiserver/resource/resource_manager.go b/backend/src/apiserver/resource/resource_manager.go index 06a982cc5e6..a3ed4101326 100644 --- a/backend/src/apiserver/resource/resource_manager.go +++ b/backend/src/apiserver/resource/resource_manager.go @@ -42,6 +42,9 @@ const ( defaultPipelineRunnerServiceAccountEnvVar = "DefaultPipelineRunnerServiceAccount" defaultPipelineRunnerServiceAccount = "pipeline-runner" defaultServiceAccount = "default-editor" + HasDefaultBucketEnvVar = "HAS_DEFAULT_BUCKET" + ProjectIDEnvVar = "PROJECT_ID" + DefaultBucketNameEnvVar = "BUCKET_NAME" ) type ClientManagerInterface interface { @@ -283,6 +286,12 @@ func (r *ResourceManager) CreateRun(apiRun *api.Run) (*model.RunDetail, error) { } // Append provided parameter workflow.OverrideParameters(parameters) + + err = OverrideParameterWithSystemDefault(workflow, apiRun) + if err != nil { + return nil, err + } + // Add label to the workflow so it can be persisted by persistent agent later. workflow.SetLabels(util.LabelKeyWorkflowRunId, runId) // Add run name annotation to the workflow so that it can be logged by the Metadata Writer. diff --git a/backend/src/apiserver/resource/resource_manager_test.go b/backend/src/apiserver/resource/resource_manager_test.go index 5519709cfab..3fcdeb0a4ec 100644 --- a/backend/src/apiserver/resource/resource_manager_test.go +++ b/backend/src/apiserver/resource/resource_manager_test.go @@ -43,7 +43,6 @@ func initEnvVars() { type FakeBadObjectStore struct{} - func (m *FakeBadObjectStore) GetPipelineKey(pipelineID string) string { return pipelineID } @@ -149,6 +148,28 @@ func initWithOneTimeRun(t *testing.T) (*FakeClientManager, *ResourceManager, *mo return store, manager, runDetail } +func initWithPatchedRun(t *testing.T) (*FakeClientManager, *ResourceManager, *model.RunDetail) { + store, manager, exp := initWithExperiment(t) + apiRun := &api.Run{ + Name: "run1", + PipelineSpec: &api.PipelineSpec{ + WorkflowManifest: testWorkflow.ToStringForStore(), + Parameters: []*api.Parameter{ + {Name: "param1", Value: "{{kfp-default-bucket}}"}, + }, + }, + ResourceReferences: []*api.ResourceReference{ + { + Key: &api.ResourceKey{Type: api.ResourceType_EXPERIMENT, Id: exp.UUID}, + Relationship: api.Relationship_OWNER, + }, + }, + } + runDetail, err := manager.CreateRun(apiRun) + assert.Nil(t, err) + return store, manager, runDetail +} + func initWithOneTimeFailedRun(t *testing.T) (*FakeClientManager, *ResourceManager, *model.RunDetail) { store, manager, exp := initWithExperiment(t) apiRun := &api.Run{ @@ -397,6 +418,54 @@ func TestCreateRun_ThroughWorkflowSpec(t *testing.T) { assert.Equal(t, expectedRunDetail, runDetail, "CreateRun stored invalid data in database") } +func TestCreateRun_ThroughWorkflowSpecWithPatch(t *testing.T) { + viper.Set(HasDefaultBucketEnvVar, "true") + viper.Set(ProjectIDEnvVar, "test-project-id") + viper.Set(DefaultBucketNameEnvVar, "test-default-bucket") + store, manager, runDetail := initWithPatchedRun(t) + expectedExperimentUUID := runDetail.ExperimentUUID + expectedRuntimeWorkflow := testWorkflow.DeepCopy() + expectedRuntimeWorkflow.Spec.Arguments.Parameters = []v1alpha1.Parameter{ + {Name: "param1", Value: util.StringPointer("test-default-bucket")}} + expectedRuntimeWorkflow.Labels = map[string]string{util.LabelKeyWorkflowRunId: "123e4567-e89b-12d3-a456-426655440000"} + expectedRuntimeWorkflow.Annotations = map[string]string{util.AnnotationKeyRunName: "run1"} + expectedRuntimeWorkflow.Spec.ServiceAccountName = defaultPipelineRunnerServiceAccount + expectedRunDetail := &model.RunDetail{ + Run: model.Run{ + UUID: "123e4567-e89b-12d3-a456-426655440000", + ExperimentUUID: expectedExperimentUUID, + DisplayName: "run1", + Name: "workflow-name", + Namespace: "test-ns", + StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), + CreatedAtInSec: 2, + Conditions: "Running", + PipelineSpec: model.PipelineSpec{ + WorkflowSpecManifest: testWorkflow.ToStringForStore(), + Parameters: "[{\"name\":\"param1\",\"value\":\"test-default-bucket\"}]", + }, + ResourceReferences: []*model.ResourceReference{ + { + ResourceUUID: "123e4567-e89b-12d3-a456-426655440000", + ResourceType: common.Run, + ReferenceUUID: DefaultFakeUUID, + ReferenceName: "e1", + ReferenceType: common.Experiment, + Relationship: common.Owner, + }, + }, + }, + PipelineRuntime: model.PipelineRuntime{ + WorkflowRuntimeManifest: util.NewWorkflow(expectedRuntimeWorkflow).ToStringForStore(), + }, + } + assert.Equal(t, expectedRunDetail, runDetail, "The CreateRun return has unexpected value.") + assert.Equal(t, 1, store.ArgoClientFake.GetWorkflowCount(), "Workflow CRD is not created.") + runDetail, err := manager.GetRun(runDetail.UUID) + assert.Nil(t, err) + assert.Equal(t, expectedRunDetail, runDetail, "CreateRun stored invalid data in database") +} + func TestCreateRun_ThroughPipelineVersion(t *testing.T) { // Create experiment, pipeline, and pipeline version. store, manager, experiment, pipeline := initWithExperimentAndPipeline(t) diff --git a/backend/src/apiserver/resource/resource_manager_util.go b/backend/src/apiserver/resource/resource_manager_util.go index 2ea2372350e..5a1b52ac8f6 100644 --- a/backend/src/apiserver/resource/resource_manager_util.go +++ b/backend/src/apiserver/resource/resource_manager_util.go @@ -16,6 +16,7 @@ package resource import ( "errors" + "fmt" "regexp" "strings" "time" @@ -24,6 +25,7 @@ import ( "github.com/argoproj/argo/workflow/common" api "github.com/kubeflow/pipelines/backend/api/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/client" + servercommon "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/common/util" scheduledworkflow "github.com/kubeflow/pipelines/backend/src/crd/pkg/apis/scheduledworkflow/v1beta1" apierr "k8s.io/apimachinery/pkg/api/errors" @@ -181,3 +183,48 @@ func deletePods(k8sCoreClient client.KubernetesCoreInterface, podsToDelete []str } return nil } + +// Mutate default values of specified pipeline spec. +// Args: +// text: (part of) pipeline file in string. +func PatchPipelineDefaultParameter(text string) (string, error) { + defaultBucket := servercommon.GetStringConfig(DefaultBucketNameEnvVar) + projectId := servercommon.GetStringConfig(ProjectIDEnvVar) + toPatch := map[string]string{ + "{{kfp-default-bucket}}": defaultBucket, + "{{kfp-project-id}}": projectId, + } + for key, value := range toPatch { + text = strings.Replace(text, key, value, -1) + } + return text, nil +} + +// Patch the system-specified default parameters if available. +func OverrideParameterWithSystemDefault(workflow util.Workflow, apiRun *api.Run) error { + // Patch the default value to workflow spec. + if servercommon.GetBoolConfigWithDefault(HasDefaultBucketEnvVar, false) { + patchedSlice := make([]wfv1.Parameter, 0) + for _, currentParam := range workflow.Spec.Arguments.Parameters { + desiredValue, err := PatchPipelineDefaultParameter(*currentParam.Value) + if err != nil { + return fmt.Errorf("failed to patch default value to pipeline. Error: %v", err) + } + patchedSlice = append(patchedSlice, wfv1.Parameter{ + Name: currentParam.Name, + Value: util.StringPointer(desiredValue), + }) + } + workflow.Spec.Arguments.Parameters = patchedSlice + + // Patched the default value to apiRun + for _, param := range apiRun.PipelineSpec.Parameters { + var err error + param.Value, err = PatchPipelineDefaultParameter(param.Value) + if err != nil { + return fmt.Errorf("failed to patch default value to pipeline. Error: %v", err) + } + } + } + return nil +} diff --git a/backend/src/apiserver/server/run_server.go b/backend/src/apiserver/server/run_server.go index 2c0364f8b89..ef4541fff41 100644 --- a/backend/src/apiserver/server/run_server.go +++ b/backend/src/apiserver/server/run_server.go @@ -16,7 +16,6 @@ package server import ( "context" - "fmt" "github.com/golang/protobuf/ptypes/empty" api "github.com/kubeflow/pipelines/backend/api/go_client" @@ -31,23 +30,7 @@ type RunServer struct { resourceManager *resource.ResourceManager } -const ( - HasDefaultBucketEnvVar = "HAS_DEFAULT_BUCKET" - ProjectIDEnvVar = "PROJECT_ID" - DefaultBucketNameEnvVar = "BUCKET_NAME" -) - func (s *RunServer) CreateRun(ctx context.Context, request *api.CreateRunRequest) (*api.RunDetail, error) { - // Patch default values - for _, param := range request.Run.PipelineSpec.Parameters { - if common.GetBoolConfigWithDefault(HasDefaultBucketEnvVar, false) { - var err error - param.Value, err = PatchPipelineDefaultParameter(param.Value) - if err != nil { - return nil, fmt.Errorf("failed to patch default value to pipeline. Error: %v", err) - } - } - } err := s.validateCreateRunRequest(request) if err != nil { return nil, util.Wrap(err, "Validate create run request failed.") diff --git a/backend/src/apiserver/server/run_server_test.go b/backend/src/apiserver/server/run_server_test.go index caef700f601..a01644674b7 100644 --- a/backend/src/apiserver/server/run_server_test.go +++ b/backend/src/apiserver/server/run_server_test.go @@ -64,9 +64,6 @@ func TestCreateRun(t *testing.T) { func TestCreateRunPatch(t *testing.T) { clients, manager, experiment := initWithExperiment(t) - viper.Set(HasDefaultBucketEnvVar, "true") - viper.Set(ProjectIDEnvVar, "test-project-id") - viper.Set(DefaultBucketNameEnvVar, "test-default-bucket") defer clients.Close() server := NewRunServer(manager) run := &api.Run{ @@ -75,8 +72,8 @@ func TestCreateRunPatch(t *testing.T) { PipelineSpec: &api.PipelineSpec{ WorkflowManifest: testWorkflowPatch.ToStringForStore(), Parameters: []*api.Parameter{ - {Name: "param1", Value: "{{kfp-default-bucket}}"}, - {Name: "param2", Value: "{{kfp-project-id}}"}}, + {Name: "param1", Value: "test-default-bucket"}, + {Name: "param2", Value: "test-project-id"}}, }, } runDetail, err := server.CreateRun(nil, &api.CreateRunRequest{Run: run}) diff --git a/backend/src/apiserver/server/util.go b/backend/src/apiserver/server/util.go index 2bf36d20fe4..94938693aab 100644 --- a/backend/src/apiserver/server/util.go +++ b/backend/src/apiserver/server/util.go @@ -182,22 +182,6 @@ func ReadPipelineFile(fileName string, fileReader io.Reader, maxFileLength int) return processedFile, nil } -// Mutate default values of specified pipeline spec. -// Args: -// text: (part of) pipeline file in string. -func PatchPipelineDefaultParameter(text string) (string, error) { - defaultBucket := common.GetStringConfig(DefaultBucketNameEnvVar) - projectId := common.GetStringConfig(ProjectIDEnvVar) - toPatch := map[string]string{ - "{{kfp-default-bucket}}": defaultBucket, - "{{kfp-project-id}}": projectId, - } - for key, value := range toPatch { - text = strings.Replace(text, key, value, -1) - } - return text, nil -} - func printParameters(params []*api.Parameter) string { var s strings.Builder for _, p := range params { diff --git a/backend/src/common/util/workflow.go b/backend/src/common/util/workflow.go index 5d3f6c53f54..417ad8d9859 100644 --- a/backend/src/common/util/workflow.go +++ b/backend/src/common/util/workflow.go @@ -15,6 +15,8 @@ package util import ( + "strings" + workflowapi "github.com/argoproj/argo/pkg/apis/workflow/v1alpha1" "github.com/golang/glog" swfregister "github.com/kubeflow/pipelines/backend/src/crd/pkg/apis/scheduledworkflow" @@ -22,7 +24,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/json" - "strings" ) // Workflow is a type to help manipulate Workflow objects.