diff --git a/lib/util/security/cert.go b/lib/util/security/cert.go index 132ec38b..00db7202 100644 --- a/lib/util/security/cert.go +++ b/lib/util/security/cert.go @@ -174,14 +174,14 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) { VerifyPeerCertificate: ci.verifyPeerCertificate, } - var certPEM, keyPEM, caPEM []byte + var certPEM, keyPEM []byte var err error if autoCerts { dur, err := time.ParseDuration(ci.cfg.AutoExpireDuration) if err != nil { dur = DefaultCertExpiration } - certPEM, keyPEM, caPEM, err = CreateTempTLS(ci.cfg.RSAKeySize, dur) + certPEM, keyPEM, _, err = CreateTempTLS(ci.cfg.RSAKeySize, dur) if err != nil { return nil, err } @@ -194,15 +194,6 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) { if err != nil { return nil, err } - if !ci.cfg.HasCA() { - lg.Warn("no CA, server will not authenticate clients (connection is still secured)") - return tcfg, nil - } else { - caPEM, err = os.ReadFile(ci.cfg.CA) - if err != nil { - return nil, err - } - } } cert, err := tls.X509KeyPair(certPEM, keyPEM) @@ -218,14 +209,23 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) { } ci.cert.Store(&cert) - if len(caPEM) != 0 { - cas, err := ci.loadCA(caPEM) - if err != nil { - return nil, errors.WithStack(err) - } - ci.ca.Store(cas) - tcfg.ClientAuth = tls.RequireAnyClientCert + if !ci.cfg.HasCA() { + lg.Warn("no CA, server will not authenticate clients (connection is still secured)") + return tcfg, nil + } + + caPEM, err := os.ReadFile(ci.cfg.CA) + if err != nil { + return nil, err + } + + cas, err := ci.loadCA(caPEM) + if err != nil { + return nil, errors.WithStack(err) } + ci.ca.Store(cas) + tcfg.ClientAuth = tls.RequireAnyClientCert + return tcfg, nil } diff --git a/lib/util/security/cert_test.go b/lib/util/security/cert_test.go new file mode 100644 index 00000000..47accaca --- /dev/null +++ b/lib/util/security/cert_test.go @@ -0,0 +1,200 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package security + +import ( + "crypto/tls" + "path/filepath" + "testing" + "time" + + "github.com/pingcap/TiProxy/lib/config" + "github.com/pingcap/TiProxy/lib/util/logger" + "github.com/stretchr/testify/require" +) + +func TestCertServer(t *testing.T) { + logger := logger.CreateLoggerForTest(t) + tmpdir := t.TempDir() + certPath := filepath.Join(tmpdir, "cert") + keyPath := filepath.Join(tmpdir, "key") + caPath := filepath.Join(tmpdir, "ca") + + require.NoError(t, createTLSCertificates(logger, certPath, keyPath, caPath, 0, time.Hour)) + + type certCase struct { + config.TLSConfig + server bool + checker func(*testing.T, *tls.Config, *CertInfo) + err string + } + + cases := []certCase{ + { + server: true, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.Nil(t, c) + require.Nil(t, ci.ca.Load()) + require.Nil(t, ci.cert.Load()) + }, + err: "", + }, + { + server: true, + TLSConfig: config.TLSConfig{ + CA: caPath, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.Nil(t, c) + require.Nil(t, ci.ca.Load()) + require.Nil(t, ci.cert.Load()) + }, + err: "", + }, + { + server: true, + TLSConfig: config.TLSConfig{ + AutoCerts: true, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.NotNil(t, c) + require.Nil(t, ci.ca.Load()) + require.NotNil(t, ci.cert.Load()) + }, + err: "", + }, + { + server: true, + TLSConfig: config.TLSConfig{ + Cert: certPath, + Key: keyPath, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.NotNil(t, c) + require.Nil(t, ci.ca.Load()) + require.NotNil(t, ci.cert.Load()) + }, + err: "", + }, + { + server: true, + TLSConfig: config.TLSConfig{ + Cert: certPath, + Key: keyPath, + CA: caPath, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.NotNil(t, c) + require.NotNil(t, ci.ca.Load()) + require.NotNil(t, ci.cert.Load()) + }, + err: "", + }, + { + server: true, + TLSConfig: config.TLSConfig{ + AutoCerts: true, + CA: caPath, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.NotNil(t, c) + require.NotNil(t, ci.ca.Load()) + require.NotNil(t, ci.cert.Load()) + }, + err: "", + }, + { + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.Nil(t, c) + require.Nil(t, ci.ca.Load()) + require.Nil(t, ci.cert.Load()) + }, + err: "", + }, + { + TLSConfig: config.TLSConfig{ + Cert: certPath, + Key: keyPath, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.Nil(t, c) + require.Nil(t, ci.ca.Load()) + require.Nil(t, ci.cert.Load()) + }, + err: "", + }, + { + TLSConfig: config.TLSConfig{ + SkipCA: true, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.NotNil(t, c) + require.Nil(t, ci.ca.Load()) + require.Nil(t, ci.cert.Load()) + }, + err: "", + }, + { + TLSConfig: config.TLSConfig{ + SkipCA: true, + Cert: certPath, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.NotNil(t, c) + require.Nil(t, ci.ca.Load()) + require.Nil(t, ci.cert.Load()) + }, + err: "", + }, + { + TLSConfig: config.TLSConfig{ + CA: caPath, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.NotNil(t, c) + require.NotNil(t, ci.ca.Load()) + require.Nil(t, ci.cert.Load()) + }, + err: "", + }, + { + TLSConfig: config.TLSConfig{ + Cert: certPath, + Key: keyPath, + CA: caPath, + }, + checker: func(t *testing.T, c *tls.Config, ci *CertInfo) { + require.NotNil(t, c) + require.NotNil(t, ci.ca.Load()) + require.NotNil(t, ci.cert.Load()) + }, + err: "", + }, + } + + for _, tc := range cases { + ci, tcfg, err := NewCert(logger, tc.TLSConfig, tc.server) + if len(tc.err) > 0 { + require.Nil(t, ci) + require.ErrorContains(t, err, tc.err) + } else { + require.NotNil(t, ci) + require.NoError(t, err) + } + if tc.checker != nil { + tc.checker(t, tcfg, ci) + } + } +}