diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index 58ec84157f0..24a75012331 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -99,9 +99,9 @@ func NewCluster(parentCtx context.Context, persistConfig *config.PersistConfig, clusterID: clusterID, checkMembershipCh: checkMembershipCh, - heartbeatRunner: ratelimit.NewConcurrentRunner(ctx, heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - miscRunner: ratelimit.NewConcurrentRunner(ctx, miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - logRunner: ratelimit.NewConcurrentRunner(ctx, logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + heartbeatRunner: ratelimit.NewConcurrentRunner(heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + miscRunner: ratelimit.NewConcurrentRunner(miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + logRunner: ratelimit.NewConcurrentRunner(logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), } c.coordinator = schedule.NewCoordinator(ctx, c, hbStreams) err = c.ruleManager.Initialize(persistConfig.GetMaxReplicas(), persistConfig.GetLocationLabels(), persistConfig.GetIsolationLevel()) @@ -549,9 +549,9 @@ func (c *Cluster) StartBackgroundJobs() { go c.runUpdateStoreStats() go c.runCoordinator() go c.runMetricsCollectionJob() - c.heartbeatRunner.Start() - c.miscRunner.Start() - c.logRunner.Start() + c.heartbeatRunner.Start(c.ctx) + c.miscRunner.Start(c.ctx) + c.logRunner.Start(c.ctx) c.running.Store(true) } diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go index 1ac7ae899af..57a19e4e682 100644 --- a/pkg/ratelimit/runner.go +++ b/pkg/ratelimit/runner.go @@ -43,7 +43,7 @@ const ( // Runner is the interface for running tasks. type Runner interface { RunTask(id uint64, name string, f func(), opts ...TaskOption) error - Start() + Start(ctx context.Context) Stop() } @@ -81,11 +81,8 @@ type ConcurrentRunner struct { } // NewConcurrentRunner creates a new ConcurrentRunner. -func NewConcurrentRunner(ctx context.Context, name string, limiter *ConcurrencyLimiter, maxPendingDuration time.Duration) *ConcurrentRunner { - ctx, cancel := context.WithCancel(ctx) +func NewConcurrentRunner(name string, limiter *ConcurrencyLimiter, maxPendingDuration time.Duration) *ConcurrentRunner { s := &ConcurrentRunner{ - ctx: ctx, - cancel: cancel, name: name, limiter: limiter, maxPendingDuration: maxPendingDuration, @@ -107,7 +104,8 @@ func WithRetained(retained bool) TaskOption { } // Start starts the runner. -func (cr *ConcurrentRunner) Start() { +func (cr *ConcurrentRunner) Start(ctx context.Context) { + cr.ctx, cr.cancel = context.WithCancel(ctx) cr.wg.Add(1) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -246,7 +244,7 @@ func (*SyncRunner) RunTask(_ uint64, _ string, f func(), _ ...TaskOption) error } // Start starts the runner. -func (*SyncRunner) Start() {} +func (*SyncRunner) Start(context.Context) {} // Stop stops the runner. func (*SyncRunner) Stop() {} diff --git a/pkg/ratelimit/runner_test.go b/pkg/ratelimit/runner_test.go index a3eac7f238e..d4aa0825e83 100644 --- a/pkg/ratelimit/runner_test.go +++ b/pkg/ratelimit/runner_test.go @@ -25,8 +25,8 @@ import ( func TestConcurrentRunner(t *testing.T) { t.Run("RunTask", func(t *testing.T) { - runner := NewConcurrentRunner(context.TODO(), "test", NewConcurrencyLimiter(1), time.Second) - runner.Start() + runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), time.Second) + runner.Start(context.TODO()) defer runner.Stop() var wg sync.WaitGroup @@ -47,8 +47,8 @@ func TestConcurrentRunner(t *testing.T) { }) t.Run("MaxPendingDuration", func(t *testing.T) { - runner := NewConcurrentRunner(context.TODO(), "test", NewConcurrencyLimiter(1), 2*time.Millisecond) - runner.Start() + runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), 2*time.Millisecond) + runner.Start(context.TODO()) defer runner.Stop() var wg sync.WaitGroup for i := 0; i < 10; i++ { @@ -76,8 +76,8 @@ func TestConcurrentRunner(t *testing.T) { }) t.Run("DuplicatedTask", func(t *testing.T) { - runner := NewConcurrentRunner(context.TODO(), "test", NewConcurrencyLimiter(1), time.Minute) - runner.Start() + runner := NewConcurrentRunner("test", NewConcurrencyLimiter(1), time.Minute) + runner.Start(context.TODO()) defer runner.Stop() for i := 1; i < 11; i++ { regionID := uint64(i) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 812cbb437f0..ed1080f617a 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -204,9 +204,9 @@ func NewRaftCluster(ctx context.Context, clusterID uint64, basicCluster *core.Ba etcdClient: etcdClient, BasicCluster: basicCluster, storage: storage, - heartbeatRunner: ratelimit.NewConcurrentRunner(ctx, heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - miscRunner: ratelimit.NewConcurrentRunner(ctx, miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), - logRunner: ratelimit.NewConcurrentRunner(ctx, logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + heartbeatRunner: ratelimit.NewConcurrentRunner(heartbeatTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + miscRunner: ratelimit.NewConcurrentRunner(miscTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), + logRunner: ratelimit.NewConcurrentRunner(logTaskRunner, ratelimit.NewConcurrencyLimiter(uint64(runtime.NumCPU()*2)), time.Minute), } } @@ -364,9 +364,9 @@ func (c *RaftCluster) Start(s Server) error { go c.startGCTuner() c.running = true - c.heartbeatRunner.Start() - c.miscRunner.Start() - c.logRunner.Start() + c.heartbeatRunner.Start(c.ctx) + c.miscRunner.Start(c.ctx) + c.logRunner.Start(c.ctx) return nil }