Skip to content

Commit

Permalink
Add config and handling for remote cluster cert (#2475)
Browse files Browse the repository at this point in the history
  • Loading branch information
meiliang86 committed Feb 23, 2022
1 parent e7852f7 commit 66ebbd4
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 60 deletions.
2 changes: 2 additions & 0 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ type (
Frontend GroupTLS `yaml:"frontend"`
// SystemWorker controls TLS setting for System Workers connecting to Frontend.
SystemWorker WorkerTLS `yaml:"systemWorker"`
// RemoteFrontendClients controls TLS setting for talking to remote cluster.
RemoteClusters map[string]GroupTLS `yaml:"remoteClusters"`
// ExpirationChecks defines settings for periodic checks for expiration of certificates
ExpirationChecks CertExpirationValidation `yaml:"expirationChecks"`
// Interval between refreshes of certificates loaded from files
Expand Down
4 changes: 2 additions & 2 deletions common/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type (
GetInternodeGRPCServerOptions() ([]grpc.ServerOption, error)
GetGRPCListener() net.Listener
GetRingpopChannel() *tchannel.Channel
CreateFrontendGRPCConnection(hostName string) *grpc.ClientConn
CreateInternodeGRPCConnection(hostName string) *grpc.ClientConn
CreateFrontendGRPCConnection(rpcAddress string) *grpc.ClientConn
CreateInternodeGRPCConnection(rpcAddress string) *grpc.ClientConn
}
)
102 changes: 85 additions & 17 deletions common/rpc/encryption/localStoreTlsProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,18 @@ type localStoreTlsProvider struct {

settings *config.RootTLS

internodeCertProvider CertProvider
internodeClientCertProvider CertProvider
frontendCertProvider CertProvider
workerCertProvider CertProvider

frontendPerHostCertProviderMap *localStorePerHostCertProviderMap

cachedInternodeServerConfig *tls.Config
cachedInternodeClientConfig *tls.Config
cachedFrontendServerConfig *tls.Config
cachedFrontendClientConfig *tls.Config
internodeCertProvider CertProvider
internodeClientCertProvider CertProvider
frontendCertProvider CertProvider
workerCertProvider CertProvider
remoteClusterClientCertProvider map[string]CertProvider
frontendPerHostCertProviderMap *localStorePerHostCertProviderMap

cachedInternodeServerConfig *tls.Config
cachedInternodeClientConfig *tls.Config
cachedFrontendServerConfig *tls.Config
cachedFrontendClientConfig *tls.Config
cachedRemoteClusterClientConfig map[string]*tls.Config

ticker *time.Ticker
logger log.Logger
Expand All @@ -84,17 +85,24 @@ func NewLocalStoreTlsProvider(tlsConfig *config.RootTLS, scope metrics.Scope, lo
workerProvider = internodeWorkerProvider
}

remoteClusterClientCertProvider := make(map[string]CertProvider)
for hostname, groupTLS := range tlsConfig.RemoteClusters {
remoteClusterClientCertProvider[hostname] = certProviderFactory(&groupTLS, nil, nil, tlsConfig.RefreshInterval, logger)
}

provider := &localStoreTlsProvider{
internodeCertProvider: internodeProvider,
internodeClientCertProvider: internodeProvider,
frontendCertProvider: certProviderFactory(&tlsConfig.Frontend, nil, nil, tlsConfig.RefreshInterval, logger),
workerCertProvider: workerProvider,
frontendPerHostCertProviderMap: newLocalStorePerHostCertProviderMap(
tlsConfig.Frontend.PerHostOverrides, certProviderFactory, tlsConfig.RefreshInterval, logger),
RWMutex: sync.RWMutex{},
settings: tlsConfig,
scope: scope,
logger: logger,
remoteClusterClientCertProvider: remoteClusterClientCertProvider,
RWMutex: sync.RWMutex{},
settings: tlsConfig,
scope: scope,
logger: logger,
cachedRemoteClusterClientConfig: make(map[string]*tls.Config),
}
provider.initialize()
return provider, nil
Expand Down Expand Up @@ -155,6 +163,26 @@ func (s *localStoreTlsProvider) GetFrontendClientConfig() (*tls.Config, error) {
)
}

func (s *localStoreTlsProvider) GetRemoteClusterClientConfig(hostname string) (*tls.Config, error) {
groupTLS, ok := s.settings.RemoteClusters[hostname]
if !ok {
return nil, nil
}

return s.getOrCreateRemoteClusterClientConfig(
hostname,
func() (*tls.Config, error) {
return newClientTLSConfig(
s.remoteClusterClientCertProvider[hostname],
groupTLS.Client.ServerName,
groupTLS.Server.RequireClientAuth,
false,
!groupTLS.Client.DisableHostVerification)
},
groupTLS.IsClientEnabled(),
)
}

func (s *localStoreTlsProvider) GetFrontendServerConfig() (*tls.Config, error) {
return s.getOrCreateConfig(
&s.cachedFrontendServerConfig,
Expand Down Expand Up @@ -239,6 +267,41 @@ func (s *localStoreTlsProvider) getOrCreateConfig(
return *cachedConfig, nil
}

func (s *localStoreTlsProvider) getOrCreateRemoteClusterClientConfig(
hostname string,
configConstructor tlsConfigConstructor,
isEnabled bool,
) (*tls.Config, error) {
if !isEnabled {
return nil, nil
}

// Check if exists under a read lock first
s.RLock()
if clientConfig, ok := s.cachedRemoteClusterClientConfig[hostname]; ok {
defer s.RUnlock()
return clientConfig, nil
}
// Not found, promote to write lock to initialize
s.RUnlock()
s.Lock()
defer s.Unlock()
// Check if someone got here first while waiting for write lock
if clientConfig, ok := s.cachedRemoteClusterClientConfig[hostname]; ok {
return clientConfig, nil
}

// Load configuration
localConfig, err := configConstructor()

if err != nil {
return nil, err
}

s.cachedRemoteClusterClientConfig[hostname] = localConfig
return localConfig, nil
}

func newServerTLSConfig(
certProvider CertProvider,
perHostCertProviderMap PerHostCertProviderMap,
Expand Down Expand Up @@ -321,8 +384,13 @@ func getServerTLSConfigFromCertProvider(
logger), nil
}

func newClientTLSConfig(clientProvider CertProvider, serverName string, isAuthRequired bool,
isWorker bool, enableHostVerification bool) (*tls.Config, error) {
func newClientTLSConfig(
clientProvider CertProvider,
serverName string,
isAuthRequired bool,
isWorker bool,
enableHostVerification bool,
) (*tls.Config, error) {
// Optional ServerCA for client if not already trusted by host
serverCa, err := clientProvider.FetchServerRootCAsForClient(isWorker)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions common/rpc/encryption/testDynamicTLSConfigProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ func (t *TestDynamicTLSConfigProvider) GetExpiringCerts(timeWindow time.Duration
panic("implement me")
}

func (t *TestDynamicTLSConfigProvider) GetRemoteClusterClientConfig(hostName string) (*tls.Config, error) {
//TODO implement me
panic("implement me")
}

var _ TLSConfigProvider = (*TestDynamicTLSConfigProvider)(nil)

func NewTestDynamicTLSConfigProvider(
Expand Down
1 change: 1 addition & 0 deletions common/rpc/encryption/tlsFactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type (
GetInternodeClientConfig() (*tls.Config, error)
GetFrontendServerConfig() (*tls.Config, error)
GetFrontendClientConfig() (*tls.Config, error)
GetRemoteClusterClientConfig(hostname string) (*tls.Config, error)
GetExpiringCerts(timeWindow time.Duration) (expiring CertExpirationMap, expired CertExpirationMap, err error)
}

Expand Down
60 changes: 44 additions & 16 deletions common/rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,24 @@ import (
"sync"

"github.com/uber/tchannel-go"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"

"go.temporal.io/server/common/cluster"
"go.temporal.io/server/common/config"
"go.temporal.io/server/common/convert"
"go.temporal.io/server/common/dynamicconfig"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/log/tag"
"go.temporal.io/server/common/rpc/encryption"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

// RPCFactory is an implementation of service.RPCFactory interface
type RPCFactory struct {
config *config.RPC
serviceName string
logger log.Logger
dc *dynamicconfig.Collection
config *config.RPC
serviceName string
logger log.Logger
dc *dynamicconfig.Collection
clusterMetadata *cluster.Config

sync.Mutex
grpcListener net.Listener
Expand All @@ -57,13 +58,21 @@ type RPCFactory struct {

// NewFactory builds a new RPCFactory
// conforming to the underlying configuration
func NewFactory(cfg *config.RPC, sName string, logger log.Logger, tlsProvider encryption.TLSConfigProvider, dc *dynamicconfig.Collection) *RPCFactory {
func NewFactory(
cfg *config.RPC,
sName string,
logger log.Logger,
tlsProvider encryption.TLSConfigProvider,
dc *dynamicconfig.Collection,
clusterMetadata *cluster.Config,
) *RPCFactory {
return &RPCFactory{
config: cfg,
serviceName: sName,
logger: logger,
dc: dc,
tlsFactory: tlsProvider,
config: cfg,
serviceName: sName,
logger: logger,
dc: dc,
tlsFactory: tlsProvider,
clusterMetadata: clusterMetadata,
}
}

Expand Down Expand Up @@ -92,6 +101,14 @@ func (d *RPCFactory) GetFrontendClientTlsConfig() (*tls.Config, error) {
return nil, nil
}

func (d *RPCFactory) GetRemoteClusterClientConfig(hostname string) (*tls.Config, error) {
if d.tlsFactory != nil {
return d.tlsFactory.GetRemoteClusterClientConfig(hostname)
}

return nil, nil
}

func (d *RPCFactory) GetInternodeGRPCServerOptions() ([]grpc.ServerOption, error) {
var opts []grpc.ServerOption

Expand Down Expand Up @@ -237,18 +254,29 @@ func getListenIP(cfg *config.RPC, logger log.Logger) net.IP {
}

// CreateFrontendGRPCConnection creates connection for gRPC calls
func (d *RPCFactory) CreateFrontendGRPCConnection(hostName string) *grpc.ClientConn {
func (d *RPCFactory) CreateFrontendGRPCConnection(rpcAddress string) *grpc.ClientConn {
var tlsClientConfig *tls.Config
var err error
if d.tlsFactory != nil {
tlsClientConfig, err = d.tlsFactory.GetFrontendClientConfig()
currCluster := d.clusterMetadata.ClusterInformation[d.clusterMetadata.CurrentClusterName]

if currCluster.RPCAddress == rpcAddress {
tlsClientConfig, err = d.tlsFactory.GetFrontendClientConfig()
} else {
hostname, _, err2 := net.SplitHostPort(rpcAddress)
if err2 != nil {
d.logger.Fatal("Invalid rpcAddress for remote cluster", tag.Error(err2))
}
tlsClientConfig, err = d.tlsFactory.GetRemoteClusterClientConfig(hostname)
}

if err != nil {
d.logger.Fatal("Failed to create tls config for gRPC connection", tag.Error(err))
return nil
}
}

return d.dial(hostName, tlsClientConfig)
return d.dial(rpcAddress, tlsClientConfig)
}

// CreateInternodeGRPCConnection creates connection for gRPC calls
Expand Down
29 changes: 20 additions & 9 deletions common/rpc/test/rpc_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ import (
"context"
"crypto/tls"
"math/rand"
"net"
"strings"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"

"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/examples/helloworld/helloworld"

"go.temporal.io/server/common/cluster"
"go.temporal.io/server/common/config"
"go.temporal.io/server/common/convert"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/rpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/examples/helloworld/helloworld"
"google.golang.org/grpc/peer"
)

// HelloServer is used to implement helloworld.GreeterServer.
Expand All @@ -53,6 +53,7 @@ type ServerUsageType int32
const (
Frontend ServerUsageType = iota
Internode
RemoteCluster
)

const (
Expand Down Expand Up @@ -82,6 +83,10 @@ var (
BroadcastAddress: localhostIPv4,
},
}
clusterMetadata = &cluster.Config{
CurrentClusterName: "test",
ClusterInformation: map[string]cluster.ClusterInformation{"test": {RPCAddress: localhostIPv4 + ":1234"}},
}
)

func startHelloWorldServer(s suite.Suite, factory *TestFactory) (*grpc.Server, string) {
Expand Down Expand Up @@ -166,15 +171,21 @@ func dialHelloAndGetTLSInfo(
logger := log.NewNoopLogger()
var cfg *tls.Config
var err error
if serverType == Internode {
switch serverType {
case Internode:
cfg, err = clientFactory.GetInternodeClientTlsConfig()
} else {
case Frontend:
cfg, err = clientFactory.GetFrontendClientTlsConfig()
case RemoteCluster:
host, _, err := net.SplitHostPort(hostport)
s.NoError(err)
cfg, err = clientFactory.GetRemoteClusterClientConfig(host)
}

s.NoError(err)

clientConn, err := rpc.Dial(hostport, cfg, logger)
s.NoError(err)

client := helloworld.NewGreeterClient(clientConn)

request := &helloworld.HelloRequest{Name: convert.Uint64ToString(rand.Uint64())}
Expand Down
Loading

0 comments on commit 66ebbd4

Please sign in to comment.