diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index 827d2f0ed3bee..5f2fc289a60d3 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -234,15 +234,15 @@ func (b *encodingBuilder) MakeEmptyRows() kv.Rows { type targetInfoGetter struct { tls *common.TLS targetDBGlue glue.Glue - pdAddr string + pdCli pd.Client } // NewTargetInfoGetter creates an TargetInfoGetter with local backend implementation. -func NewTargetInfoGetter(tls *common.TLS, g glue.Glue, pdAddr string) backend.TargetInfoGetter { +func NewTargetInfoGetter(tls *common.TLS, g glue.Glue, pdCli pd.Client) backend.TargetInfoGetter { return &targetInfoGetter{ tls: tls, targetDBGlue: g, - pdAddr: pdAddr, + pdCli: pdCli, } } @@ -264,10 +264,10 @@ func (g *targetInfoGetter) CheckRequirements(ctx context.Context, checkCtx *back if err := checkTiDBVersion(ctx, versionStr, localMinTiDBVersion, localMaxTiDBVersion); err != nil { return err } - if err := tikv.CheckPDVersion(ctx, g.tls, g.pdAddr, localMinPDVersion, localMaxPDVersion); err != nil { + if err := tikv.CheckPDVersion(ctx, g.tls, g.pdCli.GetLeaderAddr(), localMinPDVersion, localMaxPDVersion); err != nil { return err } - if err := tikv.CheckTiKVVersion(ctx, g.tls, g.pdAddr, localMinTiKVVersion, localMaxTiKVVersion); err != nil { + if err := tikv.CheckTiKVVersion(ctx, g.tls, g.pdCli.GetLeaderAddr(), localMinTiKVVersion, localMaxTiKVVersion); err != nil { return err } @@ -512,7 +512,7 @@ func NewLocalBackend( writeLimiter: writeLimiter, logger: log.FromContext(ctx), encBuilder: NewEncodingBuilder(ctx), - targetInfoGetter: NewTargetInfoGetter(tls, g, cfg.TiDB.PdAddr), + targetInfoGetter: NewTargetInfoGetter(tls, g, pdCtl.GetPDClient()), shouldCheckWriteStall: cfg.Cron.SwitchMode.Duration == 0, } if m, ok := metric.FromContext(ctx); ok { diff --git a/br/pkg/lightning/common/security.go b/br/pkg/lightning/common/security.go index a48abc48c2c54..03893ddf16b75 100644 --- a/br/pkg/lightning/common/security.go +++ b/br/pkg/lightning/common/security.go @@ -20,6 +20,7 @@ import ( "net" "net/http" "net/http/httptest" + "strings" "github.com/pingcap/errors" "github.com/pingcap/tidb/br/pkg/httputil" @@ -86,8 +87,15 @@ func NewTLSFromMockServer(server *httptest.Server) *TLS { } } +// GetMockTLSUrl returns tls's host for mock test +func GetMockTLSUrl(tls *TLS) string { + return tls.url +} + // WithHost creates a new TLS instance with the host replaced. func (tc *TLS) WithHost(host string) *TLS { + host = strings.TrimPrefix(host, "http://") + host = strings.TrimPrefix(host, "https://") var url string if tc.inner != nil { url = "https://" + host diff --git a/br/pkg/lightning/common/security_test.go b/br/pkg/lightning/common/security_test.go index e34ef3622500c..4ba9825efc883 100644 --- a/br/pkg/lightning/common/security_test.go +++ b/br/pkg/lightning/common/security_test.go @@ -70,6 +70,49 @@ func TestGetJSONSecure(t *testing.T) { require.Equal(t, "/dddd", result.Path) } +func TestWithHost(t *testing.T) { + mockTLSServer := httptest.NewTLSServer(http.HandlerFunc(respondPathHandler)) + defer mockTLSServer.Close() + mockServer := httptest.NewServer(http.HandlerFunc(respondPathHandler)) + defer mockServer.Close() + + testCases := []struct { + expected string + host string + secure bool + }{ + { + "https://127.0.0.1:2379", + "http://127.0.0.1:2379", + true, + }, + { + "http://127.0.0.1:2379", + "https://127.0.0.1:2379", + false, + }, + { + "http://127.0.0.1:2379/pd/api/v1/stores", + "127.0.0.1:2379/pd/api/v1/stores", + false, + }, + { + "https://127.0.0.1:2379", + "127.0.0.1:2379", + true, + }, + } + + for _, testCase := range testCases { + server := mockServer + if testCase.secure { + server = mockTLSServer + } + tls := common.NewTLSFromMockServer(server) + require.Equal(t, testCase.expected, common.GetMockTLSUrl(tls.WithHost(testCase.host))) + } +} + func TestInvalidTLS(t *testing.T) { tempDir := t.TempDir() caPath := filepath.Join(tempDir, "ca.pem") diff --git a/br/pkg/lightning/importer/BUILD.bazel b/br/pkg/lightning/importer/BUILD.bazel new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/br/pkg/lightning/restore/BUILD.bazel b/br/pkg/lightning/restore/BUILD.bazel index 01bf6145d86d3..fd33c7c77d91e 100644 --- a/br/pkg/lightning/restore/BUILD.bazel +++ b/br/pkg/lightning/restore/BUILD.bazel @@ -165,6 +165,7 @@ go_test( "@com_github_stretchr_testify//suite", "@com_github_tikv_client_go_v2//config", "@com_github_tikv_client_go_v2//oracle", + "@com_github_tikv_client_go_v2//testutils", "@com_github_tikv_pd_client//:client", "@com_github_xitongsys_parquet_go//writer", "@com_github_xitongsys_parquet_go_source//buffer", diff --git a/br/pkg/lightning/restore/get_pre_info.go b/br/pkg/lightning/restore/get_pre_info.go index 6251126a6c24e..7085202770bd1 100644 --- a/br/pkg/lightning/restore/get_pre_info.go +++ b/br/pkg/lightning/restore/get_pre_info.go @@ -50,6 +50,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tidb/util/mock" + pd "github.com/tikv/pd/client" "go.uber.org/zap" "golang.org/x/exp/maps" ) @@ -117,12 +118,14 @@ type TargetInfoGetterImpl struct { targetDBGlue glue.Glue tls *common.TLS backend backend.TargetInfoGetter + pdCli pd.Client } // NewTargetInfoGetterImpl creates a TargetInfoGetterImpl object. func NewTargetInfoGetterImpl( cfg *config.Config, targetDB *sql.DB, + pdCli pd.Client, ) (*TargetInfoGetterImpl, error) { targetDBGlue := glue.NewExternalTiDBGlue(targetDB, cfg.TiDB.SQLMode) tls, err := cfg.ToTLS() @@ -134,7 +137,10 @@ func NewTargetInfoGetterImpl( case config.BackendTiDB: backendTargetInfoGetter = tidb.NewTargetInfoGetter(targetDB) case config.BackendLocal: - backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDBGlue, cfg.TiDB.PdAddr) + if pdCli == nil { + return nil, common.ErrUnknown.GenWithStack("pd client is required when using local backend") + } + backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDBGlue, pdCli) default: return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) } @@ -143,6 +149,7 @@ func NewTargetInfoGetterImpl( targetDBGlue: targetDBGlue, tls: tls, backend: backendTargetInfoGetter, + pdCli: pdCli, }, nil } @@ -231,7 +238,7 @@ func (g *TargetInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Contex // It uses the PD interface through TLS to get the information. func (g *TargetInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtypes.ReplicationConfig, error) { result := new(pdtypes.ReplicationConfig) - if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdReplicate, &result); err != nil { + if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdReplicate, &result); err != nil { return nil, errors.Trace(err) } return result, nil @@ -242,7 +249,7 @@ func (g *TargetInfoGetterImpl) GetReplicationConfig(ctx context.Context) (*pdtyp // It uses the PD interface through TLS to get the information. func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.StoresInfo, error) { result := new(pdtypes.StoresInfo) - if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdStores, result); err != nil { + if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdStores, result); err != nil { return nil, errors.Trace(err) } return result, nil @@ -253,7 +260,7 @@ func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdtypes.Sto // It uses the PD interface through TLS to get the information. func (g *TargetInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdtypes.RegionsInfo, error) { result := new(pdtypes.RegionsInfo) - if err := g.tls.WithHost(g.cfg.TiDB.PdAddr).GetJSON(ctx, pdEmptyRegions, &result); err != nil { + if err := g.tls.WithHost(g.pdCli.GetLeaderAddr()).GetJSON(ctx, pdEmptyRegions, &result); err != nil { return nil, errors.Trace(err) } return result, nil diff --git a/br/pkg/lightning/restore/get_pre_info_test.go b/br/pkg/lightning/restore/get_pre_info_test.go index 71c2810d0b60e..fbcfd7d9063c7 100644 --- a/br/pkg/lightning/restore/get_pre_info_test.go +++ b/br/pkg/lightning/restore/get_pre_info_test.go @@ -758,7 +758,10 @@ func TestGetPreInfoIsTableEmpty(t *testing.T) { require.NoError(t, err) lnConfig := config.NewConfig() lnConfig.TikvImporter.Backend = config.BackendLocal - targetGetter, err := NewTargetInfoGetterImpl(lnConfig, db) + _, err = NewTargetInfoGetterImpl(lnConfig, db, nil) + require.ErrorContains(t, err, "pd client is required when using local backend") + lnConfig.TikvImporter.Backend = config.BackendTiDB + targetGetter, err := NewTargetInfoGetterImpl(lnConfig, db, nil) require.NoError(t, err) require.Equal(t, lnConfig, targetGetter.cfg) diff --git a/br/pkg/lightning/restore/precheck.go b/br/pkg/lightning/restore/precheck.go index a76854556a165..4e987c757c8e8 100644 --- a/br/pkg/lightning/restore/precheck.go +++ b/br/pkg/lightning/restore/precheck.go @@ -8,6 +8,7 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/config" "github.com/pingcap/tidb/br/pkg/lightning/mydump" ropts "github.com/pingcap/tidb/br/pkg/lightning/restore/opts" + pd "github.com/tikv/pd/client" ) type CheckItemID string @@ -57,7 +58,7 @@ type PrecheckItemBuilder struct { checkpointsDB checkpoints.DB } -func NewPrecheckItemBuilderFromConfig(ctx context.Context, cfg *config.Config, opts ...ropts.PrecheckItemBuilderOption) (*PrecheckItemBuilder, error) { +func NewPrecheckItemBuilderFromConfig(ctx context.Context, cfg *config.Config, pdCli pd.Client, opts ...ropts.PrecheckItemBuilderOption) (*PrecheckItemBuilder, error) { var gerr error builderCfg := new(ropts.PrecheckItemBuilderConfig) for _, o := range opts { @@ -67,7 +68,7 @@ func NewPrecheckItemBuilderFromConfig(ctx context.Context, cfg *config.Config, o if err != nil { return nil, errors.Trace(err) } - targetInfoGetter, err := NewTargetInfoGetterImpl(cfg, targetDB) + targetInfoGetter, err := NewTargetInfoGetterImpl(cfg, targetDB, pdCli) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/restore/restore.go b/br/pkg/lightning/restore/restore.go index e6c2406577a52..973f5aed4dab3 100644 --- a/br/pkg/lightning/restore/restore.go +++ b/br/pkg/lightning/restore/restore.go @@ -206,6 +206,7 @@ type Controller struct { pauser *common.Pauser backend backend.Backend tidbGlue glue.Glue + pdCli pd.Client alterTableLock sync.Mutex sysVars map[string]string @@ -329,6 +330,7 @@ func NewRestoreControllerWithPauser( } var backend backend.Backend + var pdCli pd.Client switch cfg.TikvImporter.Backend { case config.BackendTiDB: backend = tidb.NewTiDBBackend(ctx, db, cfg.TikvImporter.OnDuplicate, errorMgr) @@ -343,9 +345,13 @@ func NewRestoreControllerWithPauser( if maxOpenFiles < 0 { maxOpenFiles = math.MaxInt32 } + pdCli, err = pd.NewClientWithContext(ctx, []string{cfg.TiDB.PdAddr}, tls.ToPDSecurityOption()) + if err != nil { + return nil, errors.Trace(err) + } if cfg.TikvImporter.DuplicateResolution != config.DupeResAlgNone { - if err := tikv.CheckTiKVVersion(ctx, tls, cfg.TiDB.PdAddr, minTiKVVersionForDuplicateResolution, maxTiKVVersionForDuplicateResolution); err != nil { + if err := tikv.CheckTiKVVersion(ctx, tls, pdCli.GetLeaderAddr(), minTiKVVersionForDuplicateResolution, maxTiKVVersionForDuplicateResolution); err != nil { if berrors.Is(err, berrors.ErrVersionMismatch) { log.FromContext(ctx).Warn("TiKV version doesn't support duplicate resolution. The resolution algorithm will fall back to 'none'", zap.Error(err)) cfg.TikvImporter.DuplicateResolution = config.DupeResAlgNone @@ -392,6 +398,7 @@ func NewRestoreControllerWithPauser( targetDBGlue: p.Glue, tls: tls, backend: backend, + pdCli: pdCli, } preInfoGetter, err := NewPreRestoreInfoGetter( cfg, @@ -420,6 +427,7 @@ func NewRestoreControllerWithPauser( checksumWorks: worker.NewPool(ctx, cfg.TiDB.ChecksumTableConcurrency, "checksum"), pauser: p.Pauser, backend: backend, + pdCli: pdCli, tidbGlue: p.Glue, sysVars: defaultImportantVariables, tls: tls, @@ -448,6 +456,9 @@ func NewRestoreControllerWithPauser( func (rc *Controller) Close() { rc.backend.Close() rc.tidbGlue.GetSQLExecutor().Close() + if rc.pdCli != nil { + rc.pdCli.Close() + } } func (rc *Controller) Run(ctx context.Context) error { @@ -1860,7 +1871,7 @@ func (rc *Controller) fullCompact(ctx context.Context) error { } func (rc *Controller) doCompact(ctx context.Context, level int32) error { - tls := rc.tls.WithHost(rc.cfg.TiDB.PdAddr) + tls := rc.tls.WithHost(rc.pdCli.GetLeaderAddr()) return tikv.ForAllStores( ctx, tls, diff --git a/br/pkg/lightning/restore/table_restore_test.go b/br/pkg/lightning/restore/table_restore_test.go index 17fb97e346e36..5f6d3027a10c0 100644 --- a/br/pkg/lightning/restore/table_restore_test.go +++ b/br/pkg/lightning/restore/table_restore_test.go @@ -67,6 +67,8 @@ import ( filter "github.com/pingcap/tidb/util/table-filter" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/tikv/client-go/v2/testutils" + pd "github.com/tikv/pd/client" ) type tableRestoreSuiteBase struct { @@ -1099,6 +1101,8 @@ func (s *tableRestoreSuite) TestCheckClusterResource() { require.NoError(s.T(), err) mockStore, err := storage.NewLocalStorage(dir) require.NoError(s.T(), err) + _, _, pdClient, err := testutils.NewMockTiKV("", nil) + require.NoError(s.T(), err) for _, ca := range cases { server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { var err error @@ -1115,9 +1119,11 @@ func (s *tableRestoreSuite) TestCheckClusterResource() { url := strings.TrimPrefix(server.URL, "https://") cfg := &config.Config{TiDB: config.DBStore{PdAddr: url}} + pdCli := &mockPDClient{Client: pdClient, leaderAddr: url} targetInfoGetter := &TargetInfoGetterImpl{ - cfg: cfg, - tls: tls, + cfg: cfg, + tls: tls, + pdCli: pdCli, } preInfoGetter := &PreRestoreInfoGetterImpl{ cfg: cfg, @@ -1132,6 +1138,7 @@ func (s *tableRestoreSuite) TestCheckClusterResource() { checkTemplate: template, preInfoGetter: preInfoGetter, precheckItemBuilder: theCheckBuilder, + pdCli: pdCli, } var sourceSize int64 err = rc.store.WalkDir(ctx, &storage.WalkOption{}, func(path string, size int64) error { @@ -1168,6 +1175,15 @@ func (mockTaskMetaMgr) CheckTasksExclusively(ctx context.Context, action func(ta return err } +type mockPDClient struct { + pd.Client + leaderAddr string +} + +func (m *mockPDClient) GetLeaderAddr() string { + return m.leaderAddr +} + func (s *tableRestoreSuite) TestCheckClusterRegion() { type testCase struct { stores pdtypes.StoresInfo @@ -1184,6 +1200,8 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() { } return regions } + _, _, pdClient, err := testutils.NewMockTiKV("", nil) + require.NoError(s.T(), err) testCases := []testCase{ { @@ -1263,10 +1281,12 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() { url := strings.TrimPrefix(server.URL, "https://") cfg := &config.Config{TiDB: config.DBStore{PdAddr: url}} + pdCli := &mockPDClient{Client: pdClient, leaderAddr: url} targetInfoGetter := &TargetInfoGetterImpl{ - cfg: cfg, - tls: tls, + cfg: cfg, + tls: tls, + pdCli: pdCli, } dbMetas := []*mydump.MDDatabaseMeta{} preInfoGetter := &PreRestoreInfoGetterImpl{ @@ -1283,6 +1303,7 @@ func (s *tableRestoreSuite) TestCheckClusterRegion() { preInfoGetter: preInfoGetter, dbInfos: make(map[string]*checkpoints.TidbDBInfo), precheckItemBuilder: theCheckBuilder, + pdCli: pdCli, } preInfoGetter.dbInfosCache = rc.dbInfos