Skip to content

Commit

Permalink
refactor token
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Leung <rleungx@gmail.com>
  • Loading branch information
rleungx committed May 15, 2024
1 parent 2278609 commit c6673ab
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 31 deletions.
35 changes: 17 additions & 18 deletions pkg/ratelimit/concurrency_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ func (l *ConcurrencyLimiter) GetWaitingTasksNum() uint64 {
return l.waiting
}

// Acquire acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout.
func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) {
// AcquireToken acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout.
func (l *ConcurrencyLimiter) AcquireToken(ctx context.Context) (*TaskToken, error) {
l.mu.Lock()
if l.current >= l.limit {
l.waiting++
Expand All @@ -129,27 +129,26 @@ func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) {
}
}
l.current++
token := &TaskToken{limiter: l}
token := &TaskToken{}
l.mu.Unlock()
return token, nil
}

// TaskToken is a token that must be released after the task is done.
type TaskToken struct {
released bool
limiter *ConcurrencyLimiter
}

// Release releases the token.
func (tt *TaskToken) Release() {
tt.limiter.mu.Lock()
defer tt.limiter.mu.Unlock()
if tt.released {
// ReleaseToken releases the token.
func (l *ConcurrencyLimiter) ReleaseToken(token *TaskToken) {
l.mu.Lock()
defer l.mu.Unlock()
if token.released {
return
}
tt.released = true
tt.limiter.current--
if len(tt.limiter.queue) < int(tt.limiter.limit) {
tt.limiter.queue <- tt
token.released = true
l.current--
if len(l.queue) < int(l.limit) {
l.queue <- token
}
}

// TaskToken is a token that must be released after the task is done.
type TaskToken struct {
released bool
}
14 changes: 7 additions & 7 deletions pkg/ratelimit/concurrency_limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,31 +68,31 @@ func TestConcurrencyLimiter2(t *testing.T) {
defer cancel()

// Acquire two tokens
token1, err := limiter.Acquire(ctx)
token1, err := limiter.AcquireToken(ctx)
require.NoError(t, err, "Failed to acquire token")

token2, err := limiter.Acquire(ctx)
token2, err := limiter.AcquireToken(ctx)
require.NoError(t, err, "Failed to acquire token")

require.Equal(t, limit, limiter.GetRunningTasksNum(), "Expected running tasks to be 2")

// Try to acquire third token, it should not be able to acquire immediately due to limit
go func() {
_, err := limiter.Acquire(ctx)
_, err := limiter.AcquireToken(ctx)
require.NoError(t, err, "Failed to acquire token")
}()

time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run
require.Equal(t, uint64(1), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 1")

// Release a token
token1.Release()
limiter.ReleaseToken(token1)
time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run
require.Equal(t, uint64(2), limiter.GetRunningTasksNum(), "Expected running tasks to be 2")
require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0")

// Release the second token
token2.Release()
limiter.ReleaseToken(token2)
time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run
require.Equal(t, uint64(1), limiter.GetRunningTasksNum(), "Expected running tasks to be 1")
}
Expand All @@ -109,12 +109,12 @@ func TestConcurrencyLimiterAcquire(t *testing.T) {
for i := 0; i < 100; i++ {
go func(i int) {
defer wg.Done()
token, err := limiter.Acquire(ctx)
token, err := limiter.AcquireToken(ctx)
if err != nil {
fmt.Printf("Task %d failed to acquire: %v\n", i, err)
return
}
defer token.Release()
defer limiter.ReleaseToken(token)
// simulate takes some time
time.Sleep(10 * time.Millisecond)
atomic.AddInt64(&sum, 1)
Expand Down
12 changes: 6 additions & 6 deletions pkg/ratelimit/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ func (cr *ConcurrentRunner) Start() {
select {
case task := <-cr.taskChan:
if cr.limiter != nil {
token, err := cr.limiter.Acquire(context.Background())
token, err := cr.limiter.AcquireToken(context.Background())
if err != nil {
continue
}
go cr.run(task.ctx, task.f, token)
go cr.run(task, token)
} else {
go cr.run(task.ctx, task.f, nil)
go cr.run(task, nil)

Check warning on line 112 in pkg/ratelimit/runner.go

View check run for this annotation

Codecov / codecov/patch

pkg/ratelimit/runner.go#L112

Added line #L112 was not covered by tests
}
case <-cr.stopChan:
cr.pendingMu.Lock()
Expand All @@ -133,10 +133,10 @@ func (cr *ConcurrentRunner) Start() {
}()
}

func (cr *ConcurrentRunner) run(ctx context.Context, task func(context.Context), token *TaskToken) {
task(ctx)
func (cr *ConcurrentRunner) run(task *Task, token *TaskToken) {
task.f(task.ctx)
if token != nil {
token.Release()
cr.limiter.ReleaseToken(token)
cr.processPendingTasks()
}
}
Expand Down

0 comments on commit c6673ab

Please sign in to comment.