diff --git a/pkg/ratelimit/concurrency_limiter.go b/pkg/ratelimit/concurrency_limiter.go index af768461478..e5379bc48cc 100644 --- a/pkg/ratelimit/concurrency_limiter.go +++ b/pkg/ratelimit/concurrency_limiter.go @@ -106,8 +106,8 @@ func (l *ConcurrencyLimiter) GetWaitingTasksNum() uint64 { return l.waiting } -// Acquire acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout. -func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) { +// AcquireToken acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout. +func (l *ConcurrencyLimiter) AcquireToken(ctx context.Context) (*TaskToken, error) { l.mu.Lock() if l.current >= l.limit { l.waiting++ @@ -129,27 +129,26 @@ func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) { } } l.current++ - token := &TaskToken{limiter: l} + token := &TaskToken{} l.mu.Unlock() return token, nil } -// TaskToken is a token that must be released after the task is done. -type TaskToken struct { - released bool - limiter *ConcurrencyLimiter -} - -// Release releases the token. -func (tt *TaskToken) Release() { - tt.limiter.mu.Lock() - defer tt.limiter.mu.Unlock() - if tt.released { +// ReleaseToken releases the token. +func (l *ConcurrencyLimiter) ReleaseToken(token *TaskToken) { + l.mu.Lock() + defer l.mu.Unlock() + if token.released { return } - tt.released = true - tt.limiter.current-- - if len(tt.limiter.queue) < int(tt.limiter.limit) { - tt.limiter.queue <- tt + token.released = true + l.current-- + if len(l.queue) < int(l.limit) { + l.queue <- token } } + +// TaskToken is a token that must be released after the task is done. +type TaskToken struct { + released bool +} diff --git a/pkg/ratelimit/concurrency_limiter_test.go b/pkg/ratelimit/concurrency_limiter_test.go index 26e3e9efb92..f0af1125d21 100644 --- a/pkg/ratelimit/concurrency_limiter_test.go +++ b/pkg/ratelimit/concurrency_limiter_test.go @@ -68,17 +68,17 @@ func TestConcurrencyLimiter2(t *testing.T) { defer cancel() // Acquire two tokens - token1, err := limiter.Acquire(ctx) + token1, err := limiter.AcquireToken(ctx) require.NoError(t, err, "Failed to acquire token") - token2, err := limiter.Acquire(ctx) + token2, err := limiter.AcquireToken(ctx) require.NoError(t, err, "Failed to acquire token") require.Equal(t, limit, limiter.GetRunningTasksNum(), "Expected running tasks to be 2") // Try to acquire third token, it should not be able to acquire immediately due to limit go func() { - _, err := limiter.Acquire(ctx) + _, err := limiter.AcquireToken(ctx) require.NoError(t, err, "Failed to acquire token") }() @@ -86,13 +86,13 @@ func TestConcurrencyLimiter2(t *testing.T) { require.Equal(t, uint64(1), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 1") // Release a token - token1.Release() + limiter.ReleaseToken(token1) time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run require.Equal(t, uint64(2), limiter.GetRunningTasksNum(), "Expected running tasks to be 2") require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0") // Release the second token - token2.Release() + limiter.ReleaseToken(token2) time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run require.Equal(t, uint64(1), limiter.GetRunningTasksNum(), "Expected running tasks to be 1") } @@ -109,12 +109,12 @@ func TestConcurrencyLimiterAcquire(t *testing.T) { for i := 0; i < 100; i++ { go func(i int) { defer wg.Done() - token, err := limiter.Acquire(ctx) + token, err := limiter.AcquireToken(ctx) if err != nil { fmt.Printf("Task %d failed to acquire: %v\n", i, err) return } - defer token.Release() + defer limiter.ReleaseToken(token) // simulate takes some time time.Sleep(10 * time.Millisecond) atomic.AddInt64(&sum, 1) diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go index 361b08c49b8..07233af238b 100644 --- a/pkg/ratelimit/runner.go +++ b/pkg/ratelimit/runner.go @@ -103,13 +103,13 @@ func (cr *ConcurrentRunner) Start() { select { case task := <-cr.taskChan: if cr.limiter != nil { - token, err := cr.limiter.Acquire(context.Background()) + token, err := cr.limiter.AcquireToken(context.Background()) if err != nil { continue } - go cr.run(task.ctx, task.f, token) + go cr.run(task, token) } else { - go cr.run(task.ctx, task.f, nil) + go cr.run(task, nil) } case <-cr.stopChan: cr.pendingMu.Lock() @@ -133,10 +133,10 @@ func (cr *ConcurrentRunner) Start() { }() } -func (cr *ConcurrentRunner) run(ctx context.Context, task func(context.Context), token *TaskToken) { - task(ctx) +func (cr *ConcurrentRunner) run(task *Task, token *TaskToken) { + task.f(task.ctx) if token != nil { - token.Release() + cr.limiter.ReleaseToken(token) cr.processPendingTasks() } }