From d5ae67ceab5d3f818a90e35edeac34a544e2d8c0 Mon Sep 17 00:00:00 2001 From: Chunzhu Li Date: Tue, 6 Apr 2021 17:48:17 +0800 Subject: [PATCH 1/3] refine TLS configuration --- dm/master/server.go | 2 +- pkg/conn/basedb.go | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dm/master/server.go b/dm/master/server.go index 6707d7b6ac..b25f6e6fad 100644 --- a/dm/master/server.go +++ b/dm/master/server.go @@ -1230,7 +1230,7 @@ func enableTLS(tlsCfg *config.Security) bool { return false } - if len(tlsCfg.SSLCA) == 0 || len(tlsCfg.SSLCert) == 0 || len(tlsCfg.SSLKey) == 0 { + if len(tlsCfg.SSLCA) == 0 { return false } diff --git a/pkg/conn/basedb.go b/pkg/conn/basedb.go index 8112210f76..a546a2a785 100644 --- a/pkg/conn/basedb.go +++ b/pkg/conn/basedb.go @@ -64,8 +64,7 @@ func (d *DefaultDBProviderImpl) Apply(config config.DBConfig) (*BaseDB, error) { config.User, config.Password, config.Host, config.Port) doFuncInClose := func() {} - if config.Security != nil && len(config.Security.SSLCA) != 0 && - len(config.Security.SSLCert) != 0 && len(config.Security.SSLKey) != 0 { + if config.Security != nil && len(config.Security.SSLCA) != 0 { tlsConfig, err := toolutils.ToTLSConfig(config.Security.SSLCA, config.Security.SSLCert, config.Security.SSLKey) if err != nil { return nil, terror.ErrConnInvalidTLSConfig.Delegate(err) From d0a90a312d24becddffcd63f6745e51db81028b2 Mon Sep 17 00:00:00 2001 From: Chunzhu Li Date: Wed, 7 Apr 2021 19:47:45 +0800 Subject: [PATCH 2/3] fix tls problem for relay and syncer --- dumpling/dumpling.go | 7 ++++--- pkg/binlog/reader/tcp.go | 16 ++++++++++++++++ relay/relay.go | 33 +++++++++++++++------------------ relay/relay_test.go | 19 ++++++++++--------- relay/util_test.go | 5 +++-- syncer/heartbeat.go | 17 +++++++++-------- 6 files changed, 57 insertions(+), 40 deletions(-) diff --git a/dumpling/dumpling.go b/dumpling/dumpling.go index c94a8729bd..3bc250b087 100644 --- a/dumpling/dumpling.go +++ b/dumpling/dumpling.go @@ -15,7 +15,6 @@ package dumpling import ( "context" - "database/sql" "os" "strings" "time" @@ -31,6 +30,7 @@ import ( "github.com/pingcap/dm/dm/config" "github.com/pingcap/dm/dm/pb" "github.com/pingcap/dm/dm/unit" + "github.com/pingcap/dm/pkg/conn" "github.com/pingcap/dm/pkg/log" "github.com/pingcap/dm/pkg/terror" "github.com/pingcap/dm/pkg/utils" @@ -274,11 +274,12 @@ func (m *Dumpling) constructArgs() (*export.Config, error) { // detectSQLMode tries to detect SQL mode from upstream. If success, write it to LoaderConfig. // Because loader will use this SQL mode, we need to treat disable `EscapeBackslash` when NO_BACKSLASH_ESCAPES func (m *Dumpling) detectSQLMode(ctx context.Context) { - db, err := sql.Open("mysql", m.dumpConfig.GetDSN("")) + baseDB, err := conn.DefaultDBProvider.Apply(m.cfg.From) if err != nil { return } - defer db.Close() + defer baseDB.Close() + db := baseDB.DB sqlMode, err := utils.GetGlobalVariable(ctx, db, "sql_mode") if err != nil { diff --git a/pkg/binlog/reader/tcp.go b/pkg/binlog/reader/tcp.go index cb0e5b0d04..b4704218fa 100644 --- a/pkg/binlog/reader/tcp.go +++ b/pkg/binlog/reader/tcp.go @@ -18,8 +18,11 @@ import ( "database/sql" "encoding/json" "fmt" + "strconv" "sync" + "sync/atomic" + "github.com/go-sql-driver/mysql" gmysql "github.com/siddontang/go-mysql/mysql" "github.com/siddontang/go-mysql/replication" "go.uber.org/zap" @@ -31,6 +34,8 @@ import ( "github.com/pingcap/dm/pkg/utils" ) +var customID int64 + // TCPReader is a binlog event reader which read binlog events from a TCP stream. type TCPReader struct { syncerCfg replication.BinlogSyncerConfig @@ -121,6 +126,17 @@ func (r *TCPReader) Close() error { if connID > 0 { dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4", r.syncerCfg.User, r.syncerCfg.Password, r.syncerCfg.Host, r.syncerCfg.Port) + if r.syncerCfg.TLSConfig != nil { + tlsName := "replicate" + strconv.FormatInt(atomic.AddInt64(&customID, 1), 10) + err := mysql.RegisterTLSConfig(tlsName, r.syncerCfg.TLSConfig) + if err != nil { + return terror.WithScope( + terror.Annotatef(terror.DBErrorAdapt(err, terror.ErrDBDriverError), + "fail to register tls config", r.syncerCfg.Host, r.syncerCfg.Port), terror.ScopeUpstream) + } + dsn += "&tls=" + tlsName + defer mysql.DeregisterTLSConfig(tlsName) + } db, err := sql.Open("mysql", dsn) if err != nil { return terror.WithScope( diff --git a/relay/relay.go b/relay/relay.go index bc22f85a4a..f781de8163 100755 --- a/relay/relay.go +++ b/relay/relay.go @@ -16,7 +16,6 @@ package relay import ( "context" "crypto/tls" - "database/sql" "fmt" "os" "path/filepath" @@ -104,7 +103,7 @@ type Process interface { // Relay relays mysql binlog to local file. type Relay struct { - db *sql.DB + db *conn.BaseDB cfg *Config syncerCfg replication.BinlogSyncerConfig @@ -150,7 +149,7 @@ func (r *Relay) Init(ctx context.Context) (err error) { return terror.WithScope(err, terror.ScopeUpstream) } - r.db = db.DB + r.db = db rollbackHolder.Add(fr.FuncRollback{Name: "close-DB", Fn: r.closeDB}) if err2 := os.MkdirAll(r.cfg.RelayDir, 0755); err2 != nil { @@ -194,12 +193,12 @@ func (r *Relay) Process(ctx context.Context, pr chan pb.ProcessResult) { } func (r *Relay) process(ctx context.Context) error { - parser2, err := utils.GetParser(ctx, r.db) // refine to use user config later + parser2, err := utils.GetParser(ctx, r.db.DB) // refine to use user config later if err != nil { return err } - isNew, err := isNewServer(ctx, r.meta.UUID(), r.db, r.cfg.Flavor) + isNew, err := isNewServer(ctx, r.meta.UUID(), r.db.DB, r.cfg.Flavor) if err != nil { return err } @@ -379,7 +378,7 @@ func (r *Relay) tryRecoverLatestFile(ctx context.Context, parser2 *parser.Parser zap.Stringer("from position", latestPos), zap.Stringer("to position", result.LatestPos), log.WrapStringerField("from GTID set", latestGTID), log.WrapStringerField("to GTID set", result.LatestGTIDs)) if result.LatestGTIDs != nil { - dbConn, err2 := r.db.Conn(ctx) + dbConn, err2 := r.db.DB.Conn(ctx) if err2 != nil { return err2 } @@ -446,7 +445,7 @@ func (r *Relay) handleEvents(ctx context.Context, reader2 reader.Reader, transfo r.logger.Error("the requested binlog files have purged in the master server or the master server have switched, currently DM do no support to handle this error", zap.String("db host", cfg.Host), zap.Int("db port", cfg.Port), zap.Stringer("last pos", lastPos), log.ShortError(err)) // log the status for debug - pos, gs, err2 := utils.GetMasterStatus(ctx, r.db, r.cfg.Flavor) + pos, gs, err2 := utils.GetMasterStatus(ctx, r.db.DB, r.cfg.Flavor) if err2 == nil { r.logger.Info("current master status", zap.Stringer("position", pos), log.WrapStringerField("GTID sets", gs)) } @@ -485,7 +484,7 @@ func (r *Relay) handleEvents(ctx context.Context, reader2 reader.Reader, transfo // fake rotate event if _, ok := e.Event.(*replication.RotateEvent); ok && e.Header.Timestamp == 0 && e.Header.LogPos == 0 { - isNew, err2 := isNewServer(ctx, r.meta.UUID(), r.db, r.cfg.Flavor) + isNew, err2 := isNewServer(ctx, r.meta.UUID(), r.db.DB, r.cfg.Flavor) if err2 != nil { return err2 } @@ -561,7 +560,7 @@ func (r *Relay) tryUpdateActiveRelayLog(e *replication.BinlogEvent, filename str // reSetupMeta re-setup the metadata when switching to a new upstream master server. func (r *Relay) reSetupMeta(ctx context.Context) error { - uuid, err := utils.GetServerUUID(ctx, r.db, r.cfg.Flavor) + uuid, err := utils.GetServerUUID(ctx, r.db.DB, r.cfg.Flavor) if err != nil { return err } @@ -601,7 +600,7 @@ func (r *Relay) reSetupMeta(ctx context.Context) error { var latestPosName, latestGTIDStr string if (r.cfg.EnableGTID && len(r.cfg.BinlogGTID) == 0) || (!r.cfg.EnableGTID && len(r.cfg.BinLogName) == 0) { - latestPos, latestGTID, err2 := utils.GetMasterStatus(ctx, r.db, r.cfg.Flavor) + latestPos, latestGTID, err2 := utils.GetMasterStatus(ctx, r.db.DB, r.cfg.Flavor) if err2 != nil { return err2 } @@ -681,7 +680,7 @@ func (r *Relay) doIntervalOps(ctx context.Context) { return } ctx2, cancel2 := context.WithTimeout(ctx, utils.DefaultDBTimeout) - pos, _, err := utils.GetMasterStatus(ctx2, r.db, r.cfg.Flavor) + pos, _, err := utils.GetMasterStatus(ctx2, r.db.DB, r.cfg.Flavor) cancel2() if err != nil { r.logger.Warn("get master status", zap.Error(err)) @@ -721,7 +720,7 @@ func (r *Relay) setUpReader(ctx context.Context) (reader.Reader, error) { ctx2, cancel := context.WithTimeout(ctx, utils.DefaultDBTimeout) defer cancel() - randomServerID, err := utils.ReuseServerID(ctx2, r.cfg.ServerID, r.db) + randomServerID, err := utils.ReuseServerID(ctx2, r.cfg.ServerID, r.db.DB) if err != nil { // should never happened unless the master has too many slave return nil, terror.Annotate(err, "fail to get random server id for relay reader") @@ -830,7 +829,7 @@ func (r *Relay) Close() { // Status implements the dm.Unit interface. func (r *Relay) Status(ctx context.Context) interface{} { - masterPos, masterGTID, err := utils.GetMasterStatus(ctx, r.db, r.cfg.Flavor) + masterPos, masterGTID, err := utils.GetMasterStatus(ctx, r.db.DB, r.cfg.Flavor) if err != nil { r.logger.Warn("get master status", zap.Error(err)) } @@ -910,9 +909,7 @@ func (r *Relay) Reload(newCfg *Config) error { r.cfg.Charset = newCfg.Charset r.closeDB() - cfg := r.cfg.From - dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&interpolateParams=true&readTimeout=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, showStatusConnectionTimeout) - db, err := sql.Open("mysql", dbDSN) + db, err := conn.DefaultDBProvider.Apply(r.cfg.From) if err != nil { return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream) } @@ -991,7 +988,7 @@ func (r *Relay) setSyncConfig() error { func (r *Relay) adjustGTID(ctx context.Context, gset gtid.Set) (gtid.Set, error) { // setup a TCP binlog reader (because no relay can be used when upgrading). syncCfg := r.syncerCfg - randomServerID, err := utils.ReuseServerID(ctx, r.cfg.ServerID, r.db) + randomServerID, err := utils.ReuseServerID(ctx, r.cfg.ServerID, r.db.DB) if err != nil { return nil, terror.Annotate(err, "fail to get random server id when relay adjust gtid") } @@ -1003,7 +1000,7 @@ func (r *Relay) adjustGTID(ctx context.Context, gset gtid.Set) (gtid.Set, error) return nil, err } - dbConn, err2 := r.db.Conn(ctx) + dbConn, err2 := r.db.DB.Conn(ctx) if err2 != nil { return nil, err2 } diff --git a/relay/relay_test.go b/relay/relay_test.go index 18c801fab2..eaeca7cb06 100644 --- a/relay/relay_test.go +++ b/relay/relay_test.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/dm/dm/config" "github.com/pingcap/dm/pkg/binlog" "github.com/pingcap/dm/pkg/binlog/event" + "github.com/pingcap/dm/pkg/conn" "github.com/pingcap/dm/pkg/gtid" "github.com/pingcap/dm/pkg/log" "github.com/pingcap/dm/pkg/streamer" @@ -105,11 +106,10 @@ func getDBConfigForTest() config.DBConfig { } } -func openDBForTest() (*sql.DB, error) { +func openDBForTest() (*conn.BaseDB, error) { cfg := getDBConfigForTest() - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4", cfg.User, cfg.Password, cfg.Host, cfg.Port) - return sql.Open("mysql", dsn) + return conn.DefaultDBProvider.Apply(cfg) } // mockReader is used only for relay testing. @@ -512,7 +512,7 @@ func (t *testRelaySuite) TestReSetupMeta(c *C) { r.db.Close() r.db = nil }() - uuid, err := utils.GetServerUUID(ctx, r.db, r.cfg.Flavor) + uuid, err := utils.GetServerUUID(ctx, r.db.DB, r.cfg.Flavor) c.Assert(err, IsNil) // re-setup meta with start pos adjusted @@ -606,21 +606,22 @@ func (t *testRelaySuite) TestProcess(c *C) { ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second) defer cancel2() var connID uint32 + db := r.db.DB c.Assert(utils.WaitSomething(30, 100*time.Millisecond, func() bool { - connID, err = getBinlogDumpConnID(ctx2, r.db) + connID, err = getBinlogDumpConnID(ctx2, db) return err == nil }), IsTrue) - _, err = r.db.ExecContext(ctx2, fmt.Sprintf(`KILL %d`, connID)) + _, err = db.ExecContext(ctx2, fmt.Sprintf(`KILL %d`, connID)) c.Assert(err, IsNil) // execute a DDL again lastDDL := "CREATE DATABASE `test_relay_retry_db`" - _, err = r.db.ExecContext(ctx2, lastDDL) + _, err = db.ExecContext(ctx2, lastDDL) c.Assert(err, IsNil) defer func() { query := "DROP DATABASE IF EXISTS `test_relay_retry_db`" - _, err = r.db.ExecContext(ctx2, query) + _, err = db.ExecContext(ctx2, query) c.Assert(err, IsNil) }() @@ -644,7 +645,7 @@ func (t *testRelaySuite) TestProcess(c *C) { // check whether have binlog file in relay directory // and check for events already done in `TestHandleEvent` - uuid, err := utils.GetServerUUID(ctx2, r.db, r.cfg.Flavor) + uuid, err := utils.GetServerUUID(ctx2, db, r.cfg.Flavor) c.Assert(err, IsNil) files, err := streamer.CollectAllBinlogFiles(filepath.Join(relayCfg.RelayDir, fmt.Sprintf("%s.000001", uuid))) c.Assert(err, IsNil) diff --git a/relay/util_test.go b/relay/util_test.go index d095d68048..cde2ab876a 100644 --- a/relay/util_test.go +++ b/relay/util_test.go @@ -32,9 +32,10 @@ func (t *testUtilSuite) TestIsNewServer(c *C) { ctx, cancel := context.WithTimeout(context.Background(), utils.DefaultDBTimeout) defer cancel() - db, err := openDBForTest() + baseDB, err := openDBForTest() c.Assert(err, IsNil) - defer db.Close() + defer baseDB.Close() + db := baseDB.DB flavor := gmysql.MySQLFlavor diff --git a/syncer/heartbeat.go b/syncer/heartbeat.go index a211ff49a3..1f96afe509 100644 --- a/syncer/heartbeat.go +++ b/syncer/heartbeat.go @@ -27,6 +27,7 @@ import ( "go.uber.org/zap" "github.com/pingcap/dm/dm/config" + "github.com/pingcap/dm/pkg/conn" "github.com/pingcap/dm/pkg/log" "github.com/pingcap/dm/pkg/terror" ) @@ -83,7 +84,7 @@ type Heartbeat struct { schema string // for which schema the heartbeat table belongs to table string // for which table the heartbeat table belongs to - primary *sql.DB + primary *conn.BaseDB secondaryTs map[string]float64 // task-name => secondary (syncer) ts cancel context.CancelFunc @@ -122,12 +123,12 @@ func (h *Heartbeat) AddTask(name string) error { if h.primary == nil { // open DB dbCfg := h.cfg.primaryCfg - dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8&interpolateParams=true&readTimeout=1m", dbCfg.User, dbCfg.Password, dbCfg.Host, dbCfg.Port) - primary, err := sql.Open("mysql", dbDSN) + dbCfg.RawDBCfg.ReadTimeout = "1m" + baseDB, err := conn.DefaultDBProvider.Apply(dbCfg) if err != nil { return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream) } - h.primary = primary + h.primary = baseDB // init table err = h.init() @@ -275,7 +276,7 @@ func (h *Heartbeat) run(ctx context.Context) { // createTable creates heartbeat database if not exists in primary func (h *Heartbeat) createDatabase() error { createDatabase := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", h.schema) - _, err := h.primary.Exec(createDatabase) + _, err := h.primary.DB.Exec(createDatabase) h.logger.Info("create heartbeat schema", zap.String("sql", createDatabase)) return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream) } @@ -289,7 +290,7 @@ func (h *Heartbeat) createTable() error { PRIMARY KEY (server_id) )`, tableName) - _, err := h.primary.Exec(createTableStmt) + _, err := h.primary.DB.Exec(createTableStmt) h.logger.Info("create heartbeat table", zap.String("sql", createTableStmt)) return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream) } @@ -297,7 +298,7 @@ func (h *Heartbeat) createTable() error { // updateTS use `REPLACE` statement to insert or update ts func (h *Heartbeat) updateTS() error { query := fmt.Sprintf("REPLACE INTO `%s`.`%s` (`ts`, `server_id`) VALUES(UTC_TIMESTAMP(6), ?)", h.schema, h.table) - _, err := h.primary.Exec(query, h.cfg.serverID) + _, err := h.primary.DB.Exec(query, h.cfg.serverID) h.logger.Debug("update ts", zap.String("sql", query), zap.Uint32("server ID", h.cfg.serverID)) return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream) } @@ -330,7 +331,7 @@ func reportLag(taskName string, lag float64) { } func (h *Heartbeat) getPrimaryTS() (float64, error) { - return h.getTS(h.primary) + return h.getTS(h.primary.DB) } func (h *Heartbeat) getTS(db *sql.DB) (float64, error) { From 85ea9141a7b9974cbd6f95965ef079930bc7427a Mon Sep 17 00:00:00 2001 From: Chunzhu Li Date: Wed, 7 Apr 2021 21:01:30 +0800 Subject: [PATCH 3/3] fix lint --- relay/relay.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/relay.go b/relay/relay.go index 00223a7699..dd1388df79 100755 --- a/relay/relay.go +++ b/relay/relay.go @@ -916,6 +916,7 @@ func (r *Relay) Reload(newCfg *Config) error { r.cfg.Charset = newCfg.Charset r.closeDB() + r.cfg.From.RawDBCfg.ReadTimeout = showStatusConnectionTimeout db, err := conn.DefaultDBProvider.Apply(r.cfg.From) if err != nil { return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream)