Skip to content
This repository has been archived by the owner on Nov 24, 2023. It is now read-only.

conn: refine TLS configuration #1560

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dm/master/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ func enableTLS(tlsCfg *config.Security) bool {
return false
}

if len(tlsCfg.SSLCA) == 0 || len(tlsCfg.SSLCert) == 0 || len(tlsCfg.SSLKey) == 0 {
if len(tlsCfg.SSLCA) == 0 {
Copy link

@coderplay coderplay Apr 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue in #1555 was a MySQL client side tls issue , but this line of change is for DM server side tls config that will effect on the AdvertiseAddr. IIUC, it's unrelated, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pingcap/dm/pull/1560/files#diff-d42d00fe16fdbc10836731179db540c30ef59f1eb7789bf2ad3ea771554eceb5R67
The MySQL related file is in pkg/conn/basedb.go. I change this line because it's wrong for dm-master too.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does dm-worker need similar change?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just searched len(.*SSLCA), looks that's the only place left.

return false
}

Expand Down
7 changes: 4 additions & 3 deletions dumpling/dumpling.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ package dumpling

import (
"context"
"database/sql"
"os"
"strings"
"time"
Expand All @@ -31,6 +30,7 @@ import (
"github.com/pingcap/dm/dm/config"
"github.com/pingcap/dm/dm/pb"
"github.com/pingcap/dm/dm/unit"
"github.com/pingcap/dm/pkg/conn"
"github.com/pingcap/dm/pkg/log"
"github.com/pingcap/dm/pkg/terror"
"github.com/pingcap/dm/pkg/utils"
Expand Down Expand Up @@ -274,11 +274,12 @@ func (m *Dumpling) constructArgs() (*export.Config, error) {
// detectSQLMode tries to detect SQL mode from upstream. If success, write it to LoaderConfig.
// Because loader will use this SQL mode, we need to treat disable `EscapeBackslash` when NO_BACKSLASH_ESCAPES
func (m *Dumpling) detectSQLMode(ctx context.Context) {
db, err := sql.Open("mysql", m.dumpConfig.GetDSN(""))
baseDB, err := conn.DefaultDBProvider.Apply(m.cfg.From)
if err != nil {
return
}
defer db.Close()
defer baseDB.Close()
db := baseDB.DB

sqlMode, err := utils.GetGlobalVariable(ctx, db, "sql_mode")
if err != nil {
Expand Down
16 changes: 16 additions & 0 deletions pkg/binlog/reader/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ import (
"database/sql"
"encoding/json"
"fmt"
"strconv"
"sync"
"sync/atomic"

"github.com/go-sql-driver/mysql"
gmysql "github.com/siddontang/go-mysql/mysql"
"github.com/siddontang/go-mysql/replication"
"go.uber.org/zap"
Expand All @@ -31,6 +34,8 @@ import (
"github.com/pingcap/dm/pkg/utils"
)

var customID int64

// TCPReader is a binlog event reader which read binlog events from a TCP stream.
type TCPReader struct {
syncerCfg replication.BinlogSyncerConfig
Expand Down Expand Up @@ -121,6 +126,17 @@ func (r *TCPReader) Close() error {
if connID > 0 {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4",
r.syncerCfg.User, r.syncerCfg.Password, r.syncerCfg.Host, r.syncerCfg.Port)
if r.syncerCfg.TLSConfig != nil {
tlsName := "replicate" + strconv.FormatInt(atomic.AddInt64(&customID, 1), 10)
err := mysql.RegisterTLSConfig(tlsName, r.syncerCfg.TLSConfig)
if err != nil {
return terror.WithScope(
terror.Annotatef(terror.DBErrorAdapt(err, terror.ErrDBDriverError),
"fail to register tls config", r.syncerCfg.Host, r.syncerCfg.Port), terror.ScopeUpstream)
}
dsn += "&tls=" + tlsName
defer mysql.DeregisterTLSConfig(tlsName)
}
db, err := sql.Open("mysql", dsn)
if err != nil {
return terror.WithScope(
Expand Down
3 changes: 1 addition & 2 deletions pkg/conn/basedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ func (d *DefaultDBProviderImpl) Apply(config config.DBConfig) (*BaseDB, error) {
config.User, config.Password, config.Host, config.Port)

doFuncInClose := func() {}
if config.Security != nil && len(config.Security.SSLCA) != 0 &&
len(config.Security.SSLCert) != 0 && len(config.Security.SSLKey) != 0 {
if config.Security != nil && len(config.Security.SSLCA) != 0 {
tlsConfig, err := toolutils.ToTLSConfig(config.Security.SSLCA, config.Security.SSLCert, config.Security.SSLKey)
if err != nil {
return nil, terror.ErrConnInvalidTLSConfig.Delegate(err)
Expand Down
34 changes: 16 additions & 18 deletions relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package relay
import (
"context"
"crypto/tls"
"database/sql"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -106,7 +105,7 @@ type Process interface {

// Relay relays mysql binlog to local file.
type Relay struct {
db *sql.DB
db *conn.BaseDB
cfg *Config
syncerCfg replication.BinlogSyncerConfig

Expand Down Expand Up @@ -152,7 +151,7 @@ func (r *Relay) Init(ctx context.Context) (err error) {
return terror.WithScope(err, terror.ScopeUpstream)
}

r.db = db.DB
r.db = db
rollbackHolder.Add(fr.FuncRollback{Name: "close-DB", Fn: r.closeDB})

if err2 := os.MkdirAll(r.cfg.RelayDir, 0755); err2 != nil {
Expand Down Expand Up @@ -196,12 +195,12 @@ func (r *Relay) Process(ctx context.Context, pr chan pb.ProcessResult) {
}

func (r *Relay) process(ctx context.Context) error {
parser2, err := utils.GetParser(ctx, r.db) // refine to use user config later
parser2, err := utils.GetParser(ctx, r.db.DB) // refine to use user config later
if err != nil {
return err
}

isNew, err := isNewServer(ctx, r.meta.UUID(), r.db, r.cfg.Flavor)
isNew, err := isNewServer(ctx, r.meta.UUID(), r.db.DB, r.cfg.Flavor)
if err != nil {
return err
}
Expand Down Expand Up @@ -386,7 +385,7 @@ func (r *Relay) tryRecoverLatestFile(ctx context.Context, parser2 *parser.Parser
zap.Stringer("from position", latestPos), zap.Stringer("to position", result.LatestPos), log.WrapStringerField("from GTID set", latestGTID), log.WrapStringerField("to GTID set", result.LatestGTIDs))

if result.LatestGTIDs != nil {
dbConn, err2 := r.db.Conn(ctx)
dbConn, err2 := r.db.DB.Conn(ctx)
if err2 != nil {
return err2
}
Expand Down Expand Up @@ -453,7 +452,7 @@ func (r *Relay) handleEvents(ctx context.Context, reader2 reader.Reader, transfo
r.logger.Error("the requested binlog files have purged in the master server or the master server have switched, currently DM do no support to handle this error",
zap.String("db host", cfg.Host), zap.Int("db port", cfg.Port), zap.Stringer("last pos", lastPos), log.ShortError(err))
// log the status for debug
pos, gs, err2 := utils.GetMasterStatus(ctx, r.db, r.cfg.Flavor)
pos, gs, err2 := utils.GetMasterStatus(ctx, r.db.DB, r.cfg.Flavor)
if err2 == nil {
r.logger.Info("current master status", zap.Stringer("position", pos), log.WrapStringerField("GTID sets", gs))
}
Expand Down Expand Up @@ -492,7 +491,7 @@ func (r *Relay) handleEvents(ctx context.Context, reader2 reader.Reader, transfo

// fake rotate event
if _, ok := e.Event.(*replication.RotateEvent); ok && e.Header.Timestamp == 0 && e.Header.LogPos == 0 {
isNew, err2 := isNewServer(ctx, r.meta.UUID(), r.db, r.cfg.Flavor)
isNew, err2 := isNewServer(ctx, r.meta.UUID(), r.db.DB, r.cfg.Flavor)
if err2 != nil {
return err2
}
Expand Down Expand Up @@ -568,7 +567,7 @@ func (r *Relay) tryUpdateActiveRelayLog(e *replication.BinlogEvent, filename str

// reSetupMeta re-setup the metadata when switching to a new upstream master server.
func (r *Relay) reSetupMeta(ctx context.Context) error {
uuid, err := utils.GetServerUUID(ctx, r.db, r.cfg.Flavor)
uuid, err := utils.GetServerUUID(ctx, r.db.DB, r.cfg.Flavor)
if err != nil {
return err
}
Expand Down Expand Up @@ -608,7 +607,7 @@ func (r *Relay) reSetupMeta(ctx context.Context) error {

var latestPosName, latestGTIDStr string
if (r.cfg.EnableGTID && len(r.cfg.BinlogGTID) == 0) || (!r.cfg.EnableGTID && len(r.cfg.BinLogName) == 0) {
latestPos, latestGTID, err2 := utils.GetMasterStatus(ctx, r.db, r.cfg.Flavor)
latestPos, latestGTID, err2 := utils.GetMasterStatus(ctx, r.db.DB, r.cfg.Flavor)
if err2 != nil {
return err2
}
Expand Down Expand Up @@ -688,7 +687,7 @@ func (r *Relay) doIntervalOps(ctx context.Context) {
return
}
ctx2, cancel2 := context.WithTimeout(ctx, utils.DefaultDBTimeout)
pos, _, err := utils.GetMasterStatus(ctx2, r.db, r.cfg.Flavor)
pos, _, err := utils.GetMasterStatus(ctx2, r.db.DB, r.cfg.Flavor)
cancel2()
if err != nil {
r.logger.Warn("get master status", zap.Error(err))
Expand Down Expand Up @@ -728,7 +727,7 @@ func (r *Relay) setUpReader(ctx context.Context) (reader.Reader, error) {
ctx2, cancel := context.WithTimeout(ctx, utils.DefaultDBTimeout)
defer cancel()

randomServerID, err := utils.ReuseServerID(ctx2, r.cfg.ServerID, r.db)
randomServerID, err := utils.ReuseServerID(ctx2, r.cfg.ServerID, r.db.DB)
if err != nil {
// should never happened unless the master has too many slave
return nil, terror.Annotate(err, "fail to get random server id for relay reader")
Expand Down Expand Up @@ -837,7 +836,7 @@ func (r *Relay) Close() {

// Status implements the dm.Unit interface.
func (r *Relay) Status(ctx context.Context) interface{} {
masterPos, masterGTID, err := utils.GetMasterStatus(ctx, r.db, r.cfg.Flavor)
masterPos, masterGTID, err := utils.GetMasterStatus(ctx, r.db.DB, r.cfg.Flavor)
if err != nil {
r.logger.Warn("get master status", zap.Error(err))
}
Expand Down Expand Up @@ -917,9 +916,8 @@ func (r *Relay) Reload(newCfg *Config) error {
r.cfg.Charset = newCfg.Charset

r.closeDB()
cfg := r.cfg.From
dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4&interpolateParams=true&readTimeout=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, showStatusConnectionTimeout)
db, err := sql.Open("mysql", dbDSN)
r.cfg.From.RawDBCfg.ReadTimeout = showStatusConnectionTimeout
db, err := conn.DefaultDBProvider.Apply(r.cfg.From)
if err != nil {
return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream)
}
Expand Down Expand Up @@ -998,7 +996,7 @@ func (r *Relay) setSyncConfig() error {
func (r *Relay) adjustGTID(ctx context.Context, gset gtid.Set) (gtid.Set, error) {
// setup a TCP binlog reader (because no relay can be used when upgrading).
syncCfg := r.syncerCfg
randomServerID, err := utils.ReuseServerID(ctx, r.cfg.ServerID, r.db)
randomServerID, err := utils.ReuseServerID(ctx, r.cfg.ServerID, r.db.DB)
if err != nil {
return nil, terror.Annotate(err, "fail to get random server id when relay adjust gtid")
}
Expand All @@ -1010,7 +1008,7 @@ func (r *Relay) adjustGTID(ctx context.Context, gset gtid.Set) (gtid.Set, error)
return nil, err
}

dbConn, err2 := r.db.Conn(ctx)
dbConn, err2 := r.db.DB.Conn(ctx)
if err2 != nil {
return nil, err2
}
Expand Down
19 changes: 10 additions & 9 deletions relay/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/pingcap/dm/dm/config"
"github.com/pingcap/dm/pkg/binlog"
"github.com/pingcap/dm/pkg/binlog/event"
"github.com/pingcap/dm/pkg/conn"
"github.com/pingcap/dm/pkg/gtid"
"github.com/pingcap/dm/pkg/log"
"github.com/pingcap/dm/pkg/streamer"
Expand Down Expand Up @@ -105,11 +106,10 @@ func getDBConfigForTest() config.DBConfig {
}
}

func openDBForTest() (*sql.DB, error) {
func openDBForTest() (*conn.BaseDB, error) {
cfg := getDBConfigForTest()

dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4", cfg.User, cfg.Password, cfg.Host, cfg.Port)
return sql.Open("mysql", dsn)
return conn.DefaultDBProvider.Apply(cfg)
}

// mockReader is used only for relay testing.
Expand Down Expand Up @@ -512,7 +512,7 @@ func (t *testRelaySuite) TestReSetupMeta(c *C) {
r.db.Close()
r.db = nil
}()
uuid, err := utils.GetServerUUID(ctx, r.db, r.cfg.Flavor)
uuid, err := utils.GetServerUUID(ctx, r.db.DB, r.cfg.Flavor)
c.Assert(err, IsNil)

// re-setup meta with start pos adjusted
Expand Down Expand Up @@ -606,21 +606,22 @@ func (t *testRelaySuite) TestProcess(c *C) {
ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel2()
var connID uint32
db := r.db.DB
c.Assert(utils.WaitSomething(30, 100*time.Millisecond, func() bool {
connID, err = getBinlogDumpConnID(ctx2, r.db)
connID, err = getBinlogDumpConnID(ctx2, db)
return err == nil
}), IsTrue)
_, err = r.db.ExecContext(ctx2, fmt.Sprintf(`KILL %d`, connID))
_, err = db.ExecContext(ctx2, fmt.Sprintf(`KILL %d`, connID))
c.Assert(err, IsNil)

// execute a DDL again
lastDDL := "CREATE DATABASE `test_relay_retry_db`"
_, err = r.db.ExecContext(ctx2, lastDDL)
_, err = db.ExecContext(ctx2, lastDDL)
c.Assert(err, IsNil)

defer func() {
query := "DROP DATABASE IF EXISTS `test_relay_retry_db`"
_, err = r.db.ExecContext(ctx2, query)
_, err = db.ExecContext(ctx2, query)
c.Assert(err, IsNil)
}()

Expand All @@ -644,7 +645,7 @@ func (t *testRelaySuite) TestProcess(c *C) {

// check whether have binlog file in relay directory
// and check for events already done in `TestHandleEvent`
uuid, err := utils.GetServerUUID(ctx2, r.db, r.cfg.Flavor)
uuid, err := utils.GetServerUUID(ctx2, db, r.cfg.Flavor)
c.Assert(err, IsNil)
files, err := streamer.CollectAllBinlogFiles(filepath.Join(relayCfg.RelayDir, fmt.Sprintf("%s.000001", uuid)))
c.Assert(err, IsNil)
Expand Down
5 changes: 3 additions & 2 deletions relay/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ func (t *testUtilSuite) TestIsNewServer(c *C) {
ctx, cancel := context.WithTimeout(context.Background(), utils.DefaultDBTimeout)
defer cancel()

db, err := openDBForTest()
baseDB, err := openDBForTest()
c.Assert(err, IsNil)
defer db.Close()
defer baseDB.Close()
db := baseDB.DB

flavor := gmysql.MySQLFlavor

Expand Down
17 changes: 9 additions & 8 deletions syncer/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"go.uber.org/zap"

"github.com/pingcap/dm/dm/config"
"github.com/pingcap/dm/pkg/conn"
"github.com/pingcap/dm/pkg/log"
"github.com/pingcap/dm/pkg/terror"
)
Expand Down Expand Up @@ -83,7 +84,7 @@ type Heartbeat struct {
schema string // for which schema the heartbeat table belongs to
table string // for which table the heartbeat table belongs to

primary *sql.DB
primary *conn.BaseDB
secondaryTs map[string]float64 // task-name => secondary (syncer) ts

cancel context.CancelFunc
Expand Down Expand Up @@ -122,12 +123,12 @@ func (h *Heartbeat) AddTask(name string) error {
if h.primary == nil {
// open DB
dbCfg := h.cfg.primaryCfg
dbDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8&interpolateParams=true&readTimeout=1m", dbCfg.User, dbCfg.Password, dbCfg.Host, dbCfg.Port)
primary, err := sql.Open("mysql", dbDSN)
dbCfg.RawDBCfg.ReadTimeout = "1m"
baseDB, err := conn.DefaultDBProvider.Apply(dbCfg)
if err != nil {
return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream)
}
h.primary = primary
h.primary = baseDB

// init table
err = h.init()
Expand Down Expand Up @@ -275,7 +276,7 @@ func (h *Heartbeat) run(ctx context.Context) {
// createTable creates heartbeat database if not exists in primary
func (h *Heartbeat) createDatabase() error {
createDatabase := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", h.schema)
_, err := h.primary.Exec(createDatabase)
_, err := h.primary.DB.Exec(createDatabase)
h.logger.Info("create heartbeat schema", zap.String("sql", createDatabase))
return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream)
}
Expand All @@ -289,15 +290,15 @@ func (h *Heartbeat) createTable() error {
PRIMARY KEY (server_id)
)`, tableName)

_, err := h.primary.Exec(createTableStmt)
_, err := h.primary.DB.Exec(createTableStmt)
h.logger.Info("create heartbeat table", zap.String("sql", createTableStmt))
return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream)
}

// updateTS use `REPLACE` statement to insert or update ts
func (h *Heartbeat) updateTS() error {
query := fmt.Sprintf("REPLACE INTO `%s`.`%s` (`ts`, `server_id`) VALUES(UTC_TIMESTAMP(6), ?)", h.schema, h.table)
_, err := h.primary.Exec(query, h.cfg.serverID)
_, err := h.primary.DB.Exec(query, h.cfg.serverID)
h.logger.Debug("update ts", zap.String("sql", query), zap.Uint32("server ID", h.cfg.serverID))
return terror.WithScope(terror.DBErrorAdapt(err, terror.ErrDBDriverError), terror.ScopeUpstream)
}
Expand Down Expand Up @@ -330,7 +331,7 @@ func reportLag(taskName string, lag float64) {
}

func (h *Heartbeat) getPrimaryTS() (float64, error) {
return h.getTS(h.primary)
return h.getTS(h.primary.DB)
}

func (h *Heartbeat) getTS(db *sql.DB) (float64, error) {
Expand Down