diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index 84208bcd0af8c..9f257c33dd61b 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -250,6 +250,8 @@ type PdController struct { // control the pause schedulers goroutine schedulerPauseCh chan struct{} + // control the ttl of pausing schedulers + SchedulerPauseTTL time.Duration } // NewPdController creates a new PdController. @@ -445,7 +447,7 @@ func (p *PdController) getStoreInfoWith( func (p *PdController) doPauseSchedulers(ctx context.Context, schedulers []string, post pdHTTPRequest) ([]string, error) { // pause this scheduler with 300 seconds - body, err := json.Marshal(pauseSchedulerBody{Delay: int64(pauseTimeout.Seconds())}) + body, err := json.Marshal(pauseSchedulerBody{Delay: int64(p.ttlOfPausing().Seconds())}) if err != nil { return nil, errors.Trace(err) } @@ -454,9 +456,11 @@ func (p *PdController) doPauseSchedulers(ctx context.Context, schedulers []strin for _, scheduler := range schedulers { prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler) for _, addr := range p.getAllPDAddrs() { + var resp []byte _, err = post(ctx, addr, prefix, p.cli, http.MethodPost, body) if err == nil { removedSchedulers = append(removedSchedulers, scheduler) + log.Info("Paused scheduler.", zap.String("response", string(resp)), zap.String("on", addr)) break } } @@ -491,7 +495,7 @@ func (p *PdController) pauseSchedulersAndConfigWith( } go func() { - tick := time.NewTicker(pauseTimeout / 3) + tick := time.NewTicker(p.ttlOfPausing() / 3) defer tick.Stop() for { @@ -637,7 +641,7 @@ func (p *PdController) doUpdatePDScheduleConfig( func (p *PdController) doPauseConfigs(ctx context.Context, cfg map[string]interface{}, post pdHTTPRequest) error { // pause this scheduler with 300 seconds - prefix := fmt.Sprintf("%s?ttlSecond=%.0f", configPrefix, pauseTimeout.Seconds()) + prefix := fmt.Sprintf("%s?ttlSecond=%.0f", configPrefix, p.ttlOfPausing().Seconds()) return p.doUpdatePDScheduleConfig(ctx, cfg, post, prefix) } @@ -1075,6 +1079,13 @@ func (p *PdController) Close() { } } +func (p *PdController) ttlOfPausing() time.Duration { + if p.SchedulerPauseTTL > 0 { + return p.SchedulerPauseTTL + } + return pauseTimeout +} + // FetchPDVersion get pd version func FetchPDVersion(ctx context.Context, tls *common.TLS, pdAddr string) (*semver.Version, error) { // An example of PD version API. diff --git a/br/pkg/task/operator/BUILD.bazel b/br/pkg/task/operator/BUILD.bazel index 5ce85cbd1313f..83f9f042f6a89 100644 --- a/br/pkg/task/operator/BUILD.bazel +++ b/br/pkg/task/operator/BUILD.bazel @@ -9,6 +9,7 @@ go_library( importpath = "github.com/pingcap/tidb/br/pkg/task/operator", visibility = ["//visibility:public"], deps = [ + "//br/pkg/errors", "//br/pkg/logutil", "//br/pkg/pdutil", "//br/pkg/task", @@ -18,6 +19,7 @@ go_library( "@com_github_spf13_pflag//:pflag", "@org_golang_google_grpc//keepalive", "@org_golang_x_sync//errgroup", + "@org_uber_go_multierr//:multierr", "@org_uber_go_zap//:zap", ], ) diff --git a/br/pkg/task/operator/cmd.go b/br/pkg/task/operator/cmd.go index 909d18911c8d0..1917a9acd3b1b 100644 --- a/br/pkg/task/operator/cmd.go +++ b/br/pkg/task/operator/cmd.go @@ -5,16 +5,21 @@ package operator import ( "context" "crypto/tls" + "fmt" + "math/rand" + "os" "strings" "sync" "time" "github.com/pingcap/errors" "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/pdutil" "github.com/pingcap/tidb/br/pkg/task" "github.com/pingcap/tidb/br/pkg/utils" + "go.uber.org/multierr" "go.uber.org/zap" "golang.org/x/sync/errgroup" "google.golang.org/grpc/keepalive" @@ -38,13 +43,16 @@ func dialPD(ctx context.Context, cfg *task.Config) (*pdutil.PdController, error) } func (cx *AdaptEnvForSnapshotBackupContext) cleanUpWith(f func(ctx context.Context)) { - _ = cx.cleanUpWithErr(func(ctx context.Context) error { f(ctx); return nil }) + cx.cleanUpWithRetErr(nil, func(ctx context.Context) error { f(ctx); return nil }) } -func (cx *AdaptEnvForSnapshotBackupContext) cleanUpWithErr(f func(ctx context.Context) error) error { +func (cx *AdaptEnvForSnapshotBackupContext) cleanUpWithRetErr(errOut *error, f func(ctx context.Context) error) { ctx, cancel := context.WithTimeout(context.Background(), cx.cfg.TTL) defer cancel() - return f(ctx) + err := f(ctx) + if errOut != nil { + *errOut = multierr.Combine(*errOut, err) + } } type AdaptEnvForSnapshotBackupContext struct { @@ -58,6 +66,18 @@ type AdaptEnvForSnapshotBackupContext struct { runGrp *errgroup.Group } +func (cx *AdaptEnvForSnapshotBackupContext) Close() { + cx.pdMgr.Close() + cx.kvMgr.Close() +} + +func (cx *AdaptEnvForSnapshotBackupContext) GetBackOffer(operation string) utils.Backoffer { + state := utils.InitialRetryState(64, 1*time.Second, 10*time.Second) + bo := utils.GiveUpRetryOn(&state, berrors.ErrPossibleInconsistency) + bo = utils.VerboseRetry(bo, logutil.CL(cx).With(zap.String("operation", operation))) + return bo +} + func (cx *AdaptEnvForSnapshotBackupContext) ReadyL(name string, notes ...zap.Field) { logutil.CL(cx).Info("Stage ready.", append(notes, zap.String("component", name))...) cx.rdGrp.Done() @@ -77,6 +97,7 @@ func AdaptEnvForSnapshotBackup(ctx context.Context, cfg *PauseGcConfig) error { if err != nil { return errors.Annotate(err, "failed to dial PD") } + mgr.SchedulerPauseTTL = cfg.TTL var tconf *tls.Config if cfg.TLS.IsEnabled() { tconf, err = cfg.TLS.ToTLSConfig() @@ -97,6 +118,8 @@ func AdaptEnvForSnapshotBackup(ctx context.Context, cfg *PauseGcConfig) error { rdGrp: sync.WaitGroup{}, runGrp: eg, } + defer cx.Close() + cx.rdGrp.Add(3) eg.Go(func() error { return pauseGCKeeper(cx) }) @@ -104,66 +127,98 @@ func AdaptEnvForSnapshotBackup(ctx context.Context, cfg *PauseGcConfig) error { eg.Go(func() error { return pauseImporting(cx) }) go func() { cx.rdGrp.Wait() + if cfg.OnAllReady != nil { + cfg.OnAllReady() + } hintAllReady() }() + defer func() { + if cfg.OnExit != nil { + cfg.OnExit() + } + }() return eg.Wait() } +func getCallerName() string { + name, err := os.Hostname() + if err != nil { + name = fmt.Sprintf("UNKNOWN-%d", rand.Int63()) + } + return fmt.Sprintf("operator@%sT%d#%d", name, time.Now().Unix(), os.Getpid()) +} + func pauseImporting(cx *AdaptEnvForSnapshotBackupContext) error { - denyLightning := utils.NewSuspendImporting("prepare_for_snapshot_backup", cx.kvMgr) - if _, err := denyLightning.DenyAllStores(cx, cx.cfg.TTL); err != nil { + suspendLightning := utils.NewSuspendImporting(getCallerName(), cx.kvMgr) + _, err := utils.WithRetryV2(cx, cx.GetBackOffer("suspend_lightning"), func(_ context.Context) (map[uint64]bool, error) { + return suspendLightning.DenyAllStores(cx, cx.cfg.TTL) + }) + if err != nil { return errors.Trace(err) } cx.ReadyL("pause_lightning") - cx.runGrp.Go(func() error { - err := denyLightning.Keeper(cx, cx.cfg.TTL) + cx.runGrp.Go(func() (err error) { + defer cx.cleanUpWithRetErr(&err, func(ctx context.Context) error { + if ctx.Err() != nil { + //nolint: all_revive,revive // There is a false positive on returning in `defer`. + return errors.Annotate(ctx.Err(), "cleaning up timed out") + } + res, err := utils.WithRetryV2(ctx, cx.GetBackOffer("restore_lightning"), + //nolint: all_revive,revive // There is a false positive on returning in `defer`. + func(ctx context.Context) (map[uint64]bool, error) { return suspendLightning.AllowAllStores(ctx) }) + if err != nil { + //nolint: all_revive,revive // There is a false positive on returning in `defer`. + return errors.Annotatef(err, "failed to allow all stores") + } + //nolint: all_revive,revive // There is a false positive on returning in `defer`. + return suspendLightning.ConsistentWithPrev(res) + }) + + err = suspendLightning.Keeper(cx, cx.cfg.TTL) if errors.Cause(err) != context.Canceled { logutil.CL(cx).Warn("keeper encounters error.", logutil.ShortError(err)) + return err } - return cx.cleanUpWithErr(func(ctx context.Context) error { - for { - if ctx.Err() != nil { - return errors.Annotate(ctx.Err(), "cleaning up timed out") - } - res, err := denyLightning.AllowAllStores(ctx) - if err != nil { - logutil.CL(ctx).Warn("Failed to restore lightning, will retry.", logutil.ShortError(err)) - // Retry for 10 times. - time.Sleep(cx.cfg.TTL / 10) - continue - } - return denyLightning.ConsistentWithPrev(res) - } - }) + // Clean up the canceled error. + err = nil + return }) return nil } -func pauseGCKeeper(ctx *AdaptEnvForSnapshotBackupContext) error { +func pauseGCKeeper(cx *AdaptEnvForSnapshotBackupContext) (err error) { // Note: should we remove the service safepoint as soon as this exits? sp := utils.BRServiceSafePoint{ ID: utils.MakeSafePointID(), - TTL: int64(ctx.cfg.TTL.Seconds()), - BackupTS: ctx.cfg.SafePoint, + TTL: int64(cx.cfg.TTL.Seconds()), + BackupTS: cx.cfg.SafePoint, } if sp.BackupTS == 0 { - rts, err := ctx.pdMgr.GetMinResolvedTS(ctx) + rts, err := cx.pdMgr.GetMinResolvedTS(cx) if err != nil { return err } - logutil.CL(ctx).Info("No service safepoint provided, using the minimal resolved TS.", zap.Uint64("min-resolved-ts", rts)) + logutil.CL(cx).Info("No service safepoint provided, using the minimal resolved TS.", zap.Uint64("min-resolved-ts", rts)) sp.BackupTS = rts } - err := utils.StartServiceSafePointKeeper(ctx, ctx.pdMgr.GetPDClient(), sp) + err = utils.StartServiceSafePointKeeper(cx, cx.pdMgr.GetPDClient(), sp) if err != nil { return err } - ctx.ReadyL("pause_gc", zap.Object("safepoint", sp)) + cx.ReadyL("pause_gc", zap.Object("safepoint", sp)) + defer cx.cleanUpWithRetErr(&err, func(ctx context.Context) error { + cancelSP := utils.BRServiceSafePoint{ + ID: sp.ID, + TTL: 0, + } + //nolint: all_revive,revive // There is a false positive on returning in `defer`. + return utils.UpdateServiceSafePoint(ctx, cx.pdMgr.GetPDClient(), cancelSP) + }) // Note: in fact we can directly return here. // But the name `keeper` implies once the function exits, // the GC should be resume, so let's block here. - <-ctx.Done() + <-cx.Done() return nil } diff --git a/br/pkg/task/operator/config.go b/br/pkg/task/operator/config.go index 998fdc64d961e..693d4908bdee6 100644 --- a/br/pkg/task/operator/config.go +++ b/br/pkg/task/operator/config.go @@ -14,10 +14,13 @@ type PauseGcConfig struct { SafePoint uint64 `json:"safepoint" yaml:"safepoint"` TTL time.Duration `json:"ttl" yaml:"ttl"` + + OnAllReady func() `json:"-" yaml:"-"` + OnExit func() `json:"-" yaml:"-"` } func DefineFlagsForPrepareSnapBackup(f *pflag.FlagSet) { - _ = f.DurationP("ttl", "i", 5*time.Minute, "The time-to-live of the safepoint.") + _ = f.DurationP("ttl", "i", 2*time.Minute, "The time-to-live of the safepoint.") _ = f.Uint64P("safepoint", "t", 0, "The GC safepoint to be kept.") } diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel index d119c77364e1b..46853bdac6f92 100644 --- a/br/pkg/utils/BUILD.bazel +++ b/br/pkg/utils/BUILD.bazel @@ -90,7 +90,7 @@ go_test( ], embed = [":utils"], flaky = True, - shard_count = 29, + shard_count = 37, deps = [ "//br/pkg/errors", "//br/pkg/metautil", diff --git a/br/pkg/utils/retry.go b/br/pkg/utils/retry.go index 20482d7c423a2..c0476e8db3701 100644 --- a/br/pkg/utils/retry.go +++ b/br/pkg/utils/retry.go @@ -4,16 +4,20 @@ package utils import ( "context" + stderrs "errors" "strings" "sync" "time" "github.com/cznic/mathutil" + "github.com/google/uuid" "github.com/pingcap/errors" + "github.com/pingcap/log" tmysql "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/terror" "github.com/tikv/client-go/v2/tikv" "go.uber.org/multierr" + "go.uber.org/zap" ) var retryableServerError = []string{ @@ -180,3 +184,77 @@ func (r *RetryWithBackoffer) RequestBackOff(ms int) { func (r *RetryWithBackoffer) Inner() *tikv.Backoffer { return r.bo } + +type verboseBackoffer struct { + inner Backoffer + logger *zap.Logger + groupID uuid.UUID +} + +func (v *verboseBackoffer) NextBackoff(err error) time.Duration { + nextBackoff := v.inner.NextBackoff(err) + v.logger.Warn("Encountered err, retrying.", + zap.Stringer("nextBackoff", nextBackoff), + zap.String("err", err.Error()), + zap.Stringer("gid", v.groupID)) + return nextBackoff +} + +// Attempt returns the remain attempt times +func (v *verboseBackoffer) Attempt() int { + attempt := v.inner.Attempt() + if attempt > 0 { + v.logger.Debug("Retry attempt hint.", zap.Int("attempt", attempt), zap.Stringer("gid", v.groupID)) + } else { + v.logger.Warn("Retry limit exceeded.", zap.Stringer("gid", v.groupID)) + } + return attempt +} + +func VerboseRetry(bo Backoffer, logger *zap.Logger) Backoffer { + if logger == nil { + logger = log.L() + } + vlog := &verboseBackoffer{ + inner: bo, + logger: logger, + groupID: uuid.New(), + } + return vlog +} + +type failedOnErr struct { + inner Backoffer + failed bool + failedOn []error +} + +// NextBackoff returns a duration to wait before retrying again +func (f *failedOnErr) NextBackoff(err error) time.Duration { + for _, fatalErr := range f.failedOn { + if stderrs.Is(errors.Cause(err), fatalErr) { + f.failed = true + return 0 + } + } + if !f.failed { + return f.inner.NextBackoff(err) + } + return 0 +} + +// Attempt returns the remain attempt times +func (f *failedOnErr) Attempt() int { + if f.failed { + return 0 + } + return f.inner.Attempt() +} + +func GiveUpRetryOn(bo Backoffer, errs ...error) Backoffer { + return &failedOnErr{ + inner: bo, + failed: false, + failedOn: errs, + } +} diff --git a/br/pkg/utils/retry_test.go b/br/pkg/utils/retry_test.go index eeef8c61c0480..c2afe35f47741 100644 --- a/br/pkg/utils/retry_test.go +++ b/br/pkg/utils/retry_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/pingcap/errors" + berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/utils" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/tikv" @@ -47,3 +48,26 @@ func TestRetryAdapter(t *testing.T) { req.Greater(time.Since(begin), 200*time.Millisecond) } + +func TestFailNowIf(t *testing.T) { + mockBO := utils.InitialRetryState(100, time.Second, time.Second) + err1 := errors.New("error1") + err2 := errors.New("error2") + assert := require.New(t) + + bo := utils.GiveUpRetryOn(&mockBO, err1) + + // Test NextBackoff with an error that is not in failedOn + assert.Equal(time.Second, bo.NextBackoff(err2)) + assert.NotEqualValues(0, bo.Attempt()) + + annotatedErr := errors.Annotate(errors.Annotate(err1, "meow?"), "nya?") + assert.Equal(time.Duration(0), bo.NextBackoff(annotatedErr)) + assert.Equal(0, bo.Attempt()) + + mockBO = utils.InitialRetryState(100, time.Second, time.Second) + bo = utils.GiveUpRetryOn(&mockBO, berrors.ErrBackupNoLeader) + annotatedErr = berrors.ErrBackupNoLeader.FastGen("leader is taking an adventure") + assert.Equal(time.Duration(0), bo.NextBackoff(annotatedErr)) + assert.Equal(0, bo.Attempt()) +} diff --git a/br/pkg/utils/suspend_importing.go b/br/pkg/utils/suspend_importing.go index c2df70229c525..0fffb40727af4 100644 --- a/br/pkg/utils/suspend_importing.go +++ b/br/pkg/utils/suspend_importing.go @@ -1,3 +1,4 @@ +// Copyright 2023 PingCAP, Inc. Licensed under Apache-2.0. package utils import ( @@ -7,6 +8,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/util/engine" @@ -86,6 +88,7 @@ func (d *SuspendImporting) forEachStores(ctx context.Context, makeReq func() *im } result := map[uint64]bool{} + log.Info("SuspendImporting/forEachStores: hint of current store.", zap.Stringers("stores", stores)) for _, store := range stores { logutil.CL(ctx).Info("Handling store.", zap.Stringer("store", store)) if engine.IsTiFlash(store) { diff --git a/br/pkg/utils/suspend_importing_test.go b/br/pkg/utils/suspend_importing_test.go index 8ee04af072048..9ce3f271a169e 100644 --- a/br/pkg/utils/suspend_importing_test.go +++ b/br/pkg/utils/suspend_importing_test.go @@ -1,3 +1,4 @@ +// Copyright 2023 PingCAP, Inc. Licensed under Apache-2.0. package utils_test import ( diff --git a/tests/realtikvtest/brietest/BUILD.bazel b/tests/realtikvtest/brietest/BUILD.bazel index 49ea32406c7d6..6bd9e8bf7740d 100644 --- a/tests/realtikvtest/brietest/BUILD.bazel +++ b/tests/realtikvtest/brietest/BUILD.bazel @@ -8,10 +8,13 @@ go_test( "binlog_test.go", "flashback_test.go", "main_test.go", + "operator_test.go", ], flaky = True, race = "on", deps = [ + "//br/pkg/task", + "//br/pkg/task/operator", "//config", "//ddl/util", "//parser/mysql", @@ -21,11 +24,14 @@ go_test( "//testkit/testsetup", "//tests/realtikvtest", "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/import_sstpb", "@com_github_pingcap_tipb//go-binlog", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_client_go_v2//util", + "@com_github_tikv_pd_client//:client", "@org_golang_google_grpc//:grpc", + "@org_golang_google_grpc//credentials/insecure", "@org_uber_go_goleak//:goleak", ], ) diff --git a/tests/realtikvtest/brietest/operator_test.go b/tests/realtikvtest/brietest/operator_test.go new file mode 100644 index 0000000000000..3e3010132c297 --- /dev/null +++ b/tests/realtikvtest/brietest/operator_test.go @@ -0,0 +1,201 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package brietest + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + "time" + + "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/tidb/br/pkg/task" + "github.com/pingcap/tidb/br/pkg/task/operator" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/oracle" + pd "github.com/tikv/pd/client" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +var ( + serviceGCSafepointPrefix = "pd/api/v1/gc/safepoint" + schedulersPrefix = "pd/api/v1/schedulers" +) + +func getJSON(url string, response any) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + return json.NewDecoder(resp.Body).Decode(response) +} + +func pdAPI(cfg operator.PauseGcConfig, path string) string { + return fmt.Sprintf("http://%s/%s", cfg.Config.PD[0], path) +} + +type GcSafePoints struct { + SPs []struct { + ServiceID string `json:"service_id"` + ExpiredAt int64 `json:"expired_at"` + SafePoint int64 `json:"safe_point"` + } `json:"service_gc_safe_points"` +} + +func verifyGCStopped(t *require.Assertions, cfg operator.PauseGcConfig) { + var result GcSafePoints + t.NoError(getJSON(pdAPI(cfg, serviceGCSafepointPrefix), &result)) + for _, sp := range result.SPs { + if sp.ServiceID != "gc_worker" { + t.Equal(int64(cfg.SafePoint)-1, sp.SafePoint, result.SPs) + } + } +} + +func verifyGCNotStopped(t *require.Assertions, cfg operator.PauseGcConfig) { + var result GcSafePoints + t.NoError(getJSON(pdAPI(cfg, serviceGCSafepointPrefix), &result)) + for _, sp := range result.SPs { + if sp.ServiceID != "gc_worker" { + t.FailNowf("the service gc safepoint exists", "it is %#v", sp) + } + } +} + +func verifyLightningStopped(t *require.Assertions, cfg operator.PauseGcConfig) { + cx := context.Background() + pdc, err := pd.NewClient(cfg.Config.PD, pd.SecurityOption{}) + t.NoError(err) + defer pdc.Close() + stores, err := pdc.GetAllStores(cx, pd.WithExcludeTombstone()) + t.NoError(err) + s := stores[0] + conn, err := grpc.DialContext(cx, s.Address, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + t.NoError(err) + ingestCli := import_sstpb.NewImportSSTClient(conn) + res, err := ingestCli.Ingest(cx, &import_sstpb.IngestRequest{}) + t.NoError(err) + t.NotNil(res.GetError(), "res = %s", res) +} + +func verifySchedulersStopped(t *require.Assertions, cfg operator.PauseGcConfig) { + var ( + schedulers []string + pausedSchedulers []string + target = pdAPI(cfg, schedulersPrefix) + ) + + t.NoError(getJSON(target, &schedulers)) + enabledSchedulers := map[string]struct{}{} + for _, sched := range schedulers { + enabledSchedulers[sched] = struct{}{} + } + t.NoError(getJSON(target+"?status=paused", &pausedSchedulers)) + for _, scheduler := range pausedSchedulers { + t.Contains(enabledSchedulers, scheduler) + } +} + +func verifySchedulerNotStopped(t *require.Assertions, cfg operator.PauseGcConfig) { + var ( + schedulers []string + pausedSchedulers []string + target = pdAPI(cfg, schedulersPrefix) + ) + + t.NoError(getJSON(target, &schedulers)) + enabledSchedulers := map[string]struct{}{} + for _, sched := range schedulers { + enabledSchedulers[sched] = struct{}{} + } + t.NoError(getJSON(target+"?status=paused", &pausedSchedulers)) + for _, scheduler := range pausedSchedulers { + t.NotContains(enabledSchedulers, scheduler) + } +} + +func cleanUpGCSafepoint(cfg operator.PauseGcConfig, t *testing.T) { + var result GcSafePoints + pdCli, err := pd.NewClient(cfg.PD, pd.SecurityOption{}) + require.NoError(t, err) + defer pdCli.Close() + getJSON(pdAPI(cfg, serviceGCSafepointPrefix), &result) + for _, sp := range result.SPs { + if sp.ServiceID != "gc_worker" { + sp.SafePoint = 0 + _, err := pdCli.UpdateServiceGCSafePoint(context.Background(), sp.ServiceID, 0, 0) + require.NoError(t, err) + } + } +} + +func TestOperator(t *testing.T) { + req := require.New(t) + rd := make(chan struct{}) + ex := make(chan struct{}) + cfg := operator.PauseGcConfig{ + Config: task.Config{ + PD: []string{"127.0.0.1:2379"}, + }, + TTL: 5 * time.Minute, + SafePoint: oracle.GoTimeToTS(time.Now()), + OnAllReady: func() { + close(rd) + }, + OnExit: func() { + close(ex) + }, + } + + cleanUpGCSafepoint(cfg, t) + + verifyGCNotStopped(req, cfg) + verifySchedulerNotStopped(req, cfg) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + req.NoError(operator.AdaptEnvForSnapshotBackup(ctx, &cfg)) + }() + req.Eventually(func() bool { + select { + case <-rd: + return true + default: + return false + } + }, 10*time.Second, time.Second) + + cancel() + verifyGCStopped(req, cfg) + verifyLightningStopped(req, cfg) + verifySchedulersStopped(req, cfg) + + req.Eventually(func() bool { + select { + case <-ex: + return true + default: + return false + } + }, 10*time.Second, time.Second) + + verifySchedulerNotStopped(req, cfg) + verifyGCNotStopped(req, cfg) +}