From b1fba6221a90d8003890fba9cae879c737f0c4b2 Mon Sep 17 00:00:00 2001 From: vivek-ng Date: Sun, 22 Nov 2020 01:11:22 -0600 Subject: [PATCH] add support for timeout --- rateLimiter.go | 33 ++++++++++++++++++++++++++++++--- rateLimiter_test.go | 19 +++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/rateLimiter.go b/rateLimiter.go index 5e2a4c5..66146e5 100644 --- a/rateLimiter.go +++ b/rateLimiter.go @@ -3,6 +3,7 @@ package limiter import ( "container/list" "sync" + "time" ) type waiter struct { @@ -14,7 +15,7 @@ type Limiter struct { limit int mu sync.Mutex waitList list.List - //notify []chan struct{} + timeout *int } func NewLimiter(limit int) *Limiter { @@ -23,11 +24,34 @@ func NewLimiter(limit int) *Limiter { } } +func (l *Limiter) WithTimeout(timeout int) *Limiter { + l.timeout = &timeout + return l +} + func (l *Limiter) Wait() { ok, ch := l.proceed() - if !ok { - <-ch + if ok { + return } + if l.timeout != nil { + select { + case <-ch: + case <-time.After((time.Duration(*l.timeout) * time.Second)): + l.mu.Lock() + for w := l.waitList.Front(); w != nil; w = w.Next() { + ele := w.Value.(waiter) + if ele.done == ch { + close(ch) + l.waitList.Remove(w) + break + } + } + l.mu.Unlock() + } + return + } + <-ch } func (l *Limiter) proceed() (bool, chan struct{}) { @@ -51,6 +75,9 @@ func (l *Limiter) Finish() { defer l.mu.Unlock() l.count -= 1 first := l.waitList.Front() + if first == nil { + return + } w := l.waitList.Remove(first).(waiter) w.done <- struct{}{} close(w.done) diff --git a/rateLimiter_test.go b/rateLimiter_test.go index 4d48621..1f79772 100644 --- a/rateLimiter_test.go +++ b/rateLimiter_test.go @@ -45,3 +45,22 @@ func TestConcurrentRateLimiterBlocking(t *testing.T) { wg.Wait() assert.Equal(t, 0, l.waitList.Len()) } + +func TestConcurrentRateLimiterTimeout(t *testing.T) { + l := NewLimiter(2).WithTimeout(2) + + var wg sync.WaitGroup + wg.Add(5) + + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + l.Wait() + }() + } + time.Sleep(3 * time.Second) + wg.Wait() + l.Finish() + l.Finish() + assert.Equal(t, 0, l.waitList.Len()) +}