Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lightning: reload certificate for new connection #33865

Merged
merged 11 commits into from
Apr 20, 2022
39 changes: 22 additions & 17 deletions br/pkg/lightning/common/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"google.golang.org/grpc/credentials"
)

// TLS
type TLS struct {
caPath string
certPath string
Expand All @@ -50,16 +49,6 @@ func ToTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
return nil, nil
}

// Load the client certificates from disk
var certificates []tls.Certificate
if len(certPath) != 0 && len(keyPath) != 0 {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, errors.Annotate(err, "could not load client key pair")
}
certificates = []tls.Certificate{cert}
}

// Create a certificate pool from CA
certPool := x509.NewCertPool()
ca, err := os.ReadFile(caPath)
Expand All @@ -72,12 +61,28 @@ func ToTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
return nil, errors.New("failed to append ca certs")
}

return &tls.Config{
Certificates: certificates,
RootCAs: certPool,
NextProtos: []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2.
MinVersion: tls.VersionTLS12,
}, nil
tlsConfig := &tls.Config{
RootCAs: certPool,
NextProtos: []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2.
MinVersion: tls.VersionTLS12,
}

if len(certPath) != 0 && len(keyPath) != 0 {
loadCert := func() (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, errors.Annotate(err, "could not load client key pair")
}
return &cert, nil
}
tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return loadCert()
}
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return loadCert()
}
}
return tlsConfig, nil
}

// NewTLS constructs a new HTTP client with TLS configured with the CA,
Expand Down
18 changes: 16 additions & 2 deletions br/pkg/lightning/common/security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,29 @@ func TestInvalidTLS(t *testing.T) {
_, err = common.NewTLS(caPath, "", "", "localhost")
require.Regexp(t, "failed to append ca certs", err.Error())

err = os.WriteFile(caPath, []byte(`-----BEGIN CERTIFICATE-----
MIIBITCBxwIUf04/Hucshr7AynmgF8JeuFUEf9EwCgYIKoZIzj0EAwIwEzERMA8G
A1UEAwwIYnJfdGVzdHMwHhcNMjIwNDEzMDcyNDQxWhcNMjIwNDE1MDcyNDQxWjAT
MREwDwYDVQQDDAhicl90ZXN0czBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABL+X
wczUg0AbaFFaCI+FAk3K9vbB9JeIORgGKS+F1TKip5tvm96g7S5lq8SgY38SXVc3
0yS3YqWZqnRjWi+sLwIwCgYIKoZIzj0EAwIDSQAwRgIhAJcpSwsUhqkM08LK1gYC
ze4ZnCkwJdP2VdpI3WZsoI7zAiEAjP8X1c0iFwYxdAbQAveX+9msVrzyUpZOohi4
RtgQTNI=
-----END CERTIFICATE-----
`), 0o644)
require.NoError(t, err)

certPath := filepath.Join(tempDir, "test.pem")
keyPath := filepath.Join(tempDir, "test.key")
_, err = common.NewTLS(caPath, certPath, keyPath, "localhost")
tls, err := common.NewTLS(caPath, certPath, keyPath, "localhost")
_, err = tls.TLSConfig().GetCertificate(nil)
require.Regexp(t, "could not load client key pair: open.*", err.Error())

err = os.WriteFile(certPath, []byte("invalid cert content"), 0o644)
require.NoError(t, err)
err = os.WriteFile(keyPath, []byte("invalid key content"), 0o600)
require.NoError(t, err)
_, err = common.NewTLS(caPath, certPath, keyPath, "localhost")
tls, err = common.NewTLS(caPath, certPath, keyPath, "localhost")
lance6716 marked this conversation as resolved.
Show resolved Hide resolved
_, err = tls.TLSConfig().GetCertificate(nil)
require.Regexp(t, "could not load client key pair: tls.*", err.Error())
}
81 changes: 81 additions & 0 deletions br/pkg/lightning/lightning.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ package lightning
import (
"compress/gzip"
"context"
"crypto/ecdsa"
"crypto/rand"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -317,6 +321,29 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti
failpoint.Return(nil)
})

failpoint.Inject("SetCertExpiredSoon", func(val failpoint.Value) {
rootKeyPath := val.(string)
rootCaPath := taskCfg.Security.CAPath
keyPath := taskCfg.Security.KeyPath
certPath := taskCfg.Security.CertPath
certBytes, err := os.ReadFile(certPath)
if err != nil {
panic(err)
}
if err := os.WriteFile(certPath+".old", certBytes, 0o600); err != nil {
panic(err)
}
if err := updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath, time.Second*10); err != nil {
sleepymole marked this conversation as resolved.
Show resolved Hide resolved
panic(err)
}
// Must restore the original cert before the new cert is expired.
time.AfterFunc(time.Second*5, func() {
if err := os.Rename(certPath+".old", certPath); err != nil {
panic(err)
}
})
})

if err := taskCfg.TiDB.Security.RegisterMySQL(); err != nil {
return common.ErrInvalidTLSConfig.Wrap(err)
}
Expand Down Expand Up @@ -905,3 +932,57 @@ func SwitchMode(ctx context.Context, cfg *config.Config, tls *common.TLS, mode s
},
)
}

func updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath string, expiry time.Duration) error {
rootKey, err := parsePrivateKey(rootKeyPath)
if err != nil {
return err
}
rootCaPem, err := os.ReadFile(rootCaPath)
if err != nil {
return err
}
rootCaDer, _ := pem.Decode(rootCaPem)
rootCa, err := x509.ParseCertificate(rootCaDer.Bytes)
if err != nil {
return err
}
key, err := parsePrivateKey(keyPath)
if err != nil {
return err
}
certPem, err := os.ReadFile(certPath)
if err != nil {
panic(err)
}
certDer, _ := pem.Decode(certPem)
cert, err := x509.ParseCertificate(certDer.Bytes)
if err != nil {
return err
}
cert.NotBefore = time.Now()
cert.NotAfter = time.Now().Add(expiry)
derBytes, err := x509.CreateCertificate(rand.Reader, cert, rootCa, &key.PublicKey, rootKey)
if err != nil {
return err
}
return os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), 0o600)
}

func parsePrivateKey(keyPath string) (*ecdsa.PrivateKey, error) {
keyPemBlock, err := os.ReadFile(keyPath)
if err != nil {
return nil, err
}
var keyDERBlock *pem.Block
for {
keyDERBlock, keyPemBlock = pem.Decode(keyPemBlock)
if keyDERBlock == nil {
return nil, errors.New("failed to find PEM block with type ending in \"PRIVATE KEY\"")
}
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
break
}
}
return x509.ParseECPrivateKey(keyDERBlock.Bytes)
}
Empty file.
1 change: 1 addition & 0 deletions br/tests/lightning_reload_cert/data/test-schema-create.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE DATABASE test;
1 change: 1 addition & 0 deletions br/tests/lightning_reload_cert/data/test.t-schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE t(a INT PRIMARY KEY, b int);
1 change: 1 addition & 0 deletions br/tests/lightning_reload_cert/data/test.t.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
INSERT INTO t VALUES (1,1);
26 changes: 26 additions & 0 deletions br/tests/lightning_reload_cert/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
#
# Copyright 2022 PingCAP, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -eux

cp "$TEST_DIR/certs/lightning.pem" "$TEST_DIR/certs/lightning-valid.pem"
trap 'mv "$TEST_DIR/certs/lightning-valid.pem" "$TEST_DIR/certs/lightning.pem"' EXIT

# shellcheck disable=SC2089
export GO_FAILPOINTS="github.com/pingcap/tidb/br/pkg/lightning/SetCertExpiredSoon=return(\"$TEST_DIR/certs/ca.key\")"
export GO_FAILPOINTS="${GO_FAILPOINTS};github.com/pingcap/tidb/br/pkg/lightning/restore/SlowDownWriteRows=sleep(15000)"

run_lightning --backend='local'
sleepymole marked this conversation as resolved.
Show resolved Hide resolved