Skip to content

Commit

Permalink
manager: auto reload certs (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Oct 18, 2022
1 parent 458b03c commit c0ab948
Show file tree
Hide file tree
Showing 9 changed files with 427 additions and 90 deletions.
30 changes: 15 additions & 15 deletions lib/util/security/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ import (
"go.uber.org/zap"
)

func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath string, rsaKeySize int) error {
const DefaultCertExpiration = 10 * 365 * 24 * time.Hour

func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath string, rsaKeySize int, expiration time.Duration) error {
logger = logger.With(zap.String("cert", certpath), zap.String("key", keypath), zap.String("ca", capath), zap.Int("rsaKeySize", rsaKeySize))

_, e1 := os.Stat(certpath)
Expand Down Expand Up @@ -64,7 +66,7 @@ func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath strin
}
}

certPEM, keyPEM, caPEM, err := CreateTempTLS()
certPEM, keyPEM, caPEM, err := CreateTempTLS(expiration)
if err != nil {
return err
}
Expand All @@ -86,20 +88,18 @@ func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath strin
}

func AutoTLS(logger *zap.Logger, scfg *config.TLSConfig, autoca bool, workdir, mod string, keySize int) error {
if !scfg.HasCert() && scfg.AutoCerts {
scfg.Cert = filepath.Join(workdir, mod, "cert.pem")
scfg.Key = filepath.Join(workdir, mod, "key.pem")
if autoca {
scfg.CA = filepath.Join(workdir, mod, "ca.pem")
}
if err := createTLSConfigificates(logger, scfg.Cert, scfg.Key, scfg.CA, keySize); err != nil {
return errors.WithStack(err)
}
scfg.Cert = filepath.Join(workdir, mod, "cert.pem")
scfg.Key = filepath.Join(workdir, mod, "key.pem")
if autoca {
scfg.CA = filepath.Join(workdir, mod, "ca.pem")
}
if err := createTLSConfigificates(logger, scfg.Cert, scfg.Key, scfg.CA, keySize, DefaultCertExpiration); err != nil {
return errors.WithStack(err)
}
return nil
}

func CreateTempTLS() (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) {
func CreateTempTLS(expiration time.Duration) (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) {
// set up our CA certificate
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Expand All @@ -112,7 +112,7 @@ func CreateTempTLS() (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) {
PostalCode: []string{"94016"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
NotAfter: time.Now().Add(expiration),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
Expand Down Expand Up @@ -153,7 +153,7 @@ func CreateTempTLS() (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) {
},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
NotAfter: time.Now().Add(expiration),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
Expand Down Expand Up @@ -190,7 +190,7 @@ func CreateTempTLS() (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) {

// CreateTLSConfigForTest is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251.
func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
certPEM, keyPEM, caPEM, uerr := CreateTempTLS()
certPEM, keyPEM, caPEM, uerr := CreateTempTLS(DefaultCertExpiration)
if uerr != nil {
err = uerr
return
Expand Down
233 changes: 233 additions & 0 deletions pkg/manager/cert/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cert

import (
"context"
"crypto/tls"
"sync/atomic"
"time"

"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/lib/util/security"
"github.com/pingcap/TiProxy/lib/util/waitgroup"
"go.uber.org/zap"
)

const (
defaultRetryInterval = 1 * time.Hour
defaultAutoCertInterval = 30 * 24 * time.Hour
)

// Security configurations don't support dynamically updating now.
type certInfo struct {
cfg config.TLSConfig
tlsConfig *tls.Config
certificate atomic.Pointer[tls.Certificate]
autoCert bool
autoCertExp time.Time
}

func (ci *certInfo) getTLS() *tls.Config {
if ci.tlsConfig != nil {
return ci.tlsConfig.Clone()
}
return nil
}

func (ci *certInfo) setTLS(tlsConfig *tls.Config) {
if tlsConfig != nil {
tlsConfig = tlsConfig.Clone()
if tlsConfig.Certificates != nil {
ci.certificate.Store(&tlsConfig.Certificates[0])
// Doesn't support rotating CA now. It needs overwriting InsecureSkipVerify and VerifyPeerCertificate.
tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return ci.certificate.Load(), nil
}
tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return ci.certificate.Load(), nil
}
tlsConfig.Certificates = nil
}
}
ci.tlsConfig = tlsConfig
}

func (ci *certInfo) setAutoCertExp(exp time.Time) {
ci.autoCertExp = exp
}

func (ci *certInfo) needRecreateCert(now time.Time) bool {
if !ci.autoCert {
return false
}
return now.After(ci.autoCertExp)
}

// CertManager reloads certs and offers interfaces for fetching TLS configs.
// Currently, all the namespaces share the same certs but there might be per-namespace
// certs in the future.
type CertManager struct {
serverTLS certInfo // client / proxyctl -> proxy
peerTLS certInfo // proxy -> proxy
clusterTLS certInfo // proxy -> pd / tidb status port
sqlTLS certInfo // proxy -> tidb sql port
autoCertDir string
cancel context.CancelFunc
wg waitgroup.WaitGroup
retryInterval atomic.Int64
autoCertInterval atomic.Int64
cfg *config.Security
logger *zap.Logger
}

// NewCertManager creates a new CertManager.
func NewCertManager() *CertManager {
cm := &CertManager{}
cm.SetRetryInterval(defaultRetryInterval)
cm.SetAutoCertInterval(defaultAutoCertInterval)
return cm
}

func (cm *CertManager) Init(cfg *config.Config, logger *zap.Logger) error {
cm.cfg = &cfg.Security
cm.logger = logger
cm.autoCertDir = cfg.Workdir
cm.serverTLS = certInfo{
cfg: cfg.Security.ServerTLS,
autoCert: !cfg.Security.ServerTLS.HasCert() && cfg.Security.ServerTLS.AutoCerts,
}
cm.peerTLS = certInfo{
cfg: cfg.Security.PeerTLS,
autoCert: !cfg.Security.PeerTLS.HasCert() && cfg.Security.PeerTLS.AutoCerts,
}
cm.clusterTLS = certInfo{
cfg: cfg.Security.ClusterTLS,
}
cm.sqlTLS = certInfo{
cfg: cfg.Security.SQLTLS,
}

if err := cm.load(); err != nil {
return err
}

ctx, cancel := context.WithCancel(context.Background())
cm.wg.Run(func() {
cm.reloadLoop(ctx)
})
cm.cancel = cancel
return nil
}

func (cm *CertManager) SetRetryInterval(interval time.Duration) {
cm.retryInterval.Store(int64(interval))
}

func (cm *CertManager) SetAutoCertInterval(interval time.Duration) {
cm.autoCertInterval.Store(int64(interval))
}

func (cm *CertManager) ServerTLS() *tls.Config {
return cm.serverTLS.getTLS()
}

func (cm *CertManager) ClusterTLS() *tls.Config {
return cm.clusterTLS.getTLS()
}

func (cm *CertManager) SQLTLS() *tls.Config {
return cm.sqlTLS.getTLS()
}

// The proxy is supposed to be always online, so it should reload certs automatically,
// rather than reloading it by restarting the proxy.
// The proxy checks expiration time periodically and reloads certs in advance. If reloading
// fails or the cert is not replaced, it will retry in the next round.
func (cm *CertManager) reloadLoop(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Duration(cm.retryInterval.Load())):
_ = cm.load()
}
}
}

func (cm *CertManager) load() error {
errs := make([]error, 0, 4)
now := time.Now()
var err error
needReloadServer := false
if cm.serverTLS.autoCert && cm.serverTLS.needRecreateCert(now) {
if err = security.AutoTLS(cm.logger, &cm.serverTLS.cfg, false, cm.autoCertDir, "server",
cm.cfg.RSAKeySize); err != nil {
cm.logger.Error("creating server certs failed", zap.Error(err))
errs = append(errs, err)
} else {
needReloadServer = true
}
} else if !cm.serverTLS.autoCert {
needReloadServer = true
}
if needReloadServer {
var tlsConfig *tls.Config
if tlsConfig, err = security.BuildServerTLSConfig(cm.logger, cm.serverTLS.cfg); err != nil {
cm.logger.Error("loading server certs failed", zap.Error(err))
errs = append(errs, err)
} else {
cm.serverTLS.setTLS(tlsConfig)
cm.serverTLS.setAutoCertExp(now.Add(time.Duration(cm.autoCertInterval.Load())))
}
}

if cm.peerTLS.autoCert && cm.peerTLS.needRecreateCert(now) {
if err := security.AutoTLS(cm.logger, &cm.peerTLS.cfg, true, cm.autoCertDir, "peer",
cm.cfg.RSAKeySize); err != nil {
cm.logger.Error("creating peer certs failed", zap.Error(err))
errs = append(errs, err)
} else {
cm.peerTLS.setAutoCertExp(now.Add(time.Duration(cm.autoCertInterval.Load())))
}
}

if tlsConfig, err := security.BuildClientTLSConfig(cm.logger, cm.sqlTLS.cfg); err != nil {
cm.logger.Error("loading sql certs failed", zap.Error(err))
errs = append(errs, err)
} else {
cm.sqlTLS.setTLS(tlsConfig)
}

if tlsConfig, err := security.BuildClientTLSConfig(cm.logger, cm.clusterTLS.cfg); err != nil {
cm.logger.Error("loading cluster certs failed", zap.Error(err))
errs = append(errs, err)
} else {
cm.clusterTLS.setTLS(tlsConfig)
}

if len(errs) != 0 {
return errors.Collect(errors.New("loading certs"), errs...)
}
return nil
}

func (cm *CertManager) Close() {
if cm.cancel != nil {
cm.cancel()
}
cm.wg.Wait()
}
Loading

0 comments on commit c0ab948

Please sign in to comment.