Skip to content

Commit

Permalink
[Backend] Improve parameter patching (#3016)
Browse files Browse the repository at this point in the history
* update

* move patching logic

* update placeholders

* move to api converter

* refactor some tests

* fix main.go

* default remote build to false

* patch everything

* remove api converter patching

* clean up

* fix constant
  • Loading branch information
Jiaxiao Zheng authored Feb 10, 2020
1 parent 9b8e14c commit acb2038
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 49 deletions.
2 changes: 1 addition & 1 deletion backend/build_api_server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LONGOPTS=use_remote_build,gcp_credentials_file:
PARSED=$(getopt --longoptions=$LONGOPTS --options=$OPTS --name "$0" -- "$@")
eval set -- "$PARSED"

USE_REMOTE_BUILD=true
USE_REMOTE_BUILD=false
GCP_CREDENTIALS_FILE="gs://ml-pipeline-test-bazel/ml-pipeline-test-bazel-builder-credentials.json"
MACHINE_ARCH=`uname -m`

Expand Down
19 changes: 0 additions & 19 deletions backend/src/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ import (
"google.golang.org/grpc/reflection"
)

const (
HasDefaultBucketEnvVar = "HAS_DEFAULT_BUCKET"
ProjectIDEnvVar = "PROJECT_ID"
DefaultBucketNameEnvVar = "BUCKET_NAME"
)

var (
rpcPortFlag = flag.String("rpcPortFlag", ":8887", "RPC Port")
httpPortFlag = flag.String("httpPortFlag", ":8888", "Http Proxy Port")
Expand Down Expand Up @@ -196,19 +190,6 @@ func loadSamples(resourceManager *resource.ResourceManager) error {
if configErr != nil {
return fmt.Errorf("Failed to decompress the file %s. Error: %v", config.Name, configErr)
}
// Patch the default bucket name read from ConfigMap
if common.GetBoolConfigWithDefault(HasDefaultBucketEnvVar, false) {
defaultBucket := common.GetStringConfig(DefaultBucketNameEnvVar)
projectId := common.GetStringConfig(ProjectIDEnvVar)
patchMap := map[string]string{
"<your-gcs-bucket>": defaultBucket,
"<your-project-id>": projectId,
}
pipelineFile, err = server.PatchPipelineDefaultParameter(pipelineFile, patchMap)
if err != nil {
return fmt.Errorf("Failed to patch default value to %s. Error: %v", config.Name, err)
}
}
_, configErr = resourceManager.CreatePipeline(config.Name, config.Description, pipelineFile)
if configErr != nil {
// Log the error but not fail. The API Server pod can restart and it could potentially cause name collision.
Expand Down
17 changes: 17 additions & 0 deletions backend/src/apiserver/server/run_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package server

import (
"context"
"fmt"

"github.com/golang/protobuf/ptypes/empty"
api "github.com/kubeflow/pipelines/backend/api/go_client"
Expand All @@ -30,7 +31,23 @@ type RunServer struct {
resourceManager *resource.ResourceManager
}

const (
HasDefaultBucketEnvVar = "HAS_DEFAULT_BUCKET"
ProjectIDEnvVar = "PROJECT_ID"
DefaultBucketNameEnvVar = "BUCKET_NAME"
)

func (s *RunServer) CreateRun(ctx context.Context, request *api.CreateRunRequest) (*api.RunDetail, error) {
// Patch default values
for _, param := range request.Run.PipelineSpec.Parameters {
if common.GetBoolConfigWithDefault(HasDefaultBucketEnvVar, false) {
var err error
param.Value, err = PatchPipelineDefaultParameter(param.Value)
if err != nil {
return nil, fmt.Errorf("failed to patch default value to pipeline. Error: %v", err)
}
}
}
err := s.validateCreateRunRequest(request)
if err != nil {
return nil, util.Wrap(err, "Validate create run request failed.")
Expand Down
56 changes: 56 additions & 0 deletions backend/src/apiserver/server/run_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,62 @@ func TestCreateRun(t *testing.T) {
assert.Equal(t, expectedRunDetail, *runDetail)
}

func TestCreateRunPatch(t *testing.T) {
clients, manager, experiment := initWithExperiment(t)
viper.Set(HasDefaultBucketEnvVar, "true")
viper.Set(ProjectIDEnvVar, "test-project-id")
viper.Set(DefaultBucketNameEnvVar, "test-default-bucket")
defer clients.Close()
server := NewRunServer(manager)
run := &api.Run{
Name: "123",
ResourceReferences: validReference,
PipelineSpec: &api.PipelineSpec{
WorkflowManifest: testWorkflowPatch.ToStringForStore(),
Parameters: []*api.Parameter{
{Name: "param1", Value: "{{kfp-default-bucket}}"},
{Name: "param2", Value: "{{kfp-project-id}}"}},
},
}
runDetail, err := server.CreateRun(nil, &api.CreateRunRequest{Run: run})
assert.Nil(t, err)

expectedRuntimeWorkflow := testWorkflowPatch.DeepCopy()
expectedRuntimeWorkflow.Spec.Arguments.Parameters = []v1alpha1.Parameter{
{Name: "param1", Value: util.StringPointer("test-default-bucket")},
{Name: "param2", Value: util.StringPointer("test-project-id")},
}
expectedRuntimeWorkflow.Labels = map[string]string{util.LabelKeyWorkflowRunId: "123e4567-e89b-12d3-a456-426655440000"}
expectedRuntimeWorkflow.Annotations = map[string]string{util.AnnotationKeyRunName: "123"}
expectedRuntimeWorkflow.Spec.ServiceAccountName = "pipeline-runner"
expectedRunDetail := api.RunDetail{
Run: &api.Run{
Id: "123e4567-e89b-12d3-a456-426655440000",
Name: "123",
StorageState: api.Run_STORAGESTATE_AVAILABLE,
CreatedAt: &timestamp.Timestamp{Seconds: 2},
ScheduledAt: &timestamp.Timestamp{},
FinishedAt: &timestamp.Timestamp{},
PipelineSpec: &api.PipelineSpec{
WorkflowManifest: testWorkflowPatch.ToStringForStore(),
Parameters: []*api.Parameter{
{Name: "param1", Value: "test-default-bucket"},
{Name: "param2", Value: "test-project-id"}},
},
ResourceReferences: []*api.ResourceReference{
{
Key: &api.ResourceKey{Type: api.ResourceType_EXPERIMENT, Id: experiment.UUID},
Name: "123", Relationship: api.Relationship_OWNER,
},
},
},
PipelineRuntime: &api.PipelineRuntime{
WorkflowManifest: util.NewWorkflow(expectedRuntimeWorkflow).ToStringForStore(),
},
}
assert.Equal(t, expectedRunDetail, *runDetail)
}

func TestCreateRun_Unauthorized(t *testing.T) {
clients, manager, _ := initWithExperiment_KFAM_Unauthorized(t)
defer clients.Close()
Expand Down
6 changes: 6 additions & 0 deletions backend/src/apiserver/server/test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ var testWorkflow2 = util.NewWorkflow(&v1alpha1.Workflow{
Spec: v1alpha1.WorkflowSpec{Arguments: v1alpha1.Arguments{Parameters: []v1alpha1.Parameter{{Name: "param1"}}}},
})

var testWorkflowPatch = util.NewWorkflow(&v1alpha1.Workflow{
TypeMeta: v1.TypeMeta{APIVersion: "argoproj.io/v1alpha1", Kind: "Workflow"},
ObjectMeta: v1.ObjectMeta{Name: "workflow-name", UID: "workflow2"},
Spec: v1alpha1.WorkflowSpec{Arguments: v1alpha1.Arguments{Parameters: []v1alpha1.Parameter{{Name: "param1"}, {Name: "param2"}}}},
})

var validReference = []*api.ResourceReference{
{
Key: &api.ResourceKey{
Expand Down
18 changes: 11 additions & 7 deletions backend/src/apiserver/server/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,20 @@ func ReadPipelineFile(fileName string, fileReader io.Reader, maxFileLength int)
return processedFile, nil
}

// Mutate default values of specified pipeline file.
// Mutate default values of specified pipeline spec.
// Args:
// file: pipeline file in bytes.
// toPatch: mapping from the old value to its new value.
func PatchPipelineDefaultParameter(file []byte, toPatch map[string]string) ([]byte, error) {
pipelineRawString := string(file)
// text: (part of) pipeline file in string.
func PatchPipelineDefaultParameter(text string) (string, error) {
defaultBucket := common.GetStringConfig(DefaultBucketNameEnvVar)
projectId := common.GetStringConfig(ProjectIDEnvVar)
toPatch := map[string]string{
"{{kfp-default-bucket}}": defaultBucket,
"{{kfp-project-id}}": projectId,
}
for key, value := range toPatch {
pipelineRawString = strings.Replace(pipelineRawString, key, value, -1)
text = strings.Replace(text, key, value, -1)
}
return []byte(pipelineRawString), nil
return text, nil
}

func printParameters(params []*api.Parameter) string {
Expand Down
13 changes: 0 additions & 13 deletions backend/src/apiserver/server/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,6 @@ func TestReadPipelineFile_YAML(t *testing.T) {
assert.Equal(t, expectedFileBytes, fileBytes)
}

func TestParameterPatch(t *testing.T) {
file, _ := os.Open("test/arguments-parameters.yaml")
fileBytes, err := ReadPipelineFile("arguments-parameters.yaml", file, MaxFileLength)
patchMap := map[string]string{
"hello": "new-hello",
}
fileBytes, err = PatchPipelineDefaultParameter(fileBytes, patchMap)
assert.Nil(t, err)

expectedFileBytes, _ := ioutil.ReadFile("test/patched-arguments-parameters.yaml")
assert.Equal(t, expectedFileBytes, fileBytes)
}

func TestReadPipelineFile_Zip(t *testing.T) {
file, _ := os.Open("test/arguments_zip/arguments-parameters.zip")
pipelineFile, err := ReadPipelineFile("arguments-parameters.zip", file, MaxFileLength)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

# Path of pipeline root, should be a GCS path.
pipeline_root = os.path.join(
'gs://<your-gcs-bucket>', 'tfx_taxi_simple', kfp.dsl.RUN_ID_PLACEHOLDER
'gs://{{kfp-default-bucket}}', 'tfx_taxi_simple', kfp.dsl.RUN_ID_PLACEHOLDER
)


Expand Down
18 changes: 14 additions & 4 deletions samples/core/parameterized_tfx_oss/taxi_pipeline_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@
"# In TFX MLMD schema, pipeline name is used as the unique id of each pipeline.\n",
"# Assigning workflow ID as part of pipeline name allows the user to bypass\n",
"# some schema checks which are redundant for experimental pipelines.\n",
"pipeline_name = 'taxi_pipeline_with_parameters_' + kfp.dsl.RUN_ID_PLACEHOLDER\n",
"pipeline_name = 'taxi_pipeline_with_parameters'\n",
"\n",
"# Path of pipeline data root, should be a GCS path.\n",
"# Note that when running on KFP, the pipeline root is always a runtime parameter.\n",
"# The value specified here will be its default.\n",
"pipeline_root = os.path.join('gs://my-bucket', 'tfx_taxi_simple',\n",
"pipeline_root = os.path.join('gs://{{kfp-default-bucket}}', 'tfx_taxi_simple',\n",
" kfp.dsl.RUN_ID_PLACEHOLDER)\n",
"\n",
"# Location of input data, should be a GCS path under which there is a csv file.\n",
Expand Down Expand Up @@ -308,7 +308,8 @@
").create_run_from_pipeline_package(\n",
" pipeline_name + '.tar.gz', \n",
" arguments={\n",
" 'pipeline-root': 'gs://<your-gcs-bucket>/tfx_taxi_simple/' + kfp.dsl.RUN_ID_PLACEHOLDER,\n",
" # Uncomment following lines in order to use custom GCS bucket/module file/training data.\n",
" # 'pipeline-root': 'gs://<your-gcs-bucket>/tfx_taxi_simple/' + kfp.dsl.RUN_ID_PLACEHOLDER,\n",
" # 'module-file': '<gcs path to the module file>', # delete this line to use default module file.\n",
" # 'data-root': '<gcs path to the data>' # delete this line to use default data.\n",
"})"
Expand All @@ -332,8 +333,17 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5rc1"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"source": [],
"metadata": {
"collapsed": false
}
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}
8 changes: 4 additions & 4 deletions samples/core/xgboost_training_cm/xgboost_training_cm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def dataproc_train_op(
region=region,
cluster_name=cluster_name,
main_class=_TRAINER_MAIN_CLS,
spark_job=json.dumps({ 'jarFileUris': [_XGBOOST_PKG]}),
spark_job=json.dumps({'jarFileUris': [_XGBOOST_PKG]}),
args=json.dumps([
str(config),
str(rounds),
Expand Down Expand Up @@ -189,7 +189,7 @@ def dataproc_predict_op(
region=region,
cluster_name=cluster_name,
main_class=_PREDICTOR_MAIN_CLS,
spark_job=json.dumps({ 'jarFileUris': [_XGBOOST_PKG]}),
spark_job=json.dumps({'jarFileUris': [_XGBOOST_PKG]}),
args=json.dumps([
str(model),
str(data),
Expand All @@ -205,8 +205,8 @@ def dataproc_predict_op(
description='A trainer that does end-to-end distributed training for XGBoost models.'
)
def xgb_train_pipeline(
output='gs://<your-gcs-bucket>',
project='<your-project-id>',
output='gs://{{kfp-default-bucket}}',
project='{{kfp-project-id}}',
cluster_name='xgb-%s' % dsl.RUN_ID_PLACEHOLDER,
region='us-central1',
train_data='gs://ml-pipeline-playground/sfpd/train.csv',
Expand Down

0 comments on commit acb2038

Please sign in to comment.