diff --git a/backend/src/agent/persistence/client/pipeline_client.go b/backend/src/agent/persistence/client/pipeline_client.go index 25359933615..b40b45f7e7b 100644 --- a/backend/src/agent/persistence/client/pipeline_client.go +++ b/backend/src/agent/persistence/client/pipeline_client.go @@ -55,15 +55,20 @@ func NewPipelineClient( basePath string, mlPipelineServiceName string, mlPipelineServiceHttpPort string, - mlPipelineServiceGRPCPort string) (*PipelineClient, error) { + mlPipelineServiceGRPCPort string, + mlPipelineServiceTLSEnabled bool) (*PipelineClient, error) { httpAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceHttpPort) grpcAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceGRPCPort) - err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress) + scheme := "http" + if mlPipelineServiceTLSEnabled { + scheme = "https" + } + err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress, scheme) if err != nil { return nil, errors.Wrapf(err, "Failed to initialize pipeline client. Error: %s", err.Error()) } - connection, err := util.GetRpcConnection(grpcAddress) + connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled) if err != nil { return nil, errors.Wrapf(err, "Failed to get RPC connection. Error: %s", err.Error()) diff --git a/backend/src/agent/persistence/main.go b/backend/src/agent/persistence/main.go index 4da32a7095e..f3510bce3d3 100644 --- a/backend/src/agent/persistence/main.go +++ b/backend/src/agent/persistence/main.go @@ -16,6 +16,7 @@ package main import ( "flag" + "strconv" "time" "github.com/kubeflow/pipelines/backend/src/agent/persistence/client" @@ -29,21 +30,22 @@ import ( ) var ( - masterURL string - kubeconfig string - initializeTimeout time.Duration - timeout time.Duration - mlPipelineAPIServerName string - mlPipelineAPIServerPort string - mlPipelineAPIServerBasePath string - mlPipelineServiceHttpPort string - mlPipelineServiceGRPCPort string - namespace string - ttlSecondsAfterWorkflowFinish int64 - numWorker int - clientQPS float64 - clientBurst int - saTokenRefreshIntervalInSecs int64 + masterURL string + kubeconfig string + initializeTimeout time.Duration + timeout time.Duration + mlPipelineAPIServerName string + mlPipelineAPIServerPort string + mlPipelineAPIServerBasePath string + mlPipelineServiceHttpPort string + mlPipelineServiceGRPCPort string + mlPipelineServiceTLSEnabledStr string + namespace string + ttlSecondsAfterWorkflowFinish int64 + numWorker int + clientQPS float64 + clientBurst int + saTokenRefreshIntervalInSecs int64 ) const ( @@ -55,6 +57,7 @@ const ( mlPipelineAPIServerNameFlagName = "mlPipelineAPIServerName" mlPipelineAPIServerHttpPortFlagName = "mlPipelineServiceHttpPort" mlPipelineAPIServerGRPCPortFlagName = "mlPipelineServiceGRPCPort" + mlPipelineAPIServerTLSEnabled = "mlPipelineServiceTLSEnabled" namespaceFlagName = "namespace" ttlSecondsAfterWorkflowFinishFlagName = "ttlSecondsAfterWorkflowFinish" numWorkerName = "numWorker" @@ -102,6 +105,11 @@ func main() { log.Fatalf("Error starting Service Account Token Refresh Ticker due to: %v", err) } + mlPipelineServiceTLSEnabled, err := strconv.ParseBool(mlPipelineServiceTLSEnabledStr) + if err != nil { + log.Fatalf("Error parsing boolean flag %s, please provide a valid bool value (true/false). %v", mlPipelineAPIServerTLSEnabled, err) + } + pipelineClient, err := client.NewPipelineClient( initializeTimeout, timeout, @@ -109,7 +117,8 @@ func main() { mlPipelineAPIServerBasePath, mlPipelineAPIServerName, mlPipelineServiceHttpPort, - mlPipelineServiceGRPCPort) + mlPipelineServiceGRPCPort, + mlPipelineServiceTLSEnabled) if err != nil { log.Fatalf("Error creating ML pipeline API Server client: %v", err) } @@ -136,6 +145,7 @@ func init() { flag.StringVar(&mlPipelineAPIServerName, mlPipelineAPIServerNameFlagName, "ml-pipeline", "Name of the ML pipeline API server.") flag.StringVar(&mlPipelineServiceHttpPort, mlPipelineAPIServerHttpPortFlagName, "8888", "Http Port of the ML pipeline API server.") flag.StringVar(&mlPipelineServiceGRPCPort, mlPipelineAPIServerGRPCPortFlagName, "8887", "GRPC Port of the ML pipeline API server.") + flag.StringVar(&mlPipelineServiceTLSEnabledStr, mlPipelineAPIServerTLSEnabled, "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').") flag.StringVar(&mlPipelineAPIServerBasePath, mlPipelineAPIServerBasePathFlagName, "/apis/v1beta1", "The base path for the ML pipeline API server.") flag.StringVar(&namespace, namespaceFlagName, "", "The namespace name used for Kubernetes informers to obtain the listers.") diff --git a/backend/src/apiserver/main.go b/backend/src/apiserver/main.go index 3e1f828b40c..597144c873b 100644 --- a/backend/src/apiserver/main.go +++ b/backend/src/apiserver/main.go @@ -16,10 +16,12 @@ package main import ( "context" + "crypto/tls" "encoding/json" "flag" "fmt" "github.com/kubeflow/pipelines/backend/src/apiserver/client" + "google.golang.org/grpc/credentials" "io" "io/ioutil" "math" @@ -52,21 +54,49 @@ var ( httpPortFlag = flag.String("httpPortFlag", ":8888", "Http Proxy Port") configPath = flag.String("config", "", "Path to JSON file containing config") sampleConfigPath = flag.String("sampleconfig", "", "Path to samples") + tlsCertPath = flag.String("tlsCertPath", "", "Path to the public tls cert.") + tlsCertKeyPath = flag.String("tlsCertKeyPath", "", "Path to the private tls key cert.") collectMetricsFlag = flag.Bool("collectMetricsFlag", true, "Whether to collect Prometheus metrics in API server.") ) type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error +func initCerts() (*tls.Config, error) { + if *tlsCertPath == "" && *tlsCertKeyPath == "" { + // User can choose not to provide certs + return nil, nil + } else if *tlsCertPath == "" { + return nil, fmt.Errorf("Missing tlsCertPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.") + } else if *tlsCertKeyPath == "" { + return nil, fmt.Errorf("Missing tlsCertKeyPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.") + } + serverCert, err := tls.LoadX509KeyPair(*tlsCertPath, *tlsCertKeyPath) + if err != nil { + return nil, err + } + config := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } + glog.Info("TLS cert key/pair loaded.") + return config, err +} + func main() { flag.Parse() initConfig() clientManager := cm.NewClientManager() + + tlsConfig, err := initCerts() + if err != nil { + glog.Fatalf("Failed to parse Cert paths. Err: %v", err) + } + resourceManager := resource.NewResourceManager( &clientManager, &resource.ResourceManagerOptions{CollectMetrics: *collectMetricsFlag}, ) - err := loadSamples(resourceManager) + err = loadSamples(resourceManager) if err != nil { glog.Fatalf("Failed to load samples. Err: %v", err) } @@ -78,8 +108,8 @@ func main() { } } - go startRpcServer(resourceManager) - startHttpProxy(resourceManager) + go startRpcServer(resourceManager, tlsConfig) + startHttpProxy(resourceManager, tlsConfig) clientManager.Close() } @@ -93,13 +123,25 @@ func grpcCustomMatcher(key string) (string, bool) { return strings.ToLower(key), false } -func startRpcServer(resourceManager *resource.ResourceManager) { - glog.Info("Starting RPC server") +func startRpcServer(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) { + var s *grpc.Server + if tlsConfig != nil { + glog.Info("Starting RPC server (TLS enabled)") + tlsCredentials := credentials.NewTLS(tlsConfig) + s = grpc.NewServer( + grpc.Creds(tlsCredentials), + grpc.UnaryInterceptor(apiServerInterceptor), + grpc.MaxRecvMsgSize(math.MaxInt32), + ) + } else { + glog.Info("Starting RPC server") + s = grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32)) + } + listener, err := net.Listen("tcp", *rpcPortFlag) if err != nil { glog.Fatalf("Failed to start RPC server: %v", err) } - s := grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32)) sharedExperimentServer := server.NewExperimentServer(resourceManager, &server.ExperimentServerOptions{CollectMetrics: *collectMetricsFlag}) sharedPipelineServer := server.NewPipelineServer( @@ -141,8 +183,7 @@ func startRpcServer(resourceManager *resource.ResourceManager) { glog.Info("RPC server started") } -func startHttpProxy(resourceManager *resource.ResourceManager) { - glog.Info("Starting Http Proxy") +func startHttpProxy(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) @@ -150,21 +191,21 @@ func startHttpProxy(resourceManager *resource.ResourceManager) { // Create gRPC HTTP MUX and register services for v1beta1 api. runtimeMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(grpcCustomMatcher)) - registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux) + registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux, tlsConfig) // Create gRPC HTTP MUX and register services for v2beta1 api. - registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux) - registerHttpHandlerFromEndpoint(apiv2beta1.RegisterArtifactServiceHandlerFromEndpoint, "ArtifactService", ctx, runtimeMux) + registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig) + registerHttpHandlerFromEndpoint(apiv2beta1.RegisterArtifactServiceHandlerFromEndpoint, "ArtifactService", ctx, runtimeMux, tlsConfig) // Create a top level mux to include both pipeline upload server and gRPC servers. topMux := mux.NewRouter() @@ -197,13 +238,35 @@ func startHttpProxy(resourceManager *resource.ResourceManager) { // Register a handler for Prometheus to poll. topMux.Handle("/metrics", promhttp.Handler()) - http.ListenAndServe(*httpPortFlag, topMux) + if tlsConfig != nil { + glog.Info("Starting Https Proxy") + https := http.Server{ + TLSConfig: tlsConfig, + Addr: *httpPortFlag, + Handler: topMux, + } + https.ListenAndServeTLS("", "") + } else { + glog.Info("Starting Http Proxy") + http.ListenAndServe(*httpPortFlag, topMux) + } + glog.Info("Http Proxy started") } -func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux) { +func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux, tlsConfig *tls.Config) { endpoint := "localhost" + *rpcPortFlag - opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))} + var opts []grpc.DialOption + if tlsConfig != nil { + // local client connections via http proxy to grpc should not require tls + tlsConfig.InsecureSkipVerify = true + opts = []grpc.DialOption{ + grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)), + } + } else { + opts = []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))} + } if err := handler(ctx, mux, endpoint, opts); err != nil { glog.Fatalf("Failed to register %v handler: %v", serviceName, err) diff --git a/backend/src/common/util/service.go b/backend/src/common/util/service.go index 92c036a31bd..8544963db3b 100644 --- a/backend/src/common/util/service.go +++ b/backend/src/common/util/service.go @@ -15,7 +15,10 @@ package util import ( + "crypto/tls" "fmt" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "net/http" "strings" "time" @@ -28,9 +31,9 @@ import ( "k8s.io/client-go/tools/clientcmd" ) -func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string) error { +func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string, scheme string) error { operation := func() error { - response, err := http.Get(fmt.Sprintf("http://%s%s/healthz", apiAddress, basePath)) + response, err := http.Get(fmt.Sprintf("%s://%s%s/healthz", scheme, apiAddress, basePath)) if err != nil { return err } @@ -74,8 +77,17 @@ func GetKubernetesClientFromClientConfig(clientConfig clientcmd.ClientConfig) ( return clientSet, config, namespace, nil } -func GetRpcConnection(address string) (*grpc.ClientConn, error) { - conn, err := grpc.Dial(address, grpc.WithInsecure()) +func GetRpcConnection(address string, tlsEnabled bool) (*grpc.ClientConn, error) { + creds := insecure.NewCredentials() + if tlsEnabled { + config := &tls.Config{} + creds = credentials.NewTLS(config) + } + + conn, err := grpc.Dial( + address, + grpc.WithTransportCredentials(creds), + ) if err != nil { return nil, errors.Wrapf(err, "Failed to create gRPC connection") } diff --git a/backend/src/v2/cacheutils/cache.go b/backend/src/v2/cacheutils/cache.go index 529d73aee9c..2ba2486d52a 100644 --- a/backend/src/v2/cacheutils/cache.go +++ b/backend/src/v2/cacheutils/cache.go @@ -3,9 +3,12 @@ package cacheutils import ( "context" "crypto/sha256" + "crypto/tls" "encoding/hex" "encoding/json" "fmt" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "os" "google.golang.org/grpc" @@ -111,10 +114,21 @@ type Client struct { } // NewClient creates a Client. -func NewClient() (*Client, error) { +func NewClient(mlPipelineServiceTLSEnabled bool) (*Client, error) { + creds := insecure.NewCredentials() + if mlPipelineServiceTLSEnabled { + config := &tls.Config{ + InsecureSkipVerify: false, + } + creds = credentials.NewTLS(config) + } cacheEndPoint := cacheDefaultEndpoint() glog.Infof("Connecting to cache endpoint %s", cacheEndPoint) - conn, err := grpc.Dial(cacheEndPoint, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)), grpc.WithInsecure()) + conn, err := grpc.Dial( + cacheEndPoint, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)), + grpc.WithTransportCredentials(creds), + ) if err != nil { return nil, fmt.Errorf("metadata.NewClient() failed: %w", err) } diff --git a/backend/src/v2/cmd/driver/main.go b/backend/src/v2/cmd/driver/main.go index 793ccfe1b80..98127c28446 100644 --- a/backend/src/v2/cmd/driver/main.go +++ b/backend/src/v2/cmd/driver/main.go @@ -68,6 +68,8 @@ var ( // the value stored in the paths will be either 'true' or 'false' cachedDecisionPath = flag.String("cached_decision_path", "", "Cached Decision output path") conditionPath = flag.String("condition_path", "", "Condition output path") + + mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').") ) // func RootDAG(pipelineName string, runID string, component *pipelinespec.ComponentSpec, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (*Execution, error) { @@ -147,18 +149,24 @@ func drive() (err error) { if err != nil { return err } - cacheClient, err := cacheutils.NewClient() + mlPipelineServiceTLSEnabled, err := strconv.ParseBool(*mlPipelineServiceTLSEnabledStr) + if err != nil { + return err + } + + cacheClient, err := cacheutils.NewClient(mlPipelineServiceTLSEnabled) if err != nil { return err } options := driver.Options{ - PipelineName: *pipelineName, - RunID: *runID, - Namespace: namespace, - Component: componentSpec, - Task: taskSpec, - DAGExecutionID: *dagExecutionID, - IterationIndex: *iterationIndex, + PipelineName: *pipelineName, + RunID: *runID, + Namespace: namespace, + Component: componentSpec, + Task: taskSpec, + DAGExecutionID: *dagExecutionID, + IterationIndex: *iterationIndex, + MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled, } var execution *driver.Execution var driverErr error diff --git a/backend/src/v2/cmd/launcher-v2/main.go b/backend/src/v2/cmd/launcher-v2/main.go index 8fb4e8d7625..3ac4245142f 100644 --- a/backend/src/v2/cmd/launcher-v2/main.go +++ b/backend/src/v2/cmd/launcher-v2/main.go @@ -19,6 +19,7 @@ import ( "context" "flag" "fmt" + "strconv" "github.com/golang/glog" "github.com/kubeflow/pipelines/backend/src/v2/component" @@ -27,20 +28,21 @@ import ( // TODO: use https://github.com/spf13/cobra as a framework to create more complex CLI tools with subcommands. var ( - copy = flag.String("copy", "", "copy this binary to specified destination path") - pipelineName = flag.String("pipeline_name", "", "pipeline context name") - runID = flag.String("run_id", "", "pipeline run uid") - parentDagID = flag.Int64("parent_dag_id", 0, "parent DAG execution ID") - executorType = flag.String("executor_type", "container", "The type of the ExecutorSpec") - executionID = flag.Int64("execution_id", 0, "Execution ID of this task.") - executorInputJSON = flag.String("executor_input", "", "The JSON-encoded ExecutorInput.") - componentSpecJSON = flag.String("component_spec", "", "The JSON-encoded ComponentSpec.") - importerSpecJSON = flag.String("importer_spec", "", "The JSON-encoded ImporterSpec.") - taskSpecJSON = flag.String("task_spec", "", "The JSON-encoded TaskSpec.") - podName = flag.String("pod_name", "", "Kubernetes Pod name.") - podUID = flag.String("pod_uid", "", "Kubernetes Pod UID.") - mlmdServerAddress = flag.String("mlmd_server_address", "", "The MLMD gRPC server address.") - mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.") + copy = flag.String("copy", "", "copy this binary to specified destination path") + pipelineName = flag.String("pipeline_name", "", "pipeline context name") + runID = flag.String("run_id", "", "pipeline run uid") + parentDagID = flag.Int64("parent_dag_id", 0, "parent DAG execution ID") + executorType = flag.String("executor_type", "container", "The type of the ExecutorSpec") + executionID = flag.Int64("execution_id", 0, "Execution ID of this task.") + executorInputJSON = flag.String("executor_input", "", "The JSON-encoded ExecutorInput.") + componentSpecJSON = flag.String("component_spec", "", "The JSON-encoded ComponentSpec.") + importerSpecJSON = flag.String("importer_spec", "", "The JSON-encoded ImporterSpec.") + taskSpecJSON = flag.String("task_spec", "", "The JSON-encoded TaskSpec.") + podName = flag.String("pod_name", "", "Kubernetes Pod name.") + podUID = flag.String("pod_uid", "", "Kubernetes Pod UID.") + mlmdServerAddress = flag.String("mlmd_server_address", "", "The MLMD gRPC server address.") + mlmdServerPort = flag.String("mlmd_server_port", "8080", "The MLMD gRPC server port.") + mlPipelineServiceTLSEnabledStr = flag.String("mlPipelineServiceTLSEnabled", "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').") ) func main() { @@ -64,14 +66,20 @@ func run() error { if err != nil { return err } + + mlPipelineServiceTLSEnabled, err := strconv.ParseBool(*mlPipelineServiceTLSEnabledStr) + if err != nil { + return err + } launcherV2Opts := &component.LauncherV2Options{ - Namespace: namespace, - PodName: *podName, - PodUID: *podUID, - MLMDServerAddress: *mlmdServerAddress, - MLMDServerPort: *mlmdServerPort, - PipelineName: *pipelineName, - RunID: *runID, + Namespace: namespace, + PodName: *podName, + PodUID: *podUID, + MLMDServerAddress: *mlmdServerAddress, + MLMDServerPort: *mlmdServerPort, + PipelineName: *pipelineName, + RunID: *runID, + MLPipelineTLSEnabled: mlPipelineServiceTLSEnabled, } switch *executorType { diff --git a/backend/src/v2/compiler/argocompiler/argo.go b/backend/src/v2/compiler/argocompiler/argo.go index a5cfed5faef..c8dca58bef1 100644 --- a/backend/src/v2/compiler/argocompiler/argo.go +++ b/backend/src/v2/compiler/argocompiler/argo.go @@ -122,6 +122,13 @@ func Compile(jobArg *pipelinespec.PipelineJob, kubernetesSpecArg *pipelinespec.S spec: spec, executors: deploy.GetExecutors(), } + + mlPipelineTLSEnabled, err := GetMLPipelineServiceTLSEnabled() + if err != nil { + return nil, err + } + c.mlPipelineServiceTLSEnabled = mlPipelineTLSEnabled + if opts != nil { if opts.DriverImage != "" { c.driverImage = opts.DriverImage @@ -151,10 +158,11 @@ type workflowCompiler struct { spec *pipelinespec.PipelineSpec executors map[string]*pipelinespec.PipelineDeploymentConfig_ExecutorSpec // state - wf *wfapi.Workflow - templates map[string]*wfapi.Template - driverImage string - launcherImage string + wf *wfapi.Workflow + templates map[string]*wfapi.Template + driverImage string + launcherImage string + mlPipelineServiceTLSEnabled bool } func (c *workflowCompiler) Resolver(name string, component *pipelinespec.ComponentSpec, resolver *pipelinespec.PipelineDeploymentConfig_ResolverSpec) error { diff --git a/backend/src/v2/compiler/argocompiler/common.go b/backend/src/v2/compiler/argocompiler/common.go index 2d203fc7acb..75684510511 100644 --- a/backend/src/v2/compiler/argocompiler/common.go +++ b/backend/src/v2/compiler/argocompiler/common.go @@ -14,7 +14,23 @@ package argocompiler -import k8score "k8s.io/api/core/v1" +import ( + "fmt" + wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + k8score "k8s.io/api/core/v1" + "os" + "strconv" + "strings" +) + +const ( + DefaultMLPipelineServiceHost = "ml-pipeline.kubeflow.svc.cluster.local" + DefaultMLPipelineServicePortGRPC = "8887" + MLPipelineServiceHostEnvVar = "ML_PIPELINE_SERVICE_HOST" + MLPipelineServicePortGRPCEnvVar = "ML_PIPELINE_SERVICE_PORT_GRPC" + MLPipelineTLSEnabledEnvVar = "ML_PIPELINE_TLS_ENABLED" + DefaultMLPipelineTLSEnabled = false +) // env vars in metadata-grpc-configmap is defined in component package var metadataConfigIsOptional bool = true @@ -42,3 +58,105 @@ var commonEnvs = []k8score.EnvVar{{ }, }, }} + +var MLPipelineServiceEnv = []k8score.EnvVar{{ + Name: "ML_PIPELINE_SERVICE_HOST", + Value: GetMLPipelineServiceHost(), +}, { + Name: "ML_PIPELINE_SERVICE_PORT_GRPC", + Value: GetMLPipelineServicePortGRPC(), +}} + +func GetMLPipelineServiceTLSEnabled() (bool, error) { + mlPipelineServiceTLSEnabledStr := os.Getenv(MLPipelineTLSEnabledEnvVar) + if mlPipelineServiceTLSEnabledStr == "" { + return DefaultMLPipelineTLSEnabled, nil + } + mlPipelineServiceTLSEnabled, err := strconv.ParseBool(os.Getenv(MLPipelineTLSEnabledEnvVar)) + if err != nil { + return false, err + } + return mlPipelineServiceTLSEnabled, nil +} + +func GetMLPipelineServiceHost() string { + mlPipelineServiceHost := os.Getenv(MLPipelineServiceHostEnvVar) + if mlPipelineServiceHost == "" { + return DefaultMLPipelineServiceHost + } + return mlPipelineServiceHost +} + +func GetMLPipelineServicePortGRPC() string { + mlPipelineServicePortGRPC := os.Getenv(MLPipelineServicePortGRPCEnvVar) + if mlPipelineServicePortGRPC == "" { + return DefaultMLPipelineServicePortGRPC + } + return mlPipelineServicePortGRPC +} + +// ConfigureCABundle adds CABundle environment variables and volume mounts +// if CA Bundle env vars are specified. +func ConfigureCABundle(tmpl *wfapi.Template) { + caBundleCfgMapName := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_NAME") + caBundleCfgMapKey := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_KEY") + caBundleMountPath := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_MOUNTPATH") + if caBundleCfgMapName != "" && caBundleCfgMapKey != "" { + caFile := fmt.Sprintf("%s/%s", caBundleMountPath, caBundleCfgMapKey) + var certDirectories = []string{ + caBundleMountPath, + "/etc/ssl/certs", + "/etc/pki/tls/certs", + } + // Add to REQUESTS_CA_BUNDLE for python request library. + // As many python web based libraries utilize this, we add it here so the user + // does not have to manually include this in the user pipeline. + // Note: for packages like Boto3, even though it is documented to use AWS_CA_BUNDLE, + // we found the python boto3 client only works if we include REQUESTS_CA_BUNDLE. + // https://requests.readthedocs.io/en/latest/user/advanced/#ssl-cert-verification + // https://github.com/aws/aws-cli/issues/3425 + tmpl.Container.Env = append(tmpl.Container.Env, k8score.EnvVar{ + Name: "REQUESTS_CA_BUNDLE", + Value: caFile, + }) + // For AWS utilities like cli, and packages. + tmpl.Container.Env = append(tmpl.Container.Env, k8score.EnvVar{ + Name: "AWS_CA_BUNDLE", + Value: caFile, + }) + // OpenSSL default cert file env variable. + // Similar to AWS_CA_BUNDLE, the SSL_CERT_DIR equivalent for paths had unyielding + // results, even after rehashing. + // https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_default_verify_paths.html + tmpl.Container.Env = append(tmpl.Container.Env, k8score.EnvVar{ + Name: "SSL_CERT_FILE", + Value: caFile, + }) + sslCertDir := strings.Join(certDirectories, ":") + tmpl.Container.Env = append(tmpl.Container.Env, k8score.EnvVar{ + Name: "SSL_CERT_DIR", + Value: sslCertDir, + }) + volume := k8score.Volume{ + Name: volumeNameCABUndle, + VolumeSource: k8score.VolumeSource{ + ConfigMap: &k8score.ConfigMapVolumeSource{ + LocalObjectReference: k8score.LocalObjectReference{ + Name: caBundleCfgMapName, + }, + }, + }, + } + + tmpl.Volumes = append(tmpl.Volumes, volume) + + volumeMount := k8score.VolumeMount{ + Name: volumeNameCABUndle, + MountPath: caFile, + SubPath: caBundleCfgMapKey, + } + + tmpl.Container.VolumeMounts = append(tmpl.Container.VolumeMounts, volumeMount) + + } +} diff --git a/backend/src/v2/compiler/argocompiler/container.go b/backend/src/v2/compiler/argocompiler/container.go index 9d000a90211..7b12ca174d1 100644 --- a/backend/src/v2/compiler/argocompiler/container.go +++ b/backend/src/v2/compiler/argocompiler/container.go @@ -15,33 +15,31 @@ package argocompiler import ( - "fmt" - "os" - "strings" - wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" "github.com/kubeflow/pipelines/backend/src/v2/component" k8score "k8s.io/api/core/v1" + "os" + "strconv" ) const ( - volumeNameKFPLauncher = "kfp-launcher" - volumeNameCABUndle = "ca-bundle" - DefaultLauncherImage = "gcr.io/ml-pipeline/kfp-launcher@sha256:80cf120abd125db84fa547640fd6386c4b2a26936e0c2b04a7d3634991a850a4" - LauncherImageEnvVar = "V2_LAUNCHER_IMAGE" - DefaultDriverImage = "gcr.io/ml-pipeline/kfp-driver@sha256:8e60086b04d92b657898a310ca9757631d58547e76bbbb8bfc376d654bef1707" - DriverImageEnvVar = "V2_DRIVER_IMAGE" - gcsScratchLocation = "/gcs" - gcsScratchName = "gcs-scratch" - s3ScratchLocation = "/s3" - s3ScratchName = "s3-scratch" - minioScratchLocation = "/minio" - minioScratchName = "minio-scratch" - dotLocalScratchLocation = "/.local" - dotLocalScratchName = "dot-local-scratch" - dotCacheScratchLocation = "/.cache" - dotCacheScratchName = "dot-cache-scratch" + volumeNameKFPLauncher = "kfp-launcher" + volumeNameCABUndle = "ca-bundle" + DefaultLauncherImage = "gcr.io/ml-pipeline/kfp-launcher@sha256:80cf120abd125db84fa547640fd6386c4b2a26936e0c2b04a7d3634991a850a4" + LauncherImageEnvVar = "V2_LAUNCHER_IMAGE" + DefaultDriverImage = "gcr.io/ml-pipeline/kfp-driver@sha256:8e60086b04d92b657898a310ca9757631d58547e76bbbb8bfc376d654bef1707" + DriverImageEnvVar = "V2_DRIVER_IMAGE" + gcsScratchLocation = "/gcs" + gcsScratchName = "gcs-scratch" + s3ScratchLocation = "/s3" + s3ScratchName = "s3-scratch" + minioScratchLocation = "/minio" + minioScratchName = "minio-scratch" + dotLocalScratchLocation = "/.local" + dotLocalScratchName = "dot-local-scratch" + dotCacheScratchLocation = "/.cache" + dotCacheScratchName = "dot-cache-scratch" dotConfigScratchLocation = "/.config" dotConfigScratchName = "dot-config-scratch" ) @@ -150,6 +148,7 @@ func (c *workflowCompiler) addContainerDriverTemplate() string { Container: &k8score.Container{ Image: GetDriverImage(), Command: []string{"driver"}, + Env: MLPipelineServiceEnv, Args: []string{ "--type", "CONTAINER", "--pipeline_name", c.spec.GetPipelineInfo().GetName(), @@ -163,10 +162,14 @@ func (c *workflowCompiler) addContainerDriverTemplate() string { "--pod_spec_patch_path", outputPath(paramPodSpecPatch), "--condition_path", outputPath(paramCondition), "--kubernetes_config", inputValue(paramKubernetesConfig), + "--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled), }, Resources: driverResources, }, } + + ConfigureCABundle(t) + c.templates[name] = t c.wf.Spec.Templates = append(c.wf.Spec.Templates, *t) return name @@ -352,70 +355,10 @@ func (c *workflowCompiler) addContainerExecutorTemplate() string { }, }, EnvFrom: []k8score.EnvFromSource{metadataEnvFrom}, - Env: commonEnvs, + Env: append(commonEnvs, MLPipelineServiceEnv...), }, } - caBundleCfgMapName := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_NAME") - caBundleCfgMapKey := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_CONFIGMAP_KEY") - caBundleMountPath := os.Getenv("ARTIFACT_COPY_STEP_CABUNDLE_MOUNTPATH") - if caBundleCfgMapName != "" && caBundleCfgMapKey != "" { - caFile := fmt.Sprintf("%s/%s", caBundleMountPath, caBundleCfgMapKey) - var certDirectories = []string{ - caBundleMountPath, - "/etc/ssl/certs", - "/etc/pki/tls/certs", - } - // Add to REQUESTS_CA_BUNDLE for python request library. - // As many python web based libraries utilize this, we add it here so the user - // does not have to manually include this in the user pipeline. - // Note: for packages like Boto3, even though it is documented to use AWS_CA_BUNDLE, - // we found the python boto3 client only works if we include REQUESTS_CA_BUNDLE. - // https://requests.readthedocs.io/en/latest/user/advanced/#ssl-cert-verification - // https://github.com/aws/aws-cli/issues/3425 - executor.Container.Env = append(executor.Container.Env, k8score.EnvVar{ - Name: "REQUESTS_CA_BUNDLE", - Value: caFile, - }) - // For AWS utilities like cli, and packages. - executor.Container.Env = append(executor.Container.Env, k8score.EnvVar{ - Name: "AWS_CA_BUNDLE", - Value: caFile, - }) - // OpenSSL default cert file env variable. - // Similar to AWS_CA_BUNDLE, the SSL_CERT_DIR equivalent for paths had unyielding - // results, even after rehashing. - // https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_set_default_verify_paths.html - executor.Container.Env = append(executor.Container.Env, k8score.EnvVar{ - Name: "SSL_CERT_FILE", - Value: caFile, - }) - sslCertDir := strings.Join(certDirectories, ":") - executor.Container.Env = append(executor.Container.Env, k8score.EnvVar{ - Name: "SSL_CERT_DIR", - Value: sslCertDir, - }) - volume := k8score.Volume{ - Name: volumeNameCABUndle, - VolumeSource: k8score.VolumeSource{ - ConfigMap: &k8score.ConfigMapVolumeSource{ - LocalObjectReference: k8score.LocalObjectReference{ - Name: caBundleCfgMapName, - }, - }, - }, - } - - executor.Volumes = append(executor.Volumes, volume) - - volumeMount := k8score.VolumeMount{ - Name: volumeNameCABUndle, - MountPath: caFile, - SubPath: caBundleCfgMapKey, - } - - executor.Container.VolumeMounts = append(executor.Container.VolumeMounts, volumeMount) - - } + ConfigureCABundle(executor) c.templates[nameContainerImpl] = executor c.wf.Spec.Templates = append(c.wf.Spec.Templates, *container, *executor) return nameContainerExecutor diff --git a/backend/src/v2/compiler/argocompiler/dag.go b/backend/src/v2/compiler/argocompiler/dag.go index b334c4beb5f..36a239667e3 100644 --- a/backend/src/v2/compiler/argocompiler/dag.go +++ b/backend/src/v2/compiler/argocompiler/dag.go @@ -16,6 +16,7 @@ package argocompiler import ( "fmt" "sort" + "strconv" "strings" wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" @@ -428,6 +429,7 @@ func (c *workflowCompiler) addDAGDriverTemplate() string { Container: &k8score.Container{ Image: c.driverImage, Command: []string{"driver"}, + Env: MLPipelineServiceEnv, Args: []string{ "--type", inputValue(paramDriverType), "--pipeline_name", c.spec.GetPipelineInfo().GetName(), @@ -440,10 +442,12 @@ func (c *workflowCompiler) addDAGDriverTemplate() string { "--execution_id_path", outputPath(paramExecutionID), "--iteration_count_path", outputPath(paramIterationCount), "--condition_path", outputPath(paramCondition), + "--mlPipelineServiceTLSEnabled", strconv.FormatBool(c.mlPipelineServiceTLSEnabled), }, Resources: driverResources, }, } + ConfigureCABundle(t) c.templates[name] = t c.wf.Spec.Templates = append(c.wf.Spec.Templates, *t) return name diff --git a/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml b/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml index e3b427d2455..d4cd73085df 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/create_mount_delete_dynamic_pvc.yaml @@ -65,10 +65,17 @@ spec: - '{{outputs.parameters.condition.path}}' - --kubernetes_config - '{{inputs.parameters.kubernetes-config}}' + - "--mlPipelineServiceTLSEnabled" + - "false" command: - driver image: gcr.io/ml-pipeline/kfp-driver name: "" + env: + - name: ML_PIPELINE_SERVICE_HOST + value: ml-pipeline.kubeflow.svc.cluster.local + - name: ML_PIPELINE_SERVICE_PORT_GRPC + value: '8887' resources: limits: cpu: 500m @@ -132,6 +139,10 @@ spec: valueFrom: fieldRef: fieldPath: metadata.uid + - name: ML_PIPELINE_SERVICE_HOST + value: ml-pipeline.kubeflow.svc.cluster.local + - name: ML_PIPELINE_SERVICE_PORT_GRPC + value: '8887' envFrom: - configMapRef: name: metadata-grpc-configmap @@ -299,10 +310,17 @@ spec: - '{{outputs.parameters.iteration-count.path}}' - --condition_path - '{{outputs.parameters.condition.path}}' + - "--mlPipelineServiceTLSEnabled" + - "false" command: - driver image: gcr.io/ml-pipeline/kfp-driver name: "" + env: + - name: ML_PIPELINE_SERVICE_HOST + value: ml-pipeline.kubeflow.svc.cluster.local + - name: ML_PIPELINE_SERVICE_PORT_GRPC + value: '8887' resources: limits: cpu: 500m diff --git a/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml b/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml index 5685ece5de5..e285ad07188 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/hello_world.yaml @@ -48,9 +48,16 @@ spec: - '{{outputs.parameters.condition.path}}' - --kubernetes_config - '{{inputs.parameters.kubernetes-config}}' + - "--mlPipelineServiceTLSEnabled" + - "false" command: - driver image: gcr.io/ml-pipeline/kfp-driver + env: + - name: ML_PIPELINE_SERVICE_HOST + value: ml-pipeline.kubeflow.svc.cluster.local + - name: ML_PIPELINE_SERVICE_PORT_GRPC + value: '8887' name: "" resources: limits: @@ -115,6 +122,10 @@ spec: valueFrom: fieldRef: fieldPath: metadata.uid + - name: ML_PIPELINE_SERVICE_HOST + value: ml-pipeline.kubeflow.svc.cluster.local + - name: ML_PIPELINE_SERVICE_PORT_GRPC + value: '8887' envFrom: - configMapRef: name: metadata-grpc-configmap @@ -229,6 +240,13 @@ spec: - '{{outputs.parameters.iteration-count.path}}' - --condition_path - '{{outputs.parameters.condition.path}}' + - "--mlPipelineServiceTLSEnabled" + - "false" + env: + - name: ML_PIPELINE_SERVICE_HOST + value: ml-pipeline.kubeflow.svc.cluster.local + - name: ML_PIPELINE_SERVICE_PORT_GRPC + value: '8887' command: - driver image: gcr.io/ml-pipeline/kfp-driver diff --git a/backend/src/v2/compiler/argocompiler/testdata/importer.yaml b/backend/src/v2/compiler/argocompiler/testdata/importer.yaml index d0e6ef6eaae..0e2d30a12b2 100644 --- a/backend/src/v2/compiler/argocompiler/testdata/importer.yaml +++ b/backend/src/v2/compiler/argocompiler/testdata/importer.yaml @@ -118,10 +118,17 @@ spec: - '{{outputs.parameters.iteration-count.path}}' - --condition_path - '{{outputs.parameters.condition.path}}' + - "--mlPipelineServiceTLSEnabled" + - "false" command: - driver image: gcr.io/ml-pipeline/kfp-driver name: "" + env: + - name: ML_PIPELINE_SERVICE_HOST + value: ml-pipeline.kubeflow.svc.cluster.local + - name: ML_PIPELINE_SERVICE_PORT_GRPC + value: '8887' resources: limits: cpu: 500m diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index 9a3b3824488..b7682d5a4e5 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -52,6 +52,8 @@ type LauncherV2Options struct { MLMDServerPort, PipelineName, RunID string + // set to true if ml pipeline server is serving over tls + MLPipelineTLSEnabled bool } type LauncherV2 struct { @@ -112,7 +114,7 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co if err != nil { return nil, err } - cacheClient, err := cacheutils.NewClient() + cacheClient, err := cacheutils.NewClient(opts.MLPipelineTLSEnabled) if err != nil { return nil, err } diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 1433cd33b49..b2f0e15c6a0 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -74,6 +74,9 @@ type Options struct { // optional, allows to specify kubernetes-specific executor config KubernetesExecutorConfig *kubernetesplatform.KubernetesExecutorConfig + + // set to true if ml pipeline server is serving over tls + MLPipelineTLSEnabled bool } // Identifying information used for error messages @@ -336,7 +339,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl return execution, nil } - podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID) + podSpec, err := initPodSpecPatch(opts.Container, opts.Component, executorInput, execution.ID, opts.PipelineName, opts.RunID, opts.MLPipelineTLSEnabled) if err != nil { return execution, err } @@ -369,6 +372,7 @@ func initPodSpecPatch( executionID int64, pipelineName string, runID string, + mlPipelineTLSEnabled bool, ) (*k8score.PodSpec, error) { executorInputJSON, err := protojson.Marshal(executorInput) if err != nil { @@ -407,6 +411,8 @@ func initPodSpecPatch( fmt.Sprintf("$(%s)", component.EnvMetadataHost), "--mlmd_server_port", fmt.Sprintf("$(%s)", component.EnvMetadataPort), + "--mlPipelineServiceTLSEnabled", + fmt.Sprintf("%v", mlPipelineTLSEnabled), "--", // separater before user command and args } res := k8score.ResourceRequirements{ diff --git a/backend/src/v2/driver/driver_test.go b/backend/src/v2/driver/driver_test.go index f95e67cf7ca..34ed4d13bb3 100644 --- a/backend/src/v2/driver/driver_test.go +++ b/backend/src/v2/driver/driver_test.go @@ -241,7 +241,7 @@ func Test_initPodSpecPatch_acceleratorConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID) + podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false) if tt.wantErr { assert.Nil(t, podSpec) assert.NotNil(t, err) @@ -403,7 +403,7 @@ func Test_initPodSpecPatch_resourceRequests(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID) + podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID, false) assert.Nil(t, err) assert.NotEmpty(t, podSpec) podSpecString, err := json.Marshal(podSpec) @@ -530,7 +530,7 @@ func Test_extendPodSpecPatch_Secret(t *testing.T) { { Name: "secret1", VolumeSource: k8score.VolumeSource{ - Secret: &k8score.SecretVolumeSource{SecretName: "secret1", Optional: &[]bool{false}[0],}, + Secret: &k8score.SecretVolumeSource{SecretName: "secret1", Optional: &[]bool{false}[0]}, }, }, }, @@ -728,7 +728,7 @@ func Test_extendPodSpecPatch_ConfigMap(t *testing.T) { VolumeSource: k8score.VolumeSource{ ConfigMap: &k8score.ConfigMapVolumeSource{ LocalObjectReference: k8score.LocalObjectReference{Name: "cm1"}, - Optional: &[]bool{false}[0],}, + Optional: &[]bool{false}[0]}, }, }, }, diff --git a/frontend/server/configs.ts b/frontend/server/configs.ts index c2d3ef30c15..d58f6ed9d52 100644 --- a/frontend/server/configs.ts +++ b/frontend/server/configs.ts @@ -75,6 +75,8 @@ export function loadConfigs(argv: string[], env: ProcessEnv): UIConfigs { ML_PIPELINE_SERVICE_HOST = 'localhost', /** API service will listen to this port */ ML_PIPELINE_SERVICE_PORT = '3001', + /** API service will listen via this transfer protocol */ + ML_PIPELINE_SERVICE_SCHEME = "http", /** path to viewer:tensorboard pod template spec */ VIEWER_TENSORBOARD_POD_TEMPLATE_SPEC_PATH, /** Tensorflow image used for tensorboard viewer */ @@ -170,6 +172,7 @@ export function loadConfigs(argv: string[], env: ProcessEnv): UIConfigs { pipeline: { host: ML_PIPELINE_SERVICE_HOST, port: ML_PIPELINE_SERVICE_PORT, + schema: ML_PIPELINE_SERVICE_SCHEME, }, server: { apiVersion1Prefix, @@ -232,6 +235,7 @@ export interface HttpConfigs { export interface PipelineConfigs { host: string; port: string | number; + schema: string; } export interface ViewerTensorboardConfig { podTemplateSpec?: object; diff --git a/frontend/server/utils.ts b/frontend/server/utils.ts index 6d317473788..14cabb3512d 100644 --- a/frontend/server/utils.ts +++ b/frontend/server/utils.ts @@ -20,7 +20,7 @@ export function getAddress({ host, port, namespace, - schema = 'http', + schema, }: { host: string; port?: string | number;