From c0ab948c6c05696ae3899ce9e0662142ac4fa77d Mon Sep 17 00:00:00 2001 From: djshow832 Date: Tue, 18 Oct 2022 11:53:58 +0800 Subject: [PATCH] manager: auto reload certs (#114) --- lib/util/security/tls.go | 30 +-- pkg/manager/cert/manager.go | 233 ++++++++++++++++++++ pkg/manager/cert/manager_test.go | 145 ++++++++++++ pkg/manager/namespace/manager.go | 18 +- pkg/manager/namespace/namespace.go | 16 +- pkg/manager/router/backend_observer.go | 14 +- pkg/manager/router/backend_observer_test.go | 6 +- pkg/proxy/proxy.go | 30 +-- pkg/server/server.go | 25 +-- 9 files changed, 427 insertions(+), 90 deletions(-) create mode 100644 pkg/manager/cert/manager.go create mode 100644 pkg/manager/cert/manager_test.go diff --git a/lib/util/security/tls.go b/lib/util/security/tls.go index f3ea9b17..d9dffe65 100644 --- a/lib/util/security/tls.go +++ b/lib/util/security/tls.go @@ -34,7 +34,9 @@ import ( "go.uber.org/zap" ) -func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath string, rsaKeySize int) error { +const DefaultCertExpiration = 10 * 365 * 24 * time.Hour + +func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath string, rsaKeySize int, expiration time.Duration) error { logger = logger.With(zap.String("cert", certpath), zap.String("key", keypath), zap.String("ca", capath), zap.Int("rsaKeySize", rsaKeySize)) _, e1 := os.Stat(certpath) @@ -64,7 +66,7 @@ func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath strin } } - certPEM, keyPEM, caPEM, err := CreateTempTLS() + certPEM, keyPEM, caPEM, err := CreateTempTLS(expiration) if err != nil { return err } @@ -86,20 +88,18 @@ func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath strin } func AutoTLS(logger *zap.Logger, scfg *config.TLSConfig, autoca bool, workdir, mod string, keySize int) error { - if !scfg.HasCert() && scfg.AutoCerts { - scfg.Cert = filepath.Join(workdir, mod, "cert.pem") - scfg.Key = filepath.Join(workdir, mod, "key.pem") - if autoca { - scfg.CA = filepath.Join(workdir, mod, "ca.pem") - } - if err := createTLSConfigificates(logger, scfg.Cert, scfg.Key, scfg.CA, keySize); err != nil { - return errors.WithStack(err) - } + scfg.Cert = filepath.Join(workdir, mod, "cert.pem") + scfg.Key = filepath.Join(workdir, mod, "key.pem") + if autoca { + scfg.CA = filepath.Join(workdir, mod, "ca.pem") + } + if err := createTLSConfigificates(logger, scfg.Cert, scfg.Key, scfg.CA, keySize, DefaultCertExpiration); err != nil { + return errors.WithStack(err) } return nil } -func CreateTempTLS() (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) { +func CreateTempTLS(expiration time.Duration) (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) { // set up our CA certificate ca := &x509.Certificate{ SerialNumber: big.NewInt(2019), @@ -112,7 +112,7 @@ func CreateTempTLS() (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) { PostalCode: []string{"94016"}, }, NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), + NotAfter: time.Now().Add(expiration), IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, @@ -153,7 +153,7 @@ func CreateTempTLS() (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) { }, IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), + NotAfter: time.Now().Add(expiration), SubjectKeyId: []byte{1, 2, 3, 4, 6}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, @@ -190,7 +190,7 @@ func CreateTempTLS() (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) { // CreateTLSConfigForTest is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251. func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) { - certPEM, keyPEM, caPEM, uerr := CreateTempTLS() + certPEM, keyPEM, caPEM, uerr := CreateTempTLS(DefaultCertExpiration) if uerr != nil { err = uerr return diff --git a/pkg/manager/cert/manager.go b/pkg/manager/cert/manager.go new file mode 100644 index 00000000..d56c6c99 --- /dev/null +++ b/pkg/manager/cert/manager.go @@ -0,0 +1,233 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cert + +import ( + "context" + "crypto/tls" + "sync/atomic" + "time" + + "github.com/pingcap/TiProxy/lib/config" + "github.com/pingcap/TiProxy/lib/util/errors" + "github.com/pingcap/TiProxy/lib/util/security" + "github.com/pingcap/TiProxy/lib/util/waitgroup" + "go.uber.org/zap" +) + +const ( + defaultRetryInterval = 1 * time.Hour + defaultAutoCertInterval = 30 * 24 * time.Hour +) + +// Security configurations don't support dynamically updating now. +type certInfo struct { + cfg config.TLSConfig + tlsConfig *tls.Config + certificate atomic.Pointer[tls.Certificate] + autoCert bool + autoCertExp time.Time +} + +func (ci *certInfo) getTLS() *tls.Config { + if ci.tlsConfig != nil { + return ci.tlsConfig.Clone() + } + return nil +} + +func (ci *certInfo) setTLS(tlsConfig *tls.Config) { + if tlsConfig != nil { + tlsConfig = tlsConfig.Clone() + if tlsConfig.Certificates != nil { + ci.certificate.Store(&tlsConfig.Certificates[0]) + // Doesn't support rotating CA now. It needs overwriting InsecureSkipVerify and VerifyPeerCertificate. + tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return ci.certificate.Load(), nil + } + tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + return ci.certificate.Load(), nil + } + tlsConfig.Certificates = nil + } + } + ci.tlsConfig = tlsConfig +} + +func (ci *certInfo) setAutoCertExp(exp time.Time) { + ci.autoCertExp = exp +} + +func (ci *certInfo) needRecreateCert(now time.Time) bool { + if !ci.autoCert { + return false + } + return now.After(ci.autoCertExp) +} + +// CertManager reloads certs and offers interfaces for fetching TLS configs. +// Currently, all the namespaces share the same certs but there might be per-namespace +// certs in the future. +type CertManager struct { + serverTLS certInfo // client / proxyctl -> proxy + peerTLS certInfo // proxy -> proxy + clusterTLS certInfo // proxy -> pd / tidb status port + sqlTLS certInfo // proxy -> tidb sql port + autoCertDir string + cancel context.CancelFunc + wg waitgroup.WaitGroup + retryInterval atomic.Int64 + autoCertInterval atomic.Int64 + cfg *config.Security + logger *zap.Logger +} + +// NewCertManager creates a new CertManager. +func NewCertManager() *CertManager { + cm := &CertManager{} + cm.SetRetryInterval(defaultRetryInterval) + cm.SetAutoCertInterval(defaultAutoCertInterval) + return cm +} + +func (cm *CertManager) Init(cfg *config.Config, logger *zap.Logger) error { + cm.cfg = &cfg.Security + cm.logger = logger + cm.autoCertDir = cfg.Workdir + cm.serverTLS = certInfo{ + cfg: cfg.Security.ServerTLS, + autoCert: !cfg.Security.ServerTLS.HasCert() && cfg.Security.ServerTLS.AutoCerts, + } + cm.peerTLS = certInfo{ + cfg: cfg.Security.PeerTLS, + autoCert: !cfg.Security.PeerTLS.HasCert() && cfg.Security.PeerTLS.AutoCerts, + } + cm.clusterTLS = certInfo{ + cfg: cfg.Security.ClusterTLS, + } + cm.sqlTLS = certInfo{ + cfg: cfg.Security.SQLTLS, + } + + if err := cm.load(); err != nil { + return err + } + + ctx, cancel := context.WithCancel(context.Background()) + cm.wg.Run(func() { + cm.reloadLoop(ctx) + }) + cm.cancel = cancel + return nil +} + +func (cm *CertManager) SetRetryInterval(interval time.Duration) { + cm.retryInterval.Store(int64(interval)) +} + +func (cm *CertManager) SetAutoCertInterval(interval time.Duration) { + cm.autoCertInterval.Store(int64(interval)) +} + +func (cm *CertManager) ServerTLS() *tls.Config { + return cm.serverTLS.getTLS() +} + +func (cm *CertManager) ClusterTLS() *tls.Config { + return cm.clusterTLS.getTLS() +} + +func (cm *CertManager) SQLTLS() *tls.Config { + return cm.sqlTLS.getTLS() +} + +// The proxy is supposed to be always online, so it should reload certs automatically, +// rather than reloading it by restarting the proxy. +// The proxy checks expiration time periodically and reloads certs in advance. If reloading +// fails or the cert is not replaced, it will retry in the next round. +func (cm *CertManager) reloadLoop(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Duration(cm.retryInterval.Load())): + _ = cm.load() + } + } +} + +func (cm *CertManager) load() error { + errs := make([]error, 0, 4) + now := time.Now() + var err error + needReloadServer := false + if cm.serverTLS.autoCert && cm.serverTLS.needRecreateCert(now) { + if err = security.AutoTLS(cm.logger, &cm.serverTLS.cfg, false, cm.autoCertDir, "server", + cm.cfg.RSAKeySize); err != nil { + cm.logger.Error("creating server certs failed", zap.Error(err)) + errs = append(errs, err) + } else { + needReloadServer = true + } + } else if !cm.serverTLS.autoCert { + needReloadServer = true + } + if needReloadServer { + var tlsConfig *tls.Config + if tlsConfig, err = security.BuildServerTLSConfig(cm.logger, cm.serverTLS.cfg); err != nil { + cm.logger.Error("loading server certs failed", zap.Error(err)) + errs = append(errs, err) + } else { + cm.serverTLS.setTLS(tlsConfig) + cm.serverTLS.setAutoCertExp(now.Add(time.Duration(cm.autoCertInterval.Load()))) + } + } + + if cm.peerTLS.autoCert && cm.peerTLS.needRecreateCert(now) { + if err := security.AutoTLS(cm.logger, &cm.peerTLS.cfg, true, cm.autoCertDir, "peer", + cm.cfg.RSAKeySize); err != nil { + cm.logger.Error("creating peer certs failed", zap.Error(err)) + errs = append(errs, err) + } else { + cm.peerTLS.setAutoCertExp(now.Add(time.Duration(cm.autoCertInterval.Load()))) + } + } + + if tlsConfig, err := security.BuildClientTLSConfig(cm.logger, cm.sqlTLS.cfg); err != nil { + cm.logger.Error("loading sql certs failed", zap.Error(err)) + errs = append(errs, err) + } else { + cm.sqlTLS.setTLS(tlsConfig) + } + + if tlsConfig, err := security.BuildClientTLSConfig(cm.logger, cm.clusterTLS.cfg); err != nil { + cm.logger.Error("loading cluster certs failed", zap.Error(err)) + errs = append(errs, err) + } else { + cm.clusterTLS.setTLS(tlsConfig) + } + + if len(errs) != 0 { + return errors.Collect(errors.New("loading certs"), errs...) + } + return nil +} + +func (cm *CertManager) Close() { + if cm.cancel != nil { + cm.cancel() + } + cm.wg.Wait() +} diff --git a/pkg/manager/cert/manager_test.go b/pkg/manager/cert/manager_test.go new file mode 100644 index 00000000..b0b263d8 --- /dev/null +++ b/pkg/manager/cert/manager_test.go @@ -0,0 +1,145 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cert + +import ( + "bytes" + "testing" + "time" + + "github.com/pingcap/TiProxy/lib/config" + "github.com/pingcap/TiProxy/lib/util/logger" + "github.com/pingcap/TiProxy/lib/util/security" + "github.com/stretchr/testify/require" +) + +// Test that the certs are automatically created and reloaded. +func TestReloadCerts(t *testing.T) { + dir := t.TempDir() + lg := logger.CreateLoggerForTest(t) + sqlCfg := &config.TLSConfig{AutoCerts: true} + err := security.AutoTLS(lg, sqlCfg, true, dir, "sql", 1024) + require.NoError(t, err) + clusterCfg := &config.TLSConfig{AutoCerts: true} + err = security.AutoTLS(lg, clusterCfg, true, dir, "cluster", 1024) + require.NoError(t, err) + + cfg := &config.Config{ + Workdir: dir, + Security: config.Security{ + ServerTLS: config.TLSConfig{ + AutoCerts: true, + }, + PeerTLS: config.TLSConfig{ + AutoCerts: true, + }, + SQLTLS: *sqlCfg, + ClusterTLS: *clusterCfg, + }, + } + certMgr := NewCertManager() + certMgr.SetRetryInterval(100 * time.Millisecond) + certMgr.SetAutoCertInterval(50 * time.Millisecond) + err = certMgr.Init(cfg, lg) + require.NoError(t, err) + t.Cleanup(certMgr.Close) + + areCertsDifferent := func(before, after [4][][]byte) bool { + for i := 0; i < 4; i++ { + if len(before[i]) != len(after[i]) { + continue + } + if before[i] == nil || after[i] == nil { + continue + } + different := false + for j := 0; j < len(before[i]); j++ { + if !bytes.Equal(before[i][j], after[i][j]) { + different = true + break + } + } + if !different { + return false + } + } + return true + } + + var before = getAllCertificates(t, certMgr) + sqlCfg = &config.TLSConfig{AutoCerts: true} + err = security.AutoTLS(lg, sqlCfg, true, dir, "sql", 1024) + require.NoError(t, err) + clusterCfg = &config.TLSConfig{AutoCerts: true} + err = security.AutoTLS(lg, clusterCfg, true, dir, "cluster", 1024) + require.NoError(t, err) + + timer := time.NewTimer(10 * time.Second) + for { + select { + case <-timer.C: + t.Fatal("timeout") + case <-time.After(100 * time.Millisecond): + var after = getAllCertificates(t, certMgr) + if areCertsDifferent(before, after) { + return + } + } + } +} + +func getRawCertificate(t *testing.T, ci *certInfo) [][]byte { + tlsConfig := ci.getTLS() + if tlsConfig == nil { + return nil + } + cert, err := tlsConfig.GetCertificate(nil) + require.NoError(t, err) + return cert.Certificate +} + +func getAllCertificates(t *testing.T, certMgr *CertManager) [4][][]byte { + return [4][][]byte{ + getRawCertificate(t, &certMgr.serverTLS), + getRawCertificate(t, &certMgr.peerTLS), + getRawCertificate(t, &certMgr.sqlTLS), + getRawCertificate(t, &certMgr.clusterTLS), + } +} + +// Test that configuring no certs still works. +func TestReloadEmptyCerts(t *testing.T) { + dir := t.TempDir() + lg := logger.CreateLoggerForTest(t) + cfg := &config.Config{ + Workdir: dir, + Security: config.Security{ + ServerTLS: config.TLSConfig{}, + PeerTLS: config.TLSConfig{}, + SQLTLS: config.TLSConfig{}, + ClusterTLS: config.TLSConfig{}, + }, + } + certMgr := NewCertManager() + certMgr.SetRetryInterval(100 * time.Millisecond) + err := certMgr.Init(cfg, lg) + require.NoError(t, err) + t.Cleanup(certMgr.Close) + + rawCerts := getAllCertificates(t, certMgr) + for i := 0; i < len(rawCerts); i++ { + require.Nil(t, rawCerts[i]) + } +} diff --git a/pkg/manager/namespace/manager.go b/pkg/manager/namespace/manager.go index f4806d8c..08b01453 100644 --- a/pkg/manager/namespace/manager.go +++ b/pkg/manager/namespace/manager.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/TiProxy/lib/config" "github.com/pingcap/TiProxy/lib/util/errors" - "github.com/pingcap/TiProxy/lib/util/security" "github.com/pingcap/TiProxy/pkg/manager/router" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -39,6 +38,7 @@ type NamespaceManager struct { func NewNamespaceManager() *NamespaceManager { return &NamespaceManager{} } + func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace) (*Namespace, error) { logger := mgr.logger.With(zap.String("namespace", cfg.Namespace)) @@ -46,22 +46,10 @@ func (mgr *NamespaceManager) buildNamespace(cfg *config.Namespace) (*Namespace, if err != nil { return nil, errors.Errorf("build router error: %w", err) } - r := &Namespace{ + return &Namespace{ name: cfg.Namespace, router: rt, - } - - r.frontendTLS, err = security.BuildServerTLSConfig(logger, cfg.Frontend.Security) - if err != nil { - return nil, errors.Errorf("build frontend TLS error: %w", err) - } - - r.backendTLS, err = security.BuildClientTLSConfig(logger, cfg.Backend.Security) - if err != nil { - return nil, errors.Errorf("build backend TLS error: %w", err) - } - - return r, nil + }, nil } func (mgr *NamespaceManager) CommitNamespaces(nss []*config.Namespace, nss_delete []bool) error { diff --git a/pkg/manager/namespace/namespace.go b/pkg/manager/namespace/namespace.go index 30edd960..fc5aae50 100644 --- a/pkg/manager/namespace/namespace.go +++ b/pkg/manager/namespace/namespace.go @@ -16,30 +16,18 @@ package namespace import ( - "crypto/tls" - "github.com/pingcap/TiProxy/pkg/manager/router" ) type Namespace struct { - name string - router router.Router - frontendTLS *tls.Config - backendTLS *tls.Config + name string + router router.Router } func (n *Namespace) Name() string { return n.name } -func (n *Namespace) FrontendTLSConfig() *tls.Config { - return n.frontendTLS -} - -func (n *Namespace) BackendTLSConfig() *tls.Config { - return n.backendTLS -} - func (n *Namespace) GetRouter() router.Router { return n.router } diff --git a/pkg/manager/router/backend_observer.go b/pkg/manager/router/backend_observer.go index c7bf56fa..2f125156 100644 --- a/pkg/manager/router/backend_observer.go +++ b/pkg/manager/router/backend_observer.go @@ -25,8 +25,8 @@ import ( "github.com/pingcap/TiProxy/lib/config" "github.com/pingcap/TiProxy/lib/util/errors" - "github.com/pingcap/TiProxy/lib/util/security" "github.com/pingcap/TiProxy/lib/util/waitgroup" + "github.com/pingcap/TiProxy/pkg/manager/cert" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" "github.com/pingcap/tidb/domain/infosync" clientv3 "go.etcd.io/etcd/client/v3" @@ -142,7 +142,7 @@ type BackendObserver struct { } // InitEtcdClient initializes an etcd client that fetches TiDB instance topology from PD. -func InitEtcdClient(logger *zap.Logger, cfg *config.Config) (*clientv3.Client, error) { +func InitEtcdClient(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertManager) (*clientv3.Client, error) { pdAddr := cfg.Proxy.PDAddrs if len(pdAddr) == 0 { // use tidb server addresses directly @@ -150,14 +150,9 @@ func InitEtcdClient(logger *zap.Logger, cfg *config.Config) (*clientv3.Client, e } pdEndpoints := strings.Split(pdAddr, ",") logger.Info("connect PD servers", zap.Strings("addrs", pdEndpoints)) - tlsConfig, err := security.BuildClientTLSConfig(logger, cfg.Security.ClusterTLS) - if err != nil { - return nil, err - } - var etcdClient *clientv3.Client - etcdClient, err = clientv3.New(clientv3.Config{ + etcdClient, err := clientv3.New(clientv3.Config{ Endpoints: pdEndpoints, - TLS: tlsConfig, + TLS: certMgr.ClusterTLS(), Logger: logger.Named("etcdcli"), AutoSyncInterval: 30 * time.Second, DialTimeout: 5 * time.Second, @@ -166,7 +161,6 @@ func InitEtcdClient(logger *zap.Logger, cfg *config.Config) (*clientv3.Client, e Time: 10 * time.Second, Timeout: 3 * time.Second, }), - //grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), grpc.WithBlock(), grpc.WithConnectParams(grpc.ConnectParams{ Backoff: backoff.Config{ diff --git a/pkg/manager/router/backend_observer_test.go b/pkg/manager/router/backend_observer_test.go index 8a93ea1e..6c6cab43 100644 --- a/pkg/manager/router/backend_observer_test.go +++ b/pkg/manager/router/backend_observer_test.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/TiProxy/lib/config" "github.com/pingcap/TiProxy/lib/util/logger" "github.com/pingcap/TiProxy/lib/util/waitgroup" + "github.com/pingcap/TiProxy/pkg/manager/cert" "github.com/pingcap/tidb/domain/infosync" "github.com/stretchr/testify/require" clientv3 "go.etcd.io/etcd/client/v3" @@ -243,7 +244,10 @@ func createEtcdClient(t *testing.T, etcd *embed.Etcd) *clientv3.Client { PDAddrs: etcd.Clients[0].Addr().String(), }, } - client, err := InitEtcdClient(logger.CreateLoggerForTest(t), cfg) + certMgr := cert.NewCertManager() + err := certMgr.Init(cfg, logger.CreateLoggerForTest(t)) + require.NoError(t, err) + client, err := InitEtcdClient(logger.CreateLoggerForTest(t), cfg, certMgr) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, client.Close()) diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index e4587398..076811b1 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -16,14 +16,13 @@ package proxy import ( "context" - "crypto/tls" "net" "sync" "github.com/pingcap/TiProxy/lib/config" "github.com/pingcap/TiProxy/lib/util/errors" - "github.com/pingcap/TiProxy/lib/util/security" "github.com/pingcap/TiProxy/lib/util/waitgroup" + "github.com/pingcap/TiProxy/pkg/manager/cert" mgrns "github.com/pingcap/TiProxy/pkg/manager/namespace" "github.com/pingcap/TiProxy/pkg/metrics" "github.com/pingcap/TiProxy/pkg/proxy/client" @@ -41,23 +40,23 @@ type serverState struct { } type SQLServer struct { - listener net.Listener - logger *zap.Logger - nsmgr *mgrns.NamespaceManager - frontendTLSConfig *tls.Config - backendTLSConfig *tls.Config - wg waitgroup.WaitGroup + listener net.Listener + logger *zap.Logger + certMgr *cert.CertManager + nsmgr *mgrns.NamespaceManager + wg waitgroup.WaitGroup mu serverState } // NewSQLServer creates a new SQLServer. -func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, scfg config.Security, nsmgr *mgrns.NamespaceManager) (*SQLServer, error) { +func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.CertManager, nsmgr *mgrns.NamespaceManager) (*SQLServer, error) { var err error s := &SQLServer{ - logger: logger, - nsmgr: nsmgr, + logger: logger, + certMgr: certMgr, + nsmgr: nsmgr, mu: serverState{ connID: 0, clients: make(map[uint64]*client.ClientConnection), @@ -66,13 +65,6 @@ func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, scfg config.Securi s.reset(&cfg.ProxyServerOnline) - if s.frontendTLSConfig, err = security.BuildServerTLSConfig(logger, scfg.ServerTLS); err != nil { - return nil, err - } - if s.backendTLSConfig, err = security.BuildClientTLSConfig(logger, scfg.SQLTLS); err != nil { - return nil, err - } - s.listener, err = net.Listen("tcp", cfg.Addr) if err != nil { return nil, err @@ -130,7 +122,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) { connID := s.mu.connID s.mu.connID++ logger := s.logger.With(zap.Uint64("connID", connID), zap.String("remoteAddr", conn.RemoteAddr().String())) - clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.frontendTLSConfig, s.backendTLSConfig, s.nsmgr, connID, s.mu.proxyProtocol) + clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), s.nsmgr, connID, s.mu.proxyProtocol) s.mu.clients[connID] = clientConn s.mu.Unlock() diff --git a/pkg/server/server.go b/pkg/server/server.go index 8821b614..24d55543 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/TiProxy/lib/util/security" "github.com/pingcap/TiProxy/lib/util/waitgroup" + "github.com/pingcap/TiProxy/pkg/manager/cert" mgrcfg "github.com/pingcap/TiProxy/pkg/manager/config" "github.com/pingcap/TiProxy/pkg/manager/logger" mgrns "github.com/pingcap/TiProxy/pkg/manager/namespace" @@ -49,6 +50,7 @@ type Server struct { NamespaceManager *mgrns.NamespaceManager MetricsManager *metrics.MetricsManager LoggerManager *logger.LoggerManager + CertManager *cert.CertManager ObserverClient *clientv3.Client // HTTP client Http *http.Client @@ -63,6 +65,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (srv *Server, err error) ConfigManager: mgrcfg.NewConfigManager(), MetricsManager: metrics.NewMetricsManager(), NamespaceManager: mgrns.NewNamespaceManager(), + CertManager: cert.NewCertManager(), } // set up logger @@ -71,15 +74,10 @@ func NewServer(ctx context.Context, cfg *config.Config) (srv *Server, err error) return } + // setup certs { - tlogger := lg.Named("tls") - // auto generate CA for serverTLS will break - if uerr := security.AutoTLS(tlogger, &cfg.Security.ServerTLS, false, cfg.Workdir, "server", cfg.Security.RSAKeySize); uerr != nil { - err = errors.WithStack(uerr) - return - } - if uerr := security.AutoTLS(tlogger, &cfg.Security.PeerTLS, true, cfg.Workdir, "peer", cfg.Security.RSAKeySize); uerr != nil { - err = errors.WithStack(uerr) + clogger := lg.Named("cert") + if err = srv.CertManager.Init(cfg, clogger); err != nil { return } } @@ -133,14 +131,9 @@ func NewServer(ctx context.Context, cfg *config.Config) (srv *Server, err error) // general cluster HTTP client { - clientTLS, uerr := security.BuildClientTLSConfig(lg.Named("http"), cfg.Security.ClusterTLS) - if uerr != nil { - err = errors.WithStack(err) - return - } srv.Http = &http.Client{ Transport: &http.Transport{ - TLSClientConfig: clientTLS, + TLSClientConfig: srv.CertManager.ClusterTLS(), }, } } @@ -176,7 +169,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (srv *Server, err error) // setup namespace manager { - srv.ObserverClient, err = router.InitEtcdClient(lg.Named("pd"), cfg) + srv.ObserverClient, err = router.InitEtcdClient(lg.Named("pd"), cfg, srv.CertManager) if err != nil { err = errors.WithStack(err) return @@ -198,7 +191,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (srv *Server, err error) // setup proxy server { - srv.Proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg.Proxy, cfg.Security, srv.NamespaceManager) + srv.Proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg.Proxy, srv.CertManager, srv.NamespaceManager) if err != nil { err = errors.WithStack(err) return