Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy: optional server-side verification #166

Merged
merged 4 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions conf/proxy.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ rsa-key-size = 4_096
# skip-ca = trure
# client object:
# 1. requires: ca or skip-ca(skip verify server certs)
# 2. optionally: cert/key will be used if server asks
# 2. optionally: cert/key will be used if server asks, i.e. server-side client verification
# 3. useless/forbid: auto-certs
# server object:
# 1. requires: cert/key or auto-certs(generate a temporary cert, mostly for testing)
# 2. optionally: ca will enable server-side client verification.
# 3. useless/forbid: skip-ca
# 2. optionally: ca will enable server-side client verification. If skip-ca is true with non-empty ca, server will only verify clients if it can provide any cert. Otherwise, clients must provide a cert.
# peer object:
# 1. requires: cert/key/ca or auto-certs
# 2. useless/forbid: skip-ca
Expand Down
13 changes: 11 additions & 2 deletions lib/util/security/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
if err != nil {
dur = DefaultCertExpiration
}
certPEM, keyPEM, _, err = CreateTempTLS(ci.cfg.RSAKeySize, dur)
certPEM, keyPEM, _, err = createTempTLS(ci.cfg.RSAKeySize, dur)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -224,13 +224,22 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
return nil, errors.WithStack(err)
}
ci.ca.Store(cas)
tcfg.ClientAuth = tls.RequireAnyClientCert

if ci.cfg.SkipCA {
tcfg.ClientAuth = tls.VerifyClientCertIfGiven
} else {
tcfg.ClientAuth = tls.RequireAnyClientCert
}

return tcfg, nil
}

func (ci *CertInfo) buildClientConfig(lg *zap.Logger) (*tls.Config, error) {
lg = lg.With(zap.String("tls", "client"), zap.Any("cfg", ci.cfg))
if ci.cfg.AutoCerts {
lg.Info("specified auto-certs in a client tls config, ignored")
}

if !ci.cfg.HasCA() {
if ci.cfg.SkipCA {
// still enable TLS without verify server certs
Expand Down
20 changes: 19 additions & 1 deletion lib/util/security/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestCertServer(t *testing.T) {
keyPath := filepath.Join(tmpdir, "key")
caPath := filepath.Join(tmpdir, "ca")

require.NoError(t, createTLSCertificates(logger, certPath, keyPath, caPath, 0, time.Hour))
require.NoError(t, CreateTLSCertificates(logger, certPath, keyPath, caPath, 0, time.Hour))

type certCase struct {
config.TLSConfig
Expand Down Expand Up @@ -97,6 +97,23 @@ func TestCertServer(t *testing.T) {
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.Equal(t, tls.RequireAnyClientCert, c.ClientAuth)
require.NotNil(t, ci.ca.Load())
require.NotNil(t, ci.cert.Load())
},
err: "",
},
{
server: true,
TLSConfig: config.TLSConfig{
Cert: certPath,
Key: keyPath,
CA: caPath,
SkipCA: true,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.Equal(t, tls.VerifyClientCertIfGiven, c.ClientAuth)
require.NotNil(t, ci.ca.Load())
require.NotNil(t, ci.cert.Load())
},
Expand All @@ -110,6 +127,7 @@ func TestCertServer(t *testing.T) {
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.Equal(t, tls.RequireAnyClientCert, c.ClientAuth)
require.NotNil(t, ci.ca.Load())
require.NotNil(t, ci.cert.Load())
},
Expand Down
20 changes: 4 additions & 16 deletions lib/util/security/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (

const DefaultCertExpiration = 10 * 365 * 24 * time.Hour

func createTLSCertificates(logger *zap.Logger, certpath, keypath, capath string, rsaKeySize int, expiration time.Duration) error {
func CreateTLSCertificates(logger *zap.Logger, certpath, keypath, capath string, rsaKeySize int, expiration time.Duration) error {
logger = logger.With(zap.String("cert", certpath), zap.String("key", keypath), zap.String("ca", capath), zap.Int("rsaKeySize", rsaKeySize))

_, e1 := os.Stat(certpath)
Expand Down Expand Up @@ -66,7 +66,7 @@ func createTLSCertificates(logger *zap.Logger, certpath, keypath, capath string,
}
}

certPEM, keyPEM, caPEM, err := CreateTempTLS(rsaKeySize, expiration)
certPEM, keyPEM, caPEM, err := createTempTLS(rsaKeySize, expiration)
if err != nil {
return err
}
Expand All @@ -87,19 +87,7 @@ func createTLSCertificates(logger *zap.Logger, certpath, keypath, capath string,
return nil
}

func AutoTLS(logger *zap.Logger, scfg *config.TLSConfig, autoca bool, workdir, mod string, keySize int) error {
scfg.Cert = filepath.Join(workdir, mod, "cert.pem")
scfg.Key = filepath.Join(workdir, mod, "key.pem")
if autoca {
scfg.CA = filepath.Join(workdir, mod, "ca.pem")
}
if err := createTLSCertificates(logger, scfg.Cert, scfg.Key, scfg.CA, keySize, DefaultCertExpiration); err != nil {
return errors.WithStack(err)
}
return nil
}

func CreateTempTLS(rsaKeySize int, expiration time.Duration) ([]byte, []byte, []byte, error) {
func createTempTLS(rsaKeySize int, expiration time.Duration) ([]byte, []byte, []byte, error) {
if rsaKeySize < 1024 {
rsaKeySize = 1024
}
Expand Down Expand Up @@ -194,7 +182,7 @@ func CreateTempTLS(rsaKeySize int, expiration time.Duration) ([]byte, []byte, []

// CreateTLSConfigForTest is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251.
func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
certPEM, keyPEM, caPEM, uerr := CreateTempTLS(0, DefaultCertExpiration)
certPEM, keyPEM, caPEM, uerr := createTempTLS(0, DefaultCertExpiration)
if uerr != nil {
err = uerr
return
Expand Down
2 changes: 1 addition & 1 deletion lib/util/security/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

func BenchmarkCreateTLS(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _, _, err := CreateTempTLS(0, DefaultCertExpiration)
_, _, _, err := createTempTLS(0, DefaultCertExpiration)
require.Nil(b, err)
}
}
17 changes: 5 additions & 12 deletions pkg/manager/cert/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import (
)

const (
defaultRetryInterval = 1 * time.Hour
defaultAutoCertInterval = 30 * 24 * time.Hour
defaultRetryInterval = 1 * time.Hour
)

// CertManager reloads certs and offers interfaces for fetching TLS configs.
Expand All @@ -45,18 +44,16 @@ type CertManager struct {
sqlTLS *security.CertInfo // proxy -> tidb sql port
sqlTLSConfig *tls.Config

cancel context.CancelFunc
wg waitgroup.WaitGroup
retryInterval atomic.Int64
autoCertInterval atomic.Int64
logger *zap.Logger
cancel context.CancelFunc
wg waitgroup.WaitGroup
retryInterval atomic.Int64
logger *zap.Logger
}

// NewCertManager creates a new CertManager.
func NewCertManager() *CertManager {
cm := &CertManager{}
cm.SetRetryInterval(defaultRetryInterval)
cm.SetAutoCertInterval(defaultAutoCertInterval)
return cm
}

Expand Down Expand Up @@ -90,10 +87,6 @@ func (cm *CertManager) SetRetryInterval(interval time.Duration) {
cm.retryInterval.Store(int64(interval))
}

func (cm *CertManager) SetAutoCertInterval(interval time.Duration) {
cm.autoCertInterval.Store(int64(interval))
}

func (cm *CertManager) ServerTLS() *tls.Config {
return cm.serverTLSConfig
}
Expand Down
Loading