diff --git a/drainer/loopbacksync/loopbacksync.go b/drainer/loopbacksync/loopbacksync.go index 9960bb180..01b29d7f9 100644 --- a/drainer/loopbacksync/loopbacksync.go +++ b/drainer/loopbacksync/loopbacksync.go @@ -16,6 +16,8 @@ package loopbacksync const ( //MarkTableName mark table name MarkTableName = "retl._drainer_repl_mark" + //ID syncer worker coroutine id + ID = "id" //ChannelID channel id ChannelID = "channel_id" //Val val diff --git a/pkg/loader/executor.go b/pkg/loader/executor.go index 4058bfece..6ca4a1fe5 100644 --- a/pkg/loader/executor.go +++ b/pkg/loader/executor.go @@ -18,6 +18,7 @@ import ( gosql "database/sql" "fmt" "strings" + "sync/atomic" "time" "github.com/pingcap/tidb-binlog/drainer/loopbacksync" @@ -32,11 +33,16 @@ import ( "golang.org/x/sync/errgroup" ) -var defaultBatchSize = 128 +var ( + defaultBatchSize = 128 + defaultWorkerCount = 16 + index int64 +) type executor struct { db *gosql.DB batchSize int + workerCount int info *loopbacksync.LoopBackSync queryHistogramVec *prometheus.HistogramVec refreshTableInfo func(schema string, table string) (info *tableInfo, err error) @@ -44,8 +50,9 @@ type executor struct { func newExecutor(db *gosql.DB) *executor { exe := &executor{ - db: db, - batchSize: defaultBatchSize, + db: db, + batchSize: defaultBatchSize, + workerCount: defaultWorkerCount, } return exe @@ -65,6 +72,10 @@ func (e *executor) setSyncInfo(info *loopbacksync.LoopBackSync) { e.info = info } +func (e *executor) setWorkerCount(workerCount int) { + e.workerCount = workerCount +} + func (e *executor) withQueryHistogramVec(queryHistogramVec *prometheus.HistogramVec) *executor { e.queryHistogramVec = queryHistogramVec return e @@ -119,16 +130,36 @@ func (e *executor) updateMark(channel string, tx *tx) error { if e.info == nil { return nil } - status := 1 - columns := fmt.Sprintf("(%s,%s,%s) VALUES(?,?,?)", loopbacksync.ChannelID, loopbacksync.Val, loopbacksync.ChannelInfo) var args []interface{} - sql := fmt.Sprintf("INSERT INTO %s%s on duplicate key update %s=%s+1;", loopbacksync.MarkTableName, columns, loopbacksync.Val, loopbacksync.Val) - args = append(args, e.info.ChannelID, status, channel) - _, err := tx.autoRollbackExec(sql, args...) + sql := fmt.Sprintf("update %s set %s=%s+1 where %s=? and %s=? limit 1;", loopbacksync.MarkTableName, loopbacksync.Val, loopbacksync.Val, loopbacksync.ID, loopbacksync.ChannelID) + args = append(args, e.addIndex(), e.info.ChannelID) + _, err1 := tx.autoRollbackExec(sql, args...) + if err1 != nil { + return errors.Trace(err1) + } + return nil +} + +func (e *executor) cleanChannelInfo() error { + if e.info == nil { + return nil + } + tx, err := e.begin() if err != nil { return errors.Trace(err) } - return nil + var args []interface{} + sql := fmt.Sprintf("delete from %s where %s=? ", loopbacksync.MarkTableName, loopbacksync.ChannelID) + args = append(args, e.info.ChannelID) + _, err1 := tx.autoRollbackExec(sql, args...) + if err1 != nil { + return errors.Trace(err1) + } + err2 := tx.commit() + return errors.Trace(err2) +} +func (e *executor) addIndex() int64 { + return atomic.AddInt64(&index, 1) % ((int64)(e.workerCount)) } // return a wrap of sql.Tx diff --git a/pkg/loader/load.go b/pkg/loader/load.go index a8493b2e9..8329f808f 100644 --- a/pkg/loader/load.go +++ b/pkg/loader/load.go @@ -505,16 +505,63 @@ func (s *loaderImpl) createMarkTable() error { return nil } +func (s *loaderImpl) initMarkTable() error { + if err := s.createMarkTable(); err != nil { + return errors.Trace(err) + } + return s.initMarkTableData() +} +func (s *loaderImpl) initMarkTableData() error { + tx, err := s.db.Begin() + if err != nil { + return errors.Trace(err) + } + status := 1 + channel := "" + var builder strings.Builder + holder := "(?,?,?,?)" + columns := fmt.Sprintf("(%s,%s,%s,%s) ", loopbacksync.ID, loopbacksync.ChannelID, loopbacksync.Val, loopbacksync.ChannelInfo) + builder.WriteString("REPLACE INTO " + loopbacksync.MarkTableName + columns + " VALUES ") + for i := 0; i < s.workerCount; i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.WriteString(holder) + } + var args []interface{} + for id := 0; id < s.workerCount; id++ { + args = append(args, id, s.loopBackSyncInfo.ChannelID, status, channel) + } + query := builder.String() + if _, err = tx.Exec(query, args...); err != nil { + log.Error("Exec fail, will rollback", zap.String("query", query), zap.Reflect("args", args), zap.Error(err)) + if rbErr := tx.Rollback(); rbErr != nil { + log.Error("Auto rollback", zap.Error(rbErr)) + } + return errors.Trace(err) + } + if err = tx.Commit(); err != nil { + return errors.Trace(err) + } + return nil +} + +func (s *loaderImpl) cleanChannelInfo() { + executor := s.getExecutor() + _ = executor.cleanChannelInfo() +} + // Run will quit when meet any error, or all the txn are drained func (s *loaderImpl) Run() error { if s.loopBackSyncInfo != nil && s.loopBackSyncInfo.LoopbackControl { - if err := s.createMarkTable(); err != nil { + if err := s.initMarkTable(); err != nil { return errors.Trace(err) } } txnManager := newTxnManager(1024, s.input) defer func() { log.Info("Run()... in Loader quit") + s.cleanChannelInfo() close(s.successTxn) txnManager.Close() }() @@ -624,6 +671,7 @@ func (s *loaderImpl) getExecutor() *executor { e = e.withRefreshTableInfo(s.refreshTableInfo) } e.setSyncInfo(s.loopBackSyncInfo) + e.setWorkerCount(s.workerCount) if s.metrics != nil && s.metrics.QueryHistogramVec != nil { e = e.withQueryHistogramVec(s.metrics.QueryHistogramVec) } diff --git a/pkg/loader/model.go b/pkg/loader/model.go index 05e05de06..9001ee457 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -192,7 +192,7 @@ func (dml *DML) updateSQL() (sql string, args []interface{}) { } func createMarkTableDDL() string { - sql := fmt.Sprintf("CREATE TABLE If Not Exists %s ( %s bigint primary key, %s bigint DEFAULT 0, %s varchar(64));", loopbacksync.MarkTableName, loopbacksync.ChannelID, loopbacksync.Val, loopbacksync.ChannelInfo) + sql := fmt.Sprintf("CREATE TABLE If Not Exists %s (%s bigint not null,%s bigint not null DEFAULT 0, %s bigint DEFAULT 0, %s varchar(64) ,PRIMARY KEY (%s,%s));", loopbacksync.MarkTableName, loopbacksync.ID, loopbacksync.ChannelID, loopbacksync.Val, loopbacksync.ChannelInfo, loopbacksync.ID, loopbacksync.ChannelID) return sql } diff --git a/pkg/loader/model_test.go b/pkg/loader/model_test.go index 8a83e5007..e4caf939c 100644 --- a/pkg/loader/model_test.go +++ b/pkg/loader/model_test.go @@ -242,11 +242,10 @@ func (s *SQLSuite) TestUpdateMarkSQL(c *check.C) { db, mock, err := sqlmock.New() c.Assert(err, check.IsNil) defer db.Close() - columns := fmt.Sprintf("(%s,%s,%s) VALUES(?,?,?)", loopbacksync.ChannelID, loopbacksync.Val, loopbacksync.ChannelInfo) - sql := fmt.Sprintf("INSERT INTO %s%s on duplicate key update %s=%s+1;", loopbacksync.MarkTableName, columns, loopbacksync.Val, loopbacksync.Val) + sql := fmt.Sprintf("update %s set %s=%s+1 where %s=? and %s=? limit 1;", loopbacksync.MarkTableName, loopbacksync.Val, loopbacksync.Val, loopbacksync.ID, loopbacksync.ChannelID) mock.ExpectBegin() mock.ExpectExec(regexp.QuoteMeta(sql)). - WithArgs(100, 1, "").WillReturnResult(sqlmock.NewResult(1, 1)) + WithArgs(1, 100).WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() e := newExecutor(db) tx, err := e.begin() @@ -261,6 +260,6 @@ func (s *SQLSuite) TestUpdateMarkSQL(c *check.C) { } func (s *SQLSuite) TestCreateMarkTable(c *check.C) { sql := createMarkTableDDL() - sql1 := fmt.Sprintf("CREATE TABLE If Not Exists %s ( %s bigint primary key, %s bigint DEFAULT 0, %s varchar(64));", loopbacksync.MarkTableName, loopbacksync.ChannelID, loopbacksync.Val, loopbacksync.ChannelInfo) + sql1 := fmt.Sprintf("CREATE TABLE If Not Exists %s (%s bigint not null,%s bigint not null DEFAULT 0, %s bigint DEFAULT 0, %s varchar(64) ,PRIMARY KEY (%s,%s));", loopbacksync.MarkTableName, loopbacksync.ID, loopbacksync.ChannelID, loopbacksync.Val, loopbacksync.ChannelInfo, loopbacksync.ID, loopbacksync.ChannelID) c.Assert(sql, check.Equals, sql1) }