diff --git a/domain/domain_test.go b/domain/domain_test.go index 282cf1f145df5..3adbc096251f6 100644 --- a/domain/domain_test.go +++ b/domain/domain_test.go @@ -229,6 +229,8 @@ func (msm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, boo func (msm *mockSessionManager) Kill(cid uint64, query bool) {} +func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {} + func (*testSuite) TestT(c *C) { defer testleak.AfterTest(c)() store, err := mockstore.NewMockTikvStore() diff --git a/executor/executor_pkg_test.go b/executor/executor_pkg_test.go index 6794bba00467f..b1e9dc72b91fc 100644 --- a/executor/executor_pkg_test.go +++ b/executor/executor_pkg_test.go @@ -15,6 +15,7 @@ package executor import ( "context" + "crypto/tls" . "github.com/pingcap/check" "github.com/pingcap/parser/ast" @@ -73,6 +74,9 @@ func (msm *mockSessionManager) Kill(cid uint64, query bool) { } +func (msm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) { +} + func (s *testExecSuite) TestShowProcessList(c *C) { // Compose schema. names := []string{"Id", "User", "Host", "db", "Command", "Time", "State", "Info"} diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index 75905026bebf4..0efb530bbe1c3 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -14,6 +14,7 @@ package executor_test import ( + "crypto/tls" "fmt" . "github.com/pingcap/check" @@ -51,6 +52,9 @@ func (msm *mockSessionManager1) Kill(cid uint64, query bool) { } +func (msm *mockSessionManager1) UpdateTLSConfig(cfg *tls.Config) { +} + func (s *testSuite) TestExplainFor(c *C) { tkRoot := testkit.NewTestKitWithInit(c, s.store) tkUser := testkit.NewTestKitWithInit(c, s.store) diff --git a/executor/simple.go b/executor/simple.go index 12bc722e43d57..e6e2facd932b9 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" @@ -108,6 +109,8 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { err = e.executeUse(x) case *ast.FlushStmt: err = e.executeFlush(x) + case *ast.AlterInstanceStmt: + err = e.executeAlterInstance(x) case *ast.BeginStmt: err = e.executeBegin(ctx, x) case *ast.CommitStmt: @@ -1098,6 +1101,26 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error { return nil } +func (e *SimpleExec) executeAlterInstance(s *ast.AlterInstanceStmt) error { + if s.ReloadTLS { + logutil.BgLogger().Info("execute reload tls", zap.Bool("NoRollbackOnError", s.NoRollbackOnError)) + sm := e.ctx.GetSessionManager() + tlsCfg, err := util.LoadTLSCertificates( + variable.SysVars["ssl_ca"].Value, + variable.SysVars["ssl_key"].Value, + variable.SysVars["ssl_cert"].Value, + ) + if err != nil { + if !s.NoRollbackOnError { + return err + } + logutil.BgLogger().Warn("reload TLS fail but keep working without TLS due to 'no rollback on error'") + } + sm.UpdateTLSConfig(tlsCfg) + } + return nil +} + func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) error { h := domain.GetDomain(e.ctx).StatsHandle() err := h.DeleteTableStatsFromKV(s.Table.TableInfo.ID) diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index 6af9cc2e1f17b..7e56a59c9731a 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -456,6 +456,8 @@ func (sm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool func (sm *mockSessionManager) Kill(connectionID uint64, query bool) {} +func (sm *mockSessionManager) UpdateTLSConfig(cfg *tls.Config) {} + func (s *testTableSuite) TestSomeTables(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 85f63016b4447..357a93a8946da 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -433,7 +433,7 @@ func (b *PlanBuilder) Build(ctx context.Context, node ast.Node) (Plan, error) { case *ast.AnalyzeTableStmt: return b.buildAnalyze(x) case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt, - *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, + *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt, *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.AlterInstanceStmt, *ast.GrantStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.RevokeStmt, *ast.KillStmt, *ast.DropStatsStmt, *ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt, *ast.ShutdownStmt: return b.buildSimple(node.(ast.StmtNode)) @@ -1684,6 +1684,9 @@ func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) { case *ast.FlushStmt: err := ErrSpecificAccessDenied.GenWithStackByArgs("RELOAD") b.visitInfo = appendVisitInfo(b.visitInfo, mysql.ReloadPriv, "", "", "", err) + case *ast.AlterInstanceStmt: + err := ErrSpecificAccessDenied.GenWithStack("ALTER INSTANCE") + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", err) case *ast.AlterUserStmt: err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err) diff --git a/server/conn.go b/server/conn.go index 4a8474cafd136..87b435a22c24c 100644 --- a/server/conn.go +++ b/server/conn.go @@ -500,23 +500,26 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con return err } - if (resp.Capability&mysql.ClientSSL > 0) && cc.server.tlsConfig != nil { - // The packet is a SSLRequest, let's switch to TLS. - if err = cc.upgradeToTLS(cc.server.tlsConfig); err != nil { - return err - } - // Read the following HandshakeResponse packet. - data, err = cc.readPacket() - if err != nil { - return err - } - if isOldVersion { - pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data) - } else { - pos, err = parseHandshakeResponseHeader(ctx, &resp, data) - } - if err != nil { - return err + if resp.Capability&mysql.ClientSSL > 0 { + tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig)) + if tlsConfig != nil { + // The packet is a SSLRequest, let's switch to TLS. + if err = cc.upgradeToTLS(tlsConfig); err != nil { + return err + } + // Read the following HandshakeResponse packet. + data, err = cc.readPacket() + if err != nil { + return err + } + if isOldVersion { + pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data) + } else { + pos, err = parseHandshakeResponseHeader(ctx, &resp, data) + } + if err != nil { + return err + } } } diff --git a/server/server.go b/server/server.go index 4eb1a359b4330..69d3c161ba460 100644 --- a/server/server.go +++ b/server/server.go @@ -31,13 +31,12 @@ package server import ( "context" "crypto/tls" - "crypto/x509" "fmt" "io" - "io/ioutil" "math/rand" "net" "net/http" + "unsafe" // For pprof _ "net/http/pprof" "os" @@ -104,7 +103,7 @@ const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | // Server is the MySQL protocol server type Server struct { cfg *config.Config - tlsConfig *tls.Config + tlsConfig unsafe.Pointer // *tls.Config driver IDriver listener net.Listener socket net.Listener @@ -209,7 +208,16 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { clients: make(map[uint32]*clientConn), stopListenerCh: make(chan struct{}, 1), } - s.loadTLSCertificates() + + tlsConfig, err := util.LoadTLSCertificates(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert) + if err != nil { + logutil.BgLogger().Error("secure connection cert/key/ca load fail", zap.Error(err)) + return nil, err + } + logutil.BgLogger().Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0)) + setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert) + atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig)) + setSystemTimeZoneVariable() s.capability = defaultCapability @@ -217,8 +225,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { s.capability |= mysql.ClientSSL } - var err error - if s.cfg.Host != "" && s.cfg.Port != 0 { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) if s.listener, err = net.Listen("tcp", addr); err == nil { @@ -258,51 +264,12 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { return s, nil } -func (s *Server) loadTLSCertificates() { - defer func() { - if s.tlsConfig != nil { - logutil.BgLogger().Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0)) - variable.SysVars["have_openssl"].Value = "YES" - variable.SysVars["have_ssl"].Value = "YES" - variable.SysVars["ssl_cert"].Value = s.cfg.Security.SSLCert - variable.SysVars["ssl_key"].Value = s.cfg.Security.SSLKey - } else { - logutil.BgLogger().Warn("secure connection is not enabled") - } - }() - - if len(s.cfg.Security.SSLCert) == 0 || len(s.cfg.Security.SSLKey) == 0 { - s.tlsConfig = nil - return - } - - tlsCert, err := tls.LoadX509KeyPair(s.cfg.Security.SSLCert, s.cfg.Security.SSLKey) - if err != nil { - logutil.BgLogger().Warn("load x509 failed", zap.Error(err)) - s.tlsConfig = nil - return - } - - // Try loading CA cert. - clientAuthPolicy := tls.NoClientCert - var certPool *x509.CertPool - if len(s.cfg.Security.SSLCA) > 0 { - caCert, err := ioutil.ReadFile(s.cfg.Security.SSLCA) - if err != nil { - logutil.BgLogger().Warn("read file failed", zap.Error(err)) - } else { - certPool = x509.NewCertPool() - if certPool.AppendCertsFromPEM(caCert) { - clientAuthPolicy = tls.VerifyClientCertIfGiven - } - variable.SysVars["ssl_ca"].Value = s.cfg.Security.SSLCA - } - } - s.tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - ClientCAs: certPool, - ClientAuth: clientAuthPolicy, - } +func setSSLVariable(ca, key, cert string) { + variable.SysVars["have_openssl"].Value = "YES" + variable.SysVars["have_ssl"].Value = "YES" + variable.SysVars["ssl_cert"].Value = cert + variable.SysVars["ssl_key"].Value = key + variable.SysVars["ssl_ca"].Value = ca } // Run runs the server. @@ -564,6 +531,15 @@ func (s *Server) Kill(connectionID uint64, query bool) { killConn(conn) } +// UpdateTLSConfig implements the SessionManager interface. +func (s *Server) UpdateTLSConfig(cfg *tls.Config) { + atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(cfg)) +} + +func (s *Server) getTLSConfig() *tls.Config { + return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig)) +} + func killConn(conn *clientConn) { sessVars := conn.ctx.GetSessionVars() atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1) diff --git a/server/server_test.go b/server/server_test.go index 22c9c452135c7..0728492af9af2 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -30,6 +30,7 @@ import ( "github.com/go-sql-driver/mysql" . "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" tmysql "github.com/pingcap/parser/mysql" @@ -1106,10 +1107,26 @@ func (cli *testServerClient) runTestStmtCount(t *C) { } func (cli *testServerClient) runTestTLSConnection(t *C, overrider configOverrider) error { - db, err := sql.Open("mysql", cli.getDSN(overrider)) + dsn := cli.getDSN(overrider) + db, err := sql.Open("mysql", dsn) t.Assert(err, IsNil) defer db.Close() _, err = db.Exec("USE test") + if err != nil { + return errors.Annotate(err, "dsn:"+dsn) + } + return err +} + +func (cli *testServerClient) runReloadTLS(t *C, overrider configOverrider, errorNoRollback bool) error { + db, err := sql.Open("mysql", cli.getDSN(overrider)) + t.Assert(err, IsNil) + defer db.Close() + sql := "alter instance reload tls" + if errorNoRollback { + sql += " no rollback on error" + } + _, err = db.Exec(sql) return err } diff --git a/server/tidb_test.go b/server/tidb_test.go index 08c1e766b25b2..d1023fc745857 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -271,7 +271,7 @@ func (ts *tidbTestSuite) TestSocket(c *C) { // generateCert generates a private key and a certificate in PEM format based on parameters. // If parentCert and parentCertKey is specified, the new certificate will be signed by the parentCert. // Otherwise, the new certificate will be self-signed and is a CA. -func generateCert(sn int, commonName string, parentCert *x509.Certificate, parentCertKey *rsa.PrivateKey, outKeyFile string, outCertFile string) (*x509.Certificate, *rsa.PrivateKey, error) { +func generateCert(sn int, commonName string, parentCert *x509.Certificate, parentCertKey *rsa.PrivateKey, outKeyFile string, outCertFile string, opts ...func(c *x509.Certificate)) (*x509.Certificate, *rsa.PrivateKey, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 528) if err != nil { return nil, nil, errors.Trace(err) @@ -288,6 +288,9 @@ func generateCert(sn int, commonName string, parentCert *x509.Certificate, paren ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, BasicConstraintsValid: true, } + for _, opt := range opts { + opt(&template) + } var parent *x509.Certificate var priv *rsa.PrivateKey @@ -369,7 +372,7 @@ func (ts *tidbTestSuite) TestSystemTimeZone(c *C) { tk.MustQuery("select @@system_time_zone").Check(tz1) } -func (ts *tidbTestSuite) TestTLS(c *C) { +func (ts *tidbTestSerialSuite) TestTLS(c *C) { // Generate valid TLS certificates. caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key.pem", "/tmp/ca-cert.pem") c.Assert(err, IsNil) @@ -403,7 +406,7 @@ func (ts *tidbTestSuite) TestTLS(c *C) { time.Sleep(time.Millisecond * 100) err = cli.runTestTLSConnection(c, connOverrider) // We should get ErrNoTLS. c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, mysql.ErrNoTLS.Error()) + c.Assert(errors.Cause(err).Error(), Equals, mysql.ErrNoTLS.Error()) server.Close() // Start the server with TLS but without CA, in this case the server will not verify client's certificate. @@ -460,6 +463,169 @@ func (ts *tidbTestSuite) TestTLS(c *C) { c.Assert(err, IsNil) cli.runTestRegression(c, connOverrider, "TLSRegression") server.Close() + + c.Assert(util.IsTLSExpiredError(errors.New("unknown test")), IsFalse) + c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.CANotAuthorizedForThisName}), IsFalse) + c.Assert(util.IsTLSExpiredError(x509.CertificateInvalidError{Reason: x509.Expired}), IsTrue) + + _, err = util.LoadTLSCertificates("", "wrong key", "wrong cert") + c.Assert(err, NotNil) + _, err = util.LoadTLSCertificates("wrong ca", "/tmp/server-key.pem", "/tmp/server-cert.pem") + c.Assert(err, NotNil) +} + +func (ts *tidbTestSerialSuite) TestReloadTLS(c *C) { + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-reload.pem", "/tmp/ca-cert-reload.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload.pem", "/tmp/server-cert-reload.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-reload.pem", "/tmp/client-cert-reload.pem") + c.Assert(err, IsNil) + err = registerTLSConfig("client-certificate-reload", "/tmp/ca-cert-reload.pem", "/tmp/client-cert-reload.pem", "/tmp/client-key-reload.pem", "tidb-server", true) + c.Assert(err, IsNil) + + defer func() { + os.Remove("/tmp/ca-key-reload.pem") + os.Remove("/tmp/ca-cert-reload.pem") + + os.Remove("/tmp/server-key-reload.pem") + os.Remove("/tmp/server-cert-reload.pem") + os.Remove("/tmp/client-key-reload.pem") + os.Remove("/tmp/client-cert-reload.pem") + }() + + // try old cert used in startup configuration. + cli := newTestServerClient() + cfg := config.NewConfig() + cfg.Port = cli.port + cfg.Status.ReportStatus = false + cfg.Security = config.Security{ + SSLCA: "/tmp/ca-cert-reload.pem", + SSLCert: "/tmp/server-cert-reload.pem", + SSLKey: "/tmp/server-key-reload.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, IsNil) + go server.Run() + time.Sleep(time.Millisecond * 100) + // The client provides a valid certificate. + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + } + err = cli.runTestTLSConnection(c, connOverrider) + c.Assert(err, IsNil) + + // try reload a valid cert. + tlsCfg := server.getTLSConfig() + cert, err := x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) + c.Assert(err, IsNil) + oldExpireTime := cert.NotAfter + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload2.pem", "/tmp/server-cert-reload2.pem", func(c *x509.Certificate) { + c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() + c.NotAfter = time.Now().Add(1 * time.Hour).UTC() + }) + c.Assert(err, IsNil) + os.Rename("/tmp/server-key-reload2.pem", "/tmp/server-key-reload.pem") + os.Rename("/tmp/server-cert-reload2.pem", "/tmp/server-cert-reload.pem") + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + } + err = cli.runReloadTLS(c, connOverrider, false) + c.Assert(err, IsNil) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + } + err = cli.runTestTLSConnection(c, connOverrider) + c.Assert(err, IsNil) + + tlsCfg = server.getTLSConfig() + cert, err = x509.ParseCertificate(tlsCfg.Certificates[0].Certificate[0]) + c.Assert(err, IsNil) + newExpireTime := cert.NotAfter + c.Assert(newExpireTime.After(oldExpireTime), IsTrue) + + // try reload a expired cert. + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-reload3.pem", "/tmp/server-cert-reload3.pem", func(c *x509.Certificate) { + c.NotBefore = time.Now().Add(-24 * time.Hour).UTC() + c.NotAfter = c.NotBefore.Add(1 * time.Hour).UTC() + }) + c.Assert(err, IsNil) + os.Rename("/tmp/server-key-reload3.pem", "/tmp/server-key-reload.pem") + os.Rename("/tmp/server-cert-reload3.pem", "/tmp/server-cert-reload.pem") + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "skip-verify" + } + err = cli.runReloadTLS(c, connOverrider, false) + c.Assert(err, IsNil) + connOverrider = func(config *mysql.Config) { + config.TLSConfig = "client-certificate-reload" + } + err = cli.runTestTLSConnection(c, connOverrider) + c.Assert(err, NotNil) + c.Assert(util.IsTLSExpiredError(err), IsTrue, Commentf("real error is %+v", err)) + server.Close() +} + +func (ts *tidbTestSerialSuite) TestErrorNoRollback(c *C) { + // Generate valid TLS certificates. + caCert, caKey, err := generateCert(0, "TiDB CA", nil, nil, "/tmp/ca-key-rollback.pem", "/tmp/ca-cert-rollback.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(1, "tidb-server", caCert, caKey, "/tmp/server-key-rollback.pem", "/tmp/server-cert-rollback.pem") + c.Assert(err, IsNil) + _, _, err = generateCert(2, "SQL Client Certificate", caCert, caKey, "/tmp/client-key-rollback.pem", "/tmp/client-cert-rollback.pem") + c.Assert(err, IsNil) + err = registerTLSConfig("client-cert-rollback-test", "/tmp/ca-cert-rollback.pem", "/tmp/client-cert-rollback.pem", "/tmp/client-key-rollback.pem", "tidb-server", true) + c.Assert(err, IsNil) + + defer func() { + os.Remove("/tmp/ca-key-rollback.pem") + os.Remove("/tmp/ca-cert-rollback.pem") + + os.Remove("/tmp/server-key-rollback.pem") + os.Remove("/tmp/server-cert-rollback.pem") + os.Remove("/tmp/client-key-rollback.pem") + os.Remove("/tmp/client-cert-rollback.pem") + }() + + cli := newTestServerClient() + cfg := config.NewConfig() + cfg.Port = cli.port + cfg.Status.ReportStatus = false + + // test cannot startup with wrong tls config + cfg.Security = config.Security{ + SSLCA: "wrong path", + SSLCert: "wrong path", + SSLKey: "wrong path", + } + _, err = NewServer(cfg, ts.tidbdrv) + c.Assert(err, NotNil) + + // test reload tls fail with/without "error no rollback option" + cfg.Security = config.Security{ + SSLCA: "/tmp/ca-cert-rollback.pem", + SSLCert: "/tmp/server-cert-rollback.pem", + SSLKey: "/tmp/server-key-rollback.pem", + } + server, err := NewServer(cfg, ts.tidbdrv) + c.Assert(err, IsNil) + go server.Run() + time.Sleep(time.Millisecond * 100) + connOverrider := func(config *mysql.Config) { + config.TLSConfig = "client-cert-rollback-test" + } + err = cli.runTestTLSConnection(c, connOverrider) + c.Assert(err, IsNil) + os.Remove("/tmp/server-key-rollback.pem") + err = cli.runReloadTLS(c, connOverrider, false) + c.Assert(err, NotNil) + tlsCfg := server.getTLSConfig() + c.Assert(tlsCfg, NotNil) + err = cli.runReloadTLS(c, connOverrider, true) + c.Assert(err, IsNil) + tlsCfg = server.getTLSConfig() + c.Assert(tlsCfg, IsNil) } func (ts *tidbTestSuite) TestClientWithCollation(c *C) { diff --git a/util/misc.go b/util/misc.go index 6929e9d935d91..4d1042dc431a9 100644 --- a/util/misc.go +++ b/util/misc.go @@ -15,8 +15,10 @@ package util import ( "crypto/tls" + "crypto/x509" "crypto/x509/pkix" "fmt" + "io/ioutil" "runtime" "strconv" "strings" @@ -319,3 +321,50 @@ type SequenceTable interface { GetSequenceNextVal(dbName, seqName string) (int64, error) SetSequenceVal(newVal int64) (int64, bool, error) } + +// LoadTLSCertificates loads CA/KEY/CERT for special paths. +func LoadTLSCertificates(ca, key, cert string) (tlsConfig *tls.Config, err error) { + if len(cert) == 0 || len(key) == 0 { + return + } + + var tlsCert tls.Certificate + tlsCert, err = tls.LoadX509KeyPair(cert, key) + if err != nil { + logutil.BgLogger().Warn("load x509 failed", zap.Error(err)) + err = errors.Trace(err) + return + } + + // Try loading CA cert. + clientAuthPolicy := tls.NoClientCert + var certPool *x509.CertPool + if len(ca) > 0 { + var caCert []byte + caCert, err = ioutil.ReadFile(ca) + if err != nil { + logutil.BgLogger().Warn("read file failed", zap.Error(err)) + err = errors.Trace(err) + return + } + certPool = x509.NewCertPool() + if certPool.AppendCertsFromPEM(caCert) { + clientAuthPolicy = tls.VerifyClientCertIfGiven + } + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + ClientCAs: certPool, + ClientAuth: clientAuthPolicy, + } + return +} + +// IsTLSExpiredError checks error is caused by TLS expired. +func IsTLSExpiredError(err error) bool { + err = errors.Cause(err) + if inval, ok := err.(x509.CertificateInvalidError); !ok || inval.Reason != x509.Expired { + return false + } + return true +} diff --git a/util/processinfo.go b/util/processinfo.go index 113b81b6d5b78..3cd48429235a5 100644 --- a/util/processinfo.go +++ b/util/processinfo.go @@ -14,6 +14,7 @@ package util import ( + "crypto/tls" "fmt" "time" @@ -94,4 +95,5 @@ type SessionManager interface { ShowProcessList() map[uint64]*ProcessInfo GetProcessInfo(id uint64) (*ProcessInfo, bool) Kill(connectionID uint64, query bool) + UpdateTLSConfig(cfg *tls.Config) }