From d19162eb15b18f49763b60f56e4b3932577a0f6d Mon Sep 17 00:00:00 2001 From: GMHDBJD <35025882+GMHDBJD@users.noreply.github.com> Date: Wed, 21 Jun 2023 21:41:42 +0800 Subject: [PATCH 1/4] This is an automated cherry-pick of #44803 Signed-off-by: ti-chi-bot --- br/pkg/checksum/executor.go | 56 +- br/pkg/lightning/backend/local/BUILD.bazel | 6 + br/pkg/lightning/common/BUILD.bazel | 5 + br/pkg/lightning/common/common.go | 109 +++ br/pkg/lightning/common/util.go | 163 +++++ br/pkg/lightning/config/config.go | 2 + br/pkg/lightning/importer/checksum_helper.go | 89 +++ br/pkg/lightning/restore/checksum.go | 51 +- br/pkg/lightning/restore/checksum_test.go | 2 + .../lightning/restore/table_restore_test.go | 3 +- br/pkg/lightning/restore/tidb_test.go | 2 + br/tests/lightning_add_index/config1.toml | 6 + disttask/framework/dispatcher/dispatcher.go | 516 ++++++++++++++ disttask/framework/storage/task_table.go | 496 ++++++++++++++ disttask/importinto/BUILD.bazel | 80 +++ disttask/importinto/dispatcher.go | 647 ++++++++++++++++++ disttask/importinto/job.go | 279 ++++++++ disttask/importinto/subtask_executor.go | 240 +++++++ disttask/importinto/subtask_executor_test.go | 73 ++ executor/import_into.go | 302 ++++++++ executor/importer/BUILD.bazel | 104 +++ executor/importer/table_import.go | 565 +++++++++++++++ tests/realtikvtest/importintotest/job_test.go | 635 +++++++++++++++++ 23 files changed, 4424 insertions(+), 7 deletions(-) create mode 100644 br/pkg/lightning/common/common.go create mode 100644 br/pkg/lightning/importer/checksum_helper.go create mode 100644 br/tests/lightning_add_index/config1.toml create mode 100644 disttask/framework/dispatcher/dispatcher.go create mode 100644 disttask/framework/storage/task_table.go create mode 100644 disttask/importinto/BUILD.bazel create mode 100644 disttask/importinto/dispatcher.go create mode 100644 disttask/importinto/job.go create mode 100644 disttask/importinto/subtask_executor.go create mode 100644 disttask/importinto/subtask_executor_test.go create mode 100644 executor/import_into.go create mode 100644 executor/importer/BUILD.bazel create mode 100644 executor/importer/table_import.go create mode 100644 tests/realtikvtest/importintotest/job_test.go diff --git a/br/pkg/checksum/executor.go b/br/pkg/checksum/executor.go index c30ae49fccdca..eda610b8f8e41 100644 --- a/br/pkg/checksum/executor.go +++ b/br/pkg/checksum/executor.go @@ -26,7 +26,15 @@ type ExecutorBuilder struct { oldTable *metautil.Table +<<<<<<< HEAD concurrency uint +======= + concurrency uint + backoffWeight int + + oldKeyspace []byte + newKeyspace []byte +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) } // NewExecutorBuilder returns a new executor builder. @@ -51,13 +59,32 @@ func (builder *ExecutorBuilder) SetConcurrency(conc uint) *ExecutorBuilder { return builder } +<<<<<<< HEAD +======= +// SetBackoffWeight set the backoffWeight of the checksum executing. +func (builder *ExecutorBuilder) SetBackoffWeight(backoffWeight int) *ExecutorBuilder { + builder.backoffWeight = backoffWeight + return builder +} + +func (builder *ExecutorBuilder) SetOldKeyspace(keyspace []byte) *ExecutorBuilder { + builder.oldKeyspace = keyspace + return builder +} + +func (builder *ExecutorBuilder) SetNewKeyspace(keyspace []byte) *ExecutorBuilder { + builder.newKeyspace = keyspace + return builder +} + +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) // Build builds a checksum executor. func (builder *ExecutorBuilder) Build() (*Executor, error) { reqs, err := buildChecksumRequest(builder.table, builder.oldTable, builder.ts, builder.concurrency) if err != nil { return nil, errors.Trace(err) } - return &Executor{reqs: reqs}, nil + return &Executor{reqs: reqs, backoffWeight: builder.backoffWeight}, nil } func buildChecksumRequest( @@ -262,7 +289,8 @@ func updateChecksumResponse(resp, update *tipb.ChecksumResponse) { // Executor is a checksum executor. type Executor struct { - reqs []*kv.Request + reqs []*kv.Request + backoffWeight int } // Len returns the total number of checksum requests. @@ -308,7 +336,31 @@ func (exec *Executor) Execute( // // It is useful in TiDB, however, it's a place holder in BR. killed := uint32(0) +<<<<<<< HEAD resp, err := sendChecksumRequest(ctx, client, req, kv.NewVariables(&killed)) +======= + var ( + resp *tipb.ChecksumResponse + err error + ) + err = utils.WithRetry(ctx, func() error { + vars := kv.NewVariables(&killed) + if exec.backoffWeight > 0 { + vars.BackOffWeight = exec.backoffWeight + } + resp, err = sendChecksumRequest(ctx, client, req, vars) + failpoint.Inject("checksumRetryErr", func(val failpoint.Value) { + // first time reach here. return error + if val.(bool) { + err = errors.New("inject checksum error") + } + }) + if err != nil { + return errors.Trace(err) + } + return nil + }, &checksumBackoffer) +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/backend/local/BUILD.bazel b/br/pkg/lightning/backend/local/BUILD.bazel index 9524ab5febc2b..1273b807b0f0d 100644 --- a/br/pkg/lightning/backend/local/BUILD.bazel +++ b/br/pkg/lightning/backend/local/BUILD.bazel @@ -40,6 +40,11 @@ go_library( "//kv", "//parser/model", "//parser/mysql", +<<<<<<< HEAD +======= + "//sessionctx/variable", + "//store/pdtypes", +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "//table", "//tablecodec", "//types", @@ -62,6 +67,7 @@ go_library( "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_pingcap_kvproto//pkg/pdpb", "@com_github_tikv_client_go_v2//error", + "@com_github_tikv_client_go_v2//kv", "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_client_go_v2//tikv", "@com_github_tikv_pd_client//:client", diff --git a/br/pkg/lightning/common/BUILD.bazel b/br/pkg/lightning/common/BUILD.bazel index 2b36e457cd857..05d4729c89193 100644 --- a/br/pkg/lightning/common/BUILD.bazel +++ b/br/pkg/lightning/common/BUILD.bazel @@ -23,6 +23,11 @@ go_library( "//br/pkg/utils", "//errno", "//parser/model", +<<<<<<< HEAD +======= + "//parser/mysql", + "//sessionctx/variable", +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "//store/driver/error", "//table/tables", "//util", diff --git a/br/pkg/lightning/common/common.go b/br/pkg/lightning/common/common.go new file mode 100644 index 0000000000000..aaf8860e4fb58 --- /dev/null +++ b/br/pkg/lightning/common/common.go @@ -0,0 +1,109 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta/autoid" + "github.com/pingcap/tidb/parser/model" +) + +const ( + // IndexEngineID is the engine ID for index engine. + IndexEngineID = -1 +) + +// DefaultImportantVariables is used in ObtainImportantVariables to retrieve the system +// variables from downstream which may affect KV encode result. The values record the default +// values if missing. +var DefaultImportantVariables = map[string]string{ + "max_allowed_packet": "67108864", + "div_precision_increment": "4", + "time_zone": "SYSTEM", + "lc_time_names": "en_US", + "default_week_format": "0", + "block_encryption_mode": "aes-128-ecb", + "group_concat_max_len": "1024", + "tidb_backoff_weight": "6", +} + +// DefaultImportVariablesTiDB is used in ObtainImportantVariables to retrieve the system +// variables from downstream in local/importer backend. The values record the default +// values if missing. +var DefaultImportVariablesTiDB = map[string]string{ + "tidb_row_format_version": "1", +} + +// AllocGlobalAutoID allocs N consecutive autoIDs from TiDB. +func AllocGlobalAutoID(ctx context.Context, n int64, store kv.Storage, dbID int64, + tblInfo *model.TableInfo) (autoIDBase, autoIDMax int64, err error) { + alloc, err := getGlobalAutoIDAlloc(store, dbID, tblInfo) + if err != nil { + return 0, 0, err + } + return alloc.Alloc(ctx, uint64(n), 1, 1) +} + +// RebaseGlobalAutoID rebase the autoID base to newBase. +func RebaseGlobalAutoID(ctx context.Context, newBase int64, store kv.Storage, dbID int64, + tblInfo *model.TableInfo) error { + alloc, err := getGlobalAutoIDAlloc(store, dbID, tblInfo) + if err != nil { + return err + } + return alloc.Rebase(ctx, newBase, false) +} + +func getGlobalAutoIDAlloc(store kv.Storage, dbID int64, tblInfo *model.TableInfo) (autoid.Allocator, error) { + if store == nil { + return nil, errors.New("internal error: kv store should not be nil") + } + if dbID == 0 { + return nil, errors.New("internal error: dbID should not be 0") + } + + // We don't need autoid cache here because we allocate all IDs at once. + // The argument for CustomAutoIncCacheOption is the cache step. Step 1 means no cache, + // but step 1 will enable an experimental feature, so we use step 2 here. + // + // See https://github.com/pingcap/tidb/issues/38442 for more details. + noCache := autoid.CustomAutoIncCacheOption(2) + tblVer := autoid.AllocOptionTableInfoVersion(tblInfo.Version) + + hasRowID := TableHasAutoRowID(tblInfo) + hasAutoIncID := tblInfo.GetAutoIncrementColInfo() != nil + hasAutoRandID := tblInfo.ContainsAutoRandomBits() + + // Current TiDB has some limitations for auto ID. + // 1. Auto increment ID and auto row ID are using the same RowID allocator. + // See https://github.com/pingcap/tidb/issues/982. + // 2. Auto random column must be a clustered primary key. That is to say, + // there is no implicit row ID for tables with auto random column. + // 3. There is at most one auto column in a table. + // Therefore, we assume there is only one auto column in a table and use RowID allocator if possible. + switch { + case hasRowID || hasAutoIncID: + return autoid.NewAllocator(store, dbID, tblInfo.ID, tblInfo.IsAutoIncColUnsigned(), + autoid.RowIDAllocType, noCache, tblVer), nil + case hasAutoRandID: + return autoid.NewAllocator(store, dbID, tblInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), + autoid.AutoRandomType, noCache, tblVer), nil + default: + return nil, errors.Errorf("internal error: table %s has no auto ID", tblInfo.Name) + } +} diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index 621c59d820e23..c068c3390dd35 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -36,6 +36,11 @@ import ( "github.com/pingcap/tidb/br/pkg/utils" tmysql "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/model" +<<<<<<< HEAD +======= + tmysql "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx/variable" +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "github.com/pingcap/tidb/table/tables" "go.uber.org/zap" ) @@ -428,3 +433,161 @@ func GetAutoRandomColumn(tblInfo *model.TableInfo) *model.ColumnInfo { } return nil } +<<<<<<< HEAD +======= + +// GetDropIndexInfos returns the index infos that need to be dropped and the remain indexes. +func GetDropIndexInfos( + tblInfo *model.TableInfo, +) (remainIndexes []*model.IndexInfo, dropIndexes []*model.IndexInfo) { + cols := tblInfo.Columns +loop: + for _, idxInfo := range tblInfo.Indices { + if idxInfo.State != model.StatePublic { + remainIndexes = append(remainIndexes, idxInfo) + continue + } + // Primary key is a cluster index. + if idxInfo.Primary && tblInfo.HasClusteredIndex() { + remainIndexes = append(remainIndexes, idxInfo) + continue + } + // Skip index that contains auto-increment column. + // Because auto colum must be defined as a key. + for _, idxCol := range idxInfo.Columns { + flag := cols[idxCol.Offset].GetFlag() + if tmysql.HasAutoIncrementFlag(flag) { + remainIndexes = append(remainIndexes, idxInfo) + continue loop + } + } + dropIndexes = append(dropIndexes, idxInfo) + } + return remainIndexes, dropIndexes +} + +// BuildDropIndexSQL builds the SQL statement to drop index. +func BuildDropIndexSQL(tableName string, idxInfo *model.IndexInfo) string { + if idxInfo.Primary { + return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", tableName) + } + return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", tableName, EscapeIdentifier(idxInfo.Name.O)) +} + +// BuildAddIndexSQL builds the SQL statement to create missing indexes. +// It returns both a single SQL statement that creates all indexes at once, +// and a list of SQL statements that creates each index individually. +func BuildAddIndexSQL( + tableName string, + curTblInfo, + desiredTblInfo *model.TableInfo, +) (singleSQL string, multiSQLs []string) { + addIndexSpecs := make([]string, 0, len(desiredTblInfo.Indices)) +loop: + for _, desiredIdxInfo := range desiredTblInfo.Indices { + for _, curIdxInfo := range curTblInfo.Indices { + if curIdxInfo.Name.L == desiredIdxInfo.Name.L { + continue loop + } + } + + var buf bytes.Buffer + if desiredIdxInfo.Primary { + buf.WriteString("ADD PRIMARY KEY ") + } else if desiredIdxInfo.Unique { + buf.WriteString("ADD UNIQUE KEY ") + } else { + buf.WriteString("ADD KEY ") + } + // "primary" is a special name for primary key, we should not use it as index name. + if desiredIdxInfo.Name.L != "primary" { + buf.WriteString(EscapeIdentifier(desiredIdxInfo.Name.O)) + } + + colStrs := make([]string, 0, len(desiredIdxInfo.Columns)) + for _, col := range desiredIdxInfo.Columns { + var colStr string + if desiredTblInfo.Columns[col.Offset].Hidden { + colStr = fmt.Sprintf("(%s)", desiredTblInfo.Columns[col.Offset].GeneratedExprString) + } else { + colStr = EscapeIdentifier(col.Name.O) + if col.Length != types.UnspecifiedLength { + colStr = fmt.Sprintf("%s(%s)", colStr, strconv.Itoa(col.Length)) + } + } + colStrs = append(colStrs, colStr) + } + fmt.Fprintf(&buf, "(%s)", strings.Join(colStrs, ",")) + + if desiredIdxInfo.Invisible { + fmt.Fprint(&buf, " INVISIBLE") + } + if desiredIdxInfo.Comment != "" { + fmt.Fprintf(&buf, ` COMMENT '%s'`, format.OutputFormat(desiredIdxInfo.Comment)) + } + addIndexSpecs = append(addIndexSpecs, buf.String()) + } + if len(addIndexSpecs) == 0 { + return "", nil + } + + singleSQL = fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(addIndexSpecs, ", ")) + for _, spec := range addIndexSpecs { + multiSQLs = append(multiSQLs, fmt.Sprintf("ALTER TABLE %s %s", tableName, spec)) + } + return singleSQL, multiSQLs +} + +// IsDupKeyError checks if err is a duplicate index error. +func IsDupKeyError(err error) bool { + if merr, ok := errors.Cause(err).(*mysql.MySQLError); ok { + switch merr.Number { + case errno.ErrDupKeyName, errno.ErrMultiplePriKey, errno.ErrDupUnique: + return true + } + } + return false +} + +// GetBackoffWeightFromDB gets the backoff weight from database. +func GetBackoffWeightFromDB(ctx context.Context, db *sql.DB) (int, error) { + val, err := getSessionVariable(ctx, db, variable.TiDBBackOffWeight) + if err != nil { + return 0, err + } + return strconv.Atoi(val) +} + +// copy from dbutil to avoid import cycle +func getSessionVariable(ctx context.Context, db *sql.DB, variable string) (value string, err error) { + query := fmt.Sprintf("SHOW VARIABLES LIKE '%s'", variable) + rows, err := db.QueryContext(ctx, query) + + if err != nil { + return "", errors.Trace(err) + } + defer rows.Close() + + // Show an example. + /* + mysql> SHOW VARIABLES LIKE "binlog_format"; + +---------------+-------+ + | Variable_name | Value | + +---------------+-------+ + | binlog_format | ROW | + +---------------+-------+ + */ + + for rows.Next() { + if err = rows.Scan(&variable, &value); err != nil { + return "", errors.Trace(err) + } + } + + if err := rows.Err(); err != nil { + return "", errors.Trace(err) + } + + return value, nil +} +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index e1031c760f749..e372ad3e1bb18 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -440,6 +440,7 @@ type PostRestore struct { Level1Compact bool `toml:"level-1-compact" json:"level-1-compact"` PostProcessAtLast bool `toml:"post-process-at-last" json:"post-process-at-last"` Compact bool `toml:"compact" json:"compact"` + ChecksumViaSQL bool `toml:"checksum-via-sql" json:"checksum-via-sql"` } type CSVConfig struct { @@ -745,6 +746,7 @@ func NewConfig() *Config { Checksum: OpLevelRequired, Analyze: OpLevelOptional, PostProcessAtLast: true, + ChecksumViaSQL: true, }, } } diff --git a/br/pkg/lightning/importer/checksum_helper.go b/br/pkg/lightning/importer/checksum_helper.go new file mode 100644 index 0000000000000..88bc40d5a72e1 --- /dev/null +++ b/br/pkg/lightning/importer/checksum_helper.go @@ -0,0 +1,89 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importer + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/lightning/backend/local" + "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/lightning/log" + "github.com/pingcap/tidb/br/pkg/lightning/metric" + "github.com/pingcap/tidb/br/pkg/pdutil" + "github.com/pingcap/tidb/kv" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" +) + +// NewChecksumManager creates a new checksum manager. +func NewChecksumManager(ctx context.Context, rc *Controller, store kv.Storage) (local.ChecksumManager, error) { + // if we don't need checksum, just return nil + if rc.cfg.TikvImporter.Backend == config.BackendTiDB || rc.cfg.PostRestore.Checksum == config.OpLevelOff { + return nil, nil + } + + pdAddr := rc.cfg.TiDB.PdAddr + pdVersion, err := pdutil.FetchPDVersion(ctx, rc.tls, pdAddr) + if err != nil { + return nil, errors.Trace(err) + } + + // for v4.0.0 or upper, we can use the gc ttl api + var manager local.ChecksumManager + if pdVersion.Major >= 4 && !rc.cfg.PostRestore.ChecksumViaSQL { + tlsOpt := rc.tls.ToPDSecurityOption() + pdCli, err := pd.NewClientWithContext(ctx, []string{pdAddr}, tlsOpt) + if err != nil { + return nil, errors.Trace(err) + } + + backoffWeight, err := common.GetBackoffWeightFromDB(ctx, rc.db) + // only set backoff weight when it's smaller than default value + if err == nil && backoffWeight >= local.DefaultBackoffWeight { + log.FromContext(ctx).Info("get tidb_backoff_weight", zap.Int("backoff_weight", backoffWeight)) + } else { + log.FromContext(ctx).Info("set tidb_backoff_weight to default", zap.Int("backoff_weight", local.DefaultBackoffWeight)) + backoffWeight = local.DefaultBackoffWeight + } + manager = local.NewTiKVChecksumManager(store.GetClient(), pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency), backoffWeight) + } else { + manager = local.NewTiDBChecksumExecutor(rc.db) + } + + return manager, nil +} + +// DoChecksum do checksum for tables. +// table should be in ., format. e.g. foo.bar +func DoChecksum(ctx context.Context, table *checkpoints.TidbTableInfo) (*local.RemoteChecksum, error) { + var err error + manager, ok := ctx.Value(&checksumManagerKey).(local.ChecksumManager) + if !ok { + return nil, errors.New("No gcLifeTimeManager found in context, check context initialization") + } + + task := log.FromContext(ctx).With(zap.String("table", table.Name)).Begin(zap.InfoLevel, "remote checksum") + + cs, err := manager.Checksum(ctx, table) + dur := task.End(zap.ErrorLevel, err) + if m, ok := metric.FromContext(ctx); ok { + m.ChecksumSecondsHistogram.Observe(dur.Seconds()) + } + + return cs, err +} diff --git a/br/pkg/lightning/restore/checksum.go b/br/pkg/lightning/restore/checksum.go index b30fe14e01fc1..b981d6759fdd3 100644 --- a/br/pkg/lightning/restore/checksum.go +++ b/br/pkg/lightning/restore/checksum.go @@ -33,8 +33,10 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/metric" "github.com/pingcap/tidb/br/pkg/pdutil" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/mathutil" "github.com/pingcap/tipb/go-tipb" + tikvstore "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" "go.uber.org/atomic" @@ -50,7 +52,14 @@ const ( var ( serviceSafePointTTL int64 = 10 * 60 // 10 min in seconds - minDistSQLScanConcurrency = 4 + // MinDistSQLScanConcurrency is the minimum value of tidb_distsql_scan_concurrency. + MinDistSQLScanConcurrency = 4 + + // DefaultBackoffWeight is the default value of tidb_backoff_weight for checksum. + // when TiKV client encounters an error of "region not leader", it will keep retrying every 500 ms. + // If it still fails after 2 * 20 = 40 seconds, it will return "region unavailable". + // If we increase the BackOffWeight to 6, then the TiKV client will keep retrying for 120 seconds. + DefaultBackoffWeight = 3 * tikvstore.DefBackOffWeight ) // RemoteChecksum represents a checksum result got from tidb. @@ -125,6 +134,15 @@ func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoi task := log.FromContext(ctx).With(zap.String("table", tableName)).Begin(zap.InfoLevel, "remote checksum") + conn, err := e.db.Conn(ctx) + if err != nil { + return nil, errors.Trace(err) + } + defer func() { + if err := conn.Close(); err != nil { + task.Warn("close connection failed", zap.Error(err)) + } + }() // ADMIN CHECKSUM TABLE
,
example. // mysql> admin checksum table test.t; // +---------+------------+---------------------+-----------+-------------+ @@ -132,9 +150,23 @@ func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoi // +---------+------------+---------------------+-----------+-------------+ // | test | t | 8520875019404689597 | 7296873 | 357601387 | // +---------+------------+---------------------+-----------+-------------+ + backoffWeight, err := common.GetBackoffWeightFromDB(ctx, e.db) + if err == nil && backoffWeight < DefaultBackoffWeight { + task.Info("increase tidb_backoff_weight", zap.Int("original", backoffWeight), zap.Int("new", DefaultBackoffWeight)) + // increase backoff weight + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, DefaultBackoffWeight)); err != nil { + task.Warn("set tidb_backoff_weight failed", zap.Error(err)) + } else { + defer func() { + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, backoffWeight)); err != nil { + task.Warn("recover tidb_backoff_weight failed", zap.Error(err)) + } + }() + } + } cs := RemoteChecksum{} - err = common.SQLWithRetry{DB: e.db, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", + err = common.SQLWithRetry{DB: conn, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", "ADMIN CHECKSUM TABLE "+tableName, &cs.Schema, &cs.Table, &cs.Checksum, &cs.TotalKVs, &cs.TotalBytes, ) dur := task.End(zap.ErrorLevel, err) @@ -257,20 +289,31 @@ type tikvChecksumManager struct { client kv.Client manager gcTTLManager distSQLScanConcurrency uint + backoffWeight int } +<<<<<<< HEAD:br/pkg/lightning/restore/checksum.go // newTiKVChecksumManager return a new tikv checksum manager func newTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint) *tikvChecksumManager { return &tikvChecksumManager{ +======= +var _ ChecksumManager = &TiKVChecksumManager{} + +// NewTiKVChecksumManager return a new tikv checksum manager +func NewTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint, backoffWeight int) *TiKVChecksumManager { + return &TiKVChecksumManager{ +>>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)):br/pkg/lightning/backend/local/checksum.go client: client, manager: newGCTTLManager(pdClient), distSQLScanConcurrency: distSQLScanConcurrency, + backoffWeight: backoffWeight, } } func (e *tikvChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpoints.TidbTableInfo, ts uint64) (*RemoteChecksum, error) { executor, err := checksum.NewExecutorBuilder(tableInfo.Core, ts). SetConcurrency(e.distSQLScanConcurrency). + SetBackoffWeight(e.backoffWeight). Build() if err != nil { return nil, errors.Trace(err) @@ -302,8 +345,8 @@ func (e *tikvChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpo if !common.IsRetryableError(err) { break } - if distSQLScanConcurrency > minDistSQLScanConcurrency { - distSQLScanConcurrency = mathutil.Max(distSQLScanConcurrency/2, minDistSQLScanConcurrency) + if distSQLScanConcurrency > MinDistSQLScanConcurrency { + distSQLScanConcurrency = mathutil.Max(distSQLScanConcurrency/2, MinDistSQLScanConcurrency) } } diff --git a/br/pkg/lightning/restore/checksum_test.go b/br/pkg/lightning/restore/checksum_test.go index 20acc23fe6be0..ba920ee58ed84 100644 --- a/br/pkg/lightning/restore/checksum_test.go +++ b/br/pkg/lightning/restore/checksum_test.go @@ -56,6 +56,7 @@ func TestDoChecksum(t *testing.T) { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() + mock.ExpectClose() ctx := MockDoChecksumCtx(db) checksum, err := DoChecksum(ctx, &TidbTableInfo{DB: "test", Name: "t"}) @@ -216,6 +217,7 @@ func TestDoChecksumWithErrorAndLongOriginalLifetime(t *testing.T) { WithArgs("300h"). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectClose() + mock.ExpectClose() ctx := MockDoChecksumCtx(db) _, err = DoChecksum(ctx, &TidbTableInfo{DB: "test", Name: "t"}) diff --git a/br/pkg/lightning/restore/table_restore_test.go b/br/pkg/lightning/restore/table_restore_test.go index 17fb97e346e36..ad09add849a51 100644 --- a/br/pkg/lightning/restore/table_restore_test.go +++ b/br/pkg/lightning/restore/table_restore_test.go @@ -753,6 +753,7 @@ func (s *tableRestoreSuite) TestCompareChecksumSuccess() { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() + mock.ExpectClose() ctx := MockDoChecksumCtx(db) remoteChecksum, err := DoChecksum(ctx, s.tr.tableInfo) @@ -783,7 +784,7 @@ func (s *tableRestoreSuite) TestCompareChecksumFailure() { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() - + mock.ExpectClose() ctx := MockDoChecksumCtx(db) remoteChecksum, err := DoChecksum(ctx, s.tr.tableInfo) require.NoError(s.T(), err) diff --git a/br/pkg/lightning/restore/tidb_test.go b/br/pkg/lightning/restore/tidb_test.go index 9b204b2da22b1..b3ece883864f6 100644 --- a/br/pkg/lightning/restore/tidb_test.go +++ b/br/pkg/lightning/restore/tidb_test.go @@ -460,6 +460,7 @@ func TestObtainRowFormatVersionSucceed(t *testing.T) { sysVars := ObtainImportantVariables(ctx, s.tiGlue.GetSQLExecutor(), true) require.Equal(t, map[string]string{ + "tidb_backoff_weight": "6", "tidb_row_format_version": "2", "max_allowed_packet": "1073741824", "div_precision_increment": "10", @@ -487,6 +488,7 @@ func TestObtainRowFormatVersionFailure(t *testing.T) { sysVars := ObtainImportantVariables(ctx, s.tiGlue.GetSQLExecutor(), true) require.Equal(t, map[string]string{ + "tidb_backoff_weight": "6", "tidb_row_format_version": "1", "max_allowed_packet": "67108864", "div_precision_increment": "4", diff --git a/br/tests/lightning_add_index/config1.toml b/br/tests/lightning_add_index/config1.toml new file mode 100644 index 0000000000000..36b03d49a1117 --- /dev/null +++ b/br/tests/lightning_add_index/config1.toml @@ -0,0 +1,6 @@ +[tikv-importer] +backend = 'local' +add-index-by-sql = false + +[post-restore] +checksum-via-sql = false \ No newline at end of file diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go new file mode 100644 index 0000000000000..248b797c9b913 --- /dev/null +++ b/disttask/framework/dispatcher/dispatcher.go @@ -0,0 +1,516 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dispatcher + +import ( + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/domain/infosync" + "github.com/pingcap/tidb/resourcemanager/pool/spool" + "github.com/pingcap/tidb/resourcemanager/util" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" + tidbutil "github.com/pingcap/tidb/util" + disttaskutil "github.com/pingcap/tidb/util/disttask" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/syncutil" + "go.uber.org/zap" +) + +const ( + // DefaultSubtaskConcurrency is the default concurrency for handling subtask. + DefaultSubtaskConcurrency = 16 + // MaxSubtaskConcurrency is the maximum concurrency for handling subtask. + MaxSubtaskConcurrency = 256 +) + +var ( + // DefaultDispatchConcurrency is the default concurrency for handling global task. + DefaultDispatchConcurrency = 4 + checkTaskFinishedInterval = 500 * time.Millisecond + checkTaskRunningInterval = 300 * time.Millisecond + nonRetrySQLTime = 1 + retrySQLTimes = variable.DefTiDBDDLErrorCountLimit + retrySQLInterval = 500 * time.Millisecond +) + +// Dispatch defines the interface for operations inside a dispatcher. +type Dispatch interface { + // Start enables dispatching and monitoring mechanisms. + Start() + // GetAllSchedulerIDs gets handles the task's all available instances. + GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) + // Stop stops the dispatcher. + Stop() +} + +// TaskHandle provides the interface for operations needed by task flow handles. +type TaskHandle interface { + // GetAllSchedulerIDs gets handles the task's all scheduler instances. + GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) + // GetPreviousSubtaskMetas gets previous subtask metas. + GetPreviousSubtaskMetas(gTaskID int64, step int64) ([][]byte, error) + storage.SessionExecutor +} + +func (d *dispatcher) getRunningGTaskCnt() int { + d.runningGTasks.RLock() + defer d.runningGTasks.RUnlock() + return len(d.runningGTasks.taskIDs) +} + +func (d *dispatcher) setRunningGTask(gTask *proto.Task) { + d.runningGTasks.Lock() + d.runningGTasks.taskIDs[gTask.ID] = struct{}{} + d.runningGTasks.Unlock() + d.detectPendingGTaskCh <- gTask +} + +func (d *dispatcher) isRunningGTask(globalTaskID int64) bool { + d.runningGTasks.Lock() + defer d.runningGTasks.Unlock() + _, ok := d.runningGTasks.taskIDs[globalTaskID] + return ok +} + +func (d *dispatcher) delRunningGTask(globalTaskID int64) { + d.runningGTasks.Lock() + defer d.runningGTasks.Unlock() + delete(d.runningGTasks.taskIDs, globalTaskID) +} + +type dispatcher struct { + ctx context.Context + cancel context.CancelFunc + taskMgr *storage.TaskManager + wg tidbutil.WaitGroupWrapper + gPool *spool.Pool + + runningGTasks struct { + syncutil.RWMutex + taskIDs map[int64]struct{} + } + detectPendingGTaskCh chan *proto.Task +} + +// NewDispatcher creates a dispatcher struct. +func NewDispatcher(ctx context.Context, taskTable *storage.TaskManager) (Dispatch, error) { + dispatcher := &dispatcher{ + taskMgr: taskTable, + detectPendingGTaskCh: make(chan *proto.Task, DefaultDispatchConcurrency), + } + pool, err := spool.NewPool("dispatch_pool", int32(DefaultDispatchConcurrency), util.DistTask, spool.WithBlocking(true)) + if err != nil { + return nil, err + } + dispatcher.gPool = pool + dispatcher.ctx, dispatcher.cancel = context.WithCancel(ctx) + dispatcher.runningGTasks.taskIDs = make(map[int64]struct{}) + + return dispatcher, nil +} + +// Start implements Dispatch.Start interface. +func (d *dispatcher) Start() { + d.wg.Run(d.DispatchTaskLoop) + d.wg.Run(d.DetectTaskLoop) +} + +// Stop implements Dispatch.Stop interface. +func (d *dispatcher) Stop() { + d.cancel() + d.gPool.ReleaseAndWait() + d.wg.Wait() +} + +// DispatchTaskLoop dispatches the global tasks. +func (d *dispatcher) DispatchTaskLoop() { + logutil.BgLogger().Info("dispatch task loop start") + ticker := time.NewTicker(checkTaskRunningInterval) + defer ticker.Stop() + for { + select { + case <-d.ctx.Done(): + logutil.BgLogger().Info("dispatch task loop exits", zap.Error(d.ctx.Err()), zap.Int64("interval", int64(checkTaskRunningInterval)/1000000)) + return + case <-ticker.C: + cnt := d.getRunningGTaskCnt() + if d.checkConcurrencyOverflow(cnt) { + break + } + + // TODO: Consider getting these tasks, in addition to the task being worked on.. + gTasks, err := d.taskMgr.GetGlobalTasksInStates(proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStateCancelling) + if err != nil { + logutil.BgLogger().Warn("get unfinished(pending, running, reverting or cancelling) tasks failed", zap.Error(err)) + break + } + + // There are currently no global tasks to work on. + if len(gTasks) == 0 { + break + } + for _, gTask := range gTasks { + // This global task is running, so no need to reprocess it. + if d.isRunningGTask(gTask.ID) { + continue + } + // the task is not in runningGTasks set when: + // owner changed or task is cancelled when status is pending. + if gTask.State == proto.TaskStateRunning || gTask.State == proto.TaskStateReverting || gTask.State == proto.TaskStateCancelling { + d.setRunningGTask(gTask) + cnt++ + continue + } + + if d.checkConcurrencyOverflow(cnt) { + break + } + + err = d.processNormalFlow(gTask) + logutil.BgLogger().Info("dispatch task loop", zap.Int64("task ID", gTask.ID), + zap.String("state", gTask.State), zap.Uint64("concurrency", gTask.Concurrency), zap.Error(err)) + if err != nil || gTask.IsFinished() { + continue + } + d.setRunningGTask(gTask) + cnt++ + } + } + } +} + +func (d *dispatcher) probeTask(gTask *proto.Task) (isFinished bool, subTaskErr [][]byte) { + // TODO: Consider putting the following operations into a transaction. + // TODO: Consider collect some information about the tasks. + if gTask.State != proto.TaskStateReverting { + // check if global task cancelling + cancelling, err := d.taskMgr.IsGlobalTaskCancelling(gTask.ID) + if err != nil { + logutil.BgLogger().Warn("check task cancelling failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) + return false, nil + } + + if cancelling { + return false, [][]byte{[]byte("cancel")} + } + // check subtasks failed. + cnt, err := d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateFailed) + if err != nil { + logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) + return false, nil + } + if cnt > 0 { + subTaskErr, err = d.taskMgr.CollectSubTaskError(gTask.ID) + if err != nil { + logutil.BgLogger().Warn("collect subtask error failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) + return false, nil + } + return false, subTaskErr + } + // check subtasks pending or running. + cnt, err = d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStatePending, proto.TaskStateRunning) + if err != nil { + logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) + return false, nil + } + if cnt > 0 { + return false, nil + } + return true, nil + } + + // if gTask.State == TaskStateReverting, if will not convert to TaskStateCancelling again. + cnt, err := d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateRevertPending, proto.TaskStateReverting) + if err != nil { + logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) + return false, nil + } + if cnt > 0 { + return false, nil + } + return true, nil +} + +// DetectTaskLoop monitors the status of the subtasks and processes them. +func (d *dispatcher) DetectTaskLoop() { + logutil.BgLogger().Info("detect task loop start") + for { + select { + case <-d.ctx.Done(): + logutil.BgLogger().Info("detect task loop exits", zap.Error(d.ctx.Err())) + return + case task := <-d.detectPendingGTaskCh: + // Using the pool with block, so it wouldn't return an error. + _ = d.gPool.Run(func() { d.detectTask(task) }) + } + } +} + +func (d *dispatcher) detectTask(gTask *proto.Task) { + ticker := time.NewTicker(checkTaskFinishedInterval) + defer ticker.Stop() + + for { + select { + case <-d.ctx.Done(): + logutil.BgLogger().Info("detect task exits", zap.Int64("task ID", gTask.ID), zap.Error(d.ctx.Err())) + return + case <-ticker.C: + // TODO: Consider actively obtaining information about task completion. + stepIsFinished, errStr := d.probeTask(gTask) + // The global task isn't finished and not failed. + if !stepIsFinished && len(errStr) == 0 { + GetTaskFlowHandle(gTask.Type).OnTicker(d.ctx, gTask) + logutil.BgLogger().Debug("detect task, this task keeps current state", + zap.Int64("task-id", gTask.ID), zap.String("state", gTask.State)) + break + } + + err := d.processFlow(gTask, errStr) + if err == nil && gTask.IsFinished() { + logutil.BgLogger().Info("detect task, task is finished", + zap.Int64("task-id", gTask.ID), zap.String("state", gTask.State)) + d.delRunningGTask(gTask.ID) + return + } + if !d.isRunningGTask(gTask.ID) { + logutil.BgLogger().Info("detect task, this task can't run", + zap.Int64("task-id", gTask.ID), zap.String("state", gTask.State)) + } + } + } +} + +func (d *dispatcher) processFlow(gTask *proto.Task, errStr [][]byte) error { + if len(errStr) > 0 { + // Found an error when task is running. + logutil.BgLogger().Info("process flow, handle an error", zap.Int64("task-id", gTask.ID), zap.ByteStrings("err msg", errStr)) + return d.processErrFlow(gTask, errStr) + } + // previous step is finished. + if gTask.State == proto.TaskStateReverting { + // Finish the rollback step. + logutil.BgLogger().Info("process flow, update the task to reverted", zap.Int64("task-id", gTask.ID)) + return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes) + } + // Finish the normal step. + logutil.BgLogger().Info("process flow, process normal", zap.Int64("task-id", gTask.ID)) + return d.processNormalFlow(gTask) +} + +func (d *dispatcher) updateTask(gTask *proto.Task, gTaskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) { + prevState := gTask.State + gTask.State = gTaskState + for i := 0; i < retryTimes; i++ { + err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(gTask, newSubTasks, gTaskState == proto.TaskStateReverting) + if err == nil { + break + } + if i%10 == 0 { + logutil.BgLogger().Warn("updateTask first failed", zap.Int64("task-id", gTask.ID), + zap.String("previous state", prevState), zap.String("curr state", gTask.State), + zap.Int("retry times", retryTimes), zap.Error(err)) + } + time.Sleep(retrySQLInterval) + } + if err != nil && retryTimes != nonRetrySQLTime { + logutil.BgLogger().Warn("updateTask failed and delete running task info", zap.Int64("task-id", gTask.ID), + zap.String("previous state", prevState), zap.String("curr state", gTask.State), zap.Int("retry times", retryTimes), zap.Error(err)) + d.delRunningGTask(gTask.ID) + } + return err +} + +func (d *dispatcher) processErrFlow(gTask *proto.Task, receiveErr [][]byte) error { + // TODO: Maybe it gets GetTaskFlowHandle fails when rolling upgrades. + // 1. generate the needed global task meta and subTask meta (dist-plan). + meta, err := GetTaskFlowHandle(gTask.Type).ProcessErrFlow(d.ctx, d, gTask, receiveErr) + if err != nil { + logutil.BgLogger().Warn("handle error failed", zap.Error(err)) + return err + } + + // 2. dispatch revert dist-plan to EligibleInstances. + return d.dispatchSubTask4Revert(gTask, meta) +} + +func (d *dispatcher) dispatchSubTask4Revert(gTask *proto.Task, meta []byte) error { + instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, gTask.ID) + if err != nil { + logutil.BgLogger().Warn("get global task's all instances failed", zap.Error(err)) + return err + } + + if len(instanceIDs) == 0 { + return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes) + } + + subTasks := make([]*proto.Subtask, 0, len(instanceIDs)) + for _, id := range instanceIDs { + subTasks = append(subTasks, proto.NewSubtask(gTask.ID, gTask.Type, id, meta)) + } + return d.updateTask(gTask, proto.TaskStateReverting, subTasks, retrySQLTimes) +} + +func (d *dispatcher) processNormalFlow(gTask *proto.Task) error { + // 1. generate the needed global task meta and subTask meta (dist-plan). + handle := GetTaskFlowHandle(gTask.Type) + if handle == nil { + logutil.BgLogger().Warn("gen gTask flow handle failed, this type handle doesn't register", zap.Int64("ID", gTask.ID), zap.String("type", gTask.Type)) + return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes) + } + metas, err := handle.ProcessNormalFlow(d.ctx, d, gTask) + if err != nil { + logutil.BgLogger().Warn("gen dist-plan failed", zap.Error(err)) + if handle.IsRetryableErr(err) { + return err + } + gTask.Error = []byte(err.Error()) + return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes) + } + logutil.BgLogger().Info("process normal flow", zap.Int64("task ID", gTask.ID), + zap.String("state", gTask.State), zap.Uint64("concurrency", gTask.Concurrency), zap.Int("subtasks", len(metas))) + + // 2. dispatch dist-plan to EligibleInstances. + return d.dispatchSubTask(gTask, handle, metas) +} + +func (d *dispatcher) dispatchSubTask(gTask *proto.Task, handle TaskFlowHandle, metas [][]byte) error { + // Adjust the global task's concurrency. + if gTask.Concurrency == 0 { + gTask.Concurrency = DefaultSubtaskConcurrency + } + if gTask.Concurrency > MaxSubtaskConcurrency { + gTask.Concurrency = MaxSubtaskConcurrency + } + + retryTimes := retrySQLTimes + // Special handling for the new tasks. + if gTask.State == proto.TaskStatePending { + // TODO: Consider using TS. + nowTime := time.Now().UTC() + gTask.StartTime = nowTime + gTask.State = proto.TaskStateRunning + gTask.StateUpdateTime = nowTime + retryTimes = nonRetrySQLTime + } + + if len(metas) == 0 { + gTask.StateUpdateTime = time.Now().UTC() + // Write the global task meta into the storage. + err := d.updateTask(gTask, proto.TaskStateSucceed, nil, retryTimes) + if err != nil { + logutil.BgLogger().Warn("update global task failed", zap.Error(err)) + return err + } + return nil + } + // select all available TiDB nodes for this global tasks. + serverNodes, err1 := handle.GetEligibleInstances(d.ctx, gTask) + logutil.BgLogger().Debug("eligible instances", zap.Int("num", len(serverNodes))) + + if err1 != nil { + return err1 + } + if len(serverNodes) == 0 { + return errors.New("no available TiDB node") + } + subTasks := make([]*proto.Subtask, 0, len(metas)) + for i, meta := range metas { + // we assign the subtask to the instance in a round-robin way. + pos := i % len(serverNodes) + instanceID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port) + logutil.BgLogger().Debug("create subtasks", + zap.Int("gTask.ID", int(gTask.ID)), zap.String("type", gTask.Type), zap.String("instanceID", instanceID)) + subTasks = append(subTasks, proto.NewSubtask(gTask.ID, gTask.Type, instanceID, meta)) + } + + return d.updateTask(gTask, gTask.State, subTasks, retrySQLTimes) +} + +// GenerateSchedulerNodes generate a eligible TiDB nodes. +func GenerateSchedulerNodes(ctx context.Context) ([]*infosync.ServerInfo, error) { + serverInfos, err := infosync.GetAllServerInfo(ctx) + if err != nil { + return nil, err + } + if len(serverInfos) == 0 { + return nil, errors.New("not found instance") + } + + serverNodes := make([]*infosync.ServerInfo, 0, len(serverInfos)) + for _, serverInfo := range serverInfos { + serverNodes = append(serverNodes, serverInfo) + } + return serverNodes, nil +} + +// GetAllSchedulerIDs gets all the scheduler IDs. +func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) { + serverInfos, err := infosync.GetAllServerInfo(ctx) + if err != nil { + return nil, err + } + if len(serverInfos) == 0 { + return nil, nil + } + + schedulerIDs, err := d.taskMgr.GetSchedulerIDsByTaskID(gTaskID) + if err != nil { + return nil, err + } + ids := make([]string, 0, len(schedulerIDs)) + for _, id := range schedulerIDs { + if ok := disttaskutil.MatchServerInfo(serverInfos, id); ok { + ids = append(ids, id) + } + } + return ids, nil +} + +func (d *dispatcher) GetPreviousSubtaskMetas(gTaskID int64, step int64) ([][]byte, error) { + previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(gTaskID, step) + if err != nil { + logutil.BgLogger().Warn("get previous succeed subtask failed", zap.Int64("ID", gTaskID), zap.Int64("step", step)) + return nil, err + } + previousSubtaskMetas := make([][]byte, 0, len(previousSubtasks)) + for _, subtask := range previousSubtasks { + previousSubtaskMetas = append(previousSubtaskMetas, subtask.Meta) + } + return previousSubtaskMetas, nil +} + +func (d *dispatcher) WithNewSession(fn func(se sessionctx.Context) error) error { + return d.taskMgr.WithNewSession(fn) +} + +func (d *dispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { + return d.taskMgr.WithNewTxn(ctx, fn) +} + +func (*dispatcher) checkConcurrencyOverflow(cnt int) bool { + if cnt >= DefaultDispatchConcurrency { + logutil.BgLogger().Info("dispatch task loop, running GTask cnt is more than concurrency", + zap.Int("running cnt", cnt), zap.Int("concurrency", DefaultDispatchConcurrency)) + return true + } + return false +} diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go new file mode 100644 index 0000000000000..f394e99a7a540 --- /dev/null +++ b/disttask/framework/storage/task_table.go @@ -0,0 +1,496 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// SessionExecutor defines the interface for executing SQLs in a session. +type SessionExecutor interface { + // WithNewSession executes the function with a new session. + WithNewSession(fn func(se sessionctx.Context) error) error + // WithNewTxn executes the fn in a new transaction. + WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error +} + +// TaskManager is the manager of global/sub task. +type TaskManager struct { + ctx context.Context + sePool *pools.ResourcePool +} + +var _ SessionExecutor = &TaskManager{} + +var taskManagerInstance atomic.Pointer[TaskManager] + +var ( + // TestLastTaskID is used for test to set the last task ID. + TestLastTaskID atomic.Int64 +) + +// NewTaskManager creates a new task manager. +func NewTaskManager(ctx context.Context, sePool *pools.ResourcePool) *TaskManager { + ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) + return &TaskManager{ + ctx: ctx, + sePool: sePool, + } +} + +// GetTaskManager gets the task manager. +func GetTaskManager() (*TaskManager, error) { + v := taskManagerInstance.Load() + if v == nil { + return nil, errors.New("global task manager is not initialized") + } + return v, nil +} + +// SetTaskManager sets the task manager. +func SetTaskManager(is *TaskManager) { + taskManagerInstance.Store(is) +} + +// ExecSQL executes the sql and returns the result. +// TODO: consider retry. +func ExecSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) { + rs, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...) + if err != nil { + return nil, err + } + if rs != nil { + defer terror.Call(rs.Close) + return sqlexec.DrainRecordSet(ctx, rs, 1024) + } + return nil, nil +} + +// row2GlobeTask converts a row to a global task. +func row2GlobeTask(r chunk.Row) *proto.Task { + task := &proto.Task{ + ID: r.GetInt64(0), + Key: r.GetString(1), + Type: r.GetString(2), + DispatcherID: r.GetString(3), + State: r.GetString(4), + Meta: r.GetBytes(7), + Concurrency: uint64(r.GetInt64(8)), + Step: r.GetInt64(9), + Error: r.GetBytes(10), + } + // TODO: convert to local time. + task.StartTime, _ = r.GetTime(5).GoTime(time.UTC) + task.StateUpdateTime, _ = r.GetTime(6).GoTime(time.UTC) + return task +} + +// WithNewSession executes the function with a new session. +func (stm *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) error { + se, err := stm.sePool.Get() + if err != nil { + return err + } + defer stm.sePool.Put(se) + return fn(se.(sessionctx.Context)) +} + +// WithNewTxn executes the fn in a new transaction. +func (stm *TaskManager) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { + ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) + return stm.WithNewSession(func(se sessionctx.Context) (err error) { + _, err = ExecSQL(ctx, se, "begin") + if err != nil { + return err + } + + success := false + defer func() { + sql := "rollback" + if success { + sql = "commit" + } + _, commitErr := ExecSQL(ctx, se, sql) + if err == nil && commitErr != nil { + err = commitErr + } + }() + + if err = fn(se); err != nil { + return err + } + + success = true + return nil + }) +} + +func (stm *TaskManager) executeSQLWithNewSession(ctx context.Context, sql string, args ...interface{}) (rs []chunk.Row, err error) { + err = stm.WithNewSession(func(se sessionctx.Context) error { + rs, err = ExecSQL(ctx, se, sql, args...) + return err + }) + + if err != nil { + return nil, err + } + + return +} + +// AddNewGlobalTask adds a new task to global task table. +func (stm *TaskManager) AddNewGlobalTask(key, tp string, concurrency int, meta []byte) (taskID int64, err error) { + err = stm.WithNewSession(func(se sessionctx.Context) error { + var err2 error + taskID, err2 = stm.AddGlobalTaskWithSession(se, key, tp, concurrency, meta) + return err2 + }) + return +} + +// AddGlobalTaskWithSession adds a new task to global task table with session. +func (stm *TaskManager) AddGlobalTaskWithSession(se sessionctx.Context, key, tp string, concurrency int, meta []byte) (taskID int64, err error) { + _, err = ExecSQL(stm.ctx, se, + `insert into mysql.tidb_global_task(task_key, type, state, concurrency, step, meta, state_update_time) + values (%?, %?, %?, %?, %?, %?, %?)`, + key, tp, proto.TaskStatePending, concurrency, proto.StepInit, meta, time.Now().UTC().String()) + if err != nil { + return 0, err + } + + rs, err := ExecSQL(stm.ctx, se, "select @@last_insert_id") + if err != nil { + return 0, err + } + + taskID = int64(rs[0].GetUint64(0)) + failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) }) + + return taskID, nil +} + +// GetNewGlobalTask get a new task from global task table, it's used by dispatcher only. +func (stm *TaskManager) GetNewGlobalTask() (task *proto.Task, err error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending) + if err != nil { + return task, err + } + + if len(rs) == 0 { + return nil, nil + } + + return row2GlobeTask(rs[0]), nil +} + +// GetGlobalTasksInStates gets the tasks in the states. +func (stm *TaskManager) GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error) { + if len(states) == 0 { + return task, nil + } + + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...) + if err != nil { + return task, err + } + + for _, r := range rs { + task = append(task, row2GlobeTask(r)) + } + return task, nil +} + +// GetGlobalTaskByID gets the task by the global task ID. +func (stm *TaskManager) GetGlobalTaskByID(taskID int64) (task *proto.Task, err error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where id = %?", taskID) + if err != nil { + return task, err + } + if len(rs) == 0 { + return nil, nil + } + + return row2GlobeTask(rs[0]), nil +} + +// GetGlobalTaskByKey gets the task by the task key +func (stm *TaskManager) GetGlobalTaskByKey(key string) (task *proto.Task, err error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where task_key = %?", key) + if err != nil { + return task, err + } + if len(rs) == 0 { + return nil, nil + } + + return row2GlobeTask(rs[0]), nil +} + +// row2SubTask converts a row to a subtask. +func row2SubTask(r chunk.Row) *proto.Subtask { + task := &proto.Subtask{ + ID: r.GetInt64(0), + Step: r.GetInt64(1), + Type: proto.Int2Type(int(r.GetInt64(5))), + SchedulerID: r.GetString(6), + State: r.GetString(8), + Meta: r.GetBytes(12), + StartTime: r.GetUint64(10), + } + tid, err := strconv.Atoi(r.GetString(3)) + if err != nil { + logutil.BgLogger().Warn("unexpected task ID", zap.String("task ID", r.GetString(3))) + } + task.TaskID = int64(tid) + return task +} + +// AddNewSubTask adds a new task to subtask table. +func (stm *TaskManager) AddNewSubTask(globalTaskID int64, step int64, designatedTiDBID string, meta []byte, tp string, isRevert bool) error { + st := proto.TaskStatePending + if isRevert { + st = proto.TaskStateRevertPending + } + + _, err := stm.executeSQLWithNewSession(stm.ctx, "insert into mysql.tidb_background_subtask(task_key, step, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?, %?)", globalTaskID, step, designatedTiDBID, meta, st, proto.Type2Int(tp), []byte{}) + if err != nil { + return err + } + + return nil +} + +// GetSubtaskInStates gets the subtask in the states. +func (stm *TaskManager) GetSubtaskInStates(tidbID string, taskID int64, states ...interface{}) (*proto.Subtask, error) { + args := []interface{}{tidbID, taskID} + args = append(args, states...) + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select * from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...) + if err != nil { + return nil, err + } + if len(rs) == 0 { + return nil, nil + } + return row2SubTask(rs[0]), nil +} + +// PrintSubtaskInfo log the subtask info by taskKey. +func (stm *TaskManager) PrintSubtaskInfo(taskKey int) { + rs, _ := stm.executeSQLWithNewSession(stm.ctx, + "select * from mysql.tidb_background_subtask where task_key = %?", taskKey) + + for _, r := range rs { + logutil.BgLogger().Info(fmt.Sprintf("subTask: %v\n", row2SubTask(r))) + } +} + +// GetSucceedSubtasksByStep gets the subtask in the success state. +func (stm *TaskManager) GetSucceedSubtasksByStep(taskID int64, step int64) ([]*proto.Subtask, error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select * from mysql.tidb_background_subtask where task_key = %? and state = %? and step = %?", taskID, proto.TaskStateSucceed, step) + if err != nil { + return nil, err + } + if len(rs) == 0 { + return nil, nil + } + subtasks := make([]*proto.Subtask, 0, len(rs)) + for _, r := range rs { + subtasks = append(subtasks, row2SubTask(r)) + } + return subtasks, nil +} + +// GetSubtaskInStatesCnt gets the subtask count in the states. +func (stm *TaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interface{}) (int64, error) { + args := []interface{}{taskID} + args = append(args, states...) + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select count(*) from mysql.tidb_background_subtask where task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...) + if err != nil { + return 0, err + } + + return rs[0].GetInt64(0), nil +} + +// CollectSubTaskError collects the subtask error. +func (stm *TaskManager) CollectSubTaskError(taskID int64) ([][]byte, error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select error from mysql.tidb_background_subtask where task_key = %? AND state = %?", taskID, proto.TaskStateFailed) + if err != nil { + return nil, err + } + + subTaskErrors := make([][]byte, 0, len(rs)) + for _, err := range rs { + subTaskErrors = append(subTaskErrors, err.GetBytes(0)) + } + + return subTaskErrors, nil +} + +// HasSubtasksInStates checks if there are subtasks in the states. +func (stm *TaskManager) HasSubtasksInStates(tidbID string, taskID int64, states ...interface{}) (bool, error) { + args := []interface{}{tidbID, taskID} + args = append(args, states...) + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...) + if err != nil { + return false, err + } + + return len(rs) > 0, nil +} + +// UpdateSubtaskStateAndError updates the subtask state. +func (stm *TaskManager) UpdateSubtaskStateAndError(id int64, state string, subTaskErr string) error { + _, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set state = %?, error = %? where id = %?", state, subTaskErr, id) + return err +} + +// FinishSubtask updates the subtask meta and mark state to succeed. +func (stm *TaskManager) FinishSubtask(id int64, meta []byte) error { + _, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set meta = %?, state = %? where id = %?", meta, proto.TaskStateSucceed, id) + return err +} + +// UpdateSubtaskHeartbeat updates the heartbeat of the subtask. +func (stm *TaskManager) UpdateSubtaskHeartbeat(instanceID string, taskID int64, heartbeat time.Time) error { + _, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set exec_expired = %? where exec_id = %? and task_key = %?", heartbeat.String(), instanceID, taskID) + return err +} + +// DeleteSubtasksByTaskID deletes the subtask of the given global task ID. +func (stm *TaskManager) DeleteSubtasksByTaskID(taskID int64) error { + _, err := stm.executeSQLWithNewSession(stm.ctx, "delete from mysql.tidb_background_subtask where task_key = %?", taskID) + if err != nil { + return err + } + + return nil +} + +// GetSchedulerIDsByTaskID gets the scheduler IDs of the given global task ID. +func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select distinct(exec_id) from mysql.tidb_background_subtask where task_key = %?", taskID) + if err != nil { + return nil, err + } + if len(rs) == 0 { + return nil, nil + } + + instanceIDs := make([]string, 0, len(rs)) + for _, r := range rs { + id := r.GetString(0) + instanceIDs = append(instanceIDs, id) + } + + return instanceIDs, nil +} + +// UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks +func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, isSubtaskRevert bool) error { + return stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error { + _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %?", + gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, gTask.Error, gTask.ID) + if err != nil { + return err + } + + failpoint.Inject("MockUpdateTaskErr", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("updateTaskErr")) + } + }) + + subtaskState := proto.TaskStatePending + if isSubtaskRevert { + subtaskState = proto.TaskStateRevertPending + } + + for _, subtask := range subtasks { + // TODO: insert subtasks in batch + _, err = ExecSQL(stm.ctx, se, "insert into mysql.tidb_background_subtask(step, task_key, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?, %?)", + gTask.Step, gTask.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}) + if err != nil { + return err + } + } + + return nil + }) +} + +// CancelGlobalTask cancels global task +func (stm *TaskManager) CancelGlobalTask(taskID int64) error { + _, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_global_task set state=%? where id=%? and state in (%?, %?)", + proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning, + ) + return err +} + +// CancelGlobalTaskByKeySession cancels global task by key using input session +func (stm *TaskManager) CancelGlobalTaskByKeySession(se sessionctx.Context, taskKey string) error { + _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state=%? where task_key=%? and state in (%?, %?)", + proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning) + return err +} + +// IsGlobalTaskCancelling checks whether the task state is cancelling +func (stm *TaskManager) IsGlobalTaskCancelling(taskID int64) (bool, error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?", + taskID, proto.TaskStateCancelling, + ) + + if err != nil { + return false, err + } + + return len(rs) > 0, nil +} + +// GetSubtasksByStep gets subtasks of global task by step +func (stm *TaskManager) GetSubtasksByStep(taskID, step int64) ([]*proto.Subtask, error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, + "select * from mysql.tidb_background_subtask where task_key = %? and step = %?", + taskID, step) + if err != nil { + return nil, err + } + if len(rs) == 0 { + return nil, nil + } + subtasks := make([]*proto.Subtask, 0, len(rs)) + for _, r := range rs { + subtasks = append(subtasks, row2SubTask(r)) + } + return subtasks, nil +} diff --git a/disttask/importinto/BUILD.bazel b/disttask/importinto/BUILD.bazel new file mode 100644 index 0000000000000..eabe0b7ecc10a --- /dev/null +++ b/disttask/importinto/BUILD.bazel @@ -0,0 +1,80 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "importinto", + srcs = [ + "dispatcher.go", + "job.go", + "proto.go", + "scheduler.go", + "subtask_executor.go", + "wrapper.go", + ], + importpath = "github.com/pingcap/tidb/disttask/importinto", + visibility = ["//visibility:public"], + deps = [ + "//br/pkg/lightning/backend", + "//br/pkg/lightning/backend/kv", + "//br/pkg/lightning/backend/local", + "//br/pkg/lightning/checkpoints", + "//br/pkg/lightning/common", + "//br/pkg/lightning/config", + "//br/pkg/lightning/mydump", + "//br/pkg/lightning/verification", + "//br/pkg/utils", + "//disttask/framework/dispatcher", + "//disttask/framework/handle", + "//disttask/framework/proto", + "//disttask/framework/scheduler", + "//disttask/framework/storage", + "//domain/infosync", + "//errno", + "//executor/asyncloaddata", + "//executor/importer", + "//kv", + "//parser/ast", + "//parser/mysql", + "//sessionctx", + "//sessionctx/variable", + "//table/tables", + "//util/dbterror/exeerrors", + "//util/etcd", + "//util/logutil", + "//util/mathutil", + "//util/sqlexec", + "@com_github_go_sql_driver_mysql//:mysql", + "@com_github_google_uuid//:uuid", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_tikv_client_go_v2//util", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "importinto_test", + timeout = "short", + srcs = [ + "dispatcher_test.go", + "subtask_executor_test.go", + ], + embed = [":importinto"], + flaky = True, + race = "on", + deps = [ + "//br/pkg/lightning/verification", + "//disttask/framework/proto", + "//disttask/framework/storage", + "//domain/infosync", + "//executor/importer", + "//parser/model", + "//testkit", + "//util/logutil", + "@com_github_ngaut_pools//:pools", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_stretchr_testify//require", + "@com_github_stretchr_testify//suite", + "@com_github_tikv_client_go_v2//util", + ], +) diff --git a/disttask/importinto/dispatcher.go b/disttask/importinto/dispatcher.go new file mode 100644 index 0000000000000..3f4d6822bc55d --- /dev/null +++ b/disttask/importinto/dispatcher.go @@ -0,0 +1,647 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "encoding/json" + "strconv" + "strings" + "sync" + "time" + + dmysql "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/lightning/config" + verify "github.com/pingcap/tidb/br/pkg/lightning/verification" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/disttask/framework/dispatcher" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/domain/infosync" + "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/util/etcd" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +const ( + registerTaskTTL = 10 * time.Minute + refreshTaskTTLInterval = 3 * time.Minute + registerTimeout = 5 * time.Second +) + +// NewTaskRegisterWithTTL is the ctor for TaskRegister. +// It is exported for testing. +var NewTaskRegisterWithTTL = utils.NewTaskRegisterWithTTL + +type taskInfo struct { + taskID int64 + + // operation on taskInfo is run inside detect-task goroutine, so no need to synchronize. + lastRegisterTime time.Time + + // initialized lazily in register() + etcdClient *etcd.Client + taskRegister utils.TaskRegister +} + +func (t *taskInfo) register(ctx context.Context) { + if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { + return + } + + if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { + return + } + logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) + if t.taskRegister == nil { + client, err := importer.GetEtcdClient() + if err != nil { + logger.Warn("get etcd client failed", zap.Error(err)) + return + } + t.etcdClient = client + t.taskRegister = NewTaskRegisterWithTTL(client.GetClient(), registerTaskTTL, + utils.RegisterImportInto, strconv.FormatInt(t.taskID, 10)) + } + timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) + defer cancel() + if err := t.taskRegister.RegisterTaskOnce(timeoutCtx); err != nil { + logger.Warn("register task failed", zap.Error(err)) + } else { + logger.Info("register task to pd or refresh lease success") + } + // we set it even if register failed, TTL is 10min, refresh interval is 3min, + // we can try 2 times before the lease is expired. + t.lastRegisterTime = time.Now() +} + +func (t *taskInfo) close(ctx context.Context) { + logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) + if t.taskRegister != nil { + timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) + defer cancel() + if err := t.taskRegister.Close(timeoutCtx); err != nil { + logger.Warn("unregister task failed", zap.Error(err)) + } else { + logger.Info("unregister task success") + } + t.taskRegister = nil + } + if t.etcdClient != nil { + if err := t.etcdClient.Close(); err != nil { + logger.Warn("close etcd client failed", zap.Error(err)) + } + t.etcdClient = nil + } +} + +type flowHandle struct { + mu sync.RWMutex + // NOTE: there's no need to sync for below 2 fields actually, since we add a restriction that only one + // task can be running at a time. but we might support task queuing in the future, leave it for now. + // the last time we switch TiKV into IMPORT mode, this is a global operation, do it for one task makes + // no difference to do it for all tasks. So we do not need to record the switch time for each task. + lastSwitchTime atomic.Time + // taskInfoMap is a map from taskID to taskInfo + taskInfoMap sync.Map + + // currTaskID is the taskID of the current running task. + // It may be changed when we switch to a new task or switch to a new owner. + currTaskID atomic.Int64 + disableTiKVImportMode atomic.Bool +} + +var _ dispatcher.TaskFlowHandle = (*flowHandle)(nil) + +func (h *flowHandle) OnTicker(ctx context.Context, task *proto.Task) { + // only switch TiKV mode or register task when task is running + if task.State != proto.TaskStateRunning { + return + } + h.switchTiKVMode(ctx, task) + h.registerTask(ctx, task) +} + +func (h *flowHandle) switchTiKVMode(ctx context.Context, task *proto.Task) { + h.updateCurrentTask(task) + // only import step need to switch to IMPORT mode, + // If TiKV is in IMPORT mode during checksum, coprocessor will time out. + if h.disableTiKVImportMode.Load() || task.Step != StepImport { + return + } + + if time.Since(h.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { + return + } + + h.mu.Lock() + defer h.mu.Unlock() + if time.Since(h.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { + return + } + + logger := logutil.BgLogger().With(zap.Int64("task-id", task.ID)) + switcher, err := importer.GetTiKVModeSwitcher(logger) + if err != nil { + logger.Warn("get tikv mode switcher failed", zap.Error(err)) + return + } + switcher.ToImportMode(ctx) + h.lastSwitchTime.Store(time.Now()) +} + +func (h *flowHandle) registerTask(ctx context.Context, task *proto.Task) { + val, _ := h.taskInfoMap.LoadOrStore(task.ID, &taskInfo{taskID: task.ID}) + info := val.(*taskInfo) + info.register(ctx) +} + +func (h *flowHandle) unregisterTask(ctx context.Context, task *proto.Task) { + if val, loaded := h.taskInfoMap.LoadAndDelete(task.ID); loaded { + info := val.(*taskInfo) + info.close(ctx) + } +} + +func (h *flowHandle) ProcessNormalFlow(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task) ( + resSubtaskMeta [][]byte, err error) { + logger := logutil.BgLogger().With( + zap.String("type", gTask.Type), + zap.Int64("task-id", gTask.ID), + zap.String("step", stepStr(gTask.Step)), + ) + taskMeta := &TaskMeta{} + err = json.Unmarshal(gTask.Meta, taskMeta) + if err != nil { + return nil, err + } + logger.Info("process normal flow") + + defer func() { + // currently, framework will take the task as finished when err is not nil or resSubtaskMeta is empty. + taskFinished := err == nil && len(resSubtaskMeta) == 0 + if taskFinished { + // todo: we're not running in a transaction with task update + if err2 := h.finishJob(ctx, handle, gTask, taskMeta); err2 != nil { + err = err2 + } + } else if err != nil && !h.IsRetryableErr(err) { + if err2 := h.failJob(ctx, handle, gTask, taskMeta, logger, err.Error()); err2 != nil { + // todo: we're not running in a transaction with task update, there might be case + // failJob return error, but task update succeed. + logger.Error("call failJob failed", zap.Error(err2)) + } + } + }() + + switch gTask.Step { + case proto.StepInit: + if err := preProcess(ctx, handle, gTask, taskMeta, logger); err != nil { + return nil, err + } + if err = startJob(ctx, handle, taskMeta); err != nil { + return nil, err + } + subtaskMetas, err := generateImportStepMetas(ctx, taskMeta) + if err != nil { + return nil, err + } + logger.Info("move to import step", zap.Any("subtask-count", len(subtaskMetas))) + metaBytes := make([][]byte, 0, len(subtaskMetas)) + for _, subtaskMeta := range subtaskMetas { + bs, err := json.Marshal(subtaskMeta) + if err != nil { + return nil, err + } + metaBytes = append(metaBytes, bs) + } + gTask.Step = StepImport + return metaBytes, nil + case StepImport: + h.switchTiKV2NormalMode(ctx, gTask, logger) + failpoint.Inject("clearLastSwitchTime", func() { + h.lastSwitchTime.Store(time.Time{}) + }) + stepMeta, err2 := toPostProcessStep(handle, gTask, taskMeta) + if err2 != nil { + return nil, err2 + } + if err = job2Step(ctx, taskMeta, importer.JobStepValidating); err != nil { + return nil, err + } + logger.Info("move to post-process step ", zap.Any("result", taskMeta.Result), + zap.Any("step-meta", stepMeta)) + bs, err := json.Marshal(stepMeta) + if err != nil { + return nil, err + } + failpoint.Inject("failWhenDispatchPostProcessSubtask", func() { + failpoint.Return(nil, errors.New("injected error after StepImport")) + }) + gTask.Step = StepPostProcess + return [][]byte{bs}, nil + case StepPostProcess: + return nil, nil + default: + return nil, errors.Errorf("unknown step %d", gTask.Step) + } +} + +func (h *flowHandle) ProcessErrFlow(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, receiveErr [][]byte) ([]byte, error) { + logger := logutil.BgLogger().With( + zap.String("type", gTask.Type), + zap.Int64("task-id", gTask.ID), + zap.String("step", stepStr(gTask.Step)), + ) + logger.Info("process error flow", zap.ByteStrings("error-message", receiveErr)) + taskMeta := &TaskMeta{} + err := json.Unmarshal(gTask.Meta, taskMeta) + if err != nil { + return nil, err + } + errStrs := make([]string, 0, len(receiveErr)) + for _, errStr := range receiveErr { + errStrs = append(errStrs, string(errStr)) + } + if err = h.failJob(ctx, handle, gTask, taskMeta, logger, strings.Join(errStrs, "; ")); err != nil { + return nil, err + } + + gTask.Error = receiveErr[0] + + errStr := string(receiveErr[0]) + // do nothing if the error is resumable + if isResumableErr(errStr) { + return nil, nil + } + + if gTask.Step == StepImport { + err = rollback(ctx, handle, gTask, logger) + if err != nil { + // TODO: add error code according to spec. + gTask.Error = []byte(errStr + ", " + err.Error()) + } + } + return nil, err +} + +func (*flowHandle) GetEligibleInstances(ctx context.Context, gTask *proto.Task) ([]*infosync.ServerInfo, error) { + taskMeta := &TaskMeta{} + err := json.Unmarshal(gTask.Meta, taskMeta) + if err != nil { + return nil, err + } + if len(taskMeta.EligibleInstances) > 0 { + return taskMeta.EligibleInstances, nil + } + return dispatcher.GenerateSchedulerNodes(ctx) +} + +func (*flowHandle) IsRetryableErr(error) bool { + // TODO: check whether the error is retryable. + return false +} + +func (h *flowHandle) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) { + h.updateCurrentTask(task) + if h.disableTiKVImportMode.Load() { + return + } + + h.mu.Lock() + defer h.mu.Unlock() + + switcher, err := importer.GetTiKVModeSwitcher(logger) + if err != nil { + logger.Warn("get tikv mode switcher failed", zap.Error(err)) + return + } + switcher.ToNormalMode(ctx) + + // clear it, so next task can switch TiKV mode again. + h.lastSwitchTime.Store(time.Time{}) +} + +func (h *flowHandle) updateCurrentTask(task *proto.Task) { + if h.currTaskID.Swap(task.ID) != task.ID { + taskMeta := &TaskMeta{} + if err := json.Unmarshal(task.Meta, taskMeta); err == nil { + h.disableTiKVImportMode.Store(taskMeta.Plan.DisableTiKVImportMode) + } + } +} + +// preProcess does the pre-processing for the task. +func preProcess(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta, logger *zap.Logger) error { + logger.Info("pre process") + // TODO: drop table indexes depends on the option. + // if err := dropTableIndexes(ctx, handle, taskMeta, logger); err != nil { + // return err + // } + return updateMeta(gTask, taskMeta) +} + +// nolint:deadcode +func dropTableIndexes(ctx context.Context, handle dispatcher.TaskHandle, taskMeta *TaskMeta, logger *zap.Logger) error { + tblInfo := taskMeta.Plan.TableInfo + tableName := common.UniqueTable(taskMeta.Plan.DBName, tblInfo.Name.L) + + remainIndexes, dropIndexes := common.GetDropIndexInfos(tblInfo) + for _, idxInfo := range dropIndexes { + sqlStr := common.BuildDropIndexSQL(tableName, idxInfo) + if err := executeSQL(ctx, handle, logger, sqlStr); err != nil { + if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { + switch merr.Number { + case errno.ErrCantDropFieldOrKey, errno.ErrDropIndexNeededInForeignKey: + remainIndexes = append(remainIndexes, idxInfo) + logger.Warn("can't drop index, skip", zap.String("index", idxInfo.Name.O), zap.Error(err)) + continue + } + } + return err + } + } + if len(remainIndexes) < len(tblInfo.Indices) { + taskMeta.Plan.TableInfo = taskMeta.Plan.TableInfo.Clone() + taskMeta.Plan.TableInfo.Indices = remainIndexes + } + return nil +} + +// nolint:deadcode +func createTableIndexes(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) error { + tableName := common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) + singleSQL, multiSQLs := common.BuildAddIndexSQL(tableName, taskMeta.Plan.TableInfo, taskMeta.Plan.DesiredTableInfo) + logger.Info("build add index sql", zap.String("singleSQL", singleSQL), zap.Strings("multiSQLs", multiSQLs)) + if len(multiSQLs) == 0 { + return nil + } + + err := executeSQL(ctx, executor, logger, singleSQL) + if err == nil { + return nil + } + if !common.IsDupKeyError(err) { + // TODO: refine err msg and error code according to spec. + return errors.Errorf("Failed to create index: %v, please execute the SQL manually, sql: %s", err, singleSQL) + } + if len(multiSQLs) == 1 { + return nil + } + logger.Warn("cannot add all indexes in one statement, try to add them one by one", zap.Strings("sqls", multiSQLs), zap.Error(err)) + + for i, ddl := range multiSQLs { + err := executeSQL(ctx, executor, logger, ddl) + if err != nil && !common.IsDupKeyError(err) { + // TODO: refine err msg and error code according to spec. + return errors.Errorf("Failed to create index: %v, please execute the SQLs manually, sqls: %s", err, strings.Join(multiSQLs[i:], ";")) + } + } + return nil +} + +// TODO: return the result of sql. +func executeSQL(ctx context.Context, executor storage.SessionExecutor, logger *zap.Logger, sql string, args ...interface{}) (err error) { + logger.Info("execute sql", zap.String("sql", sql), zap.Any("args", args)) + return executor.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...) + return err + }) +} + +func updateMeta(gTask *proto.Task, taskMeta *TaskMeta) error { + bs, err := json.Marshal(taskMeta) + if err != nil { + return err + } + gTask.Meta = bs + return nil +} + +func buildController(taskMeta *TaskMeta) (*importer.LoadDataController, error) { + idAlloc := kv.NewPanickingAllocators(0) + tbl, err := tables.TableFromMeta(idAlloc, taskMeta.Plan.TableInfo) + if err != nil { + return nil, err + } + + astArgs, err := importer.ASTArgsFromStmt(taskMeta.Stmt) + if err != nil { + return nil, err + } + controller, err := importer.NewLoadDataController(&taskMeta.Plan, tbl, astArgs) + if err != nil { + return nil, err + } + return controller, nil +} + +// todo: converting back and forth, we should unify struct and remove this function later. +func toChunkMap(engineCheckpoints map[int32]*checkpoints.EngineCheckpoint) map[int32][]Chunk { + chunkMap := make(map[int32][]Chunk, len(engineCheckpoints)) + for id, ecp := range engineCheckpoints { + chunkMap[id] = make([]Chunk, 0, len(ecp.Chunks)) + for _, chunkCheckpoint := range ecp.Chunks { + chunkMap[id] = append(chunkMap[id], toChunk(*chunkCheckpoint)) + } + } + return chunkMap +} + +func generateImportStepMetas(ctx context.Context, taskMeta *TaskMeta) (subtaskMetas []*ImportStepMeta, err error) { + var chunkMap map[int32][]Chunk + if len(taskMeta.ChunkMap) > 0 { + chunkMap = taskMeta.ChunkMap + } else { + controller, err2 := buildController(taskMeta) + if err2 != nil { + return nil, err2 + } + if err2 = controller.InitDataFiles(ctx); err2 != nil { + return nil, err2 + } + + engineCheckpoints, err2 := controller.PopulateChunks(ctx) + if err2 != nil { + return nil, err2 + } + chunkMap = toChunkMap(engineCheckpoints) + } + for id := range chunkMap { + if id == common.IndexEngineID { + continue + } + subtaskMeta := &ImportStepMeta{ + ID: id, + Chunks: chunkMap[id], + } + subtaskMetas = append(subtaskMetas, subtaskMeta) + } + return subtaskMetas, nil +} + +// we will update taskMeta in place and make gTask.Meta point to the new taskMeta. +func toPostProcessStep(handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) (*PostProcessStepMeta, error) { + metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step) + if err != nil { + return nil, err + } + + subtaskMetas := make([]*ImportStepMeta, 0, len(metas)) + for _, bs := range metas { + var subtaskMeta ImportStepMeta + if err := json.Unmarshal(bs, &subtaskMeta); err != nil { + return nil, err + } + subtaskMetas = append(subtaskMetas, &subtaskMeta) + } + var localChecksum verify.KVChecksum + columnSizeMap := make(map[int64]int64) + for _, subtaskMeta := range subtaskMetas { + checksum := verify.MakeKVChecksum(subtaskMeta.Checksum.Size, subtaskMeta.Checksum.KVs, subtaskMeta.Checksum.Sum) + localChecksum.Add(&checksum) + + taskMeta.Result.ReadRowCnt += subtaskMeta.Result.ReadRowCnt + taskMeta.Result.LoadedRowCnt += subtaskMeta.Result.LoadedRowCnt + for key, val := range subtaskMeta.Result.ColSizeMap { + columnSizeMap[key] += val + } + } + taskMeta.Result.ColSizeMap = columnSizeMap + if err2 := updateMeta(gTask, taskMeta); err2 != nil { + return nil, err2 + } + return &PostProcessStepMeta{ + Checksum: Checksum{ + Size: localChecksum.SumSize(), + KVs: localChecksum.SumKVS(), + Sum: localChecksum.Sum(), + }, + }, nil +} + +func startJob(ctx context.Context, handle dispatcher.TaskHandle, taskMeta *TaskMeta) error { + failpoint.Inject("syncBeforeJobStarted", func() { + TestSyncChan <- struct{}{} + <-TestSyncChan + }) + err := handle.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + return importer.StartJob(ctx, exec, taskMeta.JobID) + }) + failpoint.Inject("syncAfterJobStarted", func() { + TestSyncChan <- struct{}{} + }) + return err +} + +func job2Step(ctx context.Context, taskMeta *TaskMeta, step string) error { + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return err + } + // todo: use dispatcher.TaskHandle + // we might call this in scheduler later, there's no dispatcher.TaskHandle, so we use globalTaskManager here. + return globalTaskManager.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + return importer.Job2Step(ctx, exec, taskMeta.JobID, step) + }) +} + +func (h *flowHandle) finishJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error { + h.unregisterTask(ctx, gTask) + redactSensitiveInfo(gTask, taskMeta) + summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt} + return handle.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + return importer.FinishJob(ctx, exec, taskMeta.JobID, summary) + }) +} + +func (h *flowHandle) failJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, + taskMeta *TaskMeta, logger *zap.Logger, errorMsg string) error { + h.switchTiKV2NormalMode(ctx, gTask, logger) + h.unregisterTask(ctx, gTask) + redactSensitiveInfo(gTask, taskMeta) + return handle.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + return importer.FailJob(ctx, exec, taskMeta.JobID, errorMsg) + }) +} + +func redactSensitiveInfo(gTask *proto.Task, taskMeta *TaskMeta) { + taskMeta.Stmt = "" + taskMeta.Plan.Path = ast.RedactURL(taskMeta.Plan.Path) + if err := updateMeta(gTask, taskMeta); err != nil { + // marshal failed, should not happen + logutil.BgLogger().Warn("failed to update task meta", zap.Error(err)) + } +} + +// isResumableErr checks whether it's possible to rely on checkpoint to re-import data after the error has been fixed. +func isResumableErr(string) bool { + // TODO: add more cases + return false +} + +func rollback(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, logger *zap.Logger) (err error) { + taskMeta := &TaskMeta{} + err = json.Unmarshal(gTask.Meta, taskMeta) + if err != nil { + return err + } + + logger.Info("rollback") + + // // TODO: create table indexes depends on the option. + // // create table indexes even if the rollback is failed. + // defer func() { + // err2 := createTableIndexes(ctx, handle, taskMeta, logger) + // err = multierr.Append(err, err2) + // }() + + tableName := common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) + // truncate the table + return executeSQL(ctx, handle, logger, "TRUNCATE "+tableName) +} + +func stepStr(step int64) string { + switch step { + case proto.StepInit: + return "init" + case StepImport: + return "import" + case StepPostProcess: + return "postprocess" + default: + return "unknown" + } +} + +func init() { + dispatcher.RegisterTaskFlowHandle(proto.ImportInto, &flowHandle{}) +} diff --git a/disttask/importinto/job.go b/disttask/importinto/job.go new file mode 100644 index 0000000000000..64b61048d8c88 --- /dev/null +++ b/disttask/importinto/job.go @@ -0,0 +1,279 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/disttask/framework/handle" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/domain/infosync" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/dbterror/exeerrors" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "go.uber.org/zap" +) + +// DistImporter is a JobImporter for distributed IMPORT INTO. +type DistImporter struct { + *importer.JobImportParam + plan *importer.Plan + stmt string + logger *zap.Logger + // the instance to import data, used for single-node import, nil means import data on all instances. + instance *infosync.ServerInfo + // the files to import, when import from server file, we need to pass those file to the framework. + chunkMap map[int32][]Chunk + sourceFileSize int64 + // only set after submit task + jobID int64 + taskID int64 +} + +// NewDistImporter creates a new DistImporter. +func NewDistImporter(param *importer.JobImportParam, plan *importer.Plan, stmt string, sourceFileSize int64) (*DistImporter, error) { + return &DistImporter{ + JobImportParam: param, + plan: plan, + stmt: stmt, + logger: logutil.BgLogger(), + sourceFileSize: sourceFileSize, + }, nil +} + +// NewDistImporterCurrNode creates a new DistImporter to import data on current node. +func NewDistImporterCurrNode(param *importer.JobImportParam, plan *importer.Plan, stmt string, sourceFileSize int64) (*DistImporter, error) { + serverInfo, err := infosync.GetServerInfo() + if err != nil { + return nil, err + } + return &DistImporter{ + JobImportParam: param, + plan: plan, + stmt: stmt, + logger: logutil.BgLogger(), + instance: serverInfo, + sourceFileSize: sourceFileSize, + }, nil +} + +// NewDistImporterServerFile creates a new DistImporter to import given files on current node. +// we also run import on current node. +// todo: merge all 3 ctor into one. +func NewDistImporterServerFile(param *importer.JobImportParam, plan *importer.Plan, stmt string, ecp map[int32]*checkpoints.EngineCheckpoint, sourceFileSize int64) (*DistImporter, error) { + distImporter, err := NewDistImporterCurrNode(param, plan, stmt, sourceFileSize) + if err != nil { + return nil, err + } + distImporter.chunkMap = toChunkMap(ecp) + return distImporter, nil +} + +// Param implements JobImporter.Param. +func (ti *DistImporter) Param() *importer.JobImportParam { + return ti.JobImportParam +} + +// Import implements JobImporter.Import. +func (*DistImporter) Import() { + // todo: remove it +} + +// ImportTask import task. +func (ti *DistImporter) ImportTask(task *proto.Task) { + ti.logger.Info("start distribute IMPORT INTO") + ti.Group.Go(func() error { + defer close(ti.Done) + // task is run using distribute framework, so we only wait for the task to finish. + return handle.WaitGlobalTask(ti.GroupCtx, task) + }) +} + +// Result implements JobImporter.Result. +func (ti *DistImporter) Result() importer.JobImportResult { + var result importer.JobImportResult + taskMeta, err := getTaskMeta(ti.jobID) + if err != nil { + result.Msg = err.Error() + return result + } + + var ( + numWarnings uint64 + numRecords uint64 + numDeletes uint64 + numSkipped uint64 + ) + numRecords = taskMeta.Result.ReadRowCnt + // todo: we don't have a strict REPLACE or IGNORE mode in physical mode, so we can't get the numDeletes/numSkipped. + // we can have it when there's duplicate detection. + msg := fmt.Sprintf(mysql.MySQLErrName[mysql.ErrLoadInfo].Raw, numRecords, numDeletes, numSkipped, numWarnings) + return importer.JobImportResult{ + Msg: msg, + Affected: taskMeta.Result.ReadRowCnt, + ColSizeMap: taskMeta.Result.ColSizeMap, + } +} + +// Close implements the io.Closer interface. +func (*DistImporter) Close() error { + return nil +} + +// SubmitTask submits a task to the distribute framework. +func (ti *DistImporter) SubmitTask(ctx context.Context) (int64, *proto.Task, error) { + var instances []*infosync.ServerInfo + if ti.instance != nil { + instances = append(instances, ti.instance) + } + // we use globalTaskManager to submit task, user might not have the privilege to system tables. + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return 0, nil, err + } + + var jobID, taskID int64 + plan := ti.plan + if err = globalTaskManager.WithNewTxn(ctx, func(se sessionctx.Context) error { + var err2 error + exec := se.(sqlexec.SQLExecutor) + // If 2 client try to execute IMPORT INTO concurrently, there's chance that both of them will pass the check. + // We can enforce ONLY one import job running by: + // - using LOCK TABLES, but it requires enable-table-lock=true, it's not enabled by default. + // - add a key to PD as a distributed lock, but it's a little complex, and we might support job queuing later. + // So we only add this simple soft check here and doc it. + activeJobCnt, err2 := importer.GetActiveJobCnt(ctx, exec) + if err2 != nil { + return err2 + } + if activeJobCnt > 0 { + return exeerrors.ErrLoadDataPreCheckFailed.FastGenByArgs("there's pending or running jobs") + } + jobID, err2 = importer.CreateJob(ctx, exec, plan.DBName, plan.TableInfo.Name.L, plan.TableInfo.ID, + plan.User, plan.Parameters, ti.sourceFileSize) + if err2 != nil { + return err2 + } + task := TaskMeta{ + JobID: jobID, + Plan: *plan, + Stmt: ti.stmt, + EligibleInstances: instances, + ChunkMap: ti.chunkMap, + } + taskMeta, err2 := json.Marshal(task) + if err2 != nil { + return err2 + } + taskID, err2 = globalTaskManager.AddGlobalTaskWithSession(se, TaskKey(jobID), proto.ImportInto, + int(plan.ThreadCnt), taskMeta) + if err2 != nil { + return err2 + } + return nil + }); err != nil { + return 0, nil, err + } + + globalTask, err := globalTaskManager.GetGlobalTaskByID(taskID) + if err != nil { + return 0, nil, err + } + if globalTask == nil { + return 0, nil, errors.Errorf("cannot find global task with ID %d", taskID) + } + // update logger with task id. + ti.jobID = jobID + ti.taskID = taskID + ti.logger = ti.logger.With(zap.Int64("task-id", globalTask.ID)) + + ti.logger.Info("job submitted to global task queue", zap.Int64("job-id", jobID)) + + return jobID, globalTask, nil +} + +func (*DistImporter) taskKey() string { + // task key is meaningless to IMPORT INTO, so we use a random uuid. + return fmt.Sprintf("%s/%s", proto.ImportInto, uuid.New().String()) +} + +// JobID returns the job id. +func (ti *DistImporter) JobID() int64 { + return ti.jobID +} + +func getTaskMeta(jobID int64) (*TaskMeta, error) { + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return nil, err + } + taskKey := TaskKey(jobID) + globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey) + if err != nil { + return nil, err + } + if globalTask == nil { + return nil, errors.Errorf("cannot find global task with key %s", taskKey) + } + var taskMeta TaskMeta + if err := json.Unmarshal(globalTask.Meta, &taskMeta); err != nil { + return nil, err + } + return &taskMeta, nil +} + +// GetTaskImportedRows gets the number of imported rows of a job. +// Note: for finished job, we can get the number of imported rows from task meta. +func GetTaskImportedRows(jobID int64) (uint64, error) { + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return 0, err + } + taskKey := TaskKey(jobID) + globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey) + if err != nil { + return 0, err + } + if globalTask == nil { + return 0, errors.Errorf("cannot find global task with key %s", taskKey) + } + subtasks, err := globalTaskManager.GetSubtasksByStep(globalTask.ID, StepImport) + if err != nil { + return 0, err + } + var importedRows uint64 + for _, subtask := range subtasks { + var subtaskMeta ImportStepMeta + if err2 := json.Unmarshal(subtask.Meta, &subtaskMeta); err2 != nil { + return 0, err2 + } + importedRows += subtaskMeta.Result.LoadedRowCnt + } + return importedRows, nil +} + +// TaskKey returns the task key for a job. +func TaskKey(jobID int64) string { + return fmt.Sprintf("%s/%d", proto.ImportInto, jobID) +} diff --git a/disttask/importinto/subtask_executor.go b/disttask/importinto/subtask_executor.go new file mode 100644 index 0000000000000..be6de9a75d0c0 --- /dev/null +++ b/disttask/importinto/subtask_executor.go @@ -0,0 +1,240 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "strconv" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/backend/local" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/lightning/config" + verify "github.com/pingcap/tidb/br/pkg/lightning/verification" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/scheduler" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/mathutil" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// TestSyncChan is used to test. +var TestSyncChan = make(chan struct{}) + +// ImportMinimalTaskExecutor is a minimal task executor for IMPORT INTO. +type ImportMinimalTaskExecutor struct { + mTtask *importStepMinimalTask +} + +// Run implements the SubtaskExecutor.Run interface. +func (e *ImportMinimalTaskExecutor) Run(ctx context.Context) error { + logger := logutil.BgLogger().With(zap.String("type", proto.ImportInto), zap.Int64("table-id", e.mTtask.Plan.TableInfo.ID)) + logger.Info("run minimal task") + failpoint.Inject("waitBeforeSortChunk", func() { + time.Sleep(3 * time.Second) + }) + failpoint.Inject("errorWhenSortChunk", func() { + failpoint.Return(errors.New("occur an error when sort chunk")) + }) + failpoint.Inject("syncBeforeSortChunk", func() { + TestSyncChan <- struct{}{} + <-TestSyncChan + }) + chunkCheckpoint := toChunkCheckpoint(e.mTtask.Chunk) + sharedVars := e.mTtask.SharedVars + if err := importer.ProcessChunk(ctx, &chunkCheckpoint, sharedVars.TableImporter, sharedVars.DataEngine, sharedVars.IndexEngine, sharedVars.Progress, logger); err != nil { + return err + } + + sharedVars.mu.Lock() + defer sharedVars.mu.Unlock() + sharedVars.Checksum.Add(&chunkCheckpoint.Checksum) + return nil +} + +type postProcessMinimalTaskExecutor struct { + mTask *postProcessStepMinimalTask +} + +func (e *postProcessMinimalTaskExecutor) Run(ctx context.Context) error { + mTask := e.mTask + failpoint.Inject("waitBeforePostProcess", func() { + time.Sleep(5 * time.Second) + }) + return postProcess(ctx, mTask.taskMeta, &mTask.meta, mTask.logger) +} + +// postProcess does the post-processing for the task. +func postProcess(ctx context.Context, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) (err error) { + failpoint.Inject("syncBeforePostProcess", func() { + TestSyncChan <- struct{}{} + <-TestSyncChan + }) + + logger.Info("post process") + + // TODO: create table indexes depends on the option. + // create table indexes even if the post process is failed. + // defer func() { + // err2 := createTableIndexes(ctx, globalTaskManager, taskMeta, logger) + // err = multierr.Append(err, err2) + // }() + + return verifyChecksum(ctx, taskMeta, subtaskMeta, logger) +} + +func verifyChecksum(ctx context.Context, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) error { + if taskMeta.Plan.Checksum == config.OpLevelOff { + return nil + } + localChecksum := verify.MakeKVChecksum(subtaskMeta.Checksum.Size, subtaskMeta.Checksum.KVs, subtaskMeta.Checksum.Sum) + logger.Info("local checksum", zap.Object("checksum", &localChecksum)) + + failpoint.Inject("waitCtxDone", func() { + <-ctx.Done() + }) + + globalTaskManager, err := storage.GetTaskManager() + if err != nil { + return err + } + remoteChecksum, err := checksumTable(ctx, globalTaskManager, taskMeta, logger) + if err != nil { + return err + } + if !remoteChecksum.IsEqual(&localChecksum) { + err2 := common.ErrChecksumMismatch.GenWithStackByArgs( + remoteChecksum.Checksum, localChecksum.Sum(), + remoteChecksum.TotalKVs, localChecksum.SumKVS(), + remoteChecksum.TotalBytes, localChecksum.SumSize(), + ) + if taskMeta.Plan.Checksum == config.OpLevelOptional { + logger.Warn("verify checksum failed, but checksum is optional, will skip it", zap.Error(err2)) + err2 = nil + } + return err2 + } + logger.Info("checksum pass", zap.Object("local", &localChecksum)) + return nil +} + +func checksumTable(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) (*local.RemoteChecksum, error) { + var ( + tableName = common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) + sql = "ADMIN CHECKSUM TABLE " + tableName + maxErrorRetryCount = 3 + distSQLScanConcurrencyFactor = 1 + remoteChecksum *local.RemoteChecksum + txnErr error + ) + + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + for i := 0; i < maxErrorRetryCount; i++ { + txnErr = executor.WithNewTxn(ctx, func(se sessionctx.Context) error { + // increase backoff weight + if err := setBackoffWeight(se, taskMeta, logger); err != nil { + logger.Warn("set tidb_backoff_weight failed", zap.Error(err)) + } + + distSQLScanConcurrency := se.GetSessionVars().DistSQLScanConcurrency() + se.GetSessionVars().SetDistSQLScanConcurrency(mathutil.Max(distSQLScanConcurrency/distSQLScanConcurrencyFactor, local.MinDistSQLScanConcurrency)) + defer func() { + se.GetSessionVars().SetDistSQLScanConcurrency(distSQLScanConcurrency) + }() + + rs, err := storage.ExecSQL(ctx, se, sql) + if err != nil { + return err + } + if len(rs) < 1 { + return errors.New("empty checksum result") + } + + failpoint.Inject("errWhenChecksum", func() { + if i == 0 { + failpoint.Return(errors.New("occur an error when checksum, coprocessor task terminated due to exceeding the deadline")) + } + }) + + // ADMIN CHECKSUM TABLE .
example. + // mysql> admin checksum table test.t; + // +---------+------------+---------------------+-----------+-------------+ + // | Db_name | Table_name | Checksum_crc64_xor | Total_kvs | Total_bytes | + // +---------+------------+---------------------+-----------+-------------+ + // | test | t | 8520875019404689597 | 7296873 | 357601387 | + // +---------+------------+------------- + remoteChecksum = &local.RemoteChecksum{ + Schema: rs[0].GetString(0), + Table: rs[0].GetString(1), + Checksum: rs[0].GetUint64(2), + TotalKVs: rs[0].GetUint64(3), + TotalBytes: rs[0].GetUint64(4), + } + return nil + }) + if !common.IsRetryableError(txnErr) { + break + } + distSQLScanConcurrencyFactor *= 2 + logger.Warn("retry checksum table", zap.Int("retry count", i+1), zap.Error(txnErr)) + } + return remoteChecksum, txnErr +} + +// TestChecksumTable is used to test checksum table in unit test. +func TestChecksumTable(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) (*local.RemoteChecksum, error) { + return checksumTable(ctx, executor, taskMeta, logger) +} + +func setBackoffWeight(se sessionctx.Context, taskMeta *TaskMeta, logger *zap.Logger) error { + backoffWeight := local.DefaultBackoffWeight + if val, ok := taskMeta.Plan.ImportantSysVars[variable.TiDBBackOffWeight]; ok { + if weight, err := strconv.Atoi(val); err == nil && weight > backoffWeight { + backoffWeight = weight + } + } + logger.Info("set backoff weight", zap.Int("weight", backoffWeight)) + return se.GetSessionVars().SetSystemVar(variable.TiDBBackOffWeight, strconv.Itoa(backoffWeight)) +} + +func init() { + scheduler.RegisterSubtaskExectorConstructor(proto.ImportInto, StepImport, + // The order of the subtask executors is the same as the order of the subtasks. + func(minimalTask proto.MinimalTask, step int64) (scheduler.SubtaskExecutor, error) { + task, ok := minimalTask.(*importStepMinimalTask) + if !ok { + return nil, errors.Errorf("invalid task type %T", minimalTask) + } + return &ImportMinimalTaskExecutor{mTtask: task}, nil + }, + ) + scheduler.RegisterSubtaskExectorConstructor(proto.ImportInto, StepPostProcess, + func(minimalTask proto.MinimalTask, step int64) (scheduler.SubtaskExecutor, error) { + mTask, ok := minimalTask.(*postProcessStepMinimalTask) + if !ok { + return nil, errors.Errorf("invalid task type %T", minimalTask) + } + return &postProcessMinimalTaskExecutor{mTask: mTask}, nil + }, + ) +} diff --git a/disttask/importinto/subtask_executor_test.go b/disttask/importinto/subtask_executor_test.go new file mode 100644 index 0000000000000..4596ffc795aa2 --- /dev/null +++ b/disttask/importinto/subtask_executor_test.go @@ -0,0 +1,73 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto_test + +import ( + "context" + "testing" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/failpoint" + verify "github.com/pingcap/tidb/br/pkg/lightning/verification" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/disttask/importinto" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util/logutil" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/util" +) + +func TestChecksumTable(t *testing.T) { + ctx := context.Background() + store := testkit.CreateMockStore(t) + gtk := testkit.NewTestKit(t, store) + pool := pools.NewResourcePool(func() (pools.Resource, error) { + return gtk.Session(), nil + }, 1, 1, time.Second) + defer pool.Close() + mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool) + + taskMeta := &importinto.TaskMeta{ + Plan: importer.Plan{ + DBName: "db", + TableInfo: &model.TableInfo{ + Name: model.NewCIStr("tb"), + }, + }, + } + // fake result + localChecksum := verify.MakeKVChecksum(1, 1, 1) + gtk.MustExec("create database db") + gtk.MustExec("create table db.tb(id int)") + gtk.MustExec("insert into db.tb values(1)") + remoteChecksum, err := importinto.TestChecksumTable(ctx, mgr, taskMeta, logutil.BgLogger()) + require.NoError(t, err) + require.True(t, remoteChecksum.IsEqual(&localChecksum)) + // again + remoteChecksum, err = importinto.TestChecksumTable(ctx, mgr, taskMeta, logutil.BgLogger()) + require.NoError(t, err) + require.True(t, remoteChecksum.IsEqual(&localChecksum)) + + _ = failpoint.Enable("github.com/pingcap/tidb/disttask/importinto/errWhenChecksum", `return(true)`) + defer func() { + _ = failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/errWhenChecksum") + }() + remoteChecksum, err = importinto.TestChecksumTable(ctx, mgr, taskMeta, logutil.BgLogger()) + require.NoError(t, err) + require.True(t, remoteChecksum.IsEqual(&localChecksum)) +} diff --git a/executor/import_into.go b/executor/import_into.go new file mode 100644 index 0000000000000..92f16fb13f611 --- /dev/null +++ b/executor/import_into.go @@ -0,0 +1,302 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "sync/atomic" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/disttask/framework/proto" + fstorage "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/disttask/importinto" + "github.com/pingcap/tidb/executor/asyncloaddata" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/mysql" + plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/privilege" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/sessiontxn" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/dbterror/exeerrors" + "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/sqlexec" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +var ( + // TestDetachedTaskFinished is a flag for test. + TestDetachedTaskFinished atomic.Bool + // TestCancelFunc for test. + TestCancelFunc context.CancelFunc +) + +const unknownImportedRowCount = -1 + +// ImportIntoExec represents a IMPORT INTO executor. +type ImportIntoExec struct { + baseExecutor + userSctx sessionctx.Context + importPlan *importer.Plan + controller *importer.LoadDataController + stmt string + + dataFilled bool +} + +var ( + _ Executor = (*ImportIntoExec)(nil) +) + +func newImportIntoExec(b baseExecutor, userSctx sessionctx.Context, plan *plannercore.ImportInto, tbl table.Table) ( + *ImportIntoExec, error) { + importPlan, err := importer.NewImportPlan(userSctx, plan, tbl) + if err != nil { + return nil, err + } + astArgs := importer.ASTArgsFromImportPlan(plan) + controller, err := importer.NewLoadDataController(importPlan, tbl, astArgs) + if err != nil { + return nil, err + } + return &ImportIntoExec{ + baseExecutor: b, + userSctx: userSctx, + importPlan: importPlan, + controller: controller, + stmt: plan.Stmt, + }, nil +} + +// Next implements the Executor Next interface. +func (e *ImportIntoExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + req.GrowAndReset(e.maxChunkSize) + ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) + if e.dataFilled { + // need to return an empty req to indicate all results have been written + return nil + } + if err2 := e.controller.InitDataFiles(ctx); err2 != nil { + return err2 + } + + // must use a new session to pre-check, else the stmt in show processlist will be changed. + newSCtx, err2 := CreateSession(e.userSctx) + if err2 != nil { + return err2 + } + defer CloseSession(newSCtx) + sqlExec := newSCtx.(sqlexec.SQLExecutor) + if err2 = e.controller.CheckRequirements(ctx, sqlExec); err2 != nil { + return err2 + } + + failpoint.Inject("cancellableCtx", func() { + // KILL is not implemented in testkit, so we use a fail-point to simulate it. + newCtx, cancel := context.WithCancel(ctx) + ctx = newCtx + TestCancelFunc = cancel + }) + // todo: we don't need Job now, remove it later. + parentCtx := ctx + if e.controller.Detached { + parentCtx = context.Background() + } + group, groupCtx := errgroup.WithContext(parentCtx) + param := &importer.JobImportParam{ + Job: &asyncloaddata.Job{}, + Group: group, + GroupCtx: groupCtx, + Done: make(chan struct{}), + Progress: asyncloaddata.NewProgress(false), + } + distImporter, err := e.getJobImporter(ctx, param) + if err != nil { + return err + } + defer func() { + _ = distImporter.Close() + }() + param.Progress.SourceFileSize = e.controller.TotalFileSize + jobID, task, err := distImporter.SubmitTask(ctx) + if err != nil { + return err + } + + if e.controller.Detached { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalImportInto) + se, err := CreateSession(e.userSctx) + if err != nil { + return err + } + go func() { + defer CloseSession(se) + // error is stored in system table, so we can ignore it here + //nolint: errcheck + _ = e.doImport(ctx, se, distImporter, task) + failpoint.Inject("testDetachedTaskFinished", func() { + TestDetachedTaskFinished.Store(true) + }) + }() + return e.fillJobInfo(ctx, jobID, req) + } + if err = e.doImport(ctx, e.userSctx, distImporter, task); err != nil { + return err + } + return e.fillJobInfo(ctx, jobID, req) +} + +func (e *ImportIntoExec) fillJobInfo(ctx context.Context, jobID int64, req *chunk.Chunk) error { + e.dataFilled = true + // we use globalTaskManager to get job, user might not have the privilege to system tables. + globalTaskManager, err := fstorage.GetTaskManager() + if err != nil { + return err + } + var info *importer.JobInfo + if err = globalTaskManager.WithNewSession(func(se sessionctx.Context) error { + sqlExec := se.(sqlexec.SQLExecutor) + var err2 error + info, err2 = importer.GetJob(ctx, sqlExec, jobID, e.ctx.GetSessionVars().User.String(), false) + return err2 + }); err != nil { + return err + } + fillOneImportJobInfo(info, req, unknownImportedRowCount) + return nil +} + +func (e *ImportIntoExec) getJobImporter(ctx context.Context, param *importer.JobImportParam) (*importinto.DistImporter, error) { + importFromServer, err := storage.IsLocalPath(e.controller.Path) + if err != nil { + // since we have checked this during creating controller, this should not happen. + return nil, exeerrors.ErrLoadDataInvalidURI.FastGenByArgs(err.Error()) + } + logutil.Logger(ctx).Info("get job importer", zap.Stringer("param", e.controller.Parameters), + zap.Bool("dist-task-enabled", variable.EnableDistTask.Load())) + if importFromServer { + ecp, err2 := e.controller.PopulateChunks(ctx) + if err2 != nil { + return nil, err2 + } + return importinto.NewDistImporterServerFile(param, e.importPlan, e.stmt, ecp, e.controller.TotalFileSize) + } + // if tidb_enable_dist_task=true, we import distributively, otherwise we import on current node. + if variable.EnableDistTask.Load() { + return importinto.NewDistImporter(param, e.importPlan, e.stmt, e.controller.TotalFileSize) + } + return importinto.NewDistImporterCurrNode(param, e.importPlan, e.stmt, e.controller.TotalFileSize) +} + +func (e *ImportIntoExec) doImport(ctx context.Context, se sessionctx.Context, distImporter *importinto.DistImporter, task *proto.Task) error { + distImporter.ImportTask(task) + group := distImporter.Param().Group + err := group.Wait() + // when user KILL the connection, the ctx will be canceled, we need to cancel the import job. + if errors.Cause(err) == context.Canceled { + globalTaskManager, err2 := fstorage.GetTaskManager() + if err2 != nil { + return err2 + } + // use background, since ctx is canceled already. + return cancelImportJob(context.Background(), globalTaskManager, distImporter.JobID()) + } + if err2 := flushStats(ctx, se, e.importPlan.TableInfo.ID, distImporter.Result()); err2 != nil { + logutil.Logger(ctx).Error("flush stats failed", zap.Error(err2)) + } + return err +} + +// ImportIntoActionExec represents a import into action executor. +type ImportIntoActionExec struct { + baseExecutor + tp ast.ImportIntoActionTp + jobID int64 +} + +var ( + _ Executor = (*ImportIntoActionExec)(nil) +) + +// Next implements the Executor Next interface. +func (e *ImportIntoActionExec) Next(ctx context.Context, _ *chunk.Chunk) error { + ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) + + var hasSuperPriv bool + if pm := privilege.GetPrivilegeManager(e.ctx); pm != nil { + hasSuperPriv = pm.RequestVerification(e.ctx.GetSessionVars().ActiveRoles, "", "", "", mysql.SuperPriv) + } + // we use sessionCtx from GetTaskManager, user ctx might not have enough privileges. + globalTaskManager, err := fstorage.GetTaskManager() + if err != nil { + return err + } + if err = e.checkPrivilegeAndStatus(ctx, globalTaskManager, hasSuperPriv); err != nil { + return err + } + + logutil.Logger(ctx).Info("import into action", zap.Int64("jobID", e.jobID), zap.Any("action", e.tp)) + return cancelImportJob(ctx, globalTaskManager, e.jobID) +} + +func (e *ImportIntoActionExec) checkPrivilegeAndStatus(ctx context.Context, manager *fstorage.TaskManager, hasSuperPriv bool) error { + var info *importer.JobInfo + if err := manager.WithNewSession(func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + var err2 error + info, err2 = importer.GetJob(ctx, exec, e.jobID, e.ctx.GetSessionVars().User.String(), hasSuperPriv) + return err2 + }); err != nil { + return err + } + if !info.CanCancel() { + return exeerrors.ErrLoadDataInvalidOperation.FastGenByArgs("CANCEL") + } + return nil +} + +// flushStats flushes the stats of the table. +func flushStats(ctx context.Context, se sessionctx.Context, tableID int64, result importer.JobImportResult) error { + if err := sessiontxn.NewTxn(ctx, se); err != nil { + return err + } + sessionVars := se.GetSessionVars() + sessionVars.TxnCtxMu.Lock() + defer sessionVars.TxnCtxMu.Unlock() + sessionVars.TxnCtx.UpdateDeltaForTable(tableID, int64(result.Affected), int64(result.Affected), result.ColSizeMap) + se.StmtCommit(ctx) + return se.CommitTxn(ctx) +} + +func cancelImportJob(ctx context.Context, manager *fstorage.TaskManager, jobID int64) error { + // todo: cancel is async operation, we don't wait here now, maybe add a wait syntax later. + // todo: after CANCEL, user can see the job status is Canceled immediately, but the job might still running. + // and the state of framework task might became finished since framework don't force state change DAG when update task. + // todo: add a CANCELLING status? + return manager.WithNewTxn(ctx, func(se sessionctx.Context) error { + exec := se.(sqlexec.SQLExecutor) + if err2 := importer.CancelJob(ctx, exec, jobID); err2 != nil { + return err2 + } + return manager.CancelGlobalTaskByKeySession(se, importinto.TaskKey(jobID)) + }) +} diff --git a/executor/importer/BUILD.bazel b/executor/importer/BUILD.bazel new file mode 100644 index 0000000000000..2cb8288221492 --- /dev/null +++ b/executor/importer/BUILD.bazel @@ -0,0 +1,104 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "importer", + srcs = [ + "chunk_process.go", + "engine_process.go", + "import.go", + "job.go", + "kv_encode.go", + "precheck.go", + "table_import.go", + ], + importpath = "github.com/pingcap/tidb/executor/importer", + visibility = ["//visibility:public"], + deps = [ + "//br/pkg/lightning/backend", + "//br/pkg/lightning/backend/encode", + "//br/pkg/lightning/backend/kv", + "//br/pkg/lightning/backend/local", + "//br/pkg/lightning/checkpoints", + "//br/pkg/lightning/common", + "//br/pkg/lightning/config", + "//br/pkg/lightning/log", + "//br/pkg/lightning/mydump", + "//br/pkg/lightning/verification", + "//br/pkg/storage", + "//br/pkg/streamhelper", + "//br/pkg/utils", + "//config", + "//executor/asyncloaddata", + "//expression", + "//kv", + "//meta/autoid", + "//parser", + "//parser/ast", + "//parser/format", + "//parser/model", + "//parser/mysql", + "//parser/terror", + "//planner/core", + "//sessionctx", + "//sessionctx/stmtctx", + "//sessionctx/variable", + "//table", + "//table/tables", + "//tablecodec", + "//types", + "//util", + "//util/chunk", + "//util/dbterror", + "//util/dbterror/exeerrors", + "//util/etcd", + "//util/filter", + "//util/intest", + "//util/logutil", + "//util/sqlexec", + "//util/stringutil", + "//util/syncutil", + "@com_github_docker_go_units//:go-units", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_log//:log", + "@com_github_tikv_client_go_v2//config", + "@com_github_tikv_client_go_v2//tikv", + "@com_github_tikv_client_go_v2//util", + "@org_golang_x_exp//slices", + "@org_golang_x_sync//errgroup", + "@org_uber_go_multierr//:multierr", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "importer_test", + timeout = "short", + srcs = [ + "import_test.go", + "job_test.go", + "table_import_test.go", + ], + embed = [":importer"], + flaky = True, + race = "on", + shard_count = 11, + deps = [ + "//br/pkg/errors", + "//br/pkg/lightning/config", + "//config", + "//expression", + "//parser", + "//parser/ast", + "//planner/core", + "//testkit", + "//util/dbterror/exeerrors", + "//util/logutil", + "//util/mock", + "//util/sqlexec", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_stretchr_testify//require", + "@org_uber_go_zap//:zap", + ], +) diff --git a/executor/importer/table_import.go b/executor/importer/table_import.go new file mode 100644 index 0000000000000..086e648328c3c --- /dev/null +++ b/executor/importer/table_import.go @@ -0,0 +1,565 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importer + +import ( + "context" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strconv" + "sync" + "time" + + "github.com/docker/go-units" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/lightning/backend" + "github.com/pingcap/tidb/br/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/br/pkg/lightning/backend/local" + "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/lightning/log" + "github.com/pingcap/tidb/br/pkg/lightning/mydump" + "github.com/pingcap/tidb/br/pkg/storage" + tidb "github.com/pingcap/tidb/config" + tidbkv "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/syncutil" + "go.uber.org/multierr" + "go.uber.org/zap" +) + +// NewTiKVModeSwitcher make it a var, so we can mock it in tests. +var NewTiKVModeSwitcher = local.NewTiKVModeSwitcher + +var ( + // CheckDiskQuotaInterval is the default time interval to check disk quota. + // TODO: make it dynamically adjusting according to the speed of import and the disk size. + CheckDiskQuotaInterval = time.Minute +) + +// prepareSortDir creates a new directory for import, remove previous sort directory if exists. +func prepareSortDir(e *LoadDataController, taskID int64, tidbCfg *tidb.Config) (string, error) { + sortPathSuffix := "import-" + strconv.Itoa(int(tidbCfg.Port)) + importDir := filepath.Join(tidbCfg.TempDir, sortPathSuffix) + sortDir := filepath.Join(importDir, strconv.FormatInt(taskID, 10)) + + if info, err := os.Stat(importDir); err != nil || !info.IsDir() { + if err != nil && !os.IsNotExist(err) { + e.logger.Error("stat import dir failed", zap.String("import_dir", importDir), zap.Error(err)) + return "", errors.Trace(err) + } + if info != nil && !info.IsDir() { + e.logger.Warn("import dir is not a dir, remove it", zap.String("import_dir", importDir)) + if err := os.RemoveAll(importDir); err != nil { + return "", errors.Trace(err) + } + } + e.logger.Info("import dir not exists, create it", zap.String("import_dir", importDir)) + if err := os.MkdirAll(importDir, 0o700); err != nil { + e.logger.Error("failed to make dir", zap.String("import_dir", importDir), zap.Error(err)) + return "", errors.Trace(err) + } + } + + // todo: remove this after we support checkpoint + if _, err := os.Stat(sortDir); err != nil { + if !os.IsNotExist(err) { + e.logger.Error("stat sort dir failed", zap.String("sort_dir", sortDir), zap.Error(err)) + return "", errors.Trace(err) + } + } else { + e.logger.Warn("sort dir already exists, remove it", zap.String("sort_dir", sortDir)) + if err := os.RemoveAll(sortDir); err != nil { + return "", errors.Trace(err) + } + } + return sortDir, nil +} + +// GetTiKVModeSwitcher creates a new TiKV mode switcher. +func GetTiKVModeSwitcher(logger *zap.Logger) (local.TiKVModeSwitcher, error) { + tidbCfg := tidb.GetGlobalConfig() + hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) + tls, err := common.NewTLS( + tidbCfg.Security.ClusterSSLCA, + tidbCfg.Security.ClusterSSLCert, + tidbCfg.Security.ClusterSSLKey, + hostPort, + nil, nil, nil, + ) + if err != nil { + return nil, err + } + return NewTiKVModeSwitcher(tls, tidbCfg.Path, logger), nil +} + +func getCachedKVStoreFrom(pdAddr string, tls *common.TLS) (tidbkv.Storage, error) { + // Disable GC because TiDB enables GC already. + keySpaceName := tidb.GetGlobalKeyspaceName() + // the kv store we get is a cached store, so we can't close it. + kvStore, err := GetKVStore(fmt.Sprintf("tikv://%s?disableGC=true&keyspaceName=%s", pdAddr, keySpaceName), tls.ToTiKVSecurityConfig()) + if err != nil { + return nil, errors.Trace(err) + } + return kvStore, nil +} + +// NewTableImporter creates a new table importer. +func NewTableImporter(param *JobImportParam, e *LoadDataController, taskID int64) (ti *TableImporter, err error) { + idAlloc := kv.NewPanickingAllocators(0) + tbl, err := tables.TableFromMeta(idAlloc, e.Table.Meta()) + if err != nil { + return nil, errors.Annotatef(err, "failed to tables.TableFromMeta %s", e.Table.Meta().Name) + } + + tidbCfg := tidb.GetGlobalConfig() + // todo: we only need to prepare this once on each node(we might call it 3 times in distribution framework) + dir, err := prepareSortDir(e, taskID, tidbCfg) + if err != nil { + return nil, err + } + + hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) + tls, err := common.NewTLS( + tidbCfg.Security.ClusterSSLCA, + tidbCfg.Security.ClusterSSLCert, + tidbCfg.Security.ClusterSSLKey, + hostPort, + nil, nil, nil, + ) + if err != nil { + return nil, err + } + + // no need to close kvStore, since it's a cached store. + kvStore, err := getCachedKVStoreFrom(tidbCfg.Path, tls) + if err != nil { + return nil, errors.Trace(err) + } + + backendConfig := local.BackendConfig{ + PDAddr: tidbCfg.Path, + LocalStoreDir: dir, + MaxConnPerStore: config.DefaultRangeConcurrency, + ConnCompressType: config.CompressionNone, + WorkerConcurrency: config.DefaultRangeConcurrency * 2, + KVWriteBatchSize: config.KVWriteBatchSize, + RegionSplitBatchSize: config.DefaultRegionSplitBatchSize, + RegionSplitConcurrency: runtime.GOMAXPROCS(0), + // enable after we support checkpoint + CheckpointEnabled: false, + MemTableSize: config.DefaultEngineMemCacheSize, + LocalWriterMemCacheSize: int64(config.DefaultLocalWriterMemCacheSize), + ShouldCheckTiKV: true, + DupeDetectEnabled: false, + DuplicateDetectOpt: local.DupDetectOpt{ReportErrOnDup: false}, + StoreWriteBWLimit: int(e.MaxWriteSpeed), + MaxOpenFiles: int(util.GenRLimit("table_import")), + KeyspaceName: tidb.GetGlobalKeyspaceName(), + PausePDSchedulerScope: config.PausePDSchedulerScopeTable, + } + + // todo: use a real region size getter + regionSizeGetter := &local.TableRegionSizeGetterImpl{} + localBackend, err := local.NewBackend(param.GroupCtx, tls, backendConfig, regionSizeGetter) + if err != nil { + return nil, err + } + + return &TableImporter{ + JobImportParam: param, + LoadDataController: e, + backend: localBackend, + tableInfo: &checkpoints.TidbTableInfo{ + ID: e.Table.Meta().ID, + Name: e.Table.Meta().Name.O, + Core: e.Table.Meta(), + }, + encTable: tbl, + dbID: e.DBID, + store: e.dataStore, + kvStore: kvStore, + logger: e.logger, + // this is the value we use for 50TiB data parallel import. + // this might not be the optimal value. + // todo: use different default for single-node import and distributed import. + regionSplitSize: 2 * int64(config.SplitRegionSize), + regionSplitKeys: 2 * int64(config.SplitRegionKeys), + diskQuota: adjustDiskQuota(int64(e.DiskQuota), dir, e.logger), + diskQuotaLock: new(syncutil.RWMutex), + }, nil +} + +// TableImporter is a table importer. +type TableImporter struct { + *JobImportParam + *LoadDataController + backend *local.Backend + tableInfo *checkpoints.TidbTableInfo + // this table has a separate id allocator used to record the max row id allocated. + encTable table.Table + dbID int64 + + store storage.ExternalStorage + // the kv store we get is a cached store, so we can't close it. + kvStore tidbkv.Storage + logger *zap.Logger + regionSplitSize int64 + regionSplitKeys int64 + // the smallest auto-generated ID in current import. + // if there's no auto-generated id column or the column value is not auto-generated, it will be 0. + lastInsertID uint64 + diskQuota int64 + diskQuotaLock *syncutil.RWMutex +} + +func (ti *TableImporter) getParser(ctx context.Context, chunk *checkpoints.ChunkCheckpoint) (mydump.Parser, error) { + info := LoadDataReaderInfo{ + Opener: func(ctx context.Context) (io.ReadSeekCloser, error) { + reader, err := mydump.OpenReader(ctx, &chunk.FileMeta, ti.dataStore) + if err != nil { + return nil, errors.Trace(err) + } + return reader, nil + }, + Remote: &chunk.FileMeta, + } + parser, err := ti.LoadDataController.GetParser(ctx, info) + if err != nil { + return nil, err + } + // todo: when support checkpoint, we should set pos too. + // WARN: parser.SetPos can only be set before we read anything now. should fix it before set pos. + parser.SetRowID(chunk.Chunk.PrevRowIDMax) + return parser, nil +} + +func (ti *TableImporter) getKVEncoder(chunk *checkpoints.ChunkCheckpoint) (kvEncoder, error) { + cfg := &encode.EncodingConfig{ + SessionOptions: encode.SessionOptions{ + SQLMode: ti.SQLMode, + Timestamp: chunk.Timestamp, + SysVars: ti.ImportantSysVars, + AutoRandomSeed: chunk.Chunk.PrevRowIDMax, + }, + Path: chunk.FileMeta.Path, + Table: ti.encTable, + Logger: log.Logger{Logger: ti.logger.With(zap.String("path", chunk.FileMeta.Path))}, + } + return newTableKVEncoder(cfg, ti) +} + +// PopulateChunks populates chunks from table regions. +// in dist framework, this should be done in the tidb node which is responsible for splitting job into subtasks +// then table-importer handles data belongs to the subtask. +func (e *LoadDataController) PopulateChunks(ctx context.Context) (ecp map[int32]*checkpoints.EngineCheckpoint, err error) { + task := log.BeginTask(e.logger, "populate chunks") + defer func() { + task.End(zap.ErrorLevel, err) + }() + + tableMeta := &mydump.MDTableMeta{ + DB: e.DBName, + Name: e.Table.Meta().Name.O, + DataFiles: e.toMyDumpFiles(), + } + dataDivideCfg := &mydump.DataDivideConfig{ + ColumnCnt: len(e.Table.Meta().Columns), + EngineDataSize: int64(config.DefaultBatchSize), + MaxChunkSize: int64(config.MaxRegionSize), + Concurrency: int(e.ThreadCnt), + EngineConcurrency: config.DefaultTableConcurrency, + IOWorkers: nil, + Store: e.dataStore, + TableMeta: tableMeta, + } + tableRegions, err2 := mydump.MakeTableRegions(ctx, dataDivideCfg) + + if err2 != nil { + e.logger.Error("populate chunks failed", zap.Error(err2)) + return nil, err2 + } + + var maxRowID int64 + timestamp := time.Now().Unix() + tableCp := &checkpoints.TableCheckpoint{ + Engines: map[int32]*checkpoints.EngineCheckpoint{}, + } + for _, region := range tableRegions { + engine, found := tableCp.Engines[region.EngineID] + if !found { + engine = &checkpoints.EngineCheckpoint{ + Status: checkpoints.CheckpointStatusLoaded, + } + tableCp.Engines[region.EngineID] = engine + } + ccp := &checkpoints.ChunkCheckpoint{ + Key: checkpoints.ChunkCheckpointKey{ + Path: region.FileMeta.Path, + Offset: region.Chunk.Offset, + }, + FileMeta: region.FileMeta, + ColumnPermutation: nil, + Chunk: region.Chunk, + Timestamp: timestamp, + } + engine.Chunks = append(engine.Chunks, ccp) + if region.Chunk.RowIDMax > maxRowID { + maxRowID = region.Chunk.RowIDMax + } + } + + if common.TableHasAutoID(e.Table.Meta()) { + tidbCfg := tidb.GetGlobalConfig() + hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) + tls, err4 := common.NewTLS( + tidbCfg.Security.ClusterSSLCA, + tidbCfg.Security.ClusterSSLCert, + tidbCfg.Security.ClusterSSLKey, + hostPort, + nil, nil, nil, + ) + if err4 != nil { + return nil, err4 + } + + // no need to close kvStore, since it's a cached store. + kvStore, err4 := getCachedKVStoreFrom(tidbCfg.Path, tls) + if err4 != nil { + return nil, errors.Trace(err4) + } + if err3 := common.RebaseGlobalAutoID(ctx, 0, kvStore, e.DBID, e.Table.Meta()); err3 != nil { + return nil, errors.Trace(err3) + } + newMinRowID, _, err3 := common.AllocGlobalAutoID(ctx, maxRowID, kvStore, e.DBID, e.Table.Meta()) + if err3 != nil { + return nil, errors.Trace(err3) + } + e.rebaseChunkRowID(newMinRowID, tableCp.Engines) + } + + // Add index engine checkpoint + tableCp.Engines[common.IndexEngineID] = &checkpoints.EngineCheckpoint{Status: checkpoints.CheckpointStatusLoaded} + return tableCp.Engines, nil +} + +func (*LoadDataController) rebaseChunkRowID(rowIDBase int64, engines map[int32]*checkpoints.EngineCheckpoint) { + if rowIDBase == 0 { + return + } + for _, engine := range engines { + for _, chunk := range engine.Chunks { + chunk.Chunk.PrevRowIDMax += rowIDBase + chunk.Chunk.RowIDMax += rowIDBase + } + } +} + +// a simplified version of EstimateCompactionThreshold +func (ti *TableImporter) getTotalRawFileSize(indexCnt int64) int64 { + var totalSize int64 + for _, file := range ti.dataFiles { + size := file.RealSize + if file.Type == mydump.SourceTypeParquet { + // parquet file is compressed, thus estimates with a factor of 2 + size *= 2 + } + totalSize += size + } + return totalSize * indexCnt +} + +// OpenIndexEngine opens an index engine. +func (ti *TableImporter) OpenIndexEngine(ctx context.Context, engineID int32) (*backend.OpenedEngine, error) { + idxEngineCfg := &backend.EngineConfig{ + TableInfo: ti.tableInfo, + } + idxCnt := len(ti.tableInfo.Core.Indices) + if !common.TableHasAutoRowID(ti.tableInfo.Core) { + idxCnt-- + } + // todo: getTotalRawFileSize returns size of all data files, but in distributed framework, + // we create one index engine for each engine, should reflect this in the future. + threshold := local.EstimateCompactionThreshold2(ti.getTotalRawFileSize(int64(idxCnt))) + idxEngineCfg.Local = backend.LocalEngineConfig{ + Compact: threshold > 0, + CompactConcurrency: 4, + CompactThreshold: threshold, + } + fullTableName := ti.fullTableName() + // todo: cleanup all engine data on any error since we don't support checkpoint for now + // some return path, didn't make sure all data engine and index engine are cleaned up. + // maybe we can add this in upper level to clean the whole local-sort directory + mgr := backend.MakeEngineManager(ti.backend) + return mgr.OpenEngine(ctx, idxEngineCfg, fullTableName, engineID) +} + +// OpenDataEngine opens a data engine. +func (ti *TableImporter) OpenDataEngine(ctx context.Context, engineID int32) (*backend.OpenedEngine, error) { + dataEngineCfg := &backend.EngineConfig{ + TableInfo: ti.tableInfo, + } + // todo: support checking IsRowOrdered later. + //if ti.tableMeta.IsRowOrdered { + // dataEngineCfg.Local.Compact = true + // dataEngineCfg.Local.CompactConcurrency = 4 + // dataEngineCfg.Local.CompactThreshold = local.CompactionUpperThreshold + //} + mgr := backend.MakeEngineManager(ti.backend) + return mgr.OpenEngine(ctx, dataEngineCfg, ti.fullTableName(), engineID) +} + +// ImportAndCleanup imports the engine and cleanup the engine data. +func (ti *TableImporter) ImportAndCleanup(ctx context.Context, closedEngine *backend.ClosedEngine) (int64, error) { + var kvCount int64 + importErr := closedEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys) + if closedEngine.GetID() != common.IndexEngineID { + // todo: change to a finer-grain progress later. + // each row is encoded into 1 data key + kvCount = ti.backend.GetImportedKVCount(closedEngine.GetUUID()) + } + // todo: if we need support checkpoint, engine should not be cleanup if import failed. + cleanupErr := closedEngine.Cleanup(ctx) + return kvCount, multierr.Combine(importErr, cleanupErr) +} + +// FullTableName return FQDN of the table. +func (ti *TableImporter) fullTableName() string { + return common.UniqueTable(ti.DBName, ti.Table.Meta().Name.O) +} + +// Close implements the io.Closer interface. +func (ti *TableImporter) Close() error { + ti.backend.Close() + return nil +} + +func (ti *TableImporter) setLastInsertID(id uint64) { + // todo: if we run concurrently, we should use atomic operation here. + if id == 0 { + return + } + if ti.lastInsertID == 0 || id < ti.lastInsertID { + ti.lastInsertID = id + } +} + +// CheckDiskQuota checks disk quota. +func (ti *TableImporter) CheckDiskQuota(ctx context.Context) { + var locker sync.Locker + lockDiskQuota := func() { + if locker == nil { + ti.diskQuotaLock.Lock() + locker = ti.diskQuotaLock + } + } + unlockDiskQuota := func() { + if locker != nil { + locker.Unlock() + locker = nil + } + } + + defer unlockDiskQuota() + + for { + select { + case <-ctx.Done(): + return + case <-time.After(CheckDiskQuotaInterval): + } + + largeEngines, inProgressLargeEngines, totalDiskSize, totalMemSize := local.CheckDiskQuota(ti.backend, ti.diskQuota) + if len(largeEngines) == 0 && inProgressLargeEngines == 0 { + unlockDiskQuota() + continue + } + + ti.logger.Warn("disk quota exceeded", + zap.Int64("diskSize", totalDiskSize), + zap.Int64("memSize", totalMemSize), + zap.Int64("quota", ti.diskQuota), + zap.Int("largeEnginesCount", len(largeEngines)), + zap.Int("inProgressLargeEnginesCount", inProgressLargeEngines)) + + lockDiskQuota() + + if len(largeEngines) == 0 { + ti.logger.Warn("all large engines are already importing, keep blocking all writes") + continue + } + + if err := ti.backend.FlushAllEngines(ctx); err != nil { + ti.logger.Error("flush engine for disk quota failed, check again later", log.ShortError(err)) + unlockDiskQuota() + continue + } + + // at this point, all engines are synchronized on disk. + // we then import every large engines one by one and complete. + // if any engine failed to import, we just try again next time, since the data are still intact. + var importErr error + for _, engine := range largeEngines { + // Use a larger split region size to avoid split the same region by many times. + if err := ti.backend.UnsafeImportAndReset( + ctx, + engine, + int64(config.SplitRegionSize)*int64(config.MaxSplitRegionSizeRatio), + int64(config.SplitRegionKeys)*int64(config.MaxSplitRegionSizeRatio), + ); err != nil { + importErr = multierr.Append(importErr, err) + } + } + if importErr != nil { + // discuss: should we return the error and cancel the import? + ti.logger.Error("import large engines failed, check again later", log.ShortError(importErr)) + } + unlockDiskQuota() + } +} + +func adjustDiskQuota(diskQuota int64, sortDir string, logger *zap.Logger) int64 { + sz, err := common.GetStorageSize(sortDir) + if err != nil { + logger.Warn("failed to get storage size", zap.Error(err)) + if diskQuota != 0 { + return diskQuota + } + logger.Info("use default quota instead", zap.Int64("quota", int64(DefaultDiskQuota))) + return int64(DefaultDiskQuota) + } + + maxDiskQuota := int64(float64(sz.Capacity) * 0.8) + switch { + case diskQuota == 0: + logger.Info("use 0.8 of the storage size as default disk quota", + zap.String("quota", units.HumanSize(float64(maxDiskQuota)))) + return maxDiskQuota + case diskQuota > maxDiskQuota: + logger.Warn("disk quota is larger than 0.8 of the storage size, use 0.8 of the storage size instead", + zap.String("quota", units.HumanSize(float64(maxDiskQuota)))) + return maxDiskQuota + default: + return diskQuota + } +} diff --git a/tests/realtikvtest/importintotest/job_test.go b/tests/realtikvtest/importintotest/job_test.go new file mode 100644 index 0000000000000..82397c946fa8e --- /dev/null +++ b/tests/realtikvtest/importintotest/job_test.go @@ -0,0 +1,635 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importintotest + +import ( + "context" + "fmt" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/docker/go-units" + "github.com/fsouza/fake-gcs-server/fakestorage" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/config" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/disttask/framework/scheduler" + "github.com/pingcap/tidb/disttask/framework/storage" + "github.com/pingcap/tidb/disttask/importinto" + "github.com/pingcap/tidb/executor" + "github.com/pingcap/tidb/executor/importer" + "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util/dbterror/exeerrors" +) + +func (s *mockGCSSuite) compareJobInfoWithoutTime(jobInfo *importer.JobInfo, row []interface{}) { + s.Equal(strconv.Itoa(int(jobInfo.ID)), row[0]) + + urlExpected, err := url.Parse(jobInfo.Parameters.FileLocation) + s.NoError(err) + urlGot, err := url.Parse(fmt.Sprintf("%v", row[1])) + s.NoError(err) + // order of query parameters might change + s.Equal(urlExpected.Query(), urlGot.Query()) + urlExpected.RawQuery, urlGot.RawQuery = "", "" + s.Equal(urlExpected.String(), urlGot.String()) + + s.Equal(utils.EncloseDBAndTable(jobInfo.TableSchema, jobInfo.TableName), row[2]) + s.Equal(strconv.Itoa(int(jobInfo.TableID)), row[3]) + s.Equal(jobInfo.Step, row[4]) + s.Equal(jobInfo.Status, row[5]) + s.Equal(units.HumanSize(float64(jobInfo.SourceFileSize)), row[6]) + if jobInfo.Summary == nil { + s.Equal("", row[7].(string)) + } else { + s.Equal(strconv.Itoa(int(jobInfo.Summary.ImportedRows)), row[7]) + } + s.Regexp(jobInfo.ErrorMessage, row[8]) + s.Equal(jobInfo.CreatedBy, row[12]) +} + +func (s *mockGCSSuite) TestShowJob() { + s.tk.MustExec("delete from mysql.tidb_import_jobs") + s.prepareAndUseDB("test_show_job") + s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t2 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t3 (i INT PRIMARY KEY);") + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-job", Name: "t.csv"}, + Content: []byte("1\n2"), + }) + s.T().Cleanup(func() { + _ = s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil) + }) + // create 2 user which don't have system table privileges + s.tk.MustExec(`DROP USER IF EXISTS 'test_show_job1'@'localhost';`) + s.tk.MustExec(`CREATE USER 'test_show_job1'@'localhost';`) + s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_show_job.* to 'test_show_job1'@'localhost'`) + s.tk.MustExec(`DROP USER IF EXISTS 'test_show_job2'@'localhost';`) + s.tk.MustExec(`CREATE USER 'test_show_job2'@'localhost';`) + s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_show_job.* to 'test_show_job2'@'localhost'`) + do, err := session.GetDomain(s.store) + s.NoError(err) + tableID1 := do.MustGetTableID(s.T(), "test_show_job", "t1") + tableID2 := do.MustGetTableID(s.T(), "test_show_job", "t2") + tableID3 := do.MustGetTableID(s.T(), "test_show_job", "t3") + + // show non-exists job + err = s.tk.QueryToErr("show import job 9999999999") + s.ErrorIs(err, exeerrors.ErrLoadDataJobNotFound) + + // test show job by id using test_show_job1 + s.enableFailpoint("github.com/pingcap/tidb/executor/importer/setLastImportJobID", `return(true)`) + s.enableFailpoint("github.com/pingcap/tidb/disttask/framework/storage/testSetLastTaskID", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/parser/ast/forceRedactURL", "return(true)") + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_show_job1", Hostname: "localhost"}, nil, nil, nil)) + result1 := s.tk.MustQuery(fmt.Sprintf(`import into t1 FROM 'gs://test-show-job/t.csv?access-key=aaaaaa&secret-access-key=bbbbbb&endpoint=%s'`, + gcsEndpoint)).Rows() + s.Len(result1, 1) + s.tk.MustQuery("select * from t1").Check(testkit.Rows("1", "2")) + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + s.Equal(result1, rows) + jobInfo := &importer.JobInfo{ + ID: importer.TestLastImportJobID.Load(), + TableSchema: "test_show_job", + TableName: "t1", + TableID: tableID1, + CreatedBy: "test_show_job1@localhost", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-job/t.csv?access-key=xxxxxx&secret-access-key=xxxxxx&endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "finished", + Step: "", + Summary: &importer.JobSummary{ + ImportedRows: 2, + }, + ErrorMessage: "", + } + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + + // test show job by id using test_show_job2 + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_show_job2", Hostname: "localhost"}, nil, nil, nil)) + result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test-show-job/t.csv?endpoint=%s'`, gcsEndpoint)).Rows() + s.tk.MustQuery("select * from t2").Check(testkit.Rows("1", "2")) + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + s.Equal(result2, rows) + jobInfo.ID = importer.TestLastImportJobID.Load() + jobInfo.TableName = "t2" + jobInfo.TableID = tableID2 + jobInfo.CreatedBy = "test_show_job2@localhost" + jobInfo.Parameters.FileLocation = fmt.Sprintf(`gs://test-show-job/t.csv?endpoint=%s`, gcsEndpoint) + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + rows = s.tk.MustQuery("show import jobs").Rows() + s.Len(rows, 1) + s.Equal(result2, rows) + + // show import jobs with root + checkJobsMatch := func(rows [][]interface{}) { + s.GreaterOrEqual(len(rows), 2) // other cases may create import jobs + var matched int + for _, r := range rows { + if r[0] == result1[0][0] { + s.Equal(result1[0], r) + matched++ + } + if r[0] == result2[0][0] { + s.Equal(result2[0], r) + matched++ + } + } + s.Equal(2, matched) + } + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + rows = s.tk.MustQuery("show import jobs").Rows() + checkJobsMatch(rows) + // show import job by id with root + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + s.Equal(result2, rows) + jobInfo.ID = importer.TestLastImportJobID.Load() + jobInfo.TableName = "t2" + jobInfo.TableID = tableID2 + jobInfo.CreatedBy = "test_show_job2@localhost" + jobInfo.Parameters.FileLocation = fmt.Sprintf(`gs://test-show-job/t.csv?endpoint=%s`, gcsEndpoint) + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + + // grant SUPER to test_show_job2, now it can see all jobs + s.tk.MustExec(`GRANT SUPER on *.* to 'test_show_job2'@'localhost'`) + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_show_job2", Hostname: "localhost"}, nil, nil, nil)) + rows = s.tk.MustQuery("show import jobs").Rows() + checkJobsMatch(rows) + + // show running jobs with 2 subtasks + s.enableFailpoint("github.com/pingcap/tidb/disttask/framework/scheduler/syncAfterSubtaskFinish", `return(true)`) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-job", Name: "t2.csv"}, + Content: []byte("3\n4"), + }) + backup4 := config.DefaultBatchSize + config.DefaultBatchSize = 1 + s.T().Cleanup(func() { + config.DefaultBatchSize = backup4 + }) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + // wait first subtask finish + <-scheduler.TestSyncChan + + jobInfo = &importer.JobInfo{ + ID: importer.TestLastImportJobID.Load(), + TableSchema: "test_show_job", + TableName: "t3", + TableID: tableID3, + CreatedBy: "test_show_job2@localhost", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-job/t*.csv?access-key=xxxxxx&secret-access-key=xxxxxx&endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 6, + Status: "running", + Step: "importing", + Summary: &importer.JobSummary{ + ImportedRows: 2, + }, + ErrorMessage: "", + } + tk2 := testkit.NewTestKit(s.T(), s.store) + rows = tk2.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + // show processlist, should be redacted too + procRows := tk2.MustQuery("show full processlist").Rows() + + var got bool + for _, r := range procRows { + user := r[1].(string) + sql := r[7].(string) + if user == "test_show_job2" && strings.Contains(sql, "IMPORT INTO") { + s.Contains(sql, "access-key=xxxxxx") + s.Contains(sql, "secret-access-key=xxxxxx") + s.NotContains(sql, "aaaaaa") + s.NotContains(sql, "bbbbbb") + got = true + } + } + s.True(got) + + // resume the scheduler + scheduler.TestSyncChan <- struct{}{} + // wait second subtask finish + <-scheduler.TestSyncChan + rows = tk2.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(rows, 1) + jobInfo.Summary.ImportedRows = 4 + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + // resume the scheduler, need disable failpoint first, otherwise the post-process subtask will be blocked + s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/syncAfterSubtaskFinish")) + scheduler.TestSyncChan <- struct{}{} + }() + s.tk.MustQuery(fmt.Sprintf(`import into t3 FROM 'gs://test-show-job/t*.csv?access-key=aaaaaa&secret-access-key=bbbbbb&endpoint=%s' with thread=1`, gcsEndpoint)) + wg.Wait() + s.tk.MustQuery("select * from t3").Sort().Check(testkit.Rows("1", "2", "3", "4")) +} + +func (s *mockGCSSuite) TestShowDetachedJob() { + s.prepareAndUseDB("show_detached_job") + s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t2 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t3 (i INT PRIMARY KEY);") + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-detached-job", Name: "t.csv"}, + Content: []byte("1\n2"), + }) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-detached-job", Name: "t2.csv"}, + Content: []byte("1\n1"), + }) + do, err := session.GetDomain(s.store) + s.NoError(err) + tableID1 := do.MustGetTableID(s.T(), "show_detached_job", "t1") + tableID2 := do.MustGetTableID(s.T(), "show_detached_job", "t2") + tableID3 := do.MustGetTableID(s.T(), "show_detached_job", "t3") + + jobInfo := &importer.JobInfo{ + TableSchema: "show_detached_job", + TableName: "t1", + TableID: tableID1, + CreatedBy: "root@%", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-detached-job/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "pending", + Step: "", + } + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + result1 := s.tk.MustQuery(fmt.Sprintf(`import into t1 FROM 'gs://test-show-detached-job/t.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result1, 1) + jobID1, err := strconv.Atoi(result1[0][0].(string)) + s.NoError(err) + jobInfo.ID = int64(jobID1) + s.compareJobInfoWithoutTime(jobInfo, result1[0]) + + s.Eventually(func() bool { + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID1)).Rows() + return rows[0][5] == "finished" + }, 10*time.Second, 500*time.Millisecond) + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID1)).Rows() + s.Len(rows, 1) + jobInfo.Status = "finished" + jobInfo.Summary = &importer.JobSummary{ + ImportedRows: 2, + } + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + s.tk.MustQuery("select * from t1").Check(testkit.Rows("1", "2")) + + // job fail with checksum mismatch + result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test-show-detached-job/t2.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result2, 1) + jobID2, err := strconv.Atoi(result2[0][0].(string)) + s.NoError(err) + jobInfo = &importer.JobInfo{ + ID: int64(jobID2), + TableSchema: "show_detached_job", + TableName: "t2", + TableID: tableID2, + CreatedBy: "root@%", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-detached-job/t2.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "pending", + Step: "", + } + s.compareJobInfoWithoutTime(jobInfo, result2[0]) + s.Eventually(func() bool { + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() + return rows[0][5] == "failed" + }, 10*time.Second, 500*time.Millisecond) + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() + s.Len(rows, 1) + jobInfo.Status = "failed" + jobInfo.Step = importer.JobStepValidating + jobInfo.ErrorMessage = `\[Lighting:Restore:ErrChecksumMismatch]checksum mismatched remote vs local.*` + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + + // subtask fail with error + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/errorWhenSortChunk", "return(true)") + result3 := s.tk.MustQuery(fmt.Sprintf(`import into t3 FROM 'gs://test-show-detached-job/t.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result3, 1) + jobID3, err := strconv.Atoi(result3[0][0].(string)) + s.NoError(err) + jobInfo = &importer.JobInfo{ + ID: int64(jobID3), + TableSchema: "show_detached_job", + TableName: "t3", + TableID: tableID3, + CreatedBy: "root@%", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test-show-detached-job/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "pending", + Step: "", + } + s.compareJobInfoWithoutTime(jobInfo, result3[0]) + s.Eventually(func() bool { + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID3)).Rows() + return rows[0][5] == "failed" + }, 10*time.Second, 500*time.Millisecond) + rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID3)).Rows() + s.Len(rows, 1) + jobInfo.Status = "failed" + jobInfo.Step = importer.JobStepImporting + jobInfo.ErrorMessage = `occur an error when sort chunk.*` + s.compareJobInfoWithoutTime(jobInfo, rows[0]) +} + +func (s *mockGCSSuite) TestCancelJob() { + s.prepareAndUseDB("test_cancel_job") + s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") + s.tk.MustExec("CREATE TABLE t2 (i INT PRIMARY KEY);") + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test_cancel_job", Name: "t.csv"}, + Content: []byte("1\n2"), + }) + s.T().Cleanup(func() { + _ = s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil) + }) + s.tk.MustExec(`DROP USER IF EXISTS 'test_cancel_job1'@'localhost';`) + s.tk.MustExec(`CREATE USER 'test_cancel_job1'@'localhost';`) + s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_cancel_job.* to 'test_cancel_job1'@'localhost'`) + s.tk.MustExec(`DROP USER IF EXISTS 'test_cancel_job2'@'localhost';`) + s.tk.MustExec(`CREATE USER 'test_cancel_job2'@'localhost';`) + s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_cancel_job.* to 'test_cancel_job2'@'localhost'`) + do, err := session.GetDomain(s.store) + s.NoError(err) + tableID1 := do.MustGetTableID(s.T(), "test_cancel_job", "t1") + tableID2 := do.MustGetTableID(s.T(), "test_cancel_job", "t2") + + // cancel non-exists job + err = s.tk.ExecToErr("cancel import job 9999999999") + s.ErrorIs(err, exeerrors.ErrLoadDataJobNotFound) + + getTask := func(jobID int64) *proto.Task { + globalTaskManager, err := storage.GetTaskManager() + s.NoError(err) + taskKey := importinto.TaskKey(jobID) + globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey) + s.NoError(err) + return globalTask + } + + // cancel a running job created by self + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/waitBeforeSortChunk", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncAfterJobStarted", "return(true)") + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_cancel_job1", Hostname: "localhost"}, nil, nil, nil)) + result1 := s.tk.MustQuery(fmt.Sprintf(`import into t1 FROM 'gs://test_cancel_job/t.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result1, 1) + jobID1, err := strconv.Atoi(result1[0][0].(string)) + s.NoError(err) + // wait job started + <-importinto.TestSyncChan + // dist framework has bug, the cancelled status might be overridden by running status, + // so we wait it turn running before cancel, see https://github.com/pingcap/tidb/issues/44443 + time.Sleep(3 * time.Second) + s.tk.MustExec(fmt.Sprintf("cancel import job %d", jobID1)) + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID1)).Rows() + s.Len(rows, 1) + jobInfo := &importer.JobInfo{ + ID: int64(jobID1), + TableSchema: "test_cancel_job", + TableName: "t1", + TableID: tableID1, + CreatedBy: "test_cancel_job1@localhost", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test_cancel_job/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "cancelled", + Step: importer.JobStepImporting, + ErrorMessage: "cancelled by user", + } + s.compareJobInfoWithoutTime(jobInfo, rows[0]) + s.Eventually(func() bool { + task := getTask(int64(jobID1)) + return task.State == proto.TaskStateReverted + }, 10*time.Second, 500*time.Millisecond) + + // cancel again, should fail + s.ErrorIs(s.tk.ExecToErr(fmt.Sprintf("cancel import job %d", jobID1)), exeerrors.ErrLoadDataInvalidOperation) + + // cancel a job created by test_cancel_job1 using test_cancel_job2, should fail + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_cancel_job2", Hostname: "localhost"}, nil, nil, nil)) + s.ErrorIs(s.tk.ExecToErr(fmt.Sprintf("cancel import job %d", jobID1)), core.ErrSpecificAccessDenied) + // cancel by root, should pass privilege check + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + s.ErrorIs(s.tk.ExecToErr(fmt.Sprintf("cancel import job %d", jobID1)), exeerrors.ErrLoadDataInvalidOperation) + + // cancel job in post-process phase, using test_cancel_job2 + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_cancel_job2", Hostname: "localhost"}, nil, nil, nil)) + s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/waitBeforeSortChunk")) + s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/syncAfterJobStarted")) + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncBeforePostProcess", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/waitCtxDone", "return(true)") + result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test_cancel_job/t.csv?endpoint=%s' with detached`, + gcsEndpoint)).Rows() + s.Len(result2, 1) + jobID2, err := strconv.Atoi(result2[0][0].(string)) + s.NoError(err) + // wait job reach post-process phase + <-importinto.TestSyncChan + s.tk.MustExec(fmt.Sprintf("cancel import job %d", jobID2)) + // resume the job + importinto.TestSyncChan <- struct{}{} + rows2 := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() + s.Len(rows2, 1) + jobInfo = &importer.JobInfo{ + ID: int64(jobID2), + TableSchema: "test_cancel_job", + TableName: "t2", + TableID: tableID2, + CreatedBy: "test_cancel_job2@localhost", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://test_cancel_job/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "cancelled", + Step: importer.JobStepValidating, + ErrorMessage: "cancelled by user", + } + s.compareJobInfoWithoutTime(jobInfo, rows2[0]) + globalTaskManager, err := storage.GetTaskManager() + s.NoError(err) + taskKey := importinto.TaskKey(int64(jobID2)) + s.NoError(err) + s.Eventually(func() bool { + globalTask, err2 := globalTaskManager.GetGlobalTaskByKey(taskKey) + s.NoError(err2) + subtasks, err2 := globalTaskManager.GetSubtasksByStep(globalTask.ID, importinto.StepPostProcess) + s.NoError(err2) + s.Len(subtasks, 2) // framework will generate a subtask when canceling + var cancelled bool + for _, st := range subtasks { + if st.State == proto.TaskStateCanceled { + cancelled = true + break + } + } + return globalTask.State == proto.TaskStateReverted && cancelled + }, 5*time.Second, 1*time.Second) + + // todo: enable it when https://github.com/pingcap/tidb/issues/44443 fixed + //// cancel a pending job created by test_cancel_job2 using root + //s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/syncAfterJobStarted")) + //s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncBeforeJobStarted", "return(true)") + //result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test_cancel_job/t.csv?endpoint=%s' with detached`, + // gcsEndpoint)).Rows() + //s.Len(result2, 1) + //jobID2, err := strconv.Atoi(result2[0][0].(string)) + //s.NoError(err) + //// wait job reached to the point before job started + //<-loaddata.TestSyncChan + //s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + //s.tk.MustExec(fmt.Sprintf("cancel import job %d", jobID2)) + //// resume the job + //loaddata.TestSyncChan <- struct{}{} + //rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() + //s.Len(rows, 1) + //jobInfo = &importer.JobInfo{ + // ID: int64(jobID2), + // TableSchema: "test_cancel_job", + // TableName: "t2", + // TableID: tableID2, + // CreatedBy: "test_cancel_job2@localhost", + // Parameters: importer.ImportParameters{ + // FileLocation: fmt.Sprintf(`gs://test_cancel_job/t.csv?endpoint=%s`, gcsEndpoint), + // Format: importer.DataFormatCSV, + // }, + // SourceFileSize: 3, + // Status: "cancelled", + // Step: "", + // ErrorMessage: "cancelled by user", + //} + //s.compareJobInfoWithoutTime(jobInfo, rows[0]) + //s.Eventually(func() bool { + // task := getTask(int64(jobID2)) + // return task.State == proto.TaskStateReverted + //}, 10*time.Second, 500*time.Millisecond) +} + +func (s *mockGCSSuite) TestJobFailWhenDispatchSubtask() { + s.prepareAndUseDB("fail_job_after_import") + s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "fail_job_after_import", Name: "t.csv"}, + Content: []byte("1\n2"), + }) + do, err := session.GetDomain(s.store) + s.NoError(err) + tableID1 := do.MustGetTableID(s.T(), "fail_job_after_import", "t1") + + jobInfo := &importer.JobInfo{ + TableSchema: "fail_job_after_import", + TableName: "t1", + TableID: tableID1, + CreatedBy: "root@%", + Parameters: importer.ImportParameters{ + FileLocation: fmt.Sprintf(`gs://fail_job_after_import/t.csv?endpoint=%s`, gcsEndpoint), + Format: importer.DataFormatCSV, + }, + SourceFileSize: 3, + Status: "failed", + Step: importer.JobStepValidating, + ErrorMessage: "injected error after StepImport", + } + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/failWhenDispatchPostProcessSubtask", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/executor/importer/setLastImportJobID", `return(true)`) + s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) + err = s.tk.QueryToErr(fmt.Sprintf(`import into t1 FROM 'gs://fail_job_after_import/t.csv?endpoint=%s'`, gcsEndpoint)) + s.ErrorContains(err, "injected error after StepImport") + result1 := s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() + s.Len(result1, 1) + jobID1, err := strconv.Atoi(result1[0][0].(string)) + s.NoError(err) + jobInfo.ID = int64(jobID1) + s.compareJobInfoWithoutTime(jobInfo, result1[0]) +} + +func (s *mockGCSSuite) TestKillBeforeFinish() { + s.cleanupSysTables() + s.tk.MustExec("DROP DATABASE IF EXISTS kill_job;") + s.tk.MustExec("CREATE DATABASE kill_job;") + s.tk.MustExec(`CREATE TABLE kill_job.t (a INT, b INT, c int);`) + s.server.CreateObject(fakestorage.Object{ + ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-load", Name: "t-1.tsv"}, + Content: []byte("1,11,111"), + }) + + s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncBeforeSortChunk", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/executor/cancellableCtx", "return(true)") + s.enableFailpoint("github.com/pingcap/tidb/executor/importer/setLastImportJobID", `return(true)`) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + sql := fmt.Sprintf(`IMPORT INTO kill_job.t FROM 'gs://test-load/t-*.tsv?endpoint=%s'`, gcsEndpoint) + err := s.tk.QueryToErr(sql) + s.ErrorIs(errors.Cause(err), context.Canceled) + }() + // wait for the task reach sort chunk + <-importinto.TestSyncChan + // cancel the job + executor.TestCancelFunc() + // continue the execution + importinto.TestSyncChan <- struct{}{} + wg.Wait() + jobID := importer.TestLastImportJobID.Load() + rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID)).Rows() + s.Len(rows, 1) + s.Equal("cancelled", rows[0][5]) + globalTaskManager, err := storage.GetTaskManager() + s.NoError(err) + taskKey := importinto.TaskKey(jobID) + s.NoError(err) + s.Eventually(func() bool { + globalTask, err2 := globalTaskManager.GetGlobalTaskByKey(taskKey) + s.NoError(err2) + return globalTask.State == proto.TaskStateReverted + }, 5*time.Second, 1*time.Second) +} From 7a2bdadc51650c595006a00065fba9c0a4ce487d Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Thu, 29 Jun 2023 17:19:35 +0800 Subject: [PATCH 2/4] Revert "This is an automated cherry-pick of #44803" This reverts commit d19162eb15b18f49763b60f56e4b3932577a0f6d. --- br/pkg/checksum/executor.go | 56 +- br/pkg/lightning/backend/local/BUILD.bazel | 6 - br/pkg/lightning/common/BUILD.bazel | 5 - br/pkg/lightning/common/common.go | 109 --- br/pkg/lightning/common/util.go | 163 ----- br/pkg/lightning/config/config.go | 2 - br/pkg/lightning/importer/checksum_helper.go | 89 --- br/pkg/lightning/restore/checksum.go | 51 +- br/pkg/lightning/restore/checksum_test.go | 2 - .../lightning/restore/table_restore_test.go | 3 +- br/pkg/lightning/restore/tidb_test.go | 2 - br/tests/lightning_add_index/config1.toml | 6 - disttask/framework/dispatcher/dispatcher.go | 516 -------------- disttask/framework/storage/task_table.go | 496 -------------- disttask/importinto/BUILD.bazel | 80 --- disttask/importinto/dispatcher.go | 647 ------------------ disttask/importinto/job.go | 279 -------- disttask/importinto/subtask_executor.go | 240 ------- disttask/importinto/subtask_executor_test.go | 73 -- executor/import_into.go | 302 -------- executor/importer/BUILD.bazel | 104 --- executor/importer/table_import.go | 565 --------------- tests/realtikvtest/importintotest/job_test.go | 635 ----------------- 23 files changed, 7 insertions(+), 4424 deletions(-) delete mode 100644 br/pkg/lightning/common/common.go delete mode 100644 br/pkg/lightning/importer/checksum_helper.go delete mode 100644 br/tests/lightning_add_index/config1.toml delete mode 100644 disttask/framework/dispatcher/dispatcher.go delete mode 100644 disttask/framework/storage/task_table.go delete mode 100644 disttask/importinto/BUILD.bazel delete mode 100644 disttask/importinto/dispatcher.go delete mode 100644 disttask/importinto/job.go delete mode 100644 disttask/importinto/subtask_executor.go delete mode 100644 disttask/importinto/subtask_executor_test.go delete mode 100644 executor/import_into.go delete mode 100644 executor/importer/BUILD.bazel delete mode 100644 executor/importer/table_import.go delete mode 100644 tests/realtikvtest/importintotest/job_test.go diff --git a/br/pkg/checksum/executor.go b/br/pkg/checksum/executor.go index eda610b8f8e41..c30ae49fccdca 100644 --- a/br/pkg/checksum/executor.go +++ b/br/pkg/checksum/executor.go @@ -26,15 +26,7 @@ type ExecutorBuilder struct { oldTable *metautil.Table -<<<<<<< HEAD concurrency uint -======= - concurrency uint - backoffWeight int - - oldKeyspace []byte - newKeyspace []byte ->>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) } // NewExecutorBuilder returns a new executor builder. @@ -59,32 +51,13 @@ func (builder *ExecutorBuilder) SetConcurrency(conc uint) *ExecutorBuilder { return builder } -<<<<<<< HEAD -======= -// SetBackoffWeight set the backoffWeight of the checksum executing. -func (builder *ExecutorBuilder) SetBackoffWeight(backoffWeight int) *ExecutorBuilder { - builder.backoffWeight = backoffWeight - return builder -} - -func (builder *ExecutorBuilder) SetOldKeyspace(keyspace []byte) *ExecutorBuilder { - builder.oldKeyspace = keyspace - return builder -} - -func (builder *ExecutorBuilder) SetNewKeyspace(keyspace []byte) *ExecutorBuilder { - builder.newKeyspace = keyspace - return builder -} - ->>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) // Build builds a checksum executor. func (builder *ExecutorBuilder) Build() (*Executor, error) { reqs, err := buildChecksumRequest(builder.table, builder.oldTable, builder.ts, builder.concurrency) if err != nil { return nil, errors.Trace(err) } - return &Executor{reqs: reqs, backoffWeight: builder.backoffWeight}, nil + return &Executor{reqs: reqs}, nil } func buildChecksumRequest( @@ -289,8 +262,7 @@ func updateChecksumResponse(resp, update *tipb.ChecksumResponse) { // Executor is a checksum executor. type Executor struct { - reqs []*kv.Request - backoffWeight int + reqs []*kv.Request } // Len returns the total number of checksum requests. @@ -336,31 +308,7 @@ func (exec *Executor) Execute( // // It is useful in TiDB, however, it's a place holder in BR. killed := uint32(0) -<<<<<<< HEAD resp, err := sendChecksumRequest(ctx, client, req, kv.NewVariables(&killed)) -======= - var ( - resp *tipb.ChecksumResponse - err error - ) - err = utils.WithRetry(ctx, func() error { - vars := kv.NewVariables(&killed) - if exec.backoffWeight > 0 { - vars.BackOffWeight = exec.backoffWeight - } - resp, err = sendChecksumRequest(ctx, client, req, vars) - failpoint.Inject("checksumRetryErr", func(val failpoint.Value) { - // first time reach here. return error - if val.(bool) { - err = errors.New("inject checksum error") - } - }) - if err != nil { - return errors.Trace(err) - } - return nil - }, &checksumBackoffer) ->>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/backend/local/BUILD.bazel b/br/pkg/lightning/backend/local/BUILD.bazel index 1273b807b0f0d..9524ab5febc2b 100644 --- a/br/pkg/lightning/backend/local/BUILD.bazel +++ b/br/pkg/lightning/backend/local/BUILD.bazel @@ -40,11 +40,6 @@ go_library( "//kv", "//parser/model", "//parser/mysql", -<<<<<<< HEAD -======= - "//sessionctx/variable", - "//store/pdtypes", ->>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "//table", "//tablecodec", "//types", @@ -67,7 +62,6 @@ go_library( "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_pingcap_kvproto//pkg/pdpb", "@com_github_tikv_client_go_v2//error", - "@com_github_tikv_client_go_v2//kv", "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_client_go_v2//tikv", "@com_github_tikv_pd_client//:client", diff --git a/br/pkg/lightning/common/BUILD.bazel b/br/pkg/lightning/common/BUILD.bazel index 05d4729c89193..2b36e457cd857 100644 --- a/br/pkg/lightning/common/BUILD.bazel +++ b/br/pkg/lightning/common/BUILD.bazel @@ -23,11 +23,6 @@ go_library( "//br/pkg/utils", "//errno", "//parser/model", -<<<<<<< HEAD -======= - "//parser/mysql", - "//sessionctx/variable", ->>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "//store/driver/error", "//table/tables", "//util", diff --git a/br/pkg/lightning/common/common.go b/br/pkg/lightning/common/common.go deleted file mode 100644 index aaf8860e4fb58..0000000000000 --- a/br/pkg/lightning/common/common.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package common - -import ( - "context" - - "github.com/pingcap/errors" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/meta/autoid" - "github.com/pingcap/tidb/parser/model" -) - -const ( - // IndexEngineID is the engine ID for index engine. - IndexEngineID = -1 -) - -// DefaultImportantVariables is used in ObtainImportantVariables to retrieve the system -// variables from downstream which may affect KV encode result. The values record the default -// values if missing. -var DefaultImportantVariables = map[string]string{ - "max_allowed_packet": "67108864", - "div_precision_increment": "4", - "time_zone": "SYSTEM", - "lc_time_names": "en_US", - "default_week_format": "0", - "block_encryption_mode": "aes-128-ecb", - "group_concat_max_len": "1024", - "tidb_backoff_weight": "6", -} - -// DefaultImportVariablesTiDB is used in ObtainImportantVariables to retrieve the system -// variables from downstream in local/importer backend. The values record the default -// values if missing. -var DefaultImportVariablesTiDB = map[string]string{ - "tidb_row_format_version": "1", -} - -// AllocGlobalAutoID allocs N consecutive autoIDs from TiDB. -func AllocGlobalAutoID(ctx context.Context, n int64, store kv.Storage, dbID int64, - tblInfo *model.TableInfo) (autoIDBase, autoIDMax int64, err error) { - alloc, err := getGlobalAutoIDAlloc(store, dbID, tblInfo) - if err != nil { - return 0, 0, err - } - return alloc.Alloc(ctx, uint64(n), 1, 1) -} - -// RebaseGlobalAutoID rebase the autoID base to newBase. -func RebaseGlobalAutoID(ctx context.Context, newBase int64, store kv.Storage, dbID int64, - tblInfo *model.TableInfo) error { - alloc, err := getGlobalAutoIDAlloc(store, dbID, tblInfo) - if err != nil { - return err - } - return alloc.Rebase(ctx, newBase, false) -} - -func getGlobalAutoIDAlloc(store kv.Storage, dbID int64, tblInfo *model.TableInfo) (autoid.Allocator, error) { - if store == nil { - return nil, errors.New("internal error: kv store should not be nil") - } - if dbID == 0 { - return nil, errors.New("internal error: dbID should not be 0") - } - - // We don't need autoid cache here because we allocate all IDs at once. - // The argument for CustomAutoIncCacheOption is the cache step. Step 1 means no cache, - // but step 1 will enable an experimental feature, so we use step 2 here. - // - // See https://github.com/pingcap/tidb/issues/38442 for more details. - noCache := autoid.CustomAutoIncCacheOption(2) - tblVer := autoid.AllocOptionTableInfoVersion(tblInfo.Version) - - hasRowID := TableHasAutoRowID(tblInfo) - hasAutoIncID := tblInfo.GetAutoIncrementColInfo() != nil - hasAutoRandID := tblInfo.ContainsAutoRandomBits() - - // Current TiDB has some limitations for auto ID. - // 1. Auto increment ID and auto row ID are using the same RowID allocator. - // See https://github.com/pingcap/tidb/issues/982. - // 2. Auto random column must be a clustered primary key. That is to say, - // there is no implicit row ID for tables with auto random column. - // 3. There is at most one auto column in a table. - // Therefore, we assume there is only one auto column in a table and use RowID allocator if possible. - switch { - case hasRowID || hasAutoIncID: - return autoid.NewAllocator(store, dbID, tblInfo.ID, tblInfo.IsAutoIncColUnsigned(), - autoid.RowIDAllocType, noCache, tblVer), nil - case hasAutoRandID: - return autoid.NewAllocator(store, dbID, tblInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), - autoid.AutoRandomType, noCache, tblVer), nil - default: - return nil, errors.Errorf("internal error: table %s has no auto ID", tblInfo.Name) - } -} diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index c068c3390dd35..621c59d820e23 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -36,11 +36,6 @@ import ( "github.com/pingcap/tidb/br/pkg/utils" tmysql "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/model" -<<<<<<< HEAD -======= - tmysql "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/tidb/sessionctx/variable" ->>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) "github.com/pingcap/tidb/table/tables" "go.uber.org/zap" ) @@ -433,161 +428,3 @@ func GetAutoRandomColumn(tblInfo *model.TableInfo) *model.ColumnInfo { } return nil } -<<<<<<< HEAD -======= - -// GetDropIndexInfos returns the index infos that need to be dropped and the remain indexes. -func GetDropIndexInfos( - tblInfo *model.TableInfo, -) (remainIndexes []*model.IndexInfo, dropIndexes []*model.IndexInfo) { - cols := tblInfo.Columns -loop: - for _, idxInfo := range tblInfo.Indices { - if idxInfo.State != model.StatePublic { - remainIndexes = append(remainIndexes, idxInfo) - continue - } - // Primary key is a cluster index. - if idxInfo.Primary && tblInfo.HasClusteredIndex() { - remainIndexes = append(remainIndexes, idxInfo) - continue - } - // Skip index that contains auto-increment column. - // Because auto colum must be defined as a key. - for _, idxCol := range idxInfo.Columns { - flag := cols[idxCol.Offset].GetFlag() - if tmysql.HasAutoIncrementFlag(flag) { - remainIndexes = append(remainIndexes, idxInfo) - continue loop - } - } - dropIndexes = append(dropIndexes, idxInfo) - } - return remainIndexes, dropIndexes -} - -// BuildDropIndexSQL builds the SQL statement to drop index. -func BuildDropIndexSQL(tableName string, idxInfo *model.IndexInfo) string { - if idxInfo.Primary { - return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", tableName) - } - return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", tableName, EscapeIdentifier(idxInfo.Name.O)) -} - -// BuildAddIndexSQL builds the SQL statement to create missing indexes. -// It returns both a single SQL statement that creates all indexes at once, -// and a list of SQL statements that creates each index individually. -func BuildAddIndexSQL( - tableName string, - curTblInfo, - desiredTblInfo *model.TableInfo, -) (singleSQL string, multiSQLs []string) { - addIndexSpecs := make([]string, 0, len(desiredTblInfo.Indices)) -loop: - for _, desiredIdxInfo := range desiredTblInfo.Indices { - for _, curIdxInfo := range curTblInfo.Indices { - if curIdxInfo.Name.L == desiredIdxInfo.Name.L { - continue loop - } - } - - var buf bytes.Buffer - if desiredIdxInfo.Primary { - buf.WriteString("ADD PRIMARY KEY ") - } else if desiredIdxInfo.Unique { - buf.WriteString("ADD UNIQUE KEY ") - } else { - buf.WriteString("ADD KEY ") - } - // "primary" is a special name for primary key, we should not use it as index name. - if desiredIdxInfo.Name.L != "primary" { - buf.WriteString(EscapeIdentifier(desiredIdxInfo.Name.O)) - } - - colStrs := make([]string, 0, len(desiredIdxInfo.Columns)) - for _, col := range desiredIdxInfo.Columns { - var colStr string - if desiredTblInfo.Columns[col.Offset].Hidden { - colStr = fmt.Sprintf("(%s)", desiredTblInfo.Columns[col.Offset].GeneratedExprString) - } else { - colStr = EscapeIdentifier(col.Name.O) - if col.Length != types.UnspecifiedLength { - colStr = fmt.Sprintf("%s(%s)", colStr, strconv.Itoa(col.Length)) - } - } - colStrs = append(colStrs, colStr) - } - fmt.Fprintf(&buf, "(%s)", strings.Join(colStrs, ",")) - - if desiredIdxInfo.Invisible { - fmt.Fprint(&buf, " INVISIBLE") - } - if desiredIdxInfo.Comment != "" { - fmt.Fprintf(&buf, ` COMMENT '%s'`, format.OutputFormat(desiredIdxInfo.Comment)) - } - addIndexSpecs = append(addIndexSpecs, buf.String()) - } - if len(addIndexSpecs) == 0 { - return "", nil - } - - singleSQL = fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(addIndexSpecs, ", ")) - for _, spec := range addIndexSpecs { - multiSQLs = append(multiSQLs, fmt.Sprintf("ALTER TABLE %s %s", tableName, spec)) - } - return singleSQL, multiSQLs -} - -// IsDupKeyError checks if err is a duplicate index error. -func IsDupKeyError(err error) bool { - if merr, ok := errors.Cause(err).(*mysql.MySQLError); ok { - switch merr.Number { - case errno.ErrDupKeyName, errno.ErrMultiplePriKey, errno.ErrDupUnique: - return true - } - } - return false -} - -// GetBackoffWeightFromDB gets the backoff weight from database. -func GetBackoffWeightFromDB(ctx context.Context, db *sql.DB) (int, error) { - val, err := getSessionVariable(ctx, db, variable.TiDBBackOffWeight) - if err != nil { - return 0, err - } - return strconv.Atoi(val) -} - -// copy from dbutil to avoid import cycle -func getSessionVariable(ctx context.Context, db *sql.DB, variable string) (value string, err error) { - query := fmt.Sprintf("SHOW VARIABLES LIKE '%s'", variable) - rows, err := db.QueryContext(ctx, query) - - if err != nil { - return "", errors.Trace(err) - } - defer rows.Close() - - // Show an example. - /* - mysql> SHOW VARIABLES LIKE "binlog_format"; - +---------------+-------+ - | Variable_name | Value | - +---------------+-------+ - | binlog_format | ROW | - +---------------+-------+ - */ - - for rows.Next() { - if err = rows.Scan(&variable, &value); err != nil { - return "", errors.Trace(err) - } - } - - if err := rows.Err(); err != nil { - return "", errors.Trace(err) - } - - return value, nil -} ->>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)) diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index e372ad3e1bb18..e1031c760f749 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -440,7 +440,6 @@ type PostRestore struct { Level1Compact bool `toml:"level-1-compact" json:"level-1-compact"` PostProcessAtLast bool `toml:"post-process-at-last" json:"post-process-at-last"` Compact bool `toml:"compact" json:"compact"` - ChecksumViaSQL bool `toml:"checksum-via-sql" json:"checksum-via-sql"` } type CSVConfig struct { @@ -746,7 +745,6 @@ func NewConfig() *Config { Checksum: OpLevelRequired, Analyze: OpLevelOptional, PostProcessAtLast: true, - ChecksumViaSQL: true, }, } } diff --git a/br/pkg/lightning/importer/checksum_helper.go b/br/pkg/lightning/importer/checksum_helper.go deleted file mode 100644 index 88bc40d5a72e1..0000000000000 --- a/br/pkg/lightning/importer/checksum_helper.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "context" - - "github.com/pingcap/errors" - "github.com/pingcap/tidb/br/pkg/lightning/backend/local" - "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/br/pkg/lightning/common" - "github.com/pingcap/tidb/br/pkg/lightning/config" - "github.com/pingcap/tidb/br/pkg/lightning/log" - "github.com/pingcap/tidb/br/pkg/lightning/metric" - "github.com/pingcap/tidb/br/pkg/pdutil" - "github.com/pingcap/tidb/kv" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" -) - -// NewChecksumManager creates a new checksum manager. -func NewChecksumManager(ctx context.Context, rc *Controller, store kv.Storage) (local.ChecksumManager, error) { - // if we don't need checksum, just return nil - if rc.cfg.TikvImporter.Backend == config.BackendTiDB || rc.cfg.PostRestore.Checksum == config.OpLevelOff { - return nil, nil - } - - pdAddr := rc.cfg.TiDB.PdAddr - pdVersion, err := pdutil.FetchPDVersion(ctx, rc.tls, pdAddr) - if err != nil { - return nil, errors.Trace(err) - } - - // for v4.0.0 or upper, we can use the gc ttl api - var manager local.ChecksumManager - if pdVersion.Major >= 4 && !rc.cfg.PostRestore.ChecksumViaSQL { - tlsOpt := rc.tls.ToPDSecurityOption() - pdCli, err := pd.NewClientWithContext(ctx, []string{pdAddr}, tlsOpt) - if err != nil { - return nil, errors.Trace(err) - } - - backoffWeight, err := common.GetBackoffWeightFromDB(ctx, rc.db) - // only set backoff weight when it's smaller than default value - if err == nil && backoffWeight >= local.DefaultBackoffWeight { - log.FromContext(ctx).Info("get tidb_backoff_weight", zap.Int("backoff_weight", backoffWeight)) - } else { - log.FromContext(ctx).Info("set tidb_backoff_weight to default", zap.Int("backoff_weight", local.DefaultBackoffWeight)) - backoffWeight = local.DefaultBackoffWeight - } - manager = local.NewTiKVChecksumManager(store.GetClient(), pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency), backoffWeight) - } else { - manager = local.NewTiDBChecksumExecutor(rc.db) - } - - return manager, nil -} - -// DoChecksum do checksum for tables. -// table should be in .
, format. e.g. foo.bar -func DoChecksum(ctx context.Context, table *checkpoints.TidbTableInfo) (*local.RemoteChecksum, error) { - var err error - manager, ok := ctx.Value(&checksumManagerKey).(local.ChecksumManager) - if !ok { - return nil, errors.New("No gcLifeTimeManager found in context, check context initialization") - } - - task := log.FromContext(ctx).With(zap.String("table", table.Name)).Begin(zap.InfoLevel, "remote checksum") - - cs, err := manager.Checksum(ctx, table) - dur := task.End(zap.ErrorLevel, err) - if m, ok := metric.FromContext(ctx); ok { - m.ChecksumSecondsHistogram.Observe(dur.Seconds()) - } - - return cs, err -} diff --git a/br/pkg/lightning/restore/checksum.go b/br/pkg/lightning/restore/checksum.go index b981d6759fdd3..b30fe14e01fc1 100644 --- a/br/pkg/lightning/restore/checksum.go +++ b/br/pkg/lightning/restore/checksum.go @@ -33,10 +33,8 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/metric" "github.com/pingcap/tidb/br/pkg/pdutil" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/mathutil" "github.com/pingcap/tipb/go-tipb" - tikvstore "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" "go.uber.org/atomic" @@ -52,14 +50,7 @@ const ( var ( serviceSafePointTTL int64 = 10 * 60 // 10 min in seconds - // MinDistSQLScanConcurrency is the minimum value of tidb_distsql_scan_concurrency. - MinDistSQLScanConcurrency = 4 - - // DefaultBackoffWeight is the default value of tidb_backoff_weight for checksum. - // when TiKV client encounters an error of "region not leader", it will keep retrying every 500 ms. - // If it still fails after 2 * 20 = 40 seconds, it will return "region unavailable". - // If we increase the BackOffWeight to 6, then the TiKV client will keep retrying for 120 seconds. - DefaultBackoffWeight = 3 * tikvstore.DefBackOffWeight + minDistSQLScanConcurrency = 4 ) // RemoteChecksum represents a checksum result got from tidb. @@ -134,15 +125,6 @@ func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoi task := log.FromContext(ctx).With(zap.String("table", tableName)).Begin(zap.InfoLevel, "remote checksum") - conn, err := e.db.Conn(ctx) - if err != nil { - return nil, errors.Trace(err) - } - defer func() { - if err := conn.Close(); err != nil { - task.Warn("close connection failed", zap.Error(err)) - } - }() // ADMIN CHECKSUM TABLE
,
example. // mysql> admin checksum table test.t; // +---------+------------+---------------------+-----------+-------------+ @@ -150,23 +132,9 @@ func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoi // +---------+------------+---------------------+-----------+-------------+ // | test | t | 8520875019404689597 | 7296873 | 357601387 | // +---------+------------+---------------------+-----------+-------------+ - backoffWeight, err := common.GetBackoffWeightFromDB(ctx, e.db) - if err == nil && backoffWeight < DefaultBackoffWeight { - task.Info("increase tidb_backoff_weight", zap.Int("original", backoffWeight), zap.Int("new", DefaultBackoffWeight)) - // increase backoff weight - if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, DefaultBackoffWeight)); err != nil { - task.Warn("set tidb_backoff_weight failed", zap.Error(err)) - } else { - defer func() { - if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, backoffWeight)); err != nil { - task.Warn("recover tidb_backoff_weight failed", zap.Error(err)) - } - }() - } - } cs := RemoteChecksum{} - err = common.SQLWithRetry{DB: conn, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", + err = common.SQLWithRetry{DB: e.db, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", "ADMIN CHECKSUM TABLE "+tableName, &cs.Schema, &cs.Table, &cs.Checksum, &cs.TotalKVs, &cs.TotalBytes, ) dur := task.End(zap.ErrorLevel, err) @@ -289,31 +257,20 @@ type tikvChecksumManager struct { client kv.Client manager gcTTLManager distSQLScanConcurrency uint - backoffWeight int } -<<<<<<< HEAD:br/pkg/lightning/restore/checksum.go // newTiKVChecksumManager return a new tikv checksum manager func newTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint) *tikvChecksumManager { return &tikvChecksumManager{ -======= -var _ ChecksumManager = &TiKVChecksumManager{} - -// NewTiKVChecksumManager return a new tikv checksum manager -func NewTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint, backoffWeight int) *TiKVChecksumManager { - return &TiKVChecksumManager{ ->>>>>>> 89bf7432279 (importinto/lightning: do remote checksum via sql (#44803)):br/pkg/lightning/backend/local/checksum.go client: client, manager: newGCTTLManager(pdClient), distSQLScanConcurrency: distSQLScanConcurrency, - backoffWeight: backoffWeight, } } func (e *tikvChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpoints.TidbTableInfo, ts uint64) (*RemoteChecksum, error) { executor, err := checksum.NewExecutorBuilder(tableInfo.Core, ts). SetConcurrency(e.distSQLScanConcurrency). - SetBackoffWeight(e.backoffWeight). Build() if err != nil { return nil, errors.Trace(err) @@ -345,8 +302,8 @@ func (e *tikvChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpo if !common.IsRetryableError(err) { break } - if distSQLScanConcurrency > MinDistSQLScanConcurrency { - distSQLScanConcurrency = mathutil.Max(distSQLScanConcurrency/2, MinDistSQLScanConcurrency) + if distSQLScanConcurrency > minDistSQLScanConcurrency { + distSQLScanConcurrency = mathutil.Max(distSQLScanConcurrency/2, minDistSQLScanConcurrency) } } diff --git a/br/pkg/lightning/restore/checksum_test.go b/br/pkg/lightning/restore/checksum_test.go index ba920ee58ed84..20acc23fe6be0 100644 --- a/br/pkg/lightning/restore/checksum_test.go +++ b/br/pkg/lightning/restore/checksum_test.go @@ -56,7 +56,6 @@ func TestDoChecksum(t *testing.T) { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() - mock.ExpectClose() ctx := MockDoChecksumCtx(db) checksum, err := DoChecksum(ctx, &TidbTableInfo{DB: "test", Name: "t"}) @@ -217,7 +216,6 @@ func TestDoChecksumWithErrorAndLongOriginalLifetime(t *testing.T) { WithArgs("300h"). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectClose() - mock.ExpectClose() ctx := MockDoChecksumCtx(db) _, err = DoChecksum(ctx, &TidbTableInfo{DB: "test", Name: "t"}) diff --git a/br/pkg/lightning/restore/table_restore_test.go b/br/pkg/lightning/restore/table_restore_test.go index ad09add849a51..17fb97e346e36 100644 --- a/br/pkg/lightning/restore/table_restore_test.go +++ b/br/pkg/lightning/restore/table_restore_test.go @@ -753,7 +753,6 @@ func (s *tableRestoreSuite) TestCompareChecksumSuccess() { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() - mock.ExpectClose() ctx := MockDoChecksumCtx(db) remoteChecksum, err := DoChecksum(ctx, s.tr.tableInfo) @@ -784,7 +783,7 @@ func (s *tableRestoreSuite) TestCompareChecksumFailure() { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() - mock.ExpectClose() + ctx := MockDoChecksumCtx(db) remoteChecksum, err := DoChecksum(ctx, s.tr.tableInfo) require.NoError(s.T(), err) diff --git a/br/pkg/lightning/restore/tidb_test.go b/br/pkg/lightning/restore/tidb_test.go index b3ece883864f6..9b204b2da22b1 100644 --- a/br/pkg/lightning/restore/tidb_test.go +++ b/br/pkg/lightning/restore/tidb_test.go @@ -460,7 +460,6 @@ func TestObtainRowFormatVersionSucceed(t *testing.T) { sysVars := ObtainImportantVariables(ctx, s.tiGlue.GetSQLExecutor(), true) require.Equal(t, map[string]string{ - "tidb_backoff_weight": "6", "tidb_row_format_version": "2", "max_allowed_packet": "1073741824", "div_precision_increment": "10", @@ -488,7 +487,6 @@ func TestObtainRowFormatVersionFailure(t *testing.T) { sysVars := ObtainImportantVariables(ctx, s.tiGlue.GetSQLExecutor(), true) require.Equal(t, map[string]string{ - "tidb_backoff_weight": "6", "tidb_row_format_version": "1", "max_allowed_packet": "67108864", "div_precision_increment": "4", diff --git a/br/tests/lightning_add_index/config1.toml b/br/tests/lightning_add_index/config1.toml deleted file mode 100644 index 36b03d49a1117..0000000000000 --- a/br/tests/lightning_add_index/config1.toml +++ /dev/null @@ -1,6 +0,0 @@ -[tikv-importer] -backend = 'local' -add-index-by-sql = false - -[post-restore] -checksum-via-sql = false \ No newline at end of file diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go deleted file mode 100644 index 248b797c9b913..0000000000000 --- a/disttask/framework/dispatcher/dispatcher.go +++ /dev/null @@ -1,516 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dispatcher - -import ( - "context" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/tidb/disttask/framework/proto" - "github.com/pingcap/tidb/disttask/framework/storage" - "github.com/pingcap/tidb/domain/infosync" - "github.com/pingcap/tidb/resourcemanager/pool/spool" - "github.com/pingcap/tidb/resourcemanager/util" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/variable" - tidbutil "github.com/pingcap/tidb/util" - disttaskutil "github.com/pingcap/tidb/util/disttask" - "github.com/pingcap/tidb/util/logutil" - "github.com/pingcap/tidb/util/syncutil" - "go.uber.org/zap" -) - -const ( - // DefaultSubtaskConcurrency is the default concurrency for handling subtask. - DefaultSubtaskConcurrency = 16 - // MaxSubtaskConcurrency is the maximum concurrency for handling subtask. - MaxSubtaskConcurrency = 256 -) - -var ( - // DefaultDispatchConcurrency is the default concurrency for handling global task. - DefaultDispatchConcurrency = 4 - checkTaskFinishedInterval = 500 * time.Millisecond - checkTaskRunningInterval = 300 * time.Millisecond - nonRetrySQLTime = 1 - retrySQLTimes = variable.DefTiDBDDLErrorCountLimit - retrySQLInterval = 500 * time.Millisecond -) - -// Dispatch defines the interface for operations inside a dispatcher. -type Dispatch interface { - // Start enables dispatching and monitoring mechanisms. - Start() - // GetAllSchedulerIDs gets handles the task's all available instances. - GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) - // Stop stops the dispatcher. - Stop() -} - -// TaskHandle provides the interface for operations needed by task flow handles. -type TaskHandle interface { - // GetAllSchedulerIDs gets handles the task's all scheduler instances. - GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) - // GetPreviousSubtaskMetas gets previous subtask metas. - GetPreviousSubtaskMetas(gTaskID int64, step int64) ([][]byte, error) - storage.SessionExecutor -} - -func (d *dispatcher) getRunningGTaskCnt() int { - d.runningGTasks.RLock() - defer d.runningGTasks.RUnlock() - return len(d.runningGTasks.taskIDs) -} - -func (d *dispatcher) setRunningGTask(gTask *proto.Task) { - d.runningGTasks.Lock() - d.runningGTasks.taskIDs[gTask.ID] = struct{}{} - d.runningGTasks.Unlock() - d.detectPendingGTaskCh <- gTask -} - -func (d *dispatcher) isRunningGTask(globalTaskID int64) bool { - d.runningGTasks.Lock() - defer d.runningGTasks.Unlock() - _, ok := d.runningGTasks.taskIDs[globalTaskID] - return ok -} - -func (d *dispatcher) delRunningGTask(globalTaskID int64) { - d.runningGTasks.Lock() - defer d.runningGTasks.Unlock() - delete(d.runningGTasks.taskIDs, globalTaskID) -} - -type dispatcher struct { - ctx context.Context - cancel context.CancelFunc - taskMgr *storage.TaskManager - wg tidbutil.WaitGroupWrapper - gPool *spool.Pool - - runningGTasks struct { - syncutil.RWMutex - taskIDs map[int64]struct{} - } - detectPendingGTaskCh chan *proto.Task -} - -// NewDispatcher creates a dispatcher struct. -func NewDispatcher(ctx context.Context, taskTable *storage.TaskManager) (Dispatch, error) { - dispatcher := &dispatcher{ - taskMgr: taskTable, - detectPendingGTaskCh: make(chan *proto.Task, DefaultDispatchConcurrency), - } - pool, err := spool.NewPool("dispatch_pool", int32(DefaultDispatchConcurrency), util.DistTask, spool.WithBlocking(true)) - if err != nil { - return nil, err - } - dispatcher.gPool = pool - dispatcher.ctx, dispatcher.cancel = context.WithCancel(ctx) - dispatcher.runningGTasks.taskIDs = make(map[int64]struct{}) - - return dispatcher, nil -} - -// Start implements Dispatch.Start interface. -func (d *dispatcher) Start() { - d.wg.Run(d.DispatchTaskLoop) - d.wg.Run(d.DetectTaskLoop) -} - -// Stop implements Dispatch.Stop interface. -func (d *dispatcher) Stop() { - d.cancel() - d.gPool.ReleaseAndWait() - d.wg.Wait() -} - -// DispatchTaskLoop dispatches the global tasks. -func (d *dispatcher) DispatchTaskLoop() { - logutil.BgLogger().Info("dispatch task loop start") - ticker := time.NewTicker(checkTaskRunningInterval) - defer ticker.Stop() - for { - select { - case <-d.ctx.Done(): - logutil.BgLogger().Info("dispatch task loop exits", zap.Error(d.ctx.Err()), zap.Int64("interval", int64(checkTaskRunningInterval)/1000000)) - return - case <-ticker.C: - cnt := d.getRunningGTaskCnt() - if d.checkConcurrencyOverflow(cnt) { - break - } - - // TODO: Consider getting these tasks, in addition to the task being worked on.. - gTasks, err := d.taskMgr.GetGlobalTasksInStates(proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStateCancelling) - if err != nil { - logutil.BgLogger().Warn("get unfinished(pending, running, reverting or cancelling) tasks failed", zap.Error(err)) - break - } - - // There are currently no global tasks to work on. - if len(gTasks) == 0 { - break - } - for _, gTask := range gTasks { - // This global task is running, so no need to reprocess it. - if d.isRunningGTask(gTask.ID) { - continue - } - // the task is not in runningGTasks set when: - // owner changed or task is cancelled when status is pending. - if gTask.State == proto.TaskStateRunning || gTask.State == proto.TaskStateReverting || gTask.State == proto.TaskStateCancelling { - d.setRunningGTask(gTask) - cnt++ - continue - } - - if d.checkConcurrencyOverflow(cnt) { - break - } - - err = d.processNormalFlow(gTask) - logutil.BgLogger().Info("dispatch task loop", zap.Int64("task ID", gTask.ID), - zap.String("state", gTask.State), zap.Uint64("concurrency", gTask.Concurrency), zap.Error(err)) - if err != nil || gTask.IsFinished() { - continue - } - d.setRunningGTask(gTask) - cnt++ - } - } - } -} - -func (d *dispatcher) probeTask(gTask *proto.Task) (isFinished bool, subTaskErr [][]byte) { - // TODO: Consider putting the following operations into a transaction. - // TODO: Consider collect some information about the tasks. - if gTask.State != proto.TaskStateReverting { - // check if global task cancelling - cancelling, err := d.taskMgr.IsGlobalTaskCancelling(gTask.ID) - if err != nil { - logutil.BgLogger().Warn("check task cancelling failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) - return false, nil - } - - if cancelling { - return false, [][]byte{[]byte("cancel")} - } - // check subtasks failed. - cnt, err := d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateFailed) - if err != nil { - logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) - return false, nil - } - if cnt > 0 { - subTaskErr, err = d.taskMgr.CollectSubTaskError(gTask.ID) - if err != nil { - logutil.BgLogger().Warn("collect subtask error failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) - return false, nil - } - return false, subTaskErr - } - // check subtasks pending or running. - cnt, err = d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStatePending, proto.TaskStateRunning) - if err != nil { - logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) - return false, nil - } - if cnt > 0 { - return false, nil - } - return true, nil - } - - // if gTask.State == TaskStateReverting, if will not convert to TaskStateCancelling again. - cnt, err := d.taskMgr.GetSubtaskInStatesCnt(gTask.ID, proto.TaskStateRevertPending, proto.TaskStateReverting) - if err != nil { - logutil.BgLogger().Warn("check task failed", zap.Int64("task ID", gTask.ID), zap.Error(err)) - return false, nil - } - if cnt > 0 { - return false, nil - } - return true, nil -} - -// DetectTaskLoop monitors the status of the subtasks and processes them. -func (d *dispatcher) DetectTaskLoop() { - logutil.BgLogger().Info("detect task loop start") - for { - select { - case <-d.ctx.Done(): - logutil.BgLogger().Info("detect task loop exits", zap.Error(d.ctx.Err())) - return - case task := <-d.detectPendingGTaskCh: - // Using the pool with block, so it wouldn't return an error. - _ = d.gPool.Run(func() { d.detectTask(task) }) - } - } -} - -func (d *dispatcher) detectTask(gTask *proto.Task) { - ticker := time.NewTicker(checkTaskFinishedInterval) - defer ticker.Stop() - - for { - select { - case <-d.ctx.Done(): - logutil.BgLogger().Info("detect task exits", zap.Int64("task ID", gTask.ID), zap.Error(d.ctx.Err())) - return - case <-ticker.C: - // TODO: Consider actively obtaining information about task completion. - stepIsFinished, errStr := d.probeTask(gTask) - // The global task isn't finished and not failed. - if !stepIsFinished && len(errStr) == 0 { - GetTaskFlowHandle(gTask.Type).OnTicker(d.ctx, gTask) - logutil.BgLogger().Debug("detect task, this task keeps current state", - zap.Int64("task-id", gTask.ID), zap.String("state", gTask.State)) - break - } - - err := d.processFlow(gTask, errStr) - if err == nil && gTask.IsFinished() { - logutil.BgLogger().Info("detect task, task is finished", - zap.Int64("task-id", gTask.ID), zap.String("state", gTask.State)) - d.delRunningGTask(gTask.ID) - return - } - if !d.isRunningGTask(gTask.ID) { - logutil.BgLogger().Info("detect task, this task can't run", - zap.Int64("task-id", gTask.ID), zap.String("state", gTask.State)) - } - } - } -} - -func (d *dispatcher) processFlow(gTask *proto.Task, errStr [][]byte) error { - if len(errStr) > 0 { - // Found an error when task is running. - logutil.BgLogger().Info("process flow, handle an error", zap.Int64("task-id", gTask.ID), zap.ByteStrings("err msg", errStr)) - return d.processErrFlow(gTask, errStr) - } - // previous step is finished. - if gTask.State == proto.TaskStateReverting { - // Finish the rollback step. - logutil.BgLogger().Info("process flow, update the task to reverted", zap.Int64("task-id", gTask.ID)) - return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes) - } - // Finish the normal step. - logutil.BgLogger().Info("process flow, process normal", zap.Int64("task-id", gTask.ID)) - return d.processNormalFlow(gTask) -} - -func (d *dispatcher) updateTask(gTask *proto.Task, gTaskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) { - prevState := gTask.State - gTask.State = gTaskState - for i := 0; i < retryTimes; i++ { - err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(gTask, newSubTasks, gTaskState == proto.TaskStateReverting) - if err == nil { - break - } - if i%10 == 0 { - logutil.BgLogger().Warn("updateTask first failed", zap.Int64("task-id", gTask.ID), - zap.String("previous state", prevState), zap.String("curr state", gTask.State), - zap.Int("retry times", retryTimes), zap.Error(err)) - } - time.Sleep(retrySQLInterval) - } - if err != nil && retryTimes != nonRetrySQLTime { - logutil.BgLogger().Warn("updateTask failed and delete running task info", zap.Int64("task-id", gTask.ID), - zap.String("previous state", prevState), zap.String("curr state", gTask.State), zap.Int("retry times", retryTimes), zap.Error(err)) - d.delRunningGTask(gTask.ID) - } - return err -} - -func (d *dispatcher) processErrFlow(gTask *proto.Task, receiveErr [][]byte) error { - // TODO: Maybe it gets GetTaskFlowHandle fails when rolling upgrades. - // 1. generate the needed global task meta and subTask meta (dist-plan). - meta, err := GetTaskFlowHandle(gTask.Type).ProcessErrFlow(d.ctx, d, gTask, receiveErr) - if err != nil { - logutil.BgLogger().Warn("handle error failed", zap.Error(err)) - return err - } - - // 2. dispatch revert dist-plan to EligibleInstances. - return d.dispatchSubTask4Revert(gTask, meta) -} - -func (d *dispatcher) dispatchSubTask4Revert(gTask *proto.Task, meta []byte) error { - instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, gTask.ID) - if err != nil { - logutil.BgLogger().Warn("get global task's all instances failed", zap.Error(err)) - return err - } - - if len(instanceIDs) == 0 { - return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes) - } - - subTasks := make([]*proto.Subtask, 0, len(instanceIDs)) - for _, id := range instanceIDs { - subTasks = append(subTasks, proto.NewSubtask(gTask.ID, gTask.Type, id, meta)) - } - return d.updateTask(gTask, proto.TaskStateReverting, subTasks, retrySQLTimes) -} - -func (d *dispatcher) processNormalFlow(gTask *proto.Task) error { - // 1. generate the needed global task meta and subTask meta (dist-plan). - handle := GetTaskFlowHandle(gTask.Type) - if handle == nil { - logutil.BgLogger().Warn("gen gTask flow handle failed, this type handle doesn't register", zap.Int64("ID", gTask.ID), zap.String("type", gTask.Type)) - return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes) - } - metas, err := handle.ProcessNormalFlow(d.ctx, d, gTask) - if err != nil { - logutil.BgLogger().Warn("gen dist-plan failed", zap.Error(err)) - if handle.IsRetryableErr(err) { - return err - } - gTask.Error = []byte(err.Error()) - return d.updateTask(gTask, proto.TaskStateReverted, nil, retrySQLTimes) - } - logutil.BgLogger().Info("process normal flow", zap.Int64("task ID", gTask.ID), - zap.String("state", gTask.State), zap.Uint64("concurrency", gTask.Concurrency), zap.Int("subtasks", len(metas))) - - // 2. dispatch dist-plan to EligibleInstances. - return d.dispatchSubTask(gTask, handle, metas) -} - -func (d *dispatcher) dispatchSubTask(gTask *proto.Task, handle TaskFlowHandle, metas [][]byte) error { - // Adjust the global task's concurrency. - if gTask.Concurrency == 0 { - gTask.Concurrency = DefaultSubtaskConcurrency - } - if gTask.Concurrency > MaxSubtaskConcurrency { - gTask.Concurrency = MaxSubtaskConcurrency - } - - retryTimes := retrySQLTimes - // Special handling for the new tasks. - if gTask.State == proto.TaskStatePending { - // TODO: Consider using TS. - nowTime := time.Now().UTC() - gTask.StartTime = nowTime - gTask.State = proto.TaskStateRunning - gTask.StateUpdateTime = nowTime - retryTimes = nonRetrySQLTime - } - - if len(metas) == 0 { - gTask.StateUpdateTime = time.Now().UTC() - // Write the global task meta into the storage. - err := d.updateTask(gTask, proto.TaskStateSucceed, nil, retryTimes) - if err != nil { - logutil.BgLogger().Warn("update global task failed", zap.Error(err)) - return err - } - return nil - } - // select all available TiDB nodes for this global tasks. - serverNodes, err1 := handle.GetEligibleInstances(d.ctx, gTask) - logutil.BgLogger().Debug("eligible instances", zap.Int("num", len(serverNodes))) - - if err1 != nil { - return err1 - } - if len(serverNodes) == 0 { - return errors.New("no available TiDB node") - } - subTasks := make([]*proto.Subtask, 0, len(metas)) - for i, meta := range metas { - // we assign the subtask to the instance in a round-robin way. - pos := i % len(serverNodes) - instanceID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port) - logutil.BgLogger().Debug("create subtasks", - zap.Int("gTask.ID", int(gTask.ID)), zap.String("type", gTask.Type), zap.String("instanceID", instanceID)) - subTasks = append(subTasks, proto.NewSubtask(gTask.ID, gTask.Type, instanceID, meta)) - } - - return d.updateTask(gTask, gTask.State, subTasks, retrySQLTimes) -} - -// GenerateSchedulerNodes generate a eligible TiDB nodes. -func GenerateSchedulerNodes(ctx context.Context) ([]*infosync.ServerInfo, error) { - serverInfos, err := infosync.GetAllServerInfo(ctx) - if err != nil { - return nil, err - } - if len(serverInfos) == 0 { - return nil, errors.New("not found instance") - } - - serverNodes := make([]*infosync.ServerInfo, 0, len(serverInfos)) - for _, serverInfo := range serverInfos { - serverNodes = append(serverNodes, serverInfo) - } - return serverNodes, nil -} - -// GetAllSchedulerIDs gets all the scheduler IDs. -func (d *dispatcher) GetAllSchedulerIDs(ctx context.Context, gTaskID int64) ([]string, error) { - serverInfos, err := infosync.GetAllServerInfo(ctx) - if err != nil { - return nil, err - } - if len(serverInfos) == 0 { - return nil, nil - } - - schedulerIDs, err := d.taskMgr.GetSchedulerIDsByTaskID(gTaskID) - if err != nil { - return nil, err - } - ids := make([]string, 0, len(schedulerIDs)) - for _, id := range schedulerIDs { - if ok := disttaskutil.MatchServerInfo(serverInfos, id); ok { - ids = append(ids, id) - } - } - return ids, nil -} - -func (d *dispatcher) GetPreviousSubtaskMetas(gTaskID int64, step int64) ([][]byte, error) { - previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(gTaskID, step) - if err != nil { - logutil.BgLogger().Warn("get previous succeed subtask failed", zap.Int64("ID", gTaskID), zap.Int64("step", step)) - return nil, err - } - previousSubtaskMetas := make([][]byte, 0, len(previousSubtasks)) - for _, subtask := range previousSubtasks { - previousSubtaskMetas = append(previousSubtaskMetas, subtask.Meta) - } - return previousSubtaskMetas, nil -} - -func (d *dispatcher) WithNewSession(fn func(se sessionctx.Context) error) error { - return d.taskMgr.WithNewSession(fn) -} - -func (d *dispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { - return d.taskMgr.WithNewTxn(ctx, fn) -} - -func (*dispatcher) checkConcurrencyOverflow(cnt int) bool { - if cnt >= DefaultDispatchConcurrency { - logutil.BgLogger().Info("dispatch task loop, running GTask cnt is more than concurrency", - zap.Int("running cnt", cnt), zap.Int("concurrency", DefaultDispatchConcurrency)) - return true - } - return false -} diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go deleted file mode 100644 index f394e99a7a540..0000000000000 --- a/disttask/framework/storage/task_table.go +++ /dev/null @@ -1,496 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "fmt" - "strconv" - "strings" - "sync/atomic" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/disttask/framework/proto" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/parser/terror" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/logutil" - "github.com/pingcap/tidb/util/sqlexec" - "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -// SessionExecutor defines the interface for executing SQLs in a session. -type SessionExecutor interface { - // WithNewSession executes the function with a new session. - WithNewSession(fn func(se sessionctx.Context) error) error - // WithNewTxn executes the fn in a new transaction. - WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error -} - -// TaskManager is the manager of global/sub task. -type TaskManager struct { - ctx context.Context - sePool *pools.ResourcePool -} - -var _ SessionExecutor = &TaskManager{} - -var taskManagerInstance atomic.Pointer[TaskManager] - -var ( - // TestLastTaskID is used for test to set the last task ID. - TestLastTaskID atomic.Int64 -) - -// NewTaskManager creates a new task manager. -func NewTaskManager(ctx context.Context, sePool *pools.ResourcePool) *TaskManager { - ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) - return &TaskManager{ - ctx: ctx, - sePool: sePool, - } -} - -// GetTaskManager gets the task manager. -func GetTaskManager() (*TaskManager, error) { - v := taskManagerInstance.Load() - if v == nil { - return nil, errors.New("global task manager is not initialized") - } - return v, nil -} - -// SetTaskManager sets the task manager. -func SetTaskManager(is *TaskManager) { - taskManagerInstance.Store(is) -} - -// ExecSQL executes the sql and returns the result. -// TODO: consider retry. -func ExecSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) { - rs, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...) - if err != nil { - return nil, err - } - if rs != nil { - defer terror.Call(rs.Close) - return sqlexec.DrainRecordSet(ctx, rs, 1024) - } - return nil, nil -} - -// row2GlobeTask converts a row to a global task. -func row2GlobeTask(r chunk.Row) *proto.Task { - task := &proto.Task{ - ID: r.GetInt64(0), - Key: r.GetString(1), - Type: r.GetString(2), - DispatcherID: r.GetString(3), - State: r.GetString(4), - Meta: r.GetBytes(7), - Concurrency: uint64(r.GetInt64(8)), - Step: r.GetInt64(9), - Error: r.GetBytes(10), - } - // TODO: convert to local time. - task.StartTime, _ = r.GetTime(5).GoTime(time.UTC) - task.StateUpdateTime, _ = r.GetTime(6).GoTime(time.UTC) - return task -} - -// WithNewSession executes the function with a new session. -func (stm *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) error { - se, err := stm.sePool.Get() - if err != nil { - return err - } - defer stm.sePool.Put(se) - return fn(se.(sessionctx.Context)) -} - -// WithNewTxn executes the fn in a new transaction. -func (stm *TaskManager) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { - ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) - return stm.WithNewSession(func(se sessionctx.Context) (err error) { - _, err = ExecSQL(ctx, se, "begin") - if err != nil { - return err - } - - success := false - defer func() { - sql := "rollback" - if success { - sql = "commit" - } - _, commitErr := ExecSQL(ctx, se, sql) - if err == nil && commitErr != nil { - err = commitErr - } - }() - - if err = fn(se); err != nil { - return err - } - - success = true - return nil - }) -} - -func (stm *TaskManager) executeSQLWithNewSession(ctx context.Context, sql string, args ...interface{}) (rs []chunk.Row, err error) { - err = stm.WithNewSession(func(se sessionctx.Context) error { - rs, err = ExecSQL(ctx, se, sql, args...) - return err - }) - - if err != nil { - return nil, err - } - - return -} - -// AddNewGlobalTask adds a new task to global task table. -func (stm *TaskManager) AddNewGlobalTask(key, tp string, concurrency int, meta []byte) (taskID int64, err error) { - err = stm.WithNewSession(func(se sessionctx.Context) error { - var err2 error - taskID, err2 = stm.AddGlobalTaskWithSession(se, key, tp, concurrency, meta) - return err2 - }) - return -} - -// AddGlobalTaskWithSession adds a new task to global task table with session. -func (stm *TaskManager) AddGlobalTaskWithSession(se sessionctx.Context, key, tp string, concurrency int, meta []byte) (taskID int64, err error) { - _, err = ExecSQL(stm.ctx, se, - `insert into mysql.tidb_global_task(task_key, type, state, concurrency, step, meta, state_update_time) - values (%?, %?, %?, %?, %?, %?, %?)`, - key, tp, proto.TaskStatePending, concurrency, proto.StepInit, meta, time.Now().UTC().String()) - if err != nil { - return 0, err - } - - rs, err := ExecSQL(stm.ctx, se, "select @@last_insert_id") - if err != nil { - return 0, err - } - - taskID = int64(rs[0].GetUint64(0)) - failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) }) - - return taskID, nil -} - -// GetNewGlobalTask get a new task from global task table, it's used by dispatcher only. -func (stm *TaskManager) GetNewGlobalTask() (task *proto.Task, err error) { - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending) - if err != nil { - return task, err - } - - if len(rs) == 0 { - return nil, nil - } - - return row2GlobeTask(rs[0]), nil -} - -// GetGlobalTasksInStates gets the tasks in the states. -func (stm *TaskManager) GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error) { - if len(states) == 0 { - return task, nil - } - - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...) - if err != nil { - return task, err - } - - for _, r := range rs { - task = append(task, row2GlobeTask(r)) - } - return task, nil -} - -// GetGlobalTaskByID gets the task by the global task ID. -func (stm *TaskManager) GetGlobalTaskByID(taskID int64) (task *proto.Task, err error) { - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where id = %?", taskID) - if err != nil { - return task, err - } - if len(rs) == 0 { - return nil, nil - } - - return row2GlobeTask(rs[0]), nil -} - -// GetGlobalTaskByKey gets the task by the task key -func (stm *TaskManager) GetGlobalTaskByKey(key string) (task *proto.Task, err error) { - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where task_key = %?", key) - if err != nil { - return task, err - } - if len(rs) == 0 { - return nil, nil - } - - return row2GlobeTask(rs[0]), nil -} - -// row2SubTask converts a row to a subtask. -func row2SubTask(r chunk.Row) *proto.Subtask { - task := &proto.Subtask{ - ID: r.GetInt64(0), - Step: r.GetInt64(1), - Type: proto.Int2Type(int(r.GetInt64(5))), - SchedulerID: r.GetString(6), - State: r.GetString(8), - Meta: r.GetBytes(12), - StartTime: r.GetUint64(10), - } - tid, err := strconv.Atoi(r.GetString(3)) - if err != nil { - logutil.BgLogger().Warn("unexpected task ID", zap.String("task ID", r.GetString(3))) - } - task.TaskID = int64(tid) - return task -} - -// AddNewSubTask adds a new task to subtask table. -func (stm *TaskManager) AddNewSubTask(globalTaskID int64, step int64, designatedTiDBID string, meta []byte, tp string, isRevert bool) error { - st := proto.TaskStatePending - if isRevert { - st = proto.TaskStateRevertPending - } - - _, err := stm.executeSQLWithNewSession(stm.ctx, "insert into mysql.tidb_background_subtask(task_key, step, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?, %?)", globalTaskID, step, designatedTiDBID, meta, st, proto.Type2Int(tp), []byte{}) - if err != nil { - return err - } - - return nil -} - -// GetSubtaskInStates gets the subtask in the states. -func (stm *TaskManager) GetSubtaskInStates(tidbID string, taskID int64, states ...interface{}) (*proto.Subtask, error) { - args := []interface{}{tidbID, taskID} - args = append(args, states...) - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select * from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...) - if err != nil { - return nil, err - } - if len(rs) == 0 { - return nil, nil - } - return row2SubTask(rs[0]), nil -} - -// PrintSubtaskInfo log the subtask info by taskKey. -func (stm *TaskManager) PrintSubtaskInfo(taskKey int) { - rs, _ := stm.executeSQLWithNewSession(stm.ctx, - "select * from mysql.tidb_background_subtask where task_key = %?", taskKey) - - for _, r := range rs { - logutil.BgLogger().Info(fmt.Sprintf("subTask: %v\n", row2SubTask(r))) - } -} - -// GetSucceedSubtasksByStep gets the subtask in the success state. -func (stm *TaskManager) GetSucceedSubtasksByStep(taskID int64, step int64) ([]*proto.Subtask, error) { - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select * from mysql.tidb_background_subtask where task_key = %? and state = %? and step = %?", taskID, proto.TaskStateSucceed, step) - if err != nil { - return nil, err - } - if len(rs) == 0 { - return nil, nil - } - subtasks := make([]*proto.Subtask, 0, len(rs)) - for _, r := range rs { - subtasks = append(subtasks, row2SubTask(r)) - } - return subtasks, nil -} - -// GetSubtaskInStatesCnt gets the subtask count in the states. -func (stm *TaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interface{}) (int64, error) { - args := []interface{}{taskID} - args = append(args, states...) - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select count(*) from mysql.tidb_background_subtask where task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...) - if err != nil { - return 0, err - } - - return rs[0].GetInt64(0), nil -} - -// CollectSubTaskError collects the subtask error. -func (stm *TaskManager) CollectSubTaskError(taskID int64) ([][]byte, error) { - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select error from mysql.tidb_background_subtask where task_key = %? AND state = %?", taskID, proto.TaskStateFailed) - if err != nil { - return nil, err - } - - subTaskErrors := make([][]byte, 0, len(rs)) - for _, err := range rs { - subTaskErrors = append(subTaskErrors, err.GetBytes(0)) - } - - return subTaskErrors, nil -} - -// HasSubtasksInStates checks if there are subtasks in the states. -func (stm *TaskManager) HasSubtasksInStates(tidbID string, taskID int64, states ...interface{}) (bool, error) { - args := []interface{}{tidbID, taskID} - args = append(args, states...) - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...) - if err != nil { - return false, err - } - - return len(rs) > 0, nil -} - -// UpdateSubtaskStateAndError updates the subtask state. -func (stm *TaskManager) UpdateSubtaskStateAndError(id int64, state string, subTaskErr string) error { - _, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set state = %?, error = %? where id = %?", state, subTaskErr, id) - return err -} - -// FinishSubtask updates the subtask meta and mark state to succeed. -func (stm *TaskManager) FinishSubtask(id int64, meta []byte) error { - _, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set meta = %?, state = %? where id = %?", meta, proto.TaskStateSucceed, id) - return err -} - -// UpdateSubtaskHeartbeat updates the heartbeat of the subtask. -func (stm *TaskManager) UpdateSubtaskHeartbeat(instanceID string, taskID int64, heartbeat time.Time) error { - _, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set exec_expired = %? where exec_id = %? and task_key = %?", heartbeat.String(), instanceID, taskID) - return err -} - -// DeleteSubtasksByTaskID deletes the subtask of the given global task ID. -func (stm *TaskManager) DeleteSubtasksByTaskID(taskID int64) error { - _, err := stm.executeSQLWithNewSession(stm.ctx, "delete from mysql.tidb_background_subtask where task_key = %?", taskID) - if err != nil { - return err - } - - return nil -} - -// GetSchedulerIDsByTaskID gets the scheduler IDs of the given global task ID. -func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) { - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select distinct(exec_id) from mysql.tidb_background_subtask where task_key = %?", taskID) - if err != nil { - return nil, err - } - if len(rs) == 0 { - return nil, nil - } - - instanceIDs := make([]string, 0, len(rs)) - for _, r := range rs { - id := r.GetString(0) - instanceIDs = append(instanceIDs, id) - } - - return instanceIDs, nil -} - -// UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks -func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, isSubtaskRevert bool) error { - return stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error { - _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %?", - gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, gTask.Error, gTask.ID) - if err != nil { - return err - } - - failpoint.Inject("MockUpdateTaskErr", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.New("updateTaskErr")) - } - }) - - subtaskState := proto.TaskStatePending - if isSubtaskRevert { - subtaskState = proto.TaskStateRevertPending - } - - for _, subtask := range subtasks { - // TODO: insert subtasks in batch - _, err = ExecSQL(stm.ctx, se, "insert into mysql.tidb_background_subtask(step, task_key, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?, %?)", - gTask.Step, gTask.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}) - if err != nil { - return err - } - } - - return nil - }) -} - -// CancelGlobalTask cancels global task -func (stm *TaskManager) CancelGlobalTask(taskID int64) error { - _, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_global_task set state=%? where id=%? and state in (%?, %?)", - proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning, - ) - return err -} - -// CancelGlobalTaskByKeySession cancels global task by key using input session -func (stm *TaskManager) CancelGlobalTaskByKeySession(se sessionctx.Context, taskKey string) error { - _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state=%? where task_key=%? and state in (%?, %?)", - proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning) - return err -} - -// IsGlobalTaskCancelling checks whether the task state is cancelling -func (stm *TaskManager) IsGlobalTaskCancelling(taskID int64) (bool, error) { - rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?", - taskID, proto.TaskStateCancelling, - ) - - if err != nil { - return false, err - } - - return len(rs) > 0, nil -} - -// GetSubtasksByStep gets subtasks of global task by step -func (stm *TaskManager) GetSubtasksByStep(taskID, step int64) ([]*proto.Subtask, error) { - rs, err := stm.executeSQLWithNewSession(stm.ctx, - "select * from mysql.tidb_background_subtask where task_key = %? and step = %?", - taskID, step) - if err != nil { - return nil, err - } - if len(rs) == 0 { - return nil, nil - } - subtasks := make([]*proto.Subtask, 0, len(rs)) - for _, r := range rs { - subtasks = append(subtasks, row2SubTask(r)) - } - return subtasks, nil -} diff --git a/disttask/importinto/BUILD.bazel b/disttask/importinto/BUILD.bazel deleted file mode 100644 index eabe0b7ecc10a..0000000000000 --- a/disttask/importinto/BUILD.bazel +++ /dev/null @@ -1,80 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") - -go_library( - name = "importinto", - srcs = [ - "dispatcher.go", - "job.go", - "proto.go", - "scheduler.go", - "subtask_executor.go", - "wrapper.go", - ], - importpath = "github.com/pingcap/tidb/disttask/importinto", - visibility = ["//visibility:public"], - deps = [ - "//br/pkg/lightning/backend", - "//br/pkg/lightning/backend/kv", - "//br/pkg/lightning/backend/local", - "//br/pkg/lightning/checkpoints", - "//br/pkg/lightning/common", - "//br/pkg/lightning/config", - "//br/pkg/lightning/mydump", - "//br/pkg/lightning/verification", - "//br/pkg/utils", - "//disttask/framework/dispatcher", - "//disttask/framework/handle", - "//disttask/framework/proto", - "//disttask/framework/scheduler", - "//disttask/framework/storage", - "//domain/infosync", - "//errno", - "//executor/asyncloaddata", - "//executor/importer", - "//kv", - "//parser/ast", - "//parser/mysql", - "//sessionctx", - "//sessionctx/variable", - "//table/tables", - "//util/dbterror/exeerrors", - "//util/etcd", - "//util/logutil", - "//util/mathutil", - "//util/sqlexec", - "@com_github_go_sql_driver_mysql//:mysql", - "@com_github_google_uuid//:uuid", - "@com_github_pingcap_errors//:errors", - "@com_github_pingcap_failpoint//:failpoint", - "@com_github_tikv_client_go_v2//util", - "@org_uber_go_atomic//:atomic", - "@org_uber_go_zap//:zap", - ], -) - -go_test( - name = "importinto_test", - timeout = "short", - srcs = [ - "dispatcher_test.go", - "subtask_executor_test.go", - ], - embed = [":importinto"], - flaky = True, - race = "on", - deps = [ - "//br/pkg/lightning/verification", - "//disttask/framework/proto", - "//disttask/framework/storage", - "//domain/infosync", - "//executor/importer", - "//parser/model", - "//testkit", - "//util/logutil", - "@com_github_ngaut_pools//:pools", - "@com_github_pingcap_failpoint//:failpoint", - "@com_github_stretchr_testify//require", - "@com_github_stretchr_testify//suite", - "@com_github_tikv_client_go_v2//util", - ], -) diff --git a/disttask/importinto/dispatcher.go b/disttask/importinto/dispatcher.go deleted file mode 100644 index 3f4d6822bc55d..0000000000000 --- a/disttask/importinto/dispatcher.go +++ /dev/null @@ -1,647 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importinto - -import ( - "context" - "encoding/json" - "strconv" - "strings" - "sync" - "time" - - dmysql "github.com/go-sql-driver/mysql" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/br/pkg/lightning/common" - "github.com/pingcap/tidb/br/pkg/lightning/config" - verify "github.com/pingcap/tidb/br/pkg/lightning/verification" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/disttask/framework/dispatcher" - "github.com/pingcap/tidb/disttask/framework/proto" - "github.com/pingcap/tidb/disttask/framework/storage" - "github.com/pingcap/tidb/domain/infosync" - "github.com/pingcap/tidb/errno" - "github.com/pingcap/tidb/executor/importer" - "github.com/pingcap/tidb/parser/ast" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/table/tables" - "github.com/pingcap/tidb/util/etcd" - "github.com/pingcap/tidb/util/logutil" - "github.com/pingcap/tidb/util/sqlexec" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -const ( - registerTaskTTL = 10 * time.Minute - refreshTaskTTLInterval = 3 * time.Minute - registerTimeout = 5 * time.Second -) - -// NewTaskRegisterWithTTL is the ctor for TaskRegister. -// It is exported for testing. -var NewTaskRegisterWithTTL = utils.NewTaskRegisterWithTTL - -type taskInfo struct { - taskID int64 - - // operation on taskInfo is run inside detect-task goroutine, so no need to synchronize. - lastRegisterTime time.Time - - // initialized lazily in register() - etcdClient *etcd.Client - taskRegister utils.TaskRegister -} - -func (t *taskInfo) register(ctx context.Context) { - if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { - return - } - - if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { - return - } - logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) - if t.taskRegister == nil { - client, err := importer.GetEtcdClient() - if err != nil { - logger.Warn("get etcd client failed", zap.Error(err)) - return - } - t.etcdClient = client - t.taskRegister = NewTaskRegisterWithTTL(client.GetClient(), registerTaskTTL, - utils.RegisterImportInto, strconv.FormatInt(t.taskID, 10)) - } - timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) - defer cancel() - if err := t.taskRegister.RegisterTaskOnce(timeoutCtx); err != nil { - logger.Warn("register task failed", zap.Error(err)) - } else { - logger.Info("register task to pd or refresh lease success") - } - // we set it even if register failed, TTL is 10min, refresh interval is 3min, - // we can try 2 times before the lease is expired. - t.lastRegisterTime = time.Now() -} - -func (t *taskInfo) close(ctx context.Context) { - logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) - if t.taskRegister != nil { - timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) - defer cancel() - if err := t.taskRegister.Close(timeoutCtx); err != nil { - logger.Warn("unregister task failed", zap.Error(err)) - } else { - logger.Info("unregister task success") - } - t.taskRegister = nil - } - if t.etcdClient != nil { - if err := t.etcdClient.Close(); err != nil { - logger.Warn("close etcd client failed", zap.Error(err)) - } - t.etcdClient = nil - } -} - -type flowHandle struct { - mu sync.RWMutex - // NOTE: there's no need to sync for below 2 fields actually, since we add a restriction that only one - // task can be running at a time. but we might support task queuing in the future, leave it for now. - // the last time we switch TiKV into IMPORT mode, this is a global operation, do it for one task makes - // no difference to do it for all tasks. So we do not need to record the switch time for each task. - lastSwitchTime atomic.Time - // taskInfoMap is a map from taskID to taskInfo - taskInfoMap sync.Map - - // currTaskID is the taskID of the current running task. - // It may be changed when we switch to a new task or switch to a new owner. - currTaskID atomic.Int64 - disableTiKVImportMode atomic.Bool -} - -var _ dispatcher.TaskFlowHandle = (*flowHandle)(nil) - -func (h *flowHandle) OnTicker(ctx context.Context, task *proto.Task) { - // only switch TiKV mode or register task when task is running - if task.State != proto.TaskStateRunning { - return - } - h.switchTiKVMode(ctx, task) - h.registerTask(ctx, task) -} - -func (h *flowHandle) switchTiKVMode(ctx context.Context, task *proto.Task) { - h.updateCurrentTask(task) - // only import step need to switch to IMPORT mode, - // If TiKV is in IMPORT mode during checksum, coprocessor will time out. - if h.disableTiKVImportMode.Load() || task.Step != StepImport { - return - } - - if time.Since(h.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { - return - } - - h.mu.Lock() - defer h.mu.Unlock() - if time.Since(h.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { - return - } - - logger := logutil.BgLogger().With(zap.Int64("task-id", task.ID)) - switcher, err := importer.GetTiKVModeSwitcher(logger) - if err != nil { - logger.Warn("get tikv mode switcher failed", zap.Error(err)) - return - } - switcher.ToImportMode(ctx) - h.lastSwitchTime.Store(time.Now()) -} - -func (h *flowHandle) registerTask(ctx context.Context, task *proto.Task) { - val, _ := h.taskInfoMap.LoadOrStore(task.ID, &taskInfo{taskID: task.ID}) - info := val.(*taskInfo) - info.register(ctx) -} - -func (h *flowHandle) unregisterTask(ctx context.Context, task *proto.Task) { - if val, loaded := h.taskInfoMap.LoadAndDelete(task.ID); loaded { - info := val.(*taskInfo) - info.close(ctx) - } -} - -func (h *flowHandle) ProcessNormalFlow(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task) ( - resSubtaskMeta [][]byte, err error) { - logger := logutil.BgLogger().With( - zap.String("type", gTask.Type), - zap.Int64("task-id", gTask.ID), - zap.String("step", stepStr(gTask.Step)), - ) - taskMeta := &TaskMeta{} - err = json.Unmarshal(gTask.Meta, taskMeta) - if err != nil { - return nil, err - } - logger.Info("process normal flow") - - defer func() { - // currently, framework will take the task as finished when err is not nil or resSubtaskMeta is empty. - taskFinished := err == nil && len(resSubtaskMeta) == 0 - if taskFinished { - // todo: we're not running in a transaction with task update - if err2 := h.finishJob(ctx, handle, gTask, taskMeta); err2 != nil { - err = err2 - } - } else if err != nil && !h.IsRetryableErr(err) { - if err2 := h.failJob(ctx, handle, gTask, taskMeta, logger, err.Error()); err2 != nil { - // todo: we're not running in a transaction with task update, there might be case - // failJob return error, but task update succeed. - logger.Error("call failJob failed", zap.Error(err2)) - } - } - }() - - switch gTask.Step { - case proto.StepInit: - if err := preProcess(ctx, handle, gTask, taskMeta, logger); err != nil { - return nil, err - } - if err = startJob(ctx, handle, taskMeta); err != nil { - return nil, err - } - subtaskMetas, err := generateImportStepMetas(ctx, taskMeta) - if err != nil { - return nil, err - } - logger.Info("move to import step", zap.Any("subtask-count", len(subtaskMetas))) - metaBytes := make([][]byte, 0, len(subtaskMetas)) - for _, subtaskMeta := range subtaskMetas { - bs, err := json.Marshal(subtaskMeta) - if err != nil { - return nil, err - } - metaBytes = append(metaBytes, bs) - } - gTask.Step = StepImport - return metaBytes, nil - case StepImport: - h.switchTiKV2NormalMode(ctx, gTask, logger) - failpoint.Inject("clearLastSwitchTime", func() { - h.lastSwitchTime.Store(time.Time{}) - }) - stepMeta, err2 := toPostProcessStep(handle, gTask, taskMeta) - if err2 != nil { - return nil, err2 - } - if err = job2Step(ctx, taskMeta, importer.JobStepValidating); err != nil { - return nil, err - } - logger.Info("move to post-process step ", zap.Any("result", taskMeta.Result), - zap.Any("step-meta", stepMeta)) - bs, err := json.Marshal(stepMeta) - if err != nil { - return nil, err - } - failpoint.Inject("failWhenDispatchPostProcessSubtask", func() { - failpoint.Return(nil, errors.New("injected error after StepImport")) - }) - gTask.Step = StepPostProcess - return [][]byte{bs}, nil - case StepPostProcess: - return nil, nil - default: - return nil, errors.Errorf("unknown step %d", gTask.Step) - } -} - -func (h *flowHandle) ProcessErrFlow(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, receiveErr [][]byte) ([]byte, error) { - logger := logutil.BgLogger().With( - zap.String("type", gTask.Type), - zap.Int64("task-id", gTask.ID), - zap.String("step", stepStr(gTask.Step)), - ) - logger.Info("process error flow", zap.ByteStrings("error-message", receiveErr)) - taskMeta := &TaskMeta{} - err := json.Unmarshal(gTask.Meta, taskMeta) - if err != nil { - return nil, err - } - errStrs := make([]string, 0, len(receiveErr)) - for _, errStr := range receiveErr { - errStrs = append(errStrs, string(errStr)) - } - if err = h.failJob(ctx, handle, gTask, taskMeta, logger, strings.Join(errStrs, "; ")); err != nil { - return nil, err - } - - gTask.Error = receiveErr[0] - - errStr := string(receiveErr[0]) - // do nothing if the error is resumable - if isResumableErr(errStr) { - return nil, nil - } - - if gTask.Step == StepImport { - err = rollback(ctx, handle, gTask, logger) - if err != nil { - // TODO: add error code according to spec. - gTask.Error = []byte(errStr + ", " + err.Error()) - } - } - return nil, err -} - -func (*flowHandle) GetEligibleInstances(ctx context.Context, gTask *proto.Task) ([]*infosync.ServerInfo, error) { - taskMeta := &TaskMeta{} - err := json.Unmarshal(gTask.Meta, taskMeta) - if err != nil { - return nil, err - } - if len(taskMeta.EligibleInstances) > 0 { - return taskMeta.EligibleInstances, nil - } - return dispatcher.GenerateSchedulerNodes(ctx) -} - -func (*flowHandle) IsRetryableErr(error) bool { - // TODO: check whether the error is retryable. - return false -} - -func (h *flowHandle) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) { - h.updateCurrentTask(task) - if h.disableTiKVImportMode.Load() { - return - } - - h.mu.Lock() - defer h.mu.Unlock() - - switcher, err := importer.GetTiKVModeSwitcher(logger) - if err != nil { - logger.Warn("get tikv mode switcher failed", zap.Error(err)) - return - } - switcher.ToNormalMode(ctx) - - // clear it, so next task can switch TiKV mode again. - h.lastSwitchTime.Store(time.Time{}) -} - -func (h *flowHandle) updateCurrentTask(task *proto.Task) { - if h.currTaskID.Swap(task.ID) != task.ID { - taskMeta := &TaskMeta{} - if err := json.Unmarshal(task.Meta, taskMeta); err == nil { - h.disableTiKVImportMode.Store(taskMeta.Plan.DisableTiKVImportMode) - } - } -} - -// preProcess does the pre-processing for the task. -func preProcess(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta, logger *zap.Logger) error { - logger.Info("pre process") - // TODO: drop table indexes depends on the option. - // if err := dropTableIndexes(ctx, handle, taskMeta, logger); err != nil { - // return err - // } - return updateMeta(gTask, taskMeta) -} - -// nolint:deadcode -func dropTableIndexes(ctx context.Context, handle dispatcher.TaskHandle, taskMeta *TaskMeta, logger *zap.Logger) error { - tblInfo := taskMeta.Plan.TableInfo - tableName := common.UniqueTable(taskMeta.Plan.DBName, tblInfo.Name.L) - - remainIndexes, dropIndexes := common.GetDropIndexInfos(tblInfo) - for _, idxInfo := range dropIndexes { - sqlStr := common.BuildDropIndexSQL(tableName, idxInfo) - if err := executeSQL(ctx, handle, logger, sqlStr); err != nil { - if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { - switch merr.Number { - case errno.ErrCantDropFieldOrKey, errno.ErrDropIndexNeededInForeignKey: - remainIndexes = append(remainIndexes, idxInfo) - logger.Warn("can't drop index, skip", zap.String("index", idxInfo.Name.O), zap.Error(err)) - continue - } - } - return err - } - } - if len(remainIndexes) < len(tblInfo.Indices) { - taskMeta.Plan.TableInfo = taskMeta.Plan.TableInfo.Clone() - taskMeta.Plan.TableInfo.Indices = remainIndexes - } - return nil -} - -// nolint:deadcode -func createTableIndexes(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) error { - tableName := common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) - singleSQL, multiSQLs := common.BuildAddIndexSQL(tableName, taskMeta.Plan.TableInfo, taskMeta.Plan.DesiredTableInfo) - logger.Info("build add index sql", zap.String("singleSQL", singleSQL), zap.Strings("multiSQLs", multiSQLs)) - if len(multiSQLs) == 0 { - return nil - } - - err := executeSQL(ctx, executor, logger, singleSQL) - if err == nil { - return nil - } - if !common.IsDupKeyError(err) { - // TODO: refine err msg and error code according to spec. - return errors.Errorf("Failed to create index: %v, please execute the SQL manually, sql: %s", err, singleSQL) - } - if len(multiSQLs) == 1 { - return nil - } - logger.Warn("cannot add all indexes in one statement, try to add them one by one", zap.Strings("sqls", multiSQLs), zap.Error(err)) - - for i, ddl := range multiSQLs { - err := executeSQL(ctx, executor, logger, ddl) - if err != nil && !common.IsDupKeyError(err) { - // TODO: refine err msg and error code according to spec. - return errors.Errorf("Failed to create index: %v, please execute the SQLs manually, sqls: %s", err, strings.Join(multiSQLs[i:], ";")) - } - } - return nil -} - -// TODO: return the result of sql. -func executeSQL(ctx context.Context, executor storage.SessionExecutor, logger *zap.Logger, sql string, args ...interface{}) (err error) { - logger.Info("execute sql", zap.String("sql", sql), zap.Any("args", args)) - return executor.WithNewSession(func(se sessionctx.Context) error { - _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...) - return err - }) -} - -func updateMeta(gTask *proto.Task, taskMeta *TaskMeta) error { - bs, err := json.Marshal(taskMeta) - if err != nil { - return err - } - gTask.Meta = bs - return nil -} - -func buildController(taskMeta *TaskMeta) (*importer.LoadDataController, error) { - idAlloc := kv.NewPanickingAllocators(0) - tbl, err := tables.TableFromMeta(idAlloc, taskMeta.Plan.TableInfo) - if err != nil { - return nil, err - } - - astArgs, err := importer.ASTArgsFromStmt(taskMeta.Stmt) - if err != nil { - return nil, err - } - controller, err := importer.NewLoadDataController(&taskMeta.Plan, tbl, astArgs) - if err != nil { - return nil, err - } - return controller, nil -} - -// todo: converting back and forth, we should unify struct and remove this function later. -func toChunkMap(engineCheckpoints map[int32]*checkpoints.EngineCheckpoint) map[int32][]Chunk { - chunkMap := make(map[int32][]Chunk, len(engineCheckpoints)) - for id, ecp := range engineCheckpoints { - chunkMap[id] = make([]Chunk, 0, len(ecp.Chunks)) - for _, chunkCheckpoint := range ecp.Chunks { - chunkMap[id] = append(chunkMap[id], toChunk(*chunkCheckpoint)) - } - } - return chunkMap -} - -func generateImportStepMetas(ctx context.Context, taskMeta *TaskMeta) (subtaskMetas []*ImportStepMeta, err error) { - var chunkMap map[int32][]Chunk - if len(taskMeta.ChunkMap) > 0 { - chunkMap = taskMeta.ChunkMap - } else { - controller, err2 := buildController(taskMeta) - if err2 != nil { - return nil, err2 - } - if err2 = controller.InitDataFiles(ctx); err2 != nil { - return nil, err2 - } - - engineCheckpoints, err2 := controller.PopulateChunks(ctx) - if err2 != nil { - return nil, err2 - } - chunkMap = toChunkMap(engineCheckpoints) - } - for id := range chunkMap { - if id == common.IndexEngineID { - continue - } - subtaskMeta := &ImportStepMeta{ - ID: id, - Chunks: chunkMap[id], - } - subtaskMetas = append(subtaskMetas, subtaskMeta) - } - return subtaskMetas, nil -} - -// we will update taskMeta in place and make gTask.Meta point to the new taskMeta. -func toPostProcessStep(handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) (*PostProcessStepMeta, error) { - metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step) - if err != nil { - return nil, err - } - - subtaskMetas := make([]*ImportStepMeta, 0, len(metas)) - for _, bs := range metas { - var subtaskMeta ImportStepMeta - if err := json.Unmarshal(bs, &subtaskMeta); err != nil { - return nil, err - } - subtaskMetas = append(subtaskMetas, &subtaskMeta) - } - var localChecksum verify.KVChecksum - columnSizeMap := make(map[int64]int64) - for _, subtaskMeta := range subtaskMetas { - checksum := verify.MakeKVChecksum(subtaskMeta.Checksum.Size, subtaskMeta.Checksum.KVs, subtaskMeta.Checksum.Sum) - localChecksum.Add(&checksum) - - taskMeta.Result.ReadRowCnt += subtaskMeta.Result.ReadRowCnt - taskMeta.Result.LoadedRowCnt += subtaskMeta.Result.LoadedRowCnt - for key, val := range subtaskMeta.Result.ColSizeMap { - columnSizeMap[key] += val - } - } - taskMeta.Result.ColSizeMap = columnSizeMap - if err2 := updateMeta(gTask, taskMeta); err2 != nil { - return nil, err2 - } - return &PostProcessStepMeta{ - Checksum: Checksum{ - Size: localChecksum.SumSize(), - KVs: localChecksum.SumKVS(), - Sum: localChecksum.Sum(), - }, - }, nil -} - -func startJob(ctx context.Context, handle dispatcher.TaskHandle, taskMeta *TaskMeta) error { - failpoint.Inject("syncBeforeJobStarted", func() { - TestSyncChan <- struct{}{} - <-TestSyncChan - }) - err := handle.WithNewSession(func(se sessionctx.Context) error { - exec := se.(sqlexec.SQLExecutor) - return importer.StartJob(ctx, exec, taskMeta.JobID) - }) - failpoint.Inject("syncAfterJobStarted", func() { - TestSyncChan <- struct{}{} - }) - return err -} - -func job2Step(ctx context.Context, taskMeta *TaskMeta, step string) error { - globalTaskManager, err := storage.GetTaskManager() - if err != nil { - return err - } - // todo: use dispatcher.TaskHandle - // we might call this in scheduler later, there's no dispatcher.TaskHandle, so we use globalTaskManager here. - return globalTaskManager.WithNewSession(func(se sessionctx.Context) error { - exec := se.(sqlexec.SQLExecutor) - return importer.Job2Step(ctx, exec, taskMeta.JobID, step) - }) -} - -func (h *flowHandle) finishJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error { - h.unregisterTask(ctx, gTask) - redactSensitiveInfo(gTask, taskMeta) - summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt} - return handle.WithNewSession(func(se sessionctx.Context) error { - exec := se.(sqlexec.SQLExecutor) - return importer.FinishJob(ctx, exec, taskMeta.JobID, summary) - }) -} - -func (h *flowHandle) failJob(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, - taskMeta *TaskMeta, logger *zap.Logger, errorMsg string) error { - h.switchTiKV2NormalMode(ctx, gTask, logger) - h.unregisterTask(ctx, gTask) - redactSensitiveInfo(gTask, taskMeta) - return handle.WithNewSession(func(se sessionctx.Context) error { - exec := se.(sqlexec.SQLExecutor) - return importer.FailJob(ctx, exec, taskMeta.JobID, errorMsg) - }) -} - -func redactSensitiveInfo(gTask *proto.Task, taskMeta *TaskMeta) { - taskMeta.Stmt = "" - taskMeta.Plan.Path = ast.RedactURL(taskMeta.Plan.Path) - if err := updateMeta(gTask, taskMeta); err != nil { - // marshal failed, should not happen - logutil.BgLogger().Warn("failed to update task meta", zap.Error(err)) - } -} - -// isResumableErr checks whether it's possible to rely on checkpoint to re-import data after the error has been fixed. -func isResumableErr(string) bool { - // TODO: add more cases - return false -} - -func rollback(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task, logger *zap.Logger) (err error) { - taskMeta := &TaskMeta{} - err = json.Unmarshal(gTask.Meta, taskMeta) - if err != nil { - return err - } - - logger.Info("rollback") - - // // TODO: create table indexes depends on the option. - // // create table indexes even if the rollback is failed. - // defer func() { - // err2 := createTableIndexes(ctx, handle, taskMeta, logger) - // err = multierr.Append(err, err2) - // }() - - tableName := common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) - // truncate the table - return executeSQL(ctx, handle, logger, "TRUNCATE "+tableName) -} - -func stepStr(step int64) string { - switch step { - case proto.StepInit: - return "init" - case StepImport: - return "import" - case StepPostProcess: - return "postprocess" - default: - return "unknown" - } -} - -func init() { - dispatcher.RegisterTaskFlowHandle(proto.ImportInto, &flowHandle{}) -} diff --git a/disttask/importinto/job.go b/disttask/importinto/job.go deleted file mode 100644 index 64b61048d8c88..0000000000000 --- a/disttask/importinto/job.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importinto - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/disttask/framework/handle" - "github.com/pingcap/tidb/disttask/framework/proto" - "github.com/pingcap/tidb/disttask/framework/storage" - "github.com/pingcap/tidb/domain/infosync" - "github.com/pingcap/tidb/executor/importer" - "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/util/dbterror/exeerrors" - "github.com/pingcap/tidb/util/logutil" - "github.com/pingcap/tidb/util/sqlexec" - "go.uber.org/zap" -) - -// DistImporter is a JobImporter for distributed IMPORT INTO. -type DistImporter struct { - *importer.JobImportParam - plan *importer.Plan - stmt string - logger *zap.Logger - // the instance to import data, used for single-node import, nil means import data on all instances. - instance *infosync.ServerInfo - // the files to import, when import from server file, we need to pass those file to the framework. - chunkMap map[int32][]Chunk - sourceFileSize int64 - // only set after submit task - jobID int64 - taskID int64 -} - -// NewDistImporter creates a new DistImporter. -func NewDistImporter(param *importer.JobImportParam, plan *importer.Plan, stmt string, sourceFileSize int64) (*DistImporter, error) { - return &DistImporter{ - JobImportParam: param, - plan: plan, - stmt: stmt, - logger: logutil.BgLogger(), - sourceFileSize: sourceFileSize, - }, nil -} - -// NewDistImporterCurrNode creates a new DistImporter to import data on current node. -func NewDistImporterCurrNode(param *importer.JobImportParam, plan *importer.Plan, stmt string, sourceFileSize int64) (*DistImporter, error) { - serverInfo, err := infosync.GetServerInfo() - if err != nil { - return nil, err - } - return &DistImporter{ - JobImportParam: param, - plan: plan, - stmt: stmt, - logger: logutil.BgLogger(), - instance: serverInfo, - sourceFileSize: sourceFileSize, - }, nil -} - -// NewDistImporterServerFile creates a new DistImporter to import given files on current node. -// we also run import on current node. -// todo: merge all 3 ctor into one. -func NewDistImporterServerFile(param *importer.JobImportParam, plan *importer.Plan, stmt string, ecp map[int32]*checkpoints.EngineCheckpoint, sourceFileSize int64) (*DistImporter, error) { - distImporter, err := NewDistImporterCurrNode(param, plan, stmt, sourceFileSize) - if err != nil { - return nil, err - } - distImporter.chunkMap = toChunkMap(ecp) - return distImporter, nil -} - -// Param implements JobImporter.Param. -func (ti *DistImporter) Param() *importer.JobImportParam { - return ti.JobImportParam -} - -// Import implements JobImporter.Import. -func (*DistImporter) Import() { - // todo: remove it -} - -// ImportTask import task. -func (ti *DistImporter) ImportTask(task *proto.Task) { - ti.logger.Info("start distribute IMPORT INTO") - ti.Group.Go(func() error { - defer close(ti.Done) - // task is run using distribute framework, so we only wait for the task to finish. - return handle.WaitGlobalTask(ti.GroupCtx, task) - }) -} - -// Result implements JobImporter.Result. -func (ti *DistImporter) Result() importer.JobImportResult { - var result importer.JobImportResult - taskMeta, err := getTaskMeta(ti.jobID) - if err != nil { - result.Msg = err.Error() - return result - } - - var ( - numWarnings uint64 - numRecords uint64 - numDeletes uint64 - numSkipped uint64 - ) - numRecords = taskMeta.Result.ReadRowCnt - // todo: we don't have a strict REPLACE or IGNORE mode in physical mode, so we can't get the numDeletes/numSkipped. - // we can have it when there's duplicate detection. - msg := fmt.Sprintf(mysql.MySQLErrName[mysql.ErrLoadInfo].Raw, numRecords, numDeletes, numSkipped, numWarnings) - return importer.JobImportResult{ - Msg: msg, - Affected: taskMeta.Result.ReadRowCnt, - ColSizeMap: taskMeta.Result.ColSizeMap, - } -} - -// Close implements the io.Closer interface. -func (*DistImporter) Close() error { - return nil -} - -// SubmitTask submits a task to the distribute framework. -func (ti *DistImporter) SubmitTask(ctx context.Context) (int64, *proto.Task, error) { - var instances []*infosync.ServerInfo - if ti.instance != nil { - instances = append(instances, ti.instance) - } - // we use globalTaskManager to submit task, user might not have the privilege to system tables. - globalTaskManager, err := storage.GetTaskManager() - if err != nil { - return 0, nil, err - } - - var jobID, taskID int64 - plan := ti.plan - if err = globalTaskManager.WithNewTxn(ctx, func(se sessionctx.Context) error { - var err2 error - exec := se.(sqlexec.SQLExecutor) - // If 2 client try to execute IMPORT INTO concurrently, there's chance that both of them will pass the check. - // We can enforce ONLY one import job running by: - // - using LOCK TABLES, but it requires enable-table-lock=true, it's not enabled by default. - // - add a key to PD as a distributed lock, but it's a little complex, and we might support job queuing later. - // So we only add this simple soft check here and doc it. - activeJobCnt, err2 := importer.GetActiveJobCnt(ctx, exec) - if err2 != nil { - return err2 - } - if activeJobCnt > 0 { - return exeerrors.ErrLoadDataPreCheckFailed.FastGenByArgs("there's pending or running jobs") - } - jobID, err2 = importer.CreateJob(ctx, exec, plan.DBName, plan.TableInfo.Name.L, plan.TableInfo.ID, - plan.User, plan.Parameters, ti.sourceFileSize) - if err2 != nil { - return err2 - } - task := TaskMeta{ - JobID: jobID, - Plan: *plan, - Stmt: ti.stmt, - EligibleInstances: instances, - ChunkMap: ti.chunkMap, - } - taskMeta, err2 := json.Marshal(task) - if err2 != nil { - return err2 - } - taskID, err2 = globalTaskManager.AddGlobalTaskWithSession(se, TaskKey(jobID), proto.ImportInto, - int(plan.ThreadCnt), taskMeta) - if err2 != nil { - return err2 - } - return nil - }); err != nil { - return 0, nil, err - } - - globalTask, err := globalTaskManager.GetGlobalTaskByID(taskID) - if err != nil { - return 0, nil, err - } - if globalTask == nil { - return 0, nil, errors.Errorf("cannot find global task with ID %d", taskID) - } - // update logger with task id. - ti.jobID = jobID - ti.taskID = taskID - ti.logger = ti.logger.With(zap.Int64("task-id", globalTask.ID)) - - ti.logger.Info("job submitted to global task queue", zap.Int64("job-id", jobID)) - - return jobID, globalTask, nil -} - -func (*DistImporter) taskKey() string { - // task key is meaningless to IMPORT INTO, so we use a random uuid. - return fmt.Sprintf("%s/%s", proto.ImportInto, uuid.New().String()) -} - -// JobID returns the job id. -func (ti *DistImporter) JobID() int64 { - return ti.jobID -} - -func getTaskMeta(jobID int64) (*TaskMeta, error) { - globalTaskManager, err := storage.GetTaskManager() - if err != nil { - return nil, err - } - taskKey := TaskKey(jobID) - globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey) - if err != nil { - return nil, err - } - if globalTask == nil { - return nil, errors.Errorf("cannot find global task with key %s", taskKey) - } - var taskMeta TaskMeta - if err := json.Unmarshal(globalTask.Meta, &taskMeta); err != nil { - return nil, err - } - return &taskMeta, nil -} - -// GetTaskImportedRows gets the number of imported rows of a job. -// Note: for finished job, we can get the number of imported rows from task meta. -func GetTaskImportedRows(jobID int64) (uint64, error) { - globalTaskManager, err := storage.GetTaskManager() - if err != nil { - return 0, err - } - taskKey := TaskKey(jobID) - globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey) - if err != nil { - return 0, err - } - if globalTask == nil { - return 0, errors.Errorf("cannot find global task with key %s", taskKey) - } - subtasks, err := globalTaskManager.GetSubtasksByStep(globalTask.ID, StepImport) - if err != nil { - return 0, err - } - var importedRows uint64 - for _, subtask := range subtasks { - var subtaskMeta ImportStepMeta - if err2 := json.Unmarshal(subtask.Meta, &subtaskMeta); err2 != nil { - return 0, err2 - } - importedRows += subtaskMeta.Result.LoadedRowCnt - } - return importedRows, nil -} - -// TaskKey returns the task key for a job. -func TaskKey(jobID int64) string { - return fmt.Sprintf("%s/%d", proto.ImportInto, jobID) -} diff --git a/disttask/importinto/subtask_executor.go b/disttask/importinto/subtask_executor.go deleted file mode 100644 index be6de9a75d0c0..0000000000000 --- a/disttask/importinto/subtask_executor.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importinto - -import ( - "context" - "strconv" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/lightning/backend/local" - "github.com/pingcap/tidb/br/pkg/lightning/common" - "github.com/pingcap/tidb/br/pkg/lightning/config" - verify "github.com/pingcap/tidb/br/pkg/lightning/verification" - "github.com/pingcap/tidb/disttask/framework/proto" - "github.com/pingcap/tidb/disttask/framework/scheduler" - "github.com/pingcap/tidb/disttask/framework/storage" - "github.com/pingcap/tidb/executor/importer" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/util/logutil" - "github.com/pingcap/tidb/util/mathutil" - "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -// TestSyncChan is used to test. -var TestSyncChan = make(chan struct{}) - -// ImportMinimalTaskExecutor is a minimal task executor for IMPORT INTO. -type ImportMinimalTaskExecutor struct { - mTtask *importStepMinimalTask -} - -// Run implements the SubtaskExecutor.Run interface. -func (e *ImportMinimalTaskExecutor) Run(ctx context.Context) error { - logger := logutil.BgLogger().With(zap.String("type", proto.ImportInto), zap.Int64("table-id", e.mTtask.Plan.TableInfo.ID)) - logger.Info("run minimal task") - failpoint.Inject("waitBeforeSortChunk", func() { - time.Sleep(3 * time.Second) - }) - failpoint.Inject("errorWhenSortChunk", func() { - failpoint.Return(errors.New("occur an error when sort chunk")) - }) - failpoint.Inject("syncBeforeSortChunk", func() { - TestSyncChan <- struct{}{} - <-TestSyncChan - }) - chunkCheckpoint := toChunkCheckpoint(e.mTtask.Chunk) - sharedVars := e.mTtask.SharedVars - if err := importer.ProcessChunk(ctx, &chunkCheckpoint, sharedVars.TableImporter, sharedVars.DataEngine, sharedVars.IndexEngine, sharedVars.Progress, logger); err != nil { - return err - } - - sharedVars.mu.Lock() - defer sharedVars.mu.Unlock() - sharedVars.Checksum.Add(&chunkCheckpoint.Checksum) - return nil -} - -type postProcessMinimalTaskExecutor struct { - mTask *postProcessStepMinimalTask -} - -func (e *postProcessMinimalTaskExecutor) Run(ctx context.Context) error { - mTask := e.mTask - failpoint.Inject("waitBeforePostProcess", func() { - time.Sleep(5 * time.Second) - }) - return postProcess(ctx, mTask.taskMeta, &mTask.meta, mTask.logger) -} - -// postProcess does the post-processing for the task. -func postProcess(ctx context.Context, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) (err error) { - failpoint.Inject("syncBeforePostProcess", func() { - TestSyncChan <- struct{}{} - <-TestSyncChan - }) - - logger.Info("post process") - - // TODO: create table indexes depends on the option. - // create table indexes even if the post process is failed. - // defer func() { - // err2 := createTableIndexes(ctx, globalTaskManager, taskMeta, logger) - // err = multierr.Append(err, err2) - // }() - - return verifyChecksum(ctx, taskMeta, subtaskMeta, logger) -} - -func verifyChecksum(ctx context.Context, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) error { - if taskMeta.Plan.Checksum == config.OpLevelOff { - return nil - } - localChecksum := verify.MakeKVChecksum(subtaskMeta.Checksum.Size, subtaskMeta.Checksum.KVs, subtaskMeta.Checksum.Sum) - logger.Info("local checksum", zap.Object("checksum", &localChecksum)) - - failpoint.Inject("waitCtxDone", func() { - <-ctx.Done() - }) - - globalTaskManager, err := storage.GetTaskManager() - if err != nil { - return err - } - remoteChecksum, err := checksumTable(ctx, globalTaskManager, taskMeta, logger) - if err != nil { - return err - } - if !remoteChecksum.IsEqual(&localChecksum) { - err2 := common.ErrChecksumMismatch.GenWithStackByArgs( - remoteChecksum.Checksum, localChecksum.Sum(), - remoteChecksum.TotalKVs, localChecksum.SumKVS(), - remoteChecksum.TotalBytes, localChecksum.SumSize(), - ) - if taskMeta.Plan.Checksum == config.OpLevelOptional { - logger.Warn("verify checksum failed, but checksum is optional, will skip it", zap.Error(err2)) - err2 = nil - } - return err2 - } - logger.Info("checksum pass", zap.Object("local", &localChecksum)) - return nil -} - -func checksumTable(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) (*local.RemoteChecksum, error) { - var ( - tableName = common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) - sql = "ADMIN CHECKSUM TABLE " + tableName - maxErrorRetryCount = 3 - distSQLScanConcurrencyFactor = 1 - remoteChecksum *local.RemoteChecksum - txnErr error - ) - - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - for i := 0; i < maxErrorRetryCount; i++ { - txnErr = executor.WithNewTxn(ctx, func(se sessionctx.Context) error { - // increase backoff weight - if err := setBackoffWeight(se, taskMeta, logger); err != nil { - logger.Warn("set tidb_backoff_weight failed", zap.Error(err)) - } - - distSQLScanConcurrency := se.GetSessionVars().DistSQLScanConcurrency() - se.GetSessionVars().SetDistSQLScanConcurrency(mathutil.Max(distSQLScanConcurrency/distSQLScanConcurrencyFactor, local.MinDistSQLScanConcurrency)) - defer func() { - se.GetSessionVars().SetDistSQLScanConcurrency(distSQLScanConcurrency) - }() - - rs, err := storage.ExecSQL(ctx, se, sql) - if err != nil { - return err - } - if len(rs) < 1 { - return errors.New("empty checksum result") - } - - failpoint.Inject("errWhenChecksum", func() { - if i == 0 { - failpoint.Return(errors.New("occur an error when checksum, coprocessor task terminated due to exceeding the deadline")) - } - }) - - // ADMIN CHECKSUM TABLE .
example. - // mysql> admin checksum table test.t; - // +---------+------------+---------------------+-----------+-------------+ - // | Db_name | Table_name | Checksum_crc64_xor | Total_kvs | Total_bytes | - // +---------+------------+---------------------+-----------+-------------+ - // | test | t | 8520875019404689597 | 7296873 | 357601387 | - // +---------+------------+------------- - remoteChecksum = &local.RemoteChecksum{ - Schema: rs[0].GetString(0), - Table: rs[0].GetString(1), - Checksum: rs[0].GetUint64(2), - TotalKVs: rs[0].GetUint64(3), - TotalBytes: rs[0].GetUint64(4), - } - return nil - }) - if !common.IsRetryableError(txnErr) { - break - } - distSQLScanConcurrencyFactor *= 2 - logger.Warn("retry checksum table", zap.Int("retry count", i+1), zap.Error(txnErr)) - } - return remoteChecksum, txnErr -} - -// TestChecksumTable is used to test checksum table in unit test. -func TestChecksumTable(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) (*local.RemoteChecksum, error) { - return checksumTable(ctx, executor, taskMeta, logger) -} - -func setBackoffWeight(se sessionctx.Context, taskMeta *TaskMeta, logger *zap.Logger) error { - backoffWeight := local.DefaultBackoffWeight - if val, ok := taskMeta.Plan.ImportantSysVars[variable.TiDBBackOffWeight]; ok { - if weight, err := strconv.Atoi(val); err == nil && weight > backoffWeight { - backoffWeight = weight - } - } - logger.Info("set backoff weight", zap.Int("weight", backoffWeight)) - return se.GetSessionVars().SetSystemVar(variable.TiDBBackOffWeight, strconv.Itoa(backoffWeight)) -} - -func init() { - scheduler.RegisterSubtaskExectorConstructor(proto.ImportInto, StepImport, - // The order of the subtask executors is the same as the order of the subtasks. - func(minimalTask proto.MinimalTask, step int64) (scheduler.SubtaskExecutor, error) { - task, ok := minimalTask.(*importStepMinimalTask) - if !ok { - return nil, errors.Errorf("invalid task type %T", minimalTask) - } - return &ImportMinimalTaskExecutor{mTtask: task}, nil - }, - ) - scheduler.RegisterSubtaskExectorConstructor(proto.ImportInto, StepPostProcess, - func(minimalTask proto.MinimalTask, step int64) (scheduler.SubtaskExecutor, error) { - mTask, ok := minimalTask.(*postProcessStepMinimalTask) - if !ok { - return nil, errors.Errorf("invalid task type %T", minimalTask) - } - return &postProcessMinimalTaskExecutor{mTask: mTask}, nil - }, - ) -} diff --git a/disttask/importinto/subtask_executor_test.go b/disttask/importinto/subtask_executor_test.go deleted file mode 100644 index 4596ffc795aa2..0000000000000 --- a/disttask/importinto/subtask_executor_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importinto_test - -import ( - "context" - "testing" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/failpoint" - verify "github.com/pingcap/tidb/br/pkg/lightning/verification" - "github.com/pingcap/tidb/disttask/framework/storage" - "github.com/pingcap/tidb/disttask/importinto" - "github.com/pingcap/tidb/executor/importer" - "github.com/pingcap/tidb/parser/model" - "github.com/pingcap/tidb/testkit" - "github.com/pingcap/tidb/util/logutil" - "github.com/stretchr/testify/require" - "github.com/tikv/client-go/v2/util" -) - -func TestChecksumTable(t *testing.T) { - ctx := context.Background() - store := testkit.CreateMockStore(t) - gtk := testkit.NewTestKit(t, store) - pool := pools.NewResourcePool(func() (pools.Resource, error) { - return gtk.Session(), nil - }, 1, 1, time.Second) - defer pool.Close() - mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool) - - taskMeta := &importinto.TaskMeta{ - Plan: importer.Plan{ - DBName: "db", - TableInfo: &model.TableInfo{ - Name: model.NewCIStr("tb"), - }, - }, - } - // fake result - localChecksum := verify.MakeKVChecksum(1, 1, 1) - gtk.MustExec("create database db") - gtk.MustExec("create table db.tb(id int)") - gtk.MustExec("insert into db.tb values(1)") - remoteChecksum, err := importinto.TestChecksumTable(ctx, mgr, taskMeta, logutil.BgLogger()) - require.NoError(t, err) - require.True(t, remoteChecksum.IsEqual(&localChecksum)) - // again - remoteChecksum, err = importinto.TestChecksumTable(ctx, mgr, taskMeta, logutil.BgLogger()) - require.NoError(t, err) - require.True(t, remoteChecksum.IsEqual(&localChecksum)) - - _ = failpoint.Enable("github.com/pingcap/tidb/disttask/importinto/errWhenChecksum", `return(true)`) - defer func() { - _ = failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/errWhenChecksum") - }() - remoteChecksum, err = importinto.TestChecksumTable(ctx, mgr, taskMeta, logutil.BgLogger()) - require.NoError(t, err) - require.True(t, remoteChecksum.IsEqual(&localChecksum)) -} diff --git a/executor/import_into.go b/executor/import_into.go deleted file mode 100644 index 92f16fb13f611..0000000000000 --- a/executor/import_into.go +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "sync/atomic" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/disttask/framework/proto" - fstorage "github.com/pingcap/tidb/disttask/framework/storage" - "github.com/pingcap/tidb/disttask/importinto" - "github.com/pingcap/tidb/executor/asyncloaddata" - "github.com/pingcap/tidb/executor/importer" - "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/parser/ast" - "github.com/pingcap/tidb/parser/mysql" - plannercore "github.com/pingcap/tidb/planner/core" - "github.com/pingcap/tidb/privilege" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/sessiontxn" - "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/dbterror/exeerrors" - "github.com/pingcap/tidb/util/logutil" - "github.com/pingcap/tidb/util/sqlexec" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -var ( - // TestDetachedTaskFinished is a flag for test. - TestDetachedTaskFinished atomic.Bool - // TestCancelFunc for test. - TestCancelFunc context.CancelFunc -) - -const unknownImportedRowCount = -1 - -// ImportIntoExec represents a IMPORT INTO executor. -type ImportIntoExec struct { - baseExecutor - userSctx sessionctx.Context - importPlan *importer.Plan - controller *importer.LoadDataController - stmt string - - dataFilled bool -} - -var ( - _ Executor = (*ImportIntoExec)(nil) -) - -func newImportIntoExec(b baseExecutor, userSctx sessionctx.Context, plan *plannercore.ImportInto, tbl table.Table) ( - *ImportIntoExec, error) { - importPlan, err := importer.NewImportPlan(userSctx, plan, tbl) - if err != nil { - return nil, err - } - astArgs := importer.ASTArgsFromImportPlan(plan) - controller, err := importer.NewLoadDataController(importPlan, tbl, astArgs) - if err != nil { - return nil, err - } - return &ImportIntoExec{ - baseExecutor: b, - userSctx: userSctx, - importPlan: importPlan, - controller: controller, - stmt: plan.Stmt, - }, nil -} - -// Next implements the Executor Next interface. -func (e *ImportIntoExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { - req.GrowAndReset(e.maxChunkSize) - ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) - if e.dataFilled { - // need to return an empty req to indicate all results have been written - return nil - } - if err2 := e.controller.InitDataFiles(ctx); err2 != nil { - return err2 - } - - // must use a new session to pre-check, else the stmt in show processlist will be changed. - newSCtx, err2 := CreateSession(e.userSctx) - if err2 != nil { - return err2 - } - defer CloseSession(newSCtx) - sqlExec := newSCtx.(sqlexec.SQLExecutor) - if err2 = e.controller.CheckRequirements(ctx, sqlExec); err2 != nil { - return err2 - } - - failpoint.Inject("cancellableCtx", func() { - // KILL is not implemented in testkit, so we use a fail-point to simulate it. - newCtx, cancel := context.WithCancel(ctx) - ctx = newCtx - TestCancelFunc = cancel - }) - // todo: we don't need Job now, remove it later. - parentCtx := ctx - if e.controller.Detached { - parentCtx = context.Background() - } - group, groupCtx := errgroup.WithContext(parentCtx) - param := &importer.JobImportParam{ - Job: &asyncloaddata.Job{}, - Group: group, - GroupCtx: groupCtx, - Done: make(chan struct{}), - Progress: asyncloaddata.NewProgress(false), - } - distImporter, err := e.getJobImporter(ctx, param) - if err != nil { - return err - } - defer func() { - _ = distImporter.Close() - }() - param.Progress.SourceFileSize = e.controller.TotalFileSize - jobID, task, err := distImporter.SubmitTask(ctx) - if err != nil { - return err - } - - if e.controller.Detached { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalImportInto) - se, err := CreateSession(e.userSctx) - if err != nil { - return err - } - go func() { - defer CloseSession(se) - // error is stored in system table, so we can ignore it here - //nolint: errcheck - _ = e.doImport(ctx, se, distImporter, task) - failpoint.Inject("testDetachedTaskFinished", func() { - TestDetachedTaskFinished.Store(true) - }) - }() - return e.fillJobInfo(ctx, jobID, req) - } - if err = e.doImport(ctx, e.userSctx, distImporter, task); err != nil { - return err - } - return e.fillJobInfo(ctx, jobID, req) -} - -func (e *ImportIntoExec) fillJobInfo(ctx context.Context, jobID int64, req *chunk.Chunk) error { - e.dataFilled = true - // we use globalTaskManager to get job, user might not have the privilege to system tables. - globalTaskManager, err := fstorage.GetTaskManager() - if err != nil { - return err - } - var info *importer.JobInfo - if err = globalTaskManager.WithNewSession(func(se sessionctx.Context) error { - sqlExec := se.(sqlexec.SQLExecutor) - var err2 error - info, err2 = importer.GetJob(ctx, sqlExec, jobID, e.ctx.GetSessionVars().User.String(), false) - return err2 - }); err != nil { - return err - } - fillOneImportJobInfo(info, req, unknownImportedRowCount) - return nil -} - -func (e *ImportIntoExec) getJobImporter(ctx context.Context, param *importer.JobImportParam) (*importinto.DistImporter, error) { - importFromServer, err := storage.IsLocalPath(e.controller.Path) - if err != nil { - // since we have checked this during creating controller, this should not happen. - return nil, exeerrors.ErrLoadDataInvalidURI.FastGenByArgs(err.Error()) - } - logutil.Logger(ctx).Info("get job importer", zap.Stringer("param", e.controller.Parameters), - zap.Bool("dist-task-enabled", variable.EnableDistTask.Load())) - if importFromServer { - ecp, err2 := e.controller.PopulateChunks(ctx) - if err2 != nil { - return nil, err2 - } - return importinto.NewDistImporterServerFile(param, e.importPlan, e.stmt, ecp, e.controller.TotalFileSize) - } - // if tidb_enable_dist_task=true, we import distributively, otherwise we import on current node. - if variable.EnableDistTask.Load() { - return importinto.NewDistImporter(param, e.importPlan, e.stmt, e.controller.TotalFileSize) - } - return importinto.NewDistImporterCurrNode(param, e.importPlan, e.stmt, e.controller.TotalFileSize) -} - -func (e *ImportIntoExec) doImport(ctx context.Context, se sessionctx.Context, distImporter *importinto.DistImporter, task *proto.Task) error { - distImporter.ImportTask(task) - group := distImporter.Param().Group - err := group.Wait() - // when user KILL the connection, the ctx will be canceled, we need to cancel the import job. - if errors.Cause(err) == context.Canceled { - globalTaskManager, err2 := fstorage.GetTaskManager() - if err2 != nil { - return err2 - } - // use background, since ctx is canceled already. - return cancelImportJob(context.Background(), globalTaskManager, distImporter.JobID()) - } - if err2 := flushStats(ctx, se, e.importPlan.TableInfo.ID, distImporter.Result()); err2 != nil { - logutil.Logger(ctx).Error("flush stats failed", zap.Error(err2)) - } - return err -} - -// ImportIntoActionExec represents a import into action executor. -type ImportIntoActionExec struct { - baseExecutor - tp ast.ImportIntoActionTp - jobID int64 -} - -var ( - _ Executor = (*ImportIntoActionExec)(nil) -) - -// Next implements the Executor Next interface. -func (e *ImportIntoActionExec) Next(ctx context.Context, _ *chunk.Chunk) error { - ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) - - var hasSuperPriv bool - if pm := privilege.GetPrivilegeManager(e.ctx); pm != nil { - hasSuperPriv = pm.RequestVerification(e.ctx.GetSessionVars().ActiveRoles, "", "", "", mysql.SuperPriv) - } - // we use sessionCtx from GetTaskManager, user ctx might not have enough privileges. - globalTaskManager, err := fstorage.GetTaskManager() - if err != nil { - return err - } - if err = e.checkPrivilegeAndStatus(ctx, globalTaskManager, hasSuperPriv); err != nil { - return err - } - - logutil.Logger(ctx).Info("import into action", zap.Int64("jobID", e.jobID), zap.Any("action", e.tp)) - return cancelImportJob(ctx, globalTaskManager, e.jobID) -} - -func (e *ImportIntoActionExec) checkPrivilegeAndStatus(ctx context.Context, manager *fstorage.TaskManager, hasSuperPriv bool) error { - var info *importer.JobInfo - if err := manager.WithNewSession(func(se sessionctx.Context) error { - exec := se.(sqlexec.SQLExecutor) - var err2 error - info, err2 = importer.GetJob(ctx, exec, e.jobID, e.ctx.GetSessionVars().User.String(), hasSuperPriv) - return err2 - }); err != nil { - return err - } - if !info.CanCancel() { - return exeerrors.ErrLoadDataInvalidOperation.FastGenByArgs("CANCEL") - } - return nil -} - -// flushStats flushes the stats of the table. -func flushStats(ctx context.Context, se sessionctx.Context, tableID int64, result importer.JobImportResult) error { - if err := sessiontxn.NewTxn(ctx, se); err != nil { - return err - } - sessionVars := se.GetSessionVars() - sessionVars.TxnCtxMu.Lock() - defer sessionVars.TxnCtxMu.Unlock() - sessionVars.TxnCtx.UpdateDeltaForTable(tableID, int64(result.Affected), int64(result.Affected), result.ColSizeMap) - se.StmtCommit(ctx) - return se.CommitTxn(ctx) -} - -func cancelImportJob(ctx context.Context, manager *fstorage.TaskManager, jobID int64) error { - // todo: cancel is async operation, we don't wait here now, maybe add a wait syntax later. - // todo: after CANCEL, user can see the job status is Canceled immediately, but the job might still running. - // and the state of framework task might became finished since framework don't force state change DAG when update task. - // todo: add a CANCELLING status? - return manager.WithNewTxn(ctx, func(se sessionctx.Context) error { - exec := se.(sqlexec.SQLExecutor) - if err2 := importer.CancelJob(ctx, exec, jobID); err2 != nil { - return err2 - } - return manager.CancelGlobalTaskByKeySession(se, importinto.TaskKey(jobID)) - }) -} diff --git a/executor/importer/BUILD.bazel b/executor/importer/BUILD.bazel deleted file mode 100644 index 2cb8288221492..0000000000000 --- a/executor/importer/BUILD.bazel +++ /dev/null @@ -1,104 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") - -go_library( - name = "importer", - srcs = [ - "chunk_process.go", - "engine_process.go", - "import.go", - "job.go", - "kv_encode.go", - "precheck.go", - "table_import.go", - ], - importpath = "github.com/pingcap/tidb/executor/importer", - visibility = ["//visibility:public"], - deps = [ - "//br/pkg/lightning/backend", - "//br/pkg/lightning/backend/encode", - "//br/pkg/lightning/backend/kv", - "//br/pkg/lightning/backend/local", - "//br/pkg/lightning/checkpoints", - "//br/pkg/lightning/common", - "//br/pkg/lightning/config", - "//br/pkg/lightning/log", - "//br/pkg/lightning/mydump", - "//br/pkg/lightning/verification", - "//br/pkg/storage", - "//br/pkg/streamhelper", - "//br/pkg/utils", - "//config", - "//executor/asyncloaddata", - "//expression", - "//kv", - "//meta/autoid", - "//parser", - "//parser/ast", - "//parser/format", - "//parser/model", - "//parser/mysql", - "//parser/terror", - "//planner/core", - "//sessionctx", - "//sessionctx/stmtctx", - "//sessionctx/variable", - "//table", - "//table/tables", - "//tablecodec", - "//types", - "//util", - "//util/chunk", - "//util/dbterror", - "//util/dbterror/exeerrors", - "//util/etcd", - "//util/filter", - "//util/intest", - "//util/logutil", - "//util/sqlexec", - "//util/stringutil", - "//util/syncutil", - "@com_github_docker_go_units//:go-units", - "@com_github_pingcap_errors//:errors", - "@com_github_pingcap_failpoint//:failpoint", - "@com_github_pingcap_log//:log", - "@com_github_tikv_client_go_v2//config", - "@com_github_tikv_client_go_v2//tikv", - "@com_github_tikv_client_go_v2//util", - "@org_golang_x_exp//slices", - "@org_golang_x_sync//errgroup", - "@org_uber_go_multierr//:multierr", - "@org_uber_go_zap//:zap", - ], -) - -go_test( - name = "importer_test", - timeout = "short", - srcs = [ - "import_test.go", - "job_test.go", - "table_import_test.go", - ], - embed = [":importer"], - flaky = True, - race = "on", - shard_count = 11, - deps = [ - "//br/pkg/errors", - "//br/pkg/lightning/config", - "//config", - "//expression", - "//parser", - "//parser/ast", - "//planner/core", - "//testkit", - "//util/dbterror/exeerrors", - "//util/logutil", - "//util/mock", - "//util/sqlexec", - "@com_github_pingcap_errors//:errors", - "@com_github_pingcap_failpoint//:failpoint", - "@com_github_stretchr_testify//require", - "@org_uber_go_zap//:zap", - ], -) diff --git a/executor/importer/table_import.go b/executor/importer/table_import.go deleted file mode 100644 index 086e648328c3c..0000000000000 --- a/executor/importer/table_import.go +++ /dev/null @@ -1,565 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "context" - "fmt" - "io" - "net" - "os" - "path/filepath" - "runtime" - "strconv" - "sync" - "time" - - "github.com/docker/go-units" - "github.com/pingcap/errors" - "github.com/pingcap/tidb/br/pkg/lightning/backend" - "github.com/pingcap/tidb/br/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/br/pkg/lightning/backend/local" - "github.com/pingcap/tidb/br/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/br/pkg/lightning/common" - "github.com/pingcap/tidb/br/pkg/lightning/config" - "github.com/pingcap/tidb/br/pkg/lightning/log" - "github.com/pingcap/tidb/br/pkg/lightning/mydump" - "github.com/pingcap/tidb/br/pkg/storage" - tidb "github.com/pingcap/tidb/config" - tidbkv "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/table/tables" - "github.com/pingcap/tidb/util" - "github.com/pingcap/tidb/util/syncutil" - "go.uber.org/multierr" - "go.uber.org/zap" -) - -// NewTiKVModeSwitcher make it a var, so we can mock it in tests. -var NewTiKVModeSwitcher = local.NewTiKVModeSwitcher - -var ( - // CheckDiskQuotaInterval is the default time interval to check disk quota. - // TODO: make it dynamically adjusting according to the speed of import and the disk size. - CheckDiskQuotaInterval = time.Minute -) - -// prepareSortDir creates a new directory for import, remove previous sort directory if exists. -func prepareSortDir(e *LoadDataController, taskID int64, tidbCfg *tidb.Config) (string, error) { - sortPathSuffix := "import-" + strconv.Itoa(int(tidbCfg.Port)) - importDir := filepath.Join(tidbCfg.TempDir, sortPathSuffix) - sortDir := filepath.Join(importDir, strconv.FormatInt(taskID, 10)) - - if info, err := os.Stat(importDir); err != nil || !info.IsDir() { - if err != nil && !os.IsNotExist(err) { - e.logger.Error("stat import dir failed", zap.String("import_dir", importDir), zap.Error(err)) - return "", errors.Trace(err) - } - if info != nil && !info.IsDir() { - e.logger.Warn("import dir is not a dir, remove it", zap.String("import_dir", importDir)) - if err := os.RemoveAll(importDir); err != nil { - return "", errors.Trace(err) - } - } - e.logger.Info("import dir not exists, create it", zap.String("import_dir", importDir)) - if err := os.MkdirAll(importDir, 0o700); err != nil { - e.logger.Error("failed to make dir", zap.String("import_dir", importDir), zap.Error(err)) - return "", errors.Trace(err) - } - } - - // todo: remove this after we support checkpoint - if _, err := os.Stat(sortDir); err != nil { - if !os.IsNotExist(err) { - e.logger.Error("stat sort dir failed", zap.String("sort_dir", sortDir), zap.Error(err)) - return "", errors.Trace(err) - } - } else { - e.logger.Warn("sort dir already exists, remove it", zap.String("sort_dir", sortDir)) - if err := os.RemoveAll(sortDir); err != nil { - return "", errors.Trace(err) - } - } - return sortDir, nil -} - -// GetTiKVModeSwitcher creates a new TiKV mode switcher. -func GetTiKVModeSwitcher(logger *zap.Logger) (local.TiKVModeSwitcher, error) { - tidbCfg := tidb.GetGlobalConfig() - hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) - tls, err := common.NewTLS( - tidbCfg.Security.ClusterSSLCA, - tidbCfg.Security.ClusterSSLCert, - tidbCfg.Security.ClusterSSLKey, - hostPort, - nil, nil, nil, - ) - if err != nil { - return nil, err - } - return NewTiKVModeSwitcher(tls, tidbCfg.Path, logger), nil -} - -func getCachedKVStoreFrom(pdAddr string, tls *common.TLS) (tidbkv.Storage, error) { - // Disable GC because TiDB enables GC already. - keySpaceName := tidb.GetGlobalKeyspaceName() - // the kv store we get is a cached store, so we can't close it. - kvStore, err := GetKVStore(fmt.Sprintf("tikv://%s?disableGC=true&keyspaceName=%s", pdAddr, keySpaceName), tls.ToTiKVSecurityConfig()) - if err != nil { - return nil, errors.Trace(err) - } - return kvStore, nil -} - -// NewTableImporter creates a new table importer. -func NewTableImporter(param *JobImportParam, e *LoadDataController, taskID int64) (ti *TableImporter, err error) { - idAlloc := kv.NewPanickingAllocators(0) - tbl, err := tables.TableFromMeta(idAlloc, e.Table.Meta()) - if err != nil { - return nil, errors.Annotatef(err, "failed to tables.TableFromMeta %s", e.Table.Meta().Name) - } - - tidbCfg := tidb.GetGlobalConfig() - // todo: we only need to prepare this once on each node(we might call it 3 times in distribution framework) - dir, err := prepareSortDir(e, taskID, tidbCfg) - if err != nil { - return nil, err - } - - hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) - tls, err := common.NewTLS( - tidbCfg.Security.ClusterSSLCA, - tidbCfg.Security.ClusterSSLCert, - tidbCfg.Security.ClusterSSLKey, - hostPort, - nil, nil, nil, - ) - if err != nil { - return nil, err - } - - // no need to close kvStore, since it's a cached store. - kvStore, err := getCachedKVStoreFrom(tidbCfg.Path, tls) - if err != nil { - return nil, errors.Trace(err) - } - - backendConfig := local.BackendConfig{ - PDAddr: tidbCfg.Path, - LocalStoreDir: dir, - MaxConnPerStore: config.DefaultRangeConcurrency, - ConnCompressType: config.CompressionNone, - WorkerConcurrency: config.DefaultRangeConcurrency * 2, - KVWriteBatchSize: config.KVWriteBatchSize, - RegionSplitBatchSize: config.DefaultRegionSplitBatchSize, - RegionSplitConcurrency: runtime.GOMAXPROCS(0), - // enable after we support checkpoint - CheckpointEnabled: false, - MemTableSize: config.DefaultEngineMemCacheSize, - LocalWriterMemCacheSize: int64(config.DefaultLocalWriterMemCacheSize), - ShouldCheckTiKV: true, - DupeDetectEnabled: false, - DuplicateDetectOpt: local.DupDetectOpt{ReportErrOnDup: false}, - StoreWriteBWLimit: int(e.MaxWriteSpeed), - MaxOpenFiles: int(util.GenRLimit("table_import")), - KeyspaceName: tidb.GetGlobalKeyspaceName(), - PausePDSchedulerScope: config.PausePDSchedulerScopeTable, - } - - // todo: use a real region size getter - regionSizeGetter := &local.TableRegionSizeGetterImpl{} - localBackend, err := local.NewBackend(param.GroupCtx, tls, backendConfig, regionSizeGetter) - if err != nil { - return nil, err - } - - return &TableImporter{ - JobImportParam: param, - LoadDataController: e, - backend: localBackend, - tableInfo: &checkpoints.TidbTableInfo{ - ID: e.Table.Meta().ID, - Name: e.Table.Meta().Name.O, - Core: e.Table.Meta(), - }, - encTable: tbl, - dbID: e.DBID, - store: e.dataStore, - kvStore: kvStore, - logger: e.logger, - // this is the value we use for 50TiB data parallel import. - // this might not be the optimal value. - // todo: use different default for single-node import and distributed import. - regionSplitSize: 2 * int64(config.SplitRegionSize), - regionSplitKeys: 2 * int64(config.SplitRegionKeys), - diskQuota: adjustDiskQuota(int64(e.DiskQuota), dir, e.logger), - diskQuotaLock: new(syncutil.RWMutex), - }, nil -} - -// TableImporter is a table importer. -type TableImporter struct { - *JobImportParam - *LoadDataController - backend *local.Backend - tableInfo *checkpoints.TidbTableInfo - // this table has a separate id allocator used to record the max row id allocated. - encTable table.Table - dbID int64 - - store storage.ExternalStorage - // the kv store we get is a cached store, so we can't close it. - kvStore tidbkv.Storage - logger *zap.Logger - regionSplitSize int64 - regionSplitKeys int64 - // the smallest auto-generated ID in current import. - // if there's no auto-generated id column or the column value is not auto-generated, it will be 0. - lastInsertID uint64 - diskQuota int64 - diskQuotaLock *syncutil.RWMutex -} - -func (ti *TableImporter) getParser(ctx context.Context, chunk *checkpoints.ChunkCheckpoint) (mydump.Parser, error) { - info := LoadDataReaderInfo{ - Opener: func(ctx context.Context) (io.ReadSeekCloser, error) { - reader, err := mydump.OpenReader(ctx, &chunk.FileMeta, ti.dataStore) - if err != nil { - return nil, errors.Trace(err) - } - return reader, nil - }, - Remote: &chunk.FileMeta, - } - parser, err := ti.LoadDataController.GetParser(ctx, info) - if err != nil { - return nil, err - } - // todo: when support checkpoint, we should set pos too. - // WARN: parser.SetPos can only be set before we read anything now. should fix it before set pos. - parser.SetRowID(chunk.Chunk.PrevRowIDMax) - return parser, nil -} - -func (ti *TableImporter) getKVEncoder(chunk *checkpoints.ChunkCheckpoint) (kvEncoder, error) { - cfg := &encode.EncodingConfig{ - SessionOptions: encode.SessionOptions{ - SQLMode: ti.SQLMode, - Timestamp: chunk.Timestamp, - SysVars: ti.ImportantSysVars, - AutoRandomSeed: chunk.Chunk.PrevRowIDMax, - }, - Path: chunk.FileMeta.Path, - Table: ti.encTable, - Logger: log.Logger{Logger: ti.logger.With(zap.String("path", chunk.FileMeta.Path))}, - } - return newTableKVEncoder(cfg, ti) -} - -// PopulateChunks populates chunks from table regions. -// in dist framework, this should be done in the tidb node which is responsible for splitting job into subtasks -// then table-importer handles data belongs to the subtask. -func (e *LoadDataController) PopulateChunks(ctx context.Context) (ecp map[int32]*checkpoints.EngineCheckpoint, err error) { - task := log.BeginTask(e.logger, "populate chunks") - defer func() { - task.End(zap.ErrorLevel, err) - }() - - tableMeta := &mydump.MDTableMeta{ - DB: e.DBName, - Name: e.Table.Meta().Name.O, - DataFiles: e.toMyDumpFiles(), - } - dataDivideCfg := &mydump.DataDivideConfig{ - ColumnCnt: len(e.Table.Meta().Columns), - EngineDataSize: int64(config.DefaultBatchSize), - MaxChunkSize: int64(config.MaxRegionSize), - Concurrency: int(e.ThreadCnt), - EngineConcurrency: config.DefaultTableConcurrency, - IOWorkers: nil, - Store: e.dataStore, - TableMeta: tableMeta, - } - tableRegions, err2 := mydump.MakeTableRegions(ctx, dataDivideCfg) - - if err2 != nil { - e.logger.Error("populate chunks failed", zap.Error(err2)) - return nil, err2 - } - - var maxRowID int64 - timestamp := time.Now().Unix() - tableCp := &checkpoints.TableCheckpoint{ - Engines: map[int32]*checkpoints.EngineCheckpoint{}, - } - for _, region := range tableRegions { - engine, found := tableCp.Engines[region.EngineID] - if !found { - engine = &checkpoints.EngineCheckpoint{ - Status: checkpoints.CheckpointStatusLoaded, - } - tableCp.Engines[region.EngineID] = engine - } - ccp := &checkpoints.ChunkCheckpoint{ - Key: checkpoints.ChunkCheckpointKey{ - Path: region.FileMeta.Path, - Offset: region.Chunk.Offset, - }, - FileMeta: region.FileMeta, - ColumnPermutation: nil, - Chunk: region.Chunk, - Timestamp: timestamp, - } - engine.Chunks = append(engine.Chunks, ccp) - if region.Chunk.RowIDMax > maxRowID { - maxRowID = region.Chunk.RowIDMax - } - } - - if common.TableHasAutoID(e.Table.Meta()) { - tidbCfg := tidb.GetGlobalConfig() - hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) - tls, err4 := common.NewTLS( - tidbCfg.Security.ClusterSSLCA, - tidbCfg.Security.ClusterSSLCert, - tidbCfg.Security.ClusterSSLKey, - hostPort, - nil, nil, nil, - ) - if err4 != nil { - return nil, err4 - } - - // no need to close kvStore, since it's a cached store. - kvStore, err4 := getCachedKVStoreFrom(tidbCfg.Path, tls) - if err4 != nil { - return nil, errors.Trace(err4) - } - if err3 := common.RebaseGlobalAutoID(ctx, 0, kvStore, e.DBID, e.Table.Meta()); err3 != nil { - return nil, errors.Trace(err3) - } - newMinRowID, _, err3 := common.AllocGlobalAutoID(ctx, maxRowID, kvStore, e.DBID, e.Table.Meta()) - if err3 != nil { - return nil, errors.Trace(err3) - } - e.rebaseChunkRowID(newMinRowID, tableCp.Engines) - } - - // Add index engine checkpoint - tableCp.Engines[common.IndexEngineID] = &checkpoints.EngineCheckpoint{Status: checkpoints.CheckpointStatusLoaded} - return tableCp.Engines, nil -} - -func (*LoadDataController) rebaseChunkRowID(rowIDBase int64, engines map[int32]*checkpoints.EngineCheckpoint) { - if rowIDBase == 0 { - return - } - for _, engine := range engines { - for _, chunk := range engine.Chunks { - chunk.Chunk.PrevRowIDMax += rowIDBase - chunk.Chunk.RowIDMax += rowIDBase - } - } -} - -// a simplified version of EstimateCompactionThreshold -func (ti *TableImporter) getTotalRawFileSize(indexCnt int64) int64 { - var totalSize int64 - for _, file := range ti.dataFiles { - size := file.RealSize - if file.Type == mydump.SourceTypeParquet { - // parquet file is compressed, thus estimates with a factor of 2 - size *= 2 - } - totalSize += size - } - return totalSize * indexCnt -} - -// OpenIndexEngine opens an index engine. -func (ti *TableImporter) OpenIndexEngine(ctx context.Context, engineID int32) (*backend.OpenedEngine, error) { - idxEngineCfg := &backend.EngineConfig{ - TableInfo: ti.tableInfo, - } - idxCnt := len(ti.tableInfo.Core.Indices) - if !common.TableHasAutoRowID(ti.tableInfo.Core) { - idxCnt-- - } - // todo: getTotalRawFileSize returns size of all data files, but in distributed framework, - // we create one index engine for each engine, should reflect this in the future. - threshold := local.EstimateCompactionThreshold2(ti.getTotalRawFileSize(int64(idxCnt))) - idxEngineCfg.Local = backend.LocalEngineConfig{ - Compact: threshold > 0, - CompactConcurrency: 4, - CompactThreshold: threshold, - } - fullTableName := ti.fullTableName() - // todo: cleanup all engine data on any error since we don't support checkpoint for now - // some return path, didn't make sure all data engine and index engine are cleaned up. - // maybe we can add this in upper level to clean the whole local-sort directory - mgr := backend.MakeEngineManager(ti.backend) - return mgr.OpenEngine(ctx, idxEngineCfg, fullTableName, engineID) -} - -// OpenDataEngine opens a data engine. -func (ti *TableImporter) OpenDataEngine(ctx context.Context, engineID int32) (*backend.OpenedEngine, error) { - dataEngineCfg := &backend.EngineConfig{ - TableInfo: ti.tableInfo, - } - // todo: support checking IsRowOrdered later. - //if ti.tableMeta.IsRowOrdered { - // dataEngineCfg.Local.Compact = true - // dataEngineCfg.Local.CompactConcurrency = 4 - // dataEngineCfg.Local.CompactThreshold = local.CompactionUpperThreshold - //} - mgr := backend.MakeEngineManager(ti.backend) - return mgr.OpenEngine(ctx, dataEngineCfg, ti.fullTableName(), engineID) -} - -// ImportAndCleanup imports the engine and cleanup the engine data. -func (ti *TableImporter) ImportAndCleanup(ctx context.Context, closedEngine *backend.ClosedEngine) (int64, error) { - var kvCount int64 - importErr := closedEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys) - if closedEngine.GetID() != common.IndexEngineID { - // todo: change to a finer-grain progress later. - // each row is encoded into 1 data key - kvCount = ti.backend.GetImportedKVCount(closedEngine.GetUUID()) - } - // todo: if we need support checkpoint, engine should not be cleanup if import failed. - cleanupErr := closedEngine.Cleanup(ctx) - return kvCount, multierr.Combine(importErr, cleanupErr) -} - -// FullTableName return FQDN of the table. -func (ti *TableImporter) fullTableName() string { - return common.UniqueTable(ti.DBName, ti.Table.Meta().Name.O) -} - -// Close implements the io.Closer interface. -func (ti *TableImporter) Close() error { - ti.backend.Close() - return nil -} - -func (ti *TableImporter) setLastInsertID(id uint64) { - // todo: if we run concurrently, we should use atomic operation here. - if id == 0 { - return - } - if ti.lastInsertID == 0 || id < ti.lastInsertID { - ti.lastInsertID = id - } -} - -// CheckDiskQuota checks disk quota. -func (ti *TableImporter) CheckDiskQuota(ctx context.Context) { - var locker sync.Locker - lockDiskQuota := func() { - if locker == nil { - ti.diskQuotaLock.Lock() - locker = ti.diskQuotaLock - } - } - unlockDiskQuota := func() { - if locker != nil { - locker.Unlock() - locker = nil - } - } - - defer unlockDiskQuota() - - for { - select { - case <-ctx.Done(): - return - case <-time.After(CheckDiskQuotaInterval): - } - - largeEngines, inProgressLargeEngines, totalDiskSize, totalMemSize := local.CheckDiskQuota(ti.backend, ti.diskQuota) - if len(largeEngines) == 0 && inProgressLargeEngines == 0 { - unlockDiskQuota() - continue - } - - ti.logger.Warn("disk quota exceeded", - zap.Int64("diskSize", totalDiskSize), - zap.Int64("memSize", totalMemSize), - zap.Int64("quota", ti.diskQuota), - zap.Int("largeEnginesCount", len(largeEngines)), - zap.Int("inProgressLargeEnginesCount", inProgressLargeEngines)) - - lockDiskQuota() - - if len(largeEngines) == 0 { - ti.logger.Warn("all large engines are already importing, keep blocking all writes") - continue - } - - if err := ti.backend.FlushAllEngines(ctx); err != nil { - ti.logger.Error("flush engine for disk quota failed, check again later", log.ShortError(err)) - unlockDiskQuota() - continue - } - - // at this point, all engines are synchronized on disk. - // we then import every large engines one by one and complete. - // if any engine failed to import, we just try again next time, since the data are still intact. - var importErr error - for _, engine := range largeEngines { - // Use a larger split region size to avoid split the same region by many times. - if err := ti.backend.UnsafeImportAndReset( - ctx, - engine, - int64(config.SplitRegionSize)*int64(config.MaxSplitRegionSizeRatio), - int64(config.SplitRegionKeys)*int64(config.MaxSplitRegionSizeRatio), - ); err != nil { - importErr = multierr.Append(importErr, err) - } - } - if importErr != nil { - // discuss: should we return the error and cancel the import? - ti.logger.Error("import large engines failed, check again later", log.ShortError(importErr)) - } - unlockDiskQuota() - } -} - -func adjustDiskQuota(diskQuota int64, sortDir string, logger *zap.Logger) int64 { - sz, err := common.GetStorageSize(sortDir) - if err != nil { - logger.Warn("failed to get storage size", zap.Error(err)) - if diskQuota != 0 { - return diskQuota - } - logger.Info("use default quota instead", zap.Int64("quota", int64(DefaultDiskQuota))) - return int64(DefaultDiskQuota) - } - - maxDiskQuota := int64(float64(sz.Capacity) * 0.8) - switch { - case diskQuota == 0: - logger.Info("use 0.8 of the storage size as default disk quota", - zap.String("quota", units.HumanSize(float64(maxDiskQuota)))) - return maxDiskQuota - case diskQuota > maxDiskQuota: - logger.Warn("disk quota is larger than 0.8 of the storage size, use 0.8 of the storage size instead", - zap.String("quota", units.HumanSize(float64(maxDiskQuota)))) - return maxDiskQuota - default: - return diskQuota - } -} diff --git a/tests/realtikvtest/importintotest/job_test.go b/tests/realtikvtest/importintotest/job_test.go deleted file mode 100644 index 82397c946fa8e..0000000000000 --- a/tests/realtikvtest/importintotest/job_test.go +++ /dev/null @@ -1,635 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importintotest - -import ( - "context" - "fmt" - "net/url" - "strconv" - "strings" - "sync" - "time" - - "github.com/docker/go-units" - "github.com/fsouza/fake-gcs-server/fakestorage" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/lightning/config" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/disttask/framework/proto" - "github.com/pingcap/tidb/disttask/framework/scheduler" - "github.com/pingcap/tidb/disttask/framework/storage" - "github.com/pingcap/tidb/disttask/importinto" - "github.com/pingcap/tidb/executor" - "github.com/pingcap/tidb/executor/importer" - "github.com/pingcap/tidb/parser/auth" - "github.com/pingcap/tidb/planner/core" - "github.com/pingcap/tidb/session" - "github.com/pingcap/tidb/testkit" - "github.com/pingcap/tidb/util/dbterror/exeerrors" -) - -func (s *mockGCSSuite) compareJobInfoWithoutTime(jobInfo *importer.JobInfo, row []interface{}) { - s.Equal(strconv.Itoa(int(jobInfo.ID)), row[0]) - - urlExpected, err := url.Parse(jobInfo.Parameters.FileLocation) - s.NoError(err) - urlGot, err := url.Parse(fmt.Sprintf("%v", row[1])) - s.NoError(err) - // order of query parameters might change - s.Equal(urlExpected.Query(), urlGot.Query()) - urlExpected.RawQuery, urlGot.RawQuery = "", "" - s.Equal(urlExpected.String(), urlGot.String()) - - s.Equal(utils.EncloseDBAndTable(jobInfo.TableSchema, jobInfo.TableName), row[2]) - s.Equal(strconv.Itoa(int(jobInfo.TableID)), row[3]) - s.Equal(jobInfo.Step, row[4]) - s.Equal(jobInfo.Status, row[5]) - s.Equal(units.HumanSize(float64(jobInfo.SourceFileSize)), row[6]) - if jobInfo.Summary == nil { - s.Equal("", row[7].(string)) - } else { - s.Equal(strconv.Itoa(int(jobInfo.Summary.ImportedRows)), row[7]) - } - s.Regexp(jobInfo.ErrorMessage, row[8]) - s.Equal(jobInfo.CreatedBy, row[12]) -} - -func (s *mockGCSSuite) TestShowJob() { - s.tk.MustExec("delete from mysql.tidb_import_jobs") - s.prepareAndUseDB("test_show_job") - s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") - s.tk.MustExec("CREATE TABLE t2 (i INT PRIMARY KEY);") - s.tk.MustExec("CREATE TABLE t3 (i INT PRIMARY KEY);") - s.server.CreateObject(fakestorage.Object{ - ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-job", Name: "t.csv"}, - Content: []byte("1\n2"), - }) - s.T().Cleanup(func() { - _ = s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil) - }) - // create 2 user which don't have system table privileges - s.tk.MustExec(`DROP USER IF EXISTS 'test_show_job1'@'localhost';`) - s.tk.MustExec(`CREATE USER 'test_show_job1'@'localhost';`) - s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_show_job.* to 'test_show_job1'@'localhost'`) - s.tk.MustExec(`DROP USER IF EXISTS 'test_show_job2'@'localhost';`) - s.tk.MustExec(`CREATE USER 'test_show_job2'@'localhost';`) - s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_show_job.* to 'test_show_job2'@'localhost'`) - do, err := session.GetDomain(s.store) - s.NoError(err) - tableID1 := do.MustGetTableID(s.T(), "test_show_job", "t1") - tableID2 := do.MustGetTableID(s.T(), "test_show_job", "t2") - tableID3 := do.MustGetTableID(s.T(), "test_show_job", "t3") - - // show non-exists job - err = s.tk.QueryToErr("show import job 9999999999") - s.ErrorIs(err, exeerrors.ErrLoadDataJobNotFound) - - // test show job by id using test_show_job1 - s.enableFailpoint("github.com/pingcap/tidb/executor/importer/setLastImportJobID", `return(true)`) - s.enableFailpoint("github.com/pingcap/tidb/disttask/framework/storage/testSetLastTaskID", "return(true)") - s.enableFailpoint("github.com/pingcap/tidb/parser/ast/forceRedactURL", "return(true)") - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_show_job1", Hostname: "localhost"}, nil, nil, nil)) - result1 := s.tk.MustQuery(fmt.Sprintf(`import into t1 FROM 'gs://test-show-job/t.csv?access-key=aaaaaa&secret-access-key=bbbbbb&endpoint=%s'`, - gcsEndpoint)).Rows() - s.Len(result1, 1) - s.tk.MustQuery("select * from t1").Check(testkit.Rows("1", "2")) - rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() - s.Len(rows, 1) - s.Equal(result1, rows) - jobInfo := &importer.JobInfo{ - ID: importer.TestLastImportJobID.Load(), - TableSchema: "test_show_job", - TableName: "t1", - TableID: tableID1, - CreatedBy: "test_show_job1@localhost", - Parameters: importer.ImportParameters{ - FileLocation: fmt.Sprintf(`gs://test-show-job/t.csv?access-key=xxxxxx&secret-access-key=xxxxxx&endpoint=%s`, gcsEndpoint), - Format: importer.DataFormatCSV, - }, - SourceFileSize: 3, - Status: "finished", - Step: "", - Summary: &importer.JobSummary{ - ImportedRows: 2, - }, - ErrorMessage: "", - } - s.compareJobInfoWithoutTime(jobInfo, rows[0]) - - // test show job by id using test_show_job2 - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_show_job2", Hostname: "localhost"}, nil, nil, nil)) - result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test-show-job/t.csv?endpoint=%s'`, gcsEndpoint)).Rows() - s.tk.MustQuery("select * from t2").Check(testkit.Rows("1", "2")) - rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() - s.Len(rows, 1) - s.Equal(result2, rows) - jobInfo.ID = importer.TestLastImportJobID.Load() - jobInfo.TableName = "t2" - jobInfo.TableID = tableID2 - jobInfo.CreatedBy = "test_show_job2@localhost" - jobInfo.Parameters.FileLocation = fmt.Sprintf(`gs://test-show-job/t.csv?endpoint=%s`, gcsEndpoint) - s.compareJobInfoWithoutTime(jobInfo, rows[0]) - rows = s.tk.MustQuery("show import jobs").Rows() - s.Len(rows, 1) - s.Equal(result2, rows) - - // show import jobs with root - checkJobsMatch := func(rows [][]interface{}) { - s.GreaterOrEqual(len(rows), 2) // other cases may create import jobs - var matched int - for _, r := range rows { - if r[0] == result1[0][0] { - s.Equal(result1[0], r) - matched++ - } - if r[0] == result2[0][0] { - s.Equal(result2[0], r) - matched++ - } - } - s.Equal(2, matched) - } - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) - rows = s.tk.MustQuery("show import jobs").Rows() - checkJobsMatch(rows) - // show import job by id with root - rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() - s.Len(rows, 1) - s.Equal(result2, rows) - jobInfo.ID = importer.TestLastImportJobID.Load() - jobInfo.TableName = "t2" - jobInfo.TableID = tableID2 - jobInfo.CreatedBy = "test_show_job2@localhost" - jobInfo.Parameters.FileLocation = fmt.Sprintf(`gs://test-show-job/t.csv?endpoint=%s`, gcsEndpoint) - s.compareJobInfoWithoutTime(jobInfo, rows[0]) - - // grant SUPER to test_show_job2, now it can see all jobs - s.tk.MustExec(`GRANT SUPER on *.* to 'test_show_job2'@'localhost'`) - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_show_job2", Hostname: "localhost"}, nil, nil, nil)) - rows = s.tk.MustQuery("show import jobs").Rows() - checkJobsMatch(rows) - - // show running jobs with 2 subtasks - s.enableFailpoint("github.com/pingcap/tidb/disttask/framework/scheduler/syncAfterSubtaskFinish", `return(true)`) - s.server.CreateObject(fakestorage.Object{ - ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-job", Name: "t2.csv"}, - Content: []byte("3\n4"), - }) - backup4 := config.DefaultBatchSize - config.DefaultBatchSize = 1 - s.T().Cleanup(func() { - config.DefaultBatchSize = backup4 - }) - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - defer wg.Done() - // wait first subtask finish - <-scheduler.TestSyncChan - - jobInfo = &importer.JobInfo{ - ID: importer.TestLastImportJobID.Load(), - TableSchema: "test_show_job", - TableName: "t3", - TableID: tableID3, - CreatedBy: "test_show_job2@localhost", - Parameters: importer.ImportParameters{ - FileLocation: fmt.Sprintf(`gs://test-show-job/t*.csv?access-key=xxxxxx&secret-access-key=xxxxxx&endpoint=%s`, gcsEndpoint), - Format: importer.DataFormatCSV, - }, - SourceFileSize: 6, - Status: "running", - Step: "importing", - Summary: &importer.JobSummary{ - ImportedRows: 2, - }, - ErrorMessage: "", - } - tk2 := testkit.NewTestKit(s.T(), s.store) - rows = tk2.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() - s.Len(rows, 1) - s.compareJobInfoWithoutTime(jobInfo, rows[0]) - // show processlist, should be redacted too - procRows := tk2.MustQuery("show full processlist").Rows() - - var got bool - for _, r := range procRows { - user := r[1].(string) - sql := r[7].(string) - if user == "test_show_job2" && strings.Contains(sql, "IMPORT INTO") { - s.Contains(sql, "access-key=xxxxxx") - s.Contains(sql, "secret-access-key=xxxxxx") - s.NotContains(sql, "aaaaaa") - s.NotContains(sql, "bbbbbb") - got = true - } - } - s.True(got) - - // resume the scheduler - scheduler.TestSyncChan <- struct{}{} - // wait second subtask finish - <-scheduler.TestSyncChan - rows = tk2.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() - s.Len(rows, 1) - jobInfo.Summary.ImportedRows = 4 - s.compareJobInfoWithoutTime(jobInfo, rows[0]) - // resume the scheduler, need disable failpoint first, otherwise the post-process subtask will be blocked - s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/syncAfterSubtaskFinish")) - scheduler.TestSyncChan <- struct{}{} - }() - s.tk.MustQuery(fmt.Sprintf(`import into t3 FROM 'gs://test-show-job/t*.csv?access-key=aaaaaa&secret-access-key=bbbbbb&endpoint=%s' with thread=1`, gcsEndpoint)) - wg.Wait() - s.tk.MustQuery("select * from t3").Sort().Check(testkit.Rows("1", "2", "3", "4")) -} - -func (s *mockGCSSuite) TestShowDetachedJob() { - s.prepareAndUseDB("show_detached_job") - s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") - s.tk.MustExec("CREATE TABLE t2 (i INT PRIMARY KEY);") - s.tk.MustExec("CREATE TABLE t3 (i INT PRIMARY KEY);") - s.server.CreateObject(fakestorage.Object{ - ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-detached-job", Name: "t.csv"}, - Content: []byte("1\n2"), - }) - s.server.CreateObject(fakestorage.Object{ - ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-show-detached-job", Name: "t2.csv"}, - Content: []byte("1\n1"), - }) - do, err := session.GetDomain(s.store) - s.NoError(err) - tableID1 := do.MustGetTableID(s.T(), "show_detached_job", "t1") - tableID2 := do.MustGetTableID(s.T(), "show_detached_job", "t2") - tableID3 := do.MustGetTableID(s.T(), "show_detached_job", "t3") - - jobInfo := &importer.JobInfo{ - TableSchema: "show_detached_job", - TableName: "t1", - TableID: tableID1, - CreatedBy: "root@%", - Parameters: importer.ImportParameters{ - FileLocation: fmt.Sprintf(`gs://test-show-detached-job/t.csv?endpoint=%s`, gcsEndpoint), - Format: importer.DataFormatCSV, - }, - SourceFileSize: 3, - Status: "pending", - Step: "", - } - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) - result1 := s.tk.MustQuery(fmt.Sprintf(`import into t1 FROM 'gs://test-show-detached-job/t.csv?endpoint=%s' with detached`, - gcsEndpoint)).Rows() - s.Len(result1, 1) - jobID1, err := strconv.Atoi(result1[0][0].(string)) - s.NoError(err) - jobInfo.ID = int64(jobID1) - s.compareJobInfoWithoutTime(jobInfo, result1[0]) - - s.Eventually(func() bool { - rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID1)).Rows() - return rows[0][5] == "finished" - }, 10*time.Second, 500*time.Millisecond) - rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID1)).Rows() - s.Len(rows, 1) - jobInfo.Status = "finished" - jobInfo.Summary = &importer.JobSummary{ - ImportedRows: 2, - } - s.compareJobInfoWithoutTime(jobInfo, rows[0]) - s.tk.MustQuery("select * from t1").Check(testkit.Rows("1", "2")) - - // job fail with checksum mismatch - result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test-show-detached-job/t2.csv?endpoint=%s' with detached`, - gcsEndpoint)).Rows() - s.Len(result2, 1) - jobID2, err := strconv.Atoi(result2[0][0].(string)) - s.NoError(err) - jobInfo = &importer.JobInfo{ - ID: int64(jobID2), - TableSchema: "show_detached_job", - TableName: "t2", - TableID: tableID2, - CreatedBy: "root@%", - Parameters: importer.ImportParameters{ - FileLocation: fmt.Sprintf(`gs://test-show-detached-job/t2.csv?endpoint=%s`, gcsEndpoint), - Format: importer.DataFormatCSV, - }, - SourceFileSize: 3, - Status: "pending", - Step: "", - } - s.compareJobInfoWithoutTime(jobInfo, result2[0]) - s.Eventually(func() bool { - rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() - return rows[0][5] == "failed" - }, 10*time.Second, 500*time.Millisecond) - rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() - s.Len(rows, 1) - jobInfo.Status = "failed" - jobInfo.Step = importer.JobStepValidating - jobInfo.ErrorMessage = `\[Lighting:Restore:ErrChecksumMismatch]checksum mismatched remote vs local.*` - s.compareJobInfoWithoutTime(jobInfo, rows[0]) - - // subtask fail with error - s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/errorWhenSortChunk", "return(true)") - result3 := s.tk.MustQuery(fmt.Sprintf(`import into t3 FROM 'gs://test-show-detached-job/t.csv?endpoint=%s' with detached`, - gcsEndpoint)).Rows() - s.Len(result3, 1) - jobID3, err := strconv.Atoi(result3[0][0].(string)) - s.NoError(err) - jobInfo = &importer.JobInfo{ - ID: int64(jobID3), - TableSchema: "show_detached_job", - TableName: "t3", - TableID: tableID3, - CreatedBy: "root@%", - Parameters: importer.ImportParameters{ - FileLocation: fmt.Sprintf(`gs://test-show-detached-job/t.csv?endpoint=%s`, gcsEndpoint), - Format: importer.DataFormatCSV, - }, - SourceFileSize: 3, - Status: "pending", - Step: "", - } - s.compareJobInfoWithoutTime(jobInfo, result3[0]) - s.Eventually(func() bool { - rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID3)).Rows() - return rows[0][5] == "failed" - }, 10*time.Second, 500*time.Millisecond) - rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID3)).Rows() - s.Len(rows, 1) - jobInfo.Status = "failed" - jobInfo.Step = importer.JobStepImporting - jobInfo.ErrorMessage = `occur an error when sort chunk.*` - s.compareJobInfoWithoutTime(jobInfo, rows[0]) -} - -func (s *mockGCSSuite) TestCancelJob() { - s.prepareAndUseDB("test_cancel_job") - s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") - s.tk.MustExec("CREATE TABLE t2 (i INT PRIMARY KEY);") - s.server.CreateObject(fakestorage.Object{ - ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test_cancel_job", Name: "t.csv"}, - Content: []byte("1\n2"), - }) - s.T().Cleanup(func() { - _ = s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil) - }) - s.tk.MustExec(`DROP USER IF EXISTS 'test_cancel_job1'@'localhost';`) - s.tk.MustExec(`CREATE USER 'test_cancel_job1'@'localhost';`) - s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_cancel_job.* to 'test_cancel_job1'@'localhost'`) - s.tk.MustExec(`DROP USER IF EXISTS 'test_cancel_job2'@'localhost';`) - s.tk.MustExec(`CREATE USER 'test_cancel_job2'@'localhost';`) - s.tk.MustExec(`GRANT SELECT,UPDATE,INSERT,DELETE,ALTER on test_cancel_job.* to 'test_cancel_job2'@'localhost'`) - do, err := session.GetDomain(s.store) - s.NoError(err) - tableID1 := do.MustGetTableID(s.T(), "test_cancel_job", "t1") - tableID2 := do.MustGetTableID(s.T(), "test_cancel_job", "t2") - - // cancel non-exists job - err = s.tk.ExecToErr("cancel import job 9999999999") - s.ErrorIs(err, exeerrors.ErrLoadDataJobNotFound) - - getTask := func(jobID int64) *proto.Task { - globalTaskManager, err := storage.GetTaskManager() - s.NoError(err) - taskKey := importinto.TaskKey(jobID) - globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey) - s.NoError(err) - return globalTask - } - - // cancel a running job created by self - s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/waitBeforeSortChunk", "return(true)") - s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncAfterJobStarted", "return(true)") - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_cancel_job1", Hostname: "localhost"}, nil, nil, nil)) - result1 := s.tk.MustQuery(fmt.Sprintf(`import into t1 FROM 'gs://test_cancel_job/t.csv?endpoint=%s' with detached`, - gcsEndpoint)).Rows() - s.Len(result1, 1) - jobID1, err := strconv.Atoi(result1[0][0].(string)) - s.NoError(err) - // wait job started - <-importinto.TestSyncChan - // dist framework has bug, the cancelled status might be overridden by running status, - // so we wait it turn running before cancel, see https://github.com/pingcap/tidb/issues/44443 - time.Sleep(3 * time.Second) - s.tk.MustExec(fmt.Sprintf("cancel import job %d", jobID1)) - rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID1)).Rows() - s.Len(rows, 1) - jobInfo := &importer.JobInfo{ - ID: int64(jobID1), - TableSchema: "test_cancel_job", - TableName: "t1", - TableID: tableID1, - CreatedBy: "test_cancel_job1@localhost", - Parameters: importer.ImportParameters{ - FileLocation: fmt.Sprintf(`gs://test_cancel_job/t.csv?endpoint=%s`, gcsEndpoint), - Format: importer.DataFormatCSV, - }, - SourceFileSize: 3, - Status: "cancelled", - Step: importer.JobStepImporting, - ErrorMessage: "cancelled by user", - } - s.compareJobInfoWithoutTime(jobInfo, rows[0]) - s.Eventually(func() bool { - task := getTask(int64(jobID1)) - return task.State == proto.TaskStateReverted - }, 10*time.Second, 500*time.Millisecond) - - // cancel again, should fail - s.ErrorIs(s.tk.ExecToErr(fmt.Sprintf("cancel import job %d", jobID1)), exeerrors.ErrLoadDataInvalidOperation) - - // cancel a job created by test_cancel_job1 using test_cancel_job2, should fail - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_cancel_job2", Hostname: "localhost"}, nil, nil, nil)) - s.ErrorIs(s.tk.ExecToErr(fmt.Sprintf("cancel import job %d", jobID1)), core.ErrSpecificAccessDenied) - // cancel by root, should pass privilege check - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) - s.ErrorIs(s.tk.ExecToErr(fmt.Sprintf("cancel import job %d", jobID1)), exeerrors.ErrLoadDataInvalidOperation) - - // cancel job in post-process phase, using test_cancel_job2 - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "test_cancel_job2", Hostname: "localhost"}, nil, nil, nil)) - s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/waitBeforeSortChunk")) - s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/syncAfterJobStarted")) - s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncBeforePostProcess", "return(true)") - s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/waitCtxDone", "return(true)") - result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test_cancel_job/t.csv?endpoint=%s' with detached`, - gcsEndpoint)).Rows() - s.Len(result2, 1) - jobID2, err := strconv.Atoi(result2[0][0].(string)) - s.NoError(err) - // wait job reach post-process phase - <-importinto.TestSyncChan - s.tk.MustExec(fmt.Sprintf("cancel import job %d", jobID2)) - // resume the job - importinto.TestSyncChan <- struct{}{} - rows2 := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() - s.Len(rows2, 1) - jobInfo = &importer.JobInfo{ - ID: int64(jobID2), - TableSchema: "test_cancel_job", - TableName: "t2", - TableID: tableID2, - CreatedBy: "test_cancel_job2@localhost", - Parameters: importer.ImportParameters{ - FileLocation: fmt.Sprintf(`gs://test_cancel_job/t.csv?endpoint=%s`, gcsEndpoint), - Format: importer.DataFormatCSV, - }, - SourceFileSize: 3, - Status: "cancelled", - Step: importer.JobStepValidating, - ErrorMessage: "cancelled by user", - } - s.compareJobInfoWithoutTime(jobInfo, rows2[0]) - globalTaskManager, err := storage.GetTaskManager() - s.NoError(err) - taskKey := importinto.TaskKey(int64(jobID2)) - s.NoError(err) - s.Eventually(func() bool { - globalTask, err2 := globalTaskManager.GetGlobalTaskByKey(taskKey) - s.NoError(err2) - subtasks, err2 := globalTaskManager.GetSubtasksByStep(globalTask.ID, importinto.StepPostProcess) - s.NoError(err2) - s.Len(subtasks, 2) // framework will generate a subtask when canceling - var cancelled bool - for _, st := range subtasks { - if st.State == proto.TaskStateCanceled { - cancelled = true - break - } - } - return globalTask.State == proto.TaskStateReverted && cancelled - }, 5*time.Second, 1*time.Second) - - // todo: enable it when https://github.com/pingcap/tidb/issues/44443 fixed - //// cancel a pending job created by test_cancel_job2 using root - //s.NoError(failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/syncAfterJobStarted")) - //s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncBeforeJobStarted", "return(true)") - //result2 := s.tk.MustQuery(fmt.Sprintf(`import into t2 FROM 'gs://test_cancel_job/t.csv?endpoint=%s' with detached`, - // gcsEndpoint)).Rows() - //s.Len(result2, 1) - //jobID2, err := strconv.Atoi(result2[0][0].(string)) - //s.NoError(err) - //// wait job reached to the point before job started - //<-loaddata.TestSyncChan - //s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) - //s.tk.MustExec(fmt.Sprintf("cancel import job %d", jobID2)) - //// resume the job - //loaddata.TestSyncChan <- struct{}{} - //rows = s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID2)).Rows() - //s.Len(rows, 1) - //jobInfo = &importer.JobInfo{ - // ID: int64(jobID2), - // TableSchema: "test_cancel_job", - // TableName: "t2", - // TableID: tableID2, - // CreatedBy: "test_cancel_job2@localhost", - // Parameters: importer.ImportParameters{ - // FileLocation: fmt.Sprintf(`gs://test_cancel_job/t.csv?endpoint=%s`, gcsEndpoint), - // Format: importer.DataFormatCSV, - // }, - // SourceFileSize: 3, - // Status: "cancelled", - // Step: "", - // ErrorMessage: "cancelled by user", - //} - //s.compareJobInfoWithoutTime(jobInfo, rows[0]) - //s.Eventually(func() bool { - // task := getTask(int64(jobID2)) - // return task.State == proto.TaskStateReverted - //}, 10*time.Second, 500*time.Millisecond) -} - -func (s *mockGCSSuite) TestJobFailWhenDispatchSubtask() { - s.prepareAndUseDB("fail_job_after_import") - s.tk.MustExec("CREATE TABLE t1 (i INT PRIMARY KEY);") - s.server.CreateObject(fakestorage.Object{ - ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "fail_job_after_import", Name: "t.csv"}, - Content: []byte("1\n2"), - }) - do, err := session.GetDomain(s.store) - s.NoError(err) - tableID1 := do.MustGetTableID(s.T(), "fail_job_after_import", "t1") - - jobInfo := &importer.JobInfo{ - TableSchema: "fail_job_after_import", - TableName: "t1", - TableID: tableID1, - CreatedBy: "root@%", - Parameters: importer.ImportParameters{ - FileLocation: fmt.Sprintf(`gs://fail_job_after_import/t.csv?endpoint=%s`, gcsEndpoint), - Format: importer.DataFormatCSV, - }, - SourceFileSize: 3, - Status: "failed", - Step: importer.JobStepValidating, - ErrorMessage: "injected error after StepImport", - } - s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/failWhenDispatchPostProcessSubtask", "return(true)") - s.enableFailpoint("github.com/pingcap/tidb/executor/importer/setLastImportJobID", `return(true)`) - s.NoError(s.tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil)) - err = s.tk.QueryToErr(fmt.Sprintf(`import into t1 FROM 'gs://fail_job_after_import/t.csv?endpoint=%s'`, gcsEndpoint)) - s.ErrorContains(err, "injected error after StepImport") - result1 := s.tk.MustQuery(fmt.Sprintf("show import job %d", importer.TestLastImportJobID.Load())).Rows() - s.Len(result1, 1) - jobID1, err := strconv.Atoi(result1[0][0].(string)) - s.NoError(err) - jobInfo.ID = int64(jobID1) - s.compareJobInfoWithoutTime(jobInfo, result1[0]) -} - -func (s *mockGCSSuite) TestKillBeforeFinish() { - s.cleanupSysTables() - s.tk.MustExec("DROP DATABASE IF EXISTS kill_job;") - s.tk.MustExec("CREATE DATABASE kill_job;") - s.tk.MustExec(`CREATE TABLE kill_job.t (a INT, b INT, c int);`) - s.server.CreateObject(fakestorage.Object{ - ObjectAttrs: fakestorage.ObjectAttrs{BucketName: "test-load", Name: "t-1.tsv"}, - Content: []byte("1,11,111"), - }) - - s.enableFailpoint("github.com/pingcap/tidb/disttask/importinto/syncBeforeSortChunk", "return(true)") - s.enableFailpoint("github.com/pingcap/tidb/executor/cancellableCtx", "return(true)") - s.enableFailpoint("github.com/pingcap/tidb/executor/importer/setLastImportJobID", `return(true)`) - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - defer wg.Done() - sql := fmt.Sprintf(`IMPORT INTO kill_job.t FROM 'gs://test-load/t-*.tsv?endpoint=%s'`, gcsEndpoint) - err := s.tk.QueryToErr(sql) - s.ErrorIs(errors.Cause(err), context.Canceled) - }() - // wait for the task reach sort chunk - <-importinto.TestSyncChan - // cancel the job - executor.TestCancelFunc() - // continue the execution - importinto.TestSyncChan <- struct{}{} - wg.Wait() - jobID := importer.TestLastImportJobID.Load() - rows := s.tk.MustQuery(fmt.Sprintf("show import job %d", jobID)).Rows() - s.Len(rows, 1) - s.Equal("cancelled", rows[0][5]) - globalTaskManager, err := storage.GetTaskManager() - s.NoError(err) - taskKey := importinto.TaskKey(jobID) - s.NoError(err) - s.Eventually(func() bool { - globalTask, err2 := globalTaskManager.GetGlobalTaskByKey(taskKey) - s.NoError(err2) - return globalTask.State == proto.TaskStateReverted - }, 5*time.Second, 1*time.Second) -} From d21eb358844f308a7396f11104c9bdd63f927e8a Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Thu, 29 Jun 2023 16:59:25 +0800 Subject: [PATCH 3/4] update --- br/pkg/checksum/executor.go | 20 +++++-- br/pkg/lightning/common/util.go | 43 +++++++++++++++ br/pkg/lightning/config/config.go | 2 + br/pkg/lightning/restore/checksum.go | 54 +++++++++++++++++-- br/pkg/lightning/restore/checksum_test.go | 2 + .../lightning/restore/table_restore_test.go | 3 +- br/pkg/lightning/restore/tidb.go | 1 + br/pkg/lightning/restore/tidb_test.go | 2 + 8 files changed, 118 insertions(+), 9 deletions(-) diff --git a/br/pkg/checksum/executor.go b/br/pkg/checksum/executor.go index c30ae49fccdca..01a09b57d766b 100644 --- a/br/pkg/checksum/executor.go +++ b/br/pkg/checksum/executor.go @@ -26,7 +26,8 @@ type ExecutorBuilder struct { oldTable *metautil.Table - concurrency uint + concurrency uint + backoffWeight int } // NewExecutorBuilder returns a new executor builder. @@ -51,13 +52,19 @@ func (builder *ExecutorBuilder) SetConcurrency(conc uint) *ExecutorBuilder { return builder } +// SetBackoffWeight set the backoffWeight of the checksum executing. +func (builder *ExecutorBuilder) SetBackoffWeight(backoffWeight int) *ExecutorBuilder { + builder.backoffWeight = backoffWeight + return builder +} + // Build builds a checksum executor. func (builder *ExecutorBuilder) Build() (*Executor, error) { reqs, err := buildChecksumRequest(builder.table, builder.oldTable, builder.ts, builder.concurrency) if err != nil { return nil, errors.Trace(err) } - return &Executor{reqs: reqs}, nil + return &Executor{reqs: reqs, backoffWeight: builder.backoffWeight}, nil } func buildChecksumRequest( @@ -262,7 +269,8 @@ func updateChecksumResponse(resp, update *tipb.ChecksumResponse) { // Executor is a checksum executor. type Executor struct { - reqs []*kv.Request + reqs []*kv.Request + backoffWeight int } // Len returns the total number of checksum requests. @@ -308,7 +316,11 @@ func (exec *Executor) Execute( // // It is useful in TiDB, however, it's a place holder in BR. killed := uint32(0) - resp, err := sendChecksumRequest(ctx, client, req, kv.NewVariables(&killed)) + vars := kv.NewVariables(&killed) + if exec.backoffWeight > 0 { + vars.BackOffWeight = exec.backoffWeight + } + resp, err := sendChecksumRequest(ctx, client, req, vars) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index 621c59d820e23..e66b679f0b10a 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -36,6 +36,7 @@ import ( "github.com/pingcap/tidb/br/pkg/utils" tmysql "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table/tables" "go.uber.org/zap" ) @@ -428,3 +429,45 @@ func GetAutoRandomColumn(tblInfo *model.TableInfo) *model.ColumnInfo { } return nil } + +// GetBackoffWeightFromDB gets the backoff weight from database. +func GetBackoffWeightFromDB(ctx context.Context, db *sql.DB) (int, error) { + val, err := getSessionVariable(ctx, db, variable.TiDBBackOffWeight) + if err != nil { + return 0, err + } + return strconv.Atoi(val) +} + +// copy from dbutil to avoid import cycle +func getSessionVariable(ctx context.Context, db *sql.DB, variable string) (value string, err error) { + query := fmt.Sprintf("SHOW VARIABLES LIKE '%s'", variable) + rows, err := db.QueryContext(ctx, query) + + if err != nil { + return "", errors.Trace(err) + } + defer rows.Close() + + // Show an example. + /* + mysql> SHOW VARIABLES LIKE "binlog_format"; + +---------------+-------+ + | Variable_name | Value | + +---------------+-------+ + | binlog_format | ROW | + +---------------+-------+ + */ + + for rows.Next() { + if err = rows.Scan(&variable, &value); err != nil { + return "", errors.Trace(err) + } + } + + if err := rows.Err(); err != nil { + return "", errors.Trace(err) + } + + return value, nil +} diff --git a/br/pkg/lightning/config/config.go b/br/pkg/lightning/config/config.go index e1031c760f749..e372ad3e1bb18 100644 --- a/br/pkg/lightning/config/config.go +++ b/br/pkg/lightning/config/config.go @@ -440,6 +440,7 @@ type PostRestore struct { Level1Compact bool `toml:"level-1-compact" json:"level-1-compact"` PostProcessAtLast bool `toml:"post-process-at-last" json:"post-process-at-last"` Compact bool `toml:"compact" json:"compact"` + ChecksumViaSQL bool `toml:"checksum-via-sql" json:"checksum-via-sql"` } type CSVConfig struct { @@ -745,6 +746,7 @@ func NewConfig() *Config { Checksum: OpLevelRequired, Analyze: OpLevelOptional, PostProcessAtLast: true, + ChecksumViaSQL: true, }, } } diff --git a/br/pkg/lightning/restore/checksum.go b/br/pkg/lightning/restore/checksum.go index b30fe14e01fc1..d433eac3196ff 100644 --- a/br/pkg/lightning/restore/checksum.go +++ b/br/pkg/lightning/restore/checksum.go @@ -33,8 +33,10 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/metric" "github.com/pingcap/tidb/br/pkg/pdutil" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/mathutil" "github.com/pingcap/tipb/go-tipb" + tikvstore "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" "go.uber.org/atomic" @@ -51,6 +53,12 @@ var ( serviceSafePointTTL int64 = 10 * 60 // 10 min in seconds minDistSQLScanConcurrency = 4 + + // DefaultBackoffWeight is the default value of tidb_backoff_weight for checksum. + // when TiKV client encounters an error of "region not leader", it will keep retrying every 500 ms. + // If it still fails after 2 * 20 = 40 seconds, it will return "region unavailable". + // If we increase the BackOffWeight to 6, then the TiKV client will keep retrying for 120 seconds. + DefaultBackoffWeight = 3 * tikvstore.DefBackOffWeight ) // RemoteChecksum represents a checksum result got from tidb. @@ -80,14 +88,26 @@ func newChecksumManager(ctx context.Context, rc *Controller, store kv.Storage) ( // for v4.0.0 or upper, we can use the gc ttl api var manager ChecksumManager - if pdVersion.Major >= 4 { + if pdVersion.Major >= 4 && !rc.cfg.PostRestore.ChecksumViaSQL { tlsOpt := rc.tls.ToPDSecurityOption() pdCli, err := pd.NewClientWithContext(ctx, []string{pdAddr}, tlsOpt) if err != nil { return nil, errors.Trace(err) } - manager = newTiKVChecksumManager(store.GetClient(), pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency)) + db, err := rc.tidbGlue.GetDB() + if err != nil { + return nil, errors.Trace(err) + } + backoffWeight, err := common.GetBackoffWeightFromDB(ctx, db) + // only set backoff weight when it's smaller than default value + if err == nil && backoffWeight >= DefaultBackoffWeight { + log.L().Info("get tidb_backoff_weight", zap.Int("backoff_weight", backoffWeight)) + } else { + log.L().Info("set tidb_backoff_weight to default", zap.Int("backoff_weight", DefaultBackoffWeight)) + backoffWeight = DefaultBackoffWeight + } + manager = newTiKVChecksumManager(store.GetClient(), pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency), backoffWeight) } else { db, err := rc.tidbGlue.GetDB() if err != nil { @@ -125,6 +145,15 @@ func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoi task := log.FromContext(ctx).With(zap.String("table", tableName)).Begin(zap.InfoLevel, "remote checksum") + conn, err := e.db.Conn(ctx) + if err != nil { + return nil, errors.Trace(err) + } + defer func() { + if err := conn.Close(); err != nil { + task.Warn("close connection failed", zap.Error(err)) + } + }() // ADMIN CHECKSUM TABLE
,
example. // mysql> admin checksum table test.t; // +---------+------------+---------------------+-----------+-------------+ @@ -132,9 +161,23 @@ func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoi // +---------+------------+---------------------+-----------+-------------+ // | test | t | 8520875019404689597 | 7296873 | 357601387 | // +---------+------------+---------------------+-----------+-------------+ + backoffWeight, err := common.GetBackoffWeightFromDB(ctx, e.db) + if err == nil && backoffWeight < DefaultBackoffWeight { + task.Info("increase tidb_backoff_weight", zap.Int("original", backoffWeight), zap.Int("new", DefaultBackoffWeight)) + // increase backoff weight + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, DefaultBackoffWeight)); err != nil { + task.Warn("set tidb_backoff_weight failed", zap.Error(err)) + } else { + defer func() { + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, backoffWeight)); err != nil { + task.Warn("recover tidb_backoff_weight failed", zap.Error(err)) + } + }() + } + } cs := RemoteChecksum{} - err = common.SQLWithRetry{DB: e.db, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", + err = common.SQLWithRetry{DB: conn, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", "ADMIN CHECKSUM TABLE "+tableName, &cs.Schema, &cs.Table, &cs.Checksum, &cs.TotalKVs, &cs.TotalBytes, ) dur := task.End(zap.ErrorLevel, err) @@ -257,20 +300,23 @@ type tikvChecksumManager struct { client kv.Client manager gcTTLManager distSQLScanConcurrency uint + backoffWeight int } // newTiKVChecksumManager return a new tikv checksum manager -func newTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint) *tikvChecksumManager { +func newTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint, backoffWeight int) *tikvChecksumManager { return &tikvChecksumManager{ client: client, manager: newGCTTLManager(pdClient), distSQLScanConcurrency: distSQLScanConcurrency, + backoffWeight: backoffWeight, } } func (e *tikvChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpoints.TidbTableInfo, ts uint64) (*RemoteChecksum, error) { executor, err := checksum.NewExecutorBuilder(tableInfo.Core, ts). SetConcurrency(e.distSQLScanConcurrency). + SetBackoffWeight(e.backoffWeight). Build() if err != nil { return nil, errors.Trace(err) diff --git a/br/pkg/lightning/restore/checksum_test.go b/br/pkg/lightning/restore/checksum_test.go index 20acc23fe6be0..ba920ee58ed84 100644 --- a/br/pkg/lightning/restore/checksum_test.go +++ b/br/pkg/lightning/restore/checksum_test.go @@ -56,6 +56,7 @@ func TestDoChecksum(t *testing.T) { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() + mock.ExpectClose() ctx := MockDoChecksumCtx(db) checksum, err := DoChecksum(ctx, &TidbTableInfo{DB: "test", Name: "t"}) @@ -216,6 +217,7 @@ func TestDoChecksumWithErrorAndLongOriginalLifetime(t *testing.T) { WithArgs("300h"). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectClose() + mock.ExpectClose() ctx := MockDoChecksumCtx(db) _, err = DoChecksum(ctx, &TidbTableInfo{DB: "test", Name: "t"}) diff --git a/br/pkg/lightning/restore/table_restore_test.go b/br/pkg/lightning/restore/table_restore_test.go index 17fb97e346e36..ad09add849a51 100644 --- a/br/pkg/lightning/restore/table_restore_test.go +++ b/br/pkg/lightning/restore/table_restore_test.go @@ -753,6 +753,7 @@ func (s *tableRestoreSuite) TestCompareChecksumSuccess() { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() + mock.ExpectClose() ctx := MockDoChecksumCtx(db) remoteChecksum, err := DoChecksum(ctx, s.tr.tableInfo) @@ -783,7 +784,7 @@ func (s *tableRestoreSuite) TestCompareChecksumFailure() { WithArgs("10m"). WillReturnResult(sqlmock.NewResult(2, 1)) mock.ExpectClose() - + mock.ExpectClose() ctx := MockDoChecksumCtx(db) remoteChecksum, err := DoChecksum(ctx, s.tr.tableInfo) require.NoError(s.T(), err) diff --git a/br/pkg/lightning/restore/tidb.go b/br/pkg/lightning/restore/tidb.go index 0e114bc035a56..33e9d5622a598 100644 --- a/br/pkg/lightning/restore/tidb.go +++ b/br/pkg/lightning/restore/tidb.go @@ -50,6 +50,7 @@ var defaultImportantVariables = map[string]string{ "default_week_format": "0", "block_encryption_mode": "aes-128-ecb", "group_concat_max_len": "1024", + "tidb_backoff_weight": "6", } // defaultImportVariablesTiDB is used in ObtainImportantVariables to retrieve the system diff --git a/br/pkg/lightning/restore/tidb_test.go b/br/pkg/lightning/restore/tidb_test.go index 9b204b2da22b1..b3ece883864f6 100644 --- a/br/pkg/lightning/restore/tidb_test.go +++ b/br/pkg/lightning/restore/tidb_test.go @@ -460,6 +460,7 @@ func TestObtainRowFormatVersionSucceed(t *testing.T) { sysVars := ObtainImportantVariables(ctx, s.tiGlue.GetSQLExecutor(), true) require.Equal(t, map[string]string{ + "tidb_backoff_weight": "6", "tidb_row_format_version": "2", "max_allowed_packet": "1073741824", "div_precision_increment": "10", @@ -487,6 +488,7 @@ func TestObtainRowFormatVersionFailure(t *testing.T) { sysVars := ObtainImportantVariables(ctx, s.tiGlue.GetSQLExecutor(), true) require.Equal(t, map[string]string{ + "tidb_backoff_weight": "6", "tidb_row_format_version": "1", "max_allowed_packet": "67108864", "div_precision_increment": "4", From ccc9a326824b0d61033f80a384a4c61e260501e6 Mon Sep 17 00:00:00 2001 From: gmhdbjd Date: Fri, 30 Jun 2023 11:07:58 +0800 Subject: [PATCH 4/4] fix bazel --- br/pkg/lightning/common/BUILD.bazel | 1 + br/pkg/lightning/restore/BUILD.bazel | 2 ++ 2 files changed, 3 insertions(+) diff --git a/br/pkg/lightning/common/BUILD.bazel b/br/pkg/lightning/common/BUILD.bazel index 2b36e457cd857..bb6c988e1e11a 100644 --- a/br/pkg/lightning/common/BUILD.bazel +++ b/br/pkg/lightning/common/BUILD.bazel @@ -23,6 +23,7 @@ go_library( "//br/pkg/utils", "//errno", "//parser/model", + "//sessionctx/variable", "//store/driver/error", "//table/tables", "//util", diff --git a/br/pkg/lightning/restore/BUILD.bazel b/br/pkg/lightning/restore/BUILD.bazel index ef5aeb106585b..f2a08f22f7d51 100644 --- a/br/pkg/lightning/restore/BUILD.bazel +++ b/br/pkg/lightning/restore/BUILD.bazel @@ -53,6 +53,7 @@ go_library( "//parser/model", "//parser/mysql", "//planner/core", + "//sessionctx/variable", "//store/driver", "//store/pdtypes", "//table", @@ -76,6 +77,7 @@ go_library( "@com_github_pingcap_kvproto//pkg/import_sstpb", "@com_github_pingcap_kvproto//pkg/metapb", "@com_github_pingcap_tipb//go-tipb", + "@com_github_tikv_client_go_v2//kv", "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_pd_client//:client", "@io_etcd_go_etcd_client_v3//:client",