Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add tls support for apiserver http/grpc #40

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions backend/src/agent/persistence/client/pipeline_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,20 @@ func NewPipelineClient(
basePath string,
mlPipelineServiceName string,
mlPipelineServiceHttpPort string,
mlPipelineServiceGRPCPort string) (*PipelineClient, error) {
mlPipelineServiceGRPCPort string,
mlPipelineServiceTLSEnabled bool) (*PipelineClient, error) {
httpAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceHttpPort)
grpcAddress := fmt.Sprintf(addressTemp, mlPipelineServiceName, mlPipelineServiceGRPCPort)
err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress)
scheme := "http"
if mlPipelineServiceTLSEnabled {
scheme = "https"
}
err := util.WaitForAPIAvailable(initializeTimeout, basePath, httpAddress, scheme)
if err != nil {
return nil, errors.Wrapf(err,
"Failed to initialize pipeline client. Error: %s", err.Error())
}
connection, err := util.GetRpcConnection(grpcAddress)
connection, err := util.GetRpcConnection(grpcAddress, mlPipelineServiceTLSEnabled)
if err != nil {
return nil, errors.Wrapf(err,
"Failed to get RPC connection. Error: %s", err.Error())
Expand Down
42 changes: 26 additions & 16 deletions backend/src/agent/persistence/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package main

import (
"flag"
"strconv"
"time"

"github.com/kubeflow/pipelines/backend/src/agent/persistence/client"
Expand All @@ -29,21 +30,22 @@ import (
)

var (
masterURL string
kubeconfig string
initializeTimeout time.Duration
timeout time.Duration
mlPipelineAPIServerName string
mlPipelineAPIServerPort string
mlPipelineAPIServerBasePath string
mlPipelineServiceHttpPort string
mlPipelineServiceGRPCPort string
namespace string
ttlSecondsAfterWorkflowFinish int64
numWorker int
clientQPS float64
clientBurst int
saTokenRefreshIntervalInSecs int64
masterURL string
kubeconfig string
initializeTimeout time.Duration
timeout time.Duration
mlPipelineAPIServerName string
mlPipelineAPIServerPort string
mlPipelineAPIServerBasePath string
mlPipelineServiceHttpPort string
mlPipelineServiceGRPCPort string
mlPipelineServiceTLSEnabledStr string
namespace string
ttlSecondsAfterWorkflowFinish int64
numWorker int
clientQPS float64
clientBurst int
saTokenRefreshIntervalInSecs int64
)

const (
Expand All @@ -55,6 +57,7 @@ const (
mlPipelineAPIServerNameFlagName = "mlPipelineAPIServerName"
mlPipelineAPIServerHttpPortFlagName = "mlPipelineServiceHttpPort"
mlPipelineAPIServerGRPCPortFlagName = "mlPipelineServiceGRPCPort"
mlPipelineAPIServerTLSEnabled = "mlPipelineServiceTLSEnabled"
namespaceFlagName = "namespace"
ttlSecondsAfterWorkflowFinishFlagName = "ttlSecondsAfterWorkflowFinish"
numWorkerName = "numWorker"
Expand Down Expand Up @@ -102,14 +105,20 @@ func main() {
log.Fatalf("Error starting Service Account Token Refresh Ticker due to: %v", err)
}

mlPipelineServiceTLSEnabled, err := strconv.ParseBool(mlPipelineServiceTLSEnabledStr)
if err != nil {
log.Fatalf("Error parsing boolean flag %s, please provide a valid bool value (true/false). %v", mlPipelineAPIServerTLSEnabled, err)
}

pipelineClient, err := client.NewPipelineClient(
initializeTimeout,
timeout,
tokenRefresher,
mlPipelineAPIServerBasePath,
mlPipelineAPIServerName,
mlPipelineServiceHttpPort,
mlPipelineServiceGRPCPort)
mlPipelineServiceGRPCPort,
mlPipelineServiceTLSEnabled)
if err != nil {
log.Fatalf("Error creating ML pipeline API Server client: %v", err)
}
Expand All @@ -136,6 +145,7 @@ func init() {
flag.StringVar(&mlPipelineAPIServerName, mlPipelineAPIServerNameFlagName, "ml-pipeline", "Name of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceHttpPort, mlPipelineAPIServerHttpPortFlagName, "8888", "Http Port of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceGRPCPort, mlPipelineAPIServerGRPCPortFlagName, "8887", "GRPC Port of the ML pipeline API server.")
flag.StringVar(&mlPipelineServiceTLSEnabledStr, mlPipelineAPIServerTLSEnabled, "false", "Set to 'true' if mlpipeline api server serves over TLS (default: 'false').")
flag.StringVar(&mlPipelineAPIServerBasePath, mlPipelineAPIServerBasePathFlagName,
"/apis/v1beta1", "The base path for the ML pipeline API server.")
flag.StringVar(&namespace, namespaceFlagName, "", "The namespace name used for Kubernetes informers to obtain the listers.")
Expand Down
111 changes: 87 additions & 24 deletions backend/src/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ package main

import (
"context"
"crypto/tls"
"encoding/json"
"flag"
"fmt"
"github.com/kubeflow/pipelines/backend/src/apiserver/client"
"google.golang.org/grpc/credentials"
"io"
"io/ioutil"
"math"
Expand Down Expand Up @@ -52,21 +54,49 @@ var (
httpPortFlag = flag.String("httpPortFlag", ":8888", "Http Proxy Port")
configPath = flag.String("config", "", "Path to JSON file containing config")
sampleConfigPath = flag.String("sampleconfig", "", "Path to samples")
tlsCertPath = flag.String("tlsCertPath", "", "Path to the public tls cert.")
tlsCertKeyPath = flag.String("tlsCertKeyPath", "", "Path to the private tls key cert.")
collectMetricsFlag = flag.Bool("collectMetricsFlag", true, "Whether to collect Prometheus metrics in API server.")
)

type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error

func initCerts() (*tls.Config, error) {
if *tlsCertPath == "" && *tlsCertKeyPath == "" {
// User can choose not to provide certs
return nil, nil
} else if *tlsCertPath == "" {
return nil, fmt.Errorf("Missing tlsCertPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.")
} else if *tlsCertKeyPath == "" {
return nil, fmt.Errorf("Missing tlsCertKeyPath when specifying cert paths, both tlsCertPath and tlsCertKeyPath are required.")
}
serverCert, err := tls.LoadX509KeyPair(*tlsCertPath, *tlsCertKeyPath)
if err != nil {
return nil, err
}
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
}
glog.Info("TLS cert key/pair loaded.")
return config, err
}

func main() {
flag.Parse()

initConfig()
clientManager := cm.NewClientManager()

tlsConfig, err := initCerts()
if err != nil {
glog.Fatalf("Failed to parse Cert paths. Err: %v", err)
}

resourceManager := resource.NewResourceManager(
&clientManager,
&resource.ResourceManagerOptions{CollectMetrics: *collectMetricsFlag},
)
err := loadSamples(resourceManager)
err = loadSamples(resourceManager)
if err != nil {
glog.Fatalf("Failed to load samples. Err: %v", err)
}
Expand All @@ -78,8 +108,8 @@ func main() {
}
}

go startRpcServer(resourceManager)
startHttpProxy(resourceManager)
go startRpcServer(resourceManager, tlsConfig)
startHttpProxy(resourceManager, tlsConfig)

clientManager.Close()
}
Expand All @@ -93,13 +123,25 @@ func grpcCustomMatcher(key string) (string, bool) {
return strings.ToLower(key), false
}

func startRpcServer(resourceManager *resource.ResourceManager) {
glog.Info("Starting RPC server")
func startRpcServer(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) {
var s *grpc.Server
if tlsConfig != nil {
glog.Info("Starting RPC server (TLS enabled)")
tlsCredentials := credentials.NewTLS(tlsConfig)
s = grpc.NewServer(
grpc.Creds(tlsCredentials),
grpc.UnaryInterceptor(apiServerInterceptor),
grpc.MaxRecvMsgSize(math.MaxInt32),
)
} else {
glog.Info("Starting RPC server")
s = grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32))
}

listener, err := net.Listen("tcp", *rpcPortFlag)
if err != nil {
glog.Fatalf("Failed to start RPC server: %v", err)
}
s := grpc.NewServer(grpc.UnaryInterceptor(apiServerInterceptor), grpc.MaxRecvMsgSize(math.MaxInt32))

sharedExperimentServer := server.NewExperimentServer(resourceManager, &server.ExperimentServerOptions{CollectMetrics: *collectMetricsFlag})
sharedPipelineServer := server.NewPipelineServer(
Expand Down Expand Up @@ -141,30 +183,29 @@ func startRpcServer(resourceManager *resource.ResourceManager) {
glog.Info("RPC server started")
}

func startHttpProxy(resourceManager *resource.ResourceManager) {
glog.Info("Starting Http Proxy")
func startHttpProxy(resourceManager *resource.ResourceManager, tlsConfig *tls.Config) {

ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

// Create gRPC HTTP MUX and register services for v1beta1 api.
runtimeMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(grpcCustomMatcher))
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig)
HumairAK marked this conversation as resolved.
Show resolved Hide resolved
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterTaskServiceHandlerFromEndpoint, "TaskService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv1beta1.RegisterAuthServiceHandlerFromEndpoint, "AuthService", ctx, runtimeMux, tlsConfig)

// Create gRPC HTTP MUX and register services for v2beta1 api.
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterArtifactServiceHandlerFromEndpoint, "ArtifactService", ctx, runtimeMux)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterExperimentServiceHandlerFromEndpoint, "ExperimentService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterPipelineServiceHandlerFromEndpoint, "PipelineService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRecurringRunServiceHandlerFromEndpoint, "RecurringRunService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, runtimeMux, tlsConfig)
registerHttpHandlerFromEndpoint(apiv2beta1.RegisterArtifactServiceHandlerFromEndpoint, "ArtifactService", ctx, runtimeMux, tlsConfig)

// Create a top level mux to include both pipeline upload server and gRPC servers.
topMux := mux.NewRouter()
Expand Down Expand Up @@ -197,13 +238,35 @@ func startHttpProxy(resourceManager *resource.ResourceManager) {
// Register a handler for Prometheus to poll.
topMux.Handle("/metrics", promhttp.Handler())

http.ListenAndServe(*httpPortFlag, topMux)
if tlsConfig != nil {
glog.Info("Starting Https Proxy")
https := http.Server{
TLSConfig: tlsConfig,
Addr: *httpPortFlag,
Handler: topMux,
}
https.ListenAndServeTLS("", "")
} else {
glog.Info("Starting Http Proxy")
http.ListenAndServe(*httpPortFlag, topMux)
}

glog.Info("Http Proxy started")
}

func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux) {
func registerHttpHandlerFromEndpoint(handler RegisterHttpHandlerFromEndpoint, serviceName string, ctx context.Context, mux *runtime.ServeMux, tlsConfig *tls.Config) {
endpoint := "localhost" + *rpcPortFlag
opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))}
var opts []grpc.DialOption
if tlsConfig != nil {
// local client connections via http proxy to grpc should not require tls
tlsConfig.InsecureSkipVerify = true
opts = []grpc.DialOption{
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
}
} else {
opts = []grpc.DialOption{grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))}
}

if err := handler(ctx, mux, endpoint, opts); err != nil {
glog.Fatalf("Failed to register %v handler: %v", serviceName, err)
Expand Down
20 changes: 16 additions & 4 deletions backend/src/common/util/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package util

import (
"crypto/tls"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"net/http"
"strings"
"time"
Expand All @@ -28,9 +31,9 @@ import (
"k8s.io/client-go/tools/clientcmd"
)

func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string) error {
func WaitForAPIAvailable(initializeTimeout time.Duration, basePath string, apiAddress string, scheme string) error {
operation := func() error {
response, err := http.Get(fmt.Sprintf("http://%s%s/healthz", apiAddress, basePath))
response, err := http.Get(fmt.Sprintf("%s://%s%s/healthz", scheme, apiAddress, basePath))
if err != nil {
return err
}
Expand Down Expand Up @@ -74,8 +77,17 @@ func GetKubernetesClientFromClientConfig(clientConfig clientcmd.ClientConfig) (
return clientSet, config, namespace, nil
}

func GetRpcConnection(address string) (*grpc.ClientConn, error) {
conn, err := grpc.Dial(address, grpc.WithInsecure())
HumairAK marked this conversation as resolved.
Show resolved Hide resolved
func GetRpcConnection(address string, tlsEnabled bool) (*grpc.ClientConn, error) {
creds := insecure.NewCredentials()
if tlsEnabled {
config := &tls.Config{}
creds = credentials.NewTLS(config)
}

conn, err := grpc.Dial(
address,
grpc.WithTransportCredentials(creds),
)
if err != nil {
return nil, errors.Wrapf(err, "Failed to create gRPC connection")
}
Expand Down
18 changes: 16 additions & 2 deletions backend/src/v2/cacheutils/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package cacheutils
import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"encoding/json"
"fmt"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"os"

"google.golang.org/grpc"
Expand Down Expand Up @@ -111,10 +114,21 @@ type Client struct {
}

// NewClient creates a Client.
func NewClient() (*Client, error) {
func NewClient(mlPipelineServiceTLSEnabled bool) (*Client, error) {
creds := insecure.NewCredentials()
if mlPipelineServiceTLSEnabled {
config := &tls.Config{
InsecureSkipVerify: false,
}
creds = credentials.NewTLS(config)
}
cacheEndPoint := cacheDefaultEndpoint()
glog.Infof("Connecting to cache endpoint %s", cacheEndPoint)
conn, err := grpc.Dial(cacheEndPoint, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)), grpc.WithInsecure())
conn, err := grpc.Dial(
cacheEndPoint,
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxClientGRPCMessageSize)),
grpc.WithTransportCredentials(creds),
)
if err != nil {
return nil, fmt.Errorf("metadata.NewClient() failed: %w", err)
}
Expand Down
Loading
Loading