Skip to content

Commit

Permalink
security: fix server cert config (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox authored Dec 27, 2022
1 parent f725184 commit 1649979
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 18 deletions.
36 changes: 18 additions & 18 deletions lib/util/security/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,14 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
VerifyPeerCertificate: ci.verifyPeerCertificate,
}

var certPEM, keyPEM, caPEM []byte
var certPEM, keyPEM []byte
var err error
if autoCerts {
dur, err := time.ParseDuration(ci.cfg.AutoExpireDuration)
if err != nil {
dur = DefaultCertExpiration
}
certPEM, keyPEM, caPEM, err = CreateTempTLS(ci.cfg.RSAKeySize, dur)
certPEM, keyPEM, _, err = CreateTempTLS(ci.cfg.RSAKeySize, dur)
if err != nil {
return nil, err
}
Expand All @@ -194,15 +194,6 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
if err != nil {
return nil, err
}
if !ci.cfg.HasCA() {
lg.Warn("no CA, server will not authenticate clients (connection is still secured)")
return tcfg, nil
} else {
caPEM, err = os.ReadFile(ci.cfg.CA)
if err != nil {
return nil, err
}
}
}

cert, err := tls.X509KeyPair(certPEM, keyPEM)
Expand All @@ -218,14 +209,23 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
}
ci.cert.Store(&cert)

if len(caPEM) != 0 {
cas, err := ci.loadCA(caPEM)
if err != nil {
return nil, errors.WithStack(err)
}
ci.ca.Store(cas)
tcfg.ClientAuth = tls.RequireAnyClientCert
if !ci.cfg.HasCA() {
lg.Warn("no CA, server will not authenticate clients (connection is still secured)")
return tcfg, nil
}

caPEM, err := os.ReadFile(ci.cfg.CA)
if err != nil {
return nil, err
}

cas, err := ci.loadCA(caPEM)
if err != nil {
return nil, errors.WithStack(err)
}
ci.ca.Store(cas)
tcfg.ClientAuth = tls.RequireAnyClientCert

return tcfg, nil
}

Expand Down
200 changes: 200 additions & 0 deletions lib/util/security/cert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package security

import (
"crypto/tls"
"path/filepath"
"testing"
"time"

"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/lib/util/logger"
"github.com/stretchr/testify/require"
)

func TestCertServer(t *testing.T) {
logger := logger.CreateLoggerForTest(t)
tmpdir := t.TempDir()
certPath := filepath.Join(tmpdir, "cert")
keyPath := filepath.Join(tmpdir, "key")
caPath := filepath.Join(tmpdir, "ca")

require.NoError(t, createTLSCertificates(logger, certPath, keyPath, caPath, 0, time.Hour))

type certCase struct {
config.TLSConfig
server bool
checker func(*testing.T, *tls.Config, *CertInfo)
err string
}

cases := []certCase{
{
server: true,
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.Nil(t, c)
require.Nil(t, ci.ca.Load())
require.Nil(t, ci.cert.Load())
},
err: "",
},
{
server: true,
TLSConfig: config.TLSConfig{
CA: caPath,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.Nil(t, c)
require.Nil(t, ci.ca.Load())
require.Nil(t, ci.cert.Load())
},
err: "",
},
{
server: true,
TLSConfig: config.TLSConfig{
AutoCerts: true,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.Nil(t, ci.ca.Load())
require.NotNil(t, ci.cert.Load())
},
err: "",
},
{
server: true,
TLSConfig: config.TLSConfig{
Cert: certPath,
Key: keyPath,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.Nil(t, ci.ca.Load())
require.NotNil(t, ci.cert.Load())
},
err: "",
},
{
server: true,
TLSConfig: config.TLSConfig{
Cert: certPath,
Key: keyPath,
CA: caPath,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.NotNil(t, ci.ca.Load())
require.NotNil(t, ci.cert.Load())
},
err: "",
},
{
server: true,
TLSConfig: config.TLSConfig{
AutoCerts: true,
CA: caPath,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.NotNil(t, ci.ca.Load())
require.NotNil(t, ci.cert.Load())
},
err: "",
},
{
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.Nil(t, c)
require.Nil(t, ci.ca.Load())
require.Nil(t, ci.cert.Load())
},
err: "",
},
{
TLSConfig: config.TLSConfig{
Cert: certPath,
Key: keyPath,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.Nil(t, c)
require.Nil(t, ci.ca.Load())
require.Nil(t, ci.cert.Load())
},
err: "",
},
{
TLSConfig: config.TLSConfig{
SkipCA: true,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.Nil(t, ci.ca.Load())
require.Nil(t, ci.cert.Load())
},
err: "",
},
{
TLSConfig: config.TLSConfig{
SkipCA: true,
Cert: certPath,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.Nil(t, ci.ca.Load())
require.Nil(t, ci.cert.Load())
},
err: "",
},
{
TLSConfig: config.TLSConfig{
CA: caPath,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.NotNil(t, ci.ca.Load())
require.Nil(t, ci.cert.Load())
},
err: "",
},
{
TLSConfig: config.TLSConfig{
Cert: certPath,
Key: keyPath,
CA: caPath,
},
checker: func(t *testing.T, c *tls.Config, ci *CertInfo) {
require.NotNil(t, c)
require.NotNil(t, ci.ca.Load())
require.NotNil(t, ci.cert.Load())
},
err: "",
},
}

for _, tc := range cases {
ci, tcfg, err := NewCert(logger, tc.TLSConfig, tc.server)
if len(tc.err) > 0 {
require.Nil(t, ci)
require.ErrorContains(t, err, tc.err)
} else {
require.NotNil(t, ci)
require.NoError(t, err)
}
if tc.checker != nil {
tc.checker(t, tcfg, ci)
}
}
}

0 comments on commit 1649979

Please sign in to comment.