From 2aa08d0433f639c71457385490193fd47b3efe98 Mon Sep 17 00:00:00 2001 From: Deng Ming Date: Mon, 31 Oct 2022 14:49:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BB=B6=E8=BF=9F=E9=98=9F=E5=88=97=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 1 + constrain.go | 8 -- queue/delay_queue.go | 160 +++++++++++++++---------- queue/delay_queue_test.go | 245 +++++++++++++++++++++++++++++++++++++- 4 files changed, 338 insertions(+), 76 deletions(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 9de654cc..e49ccf43 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -2,6 +2,7 @@ - [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101) - [queue: API 定义](https://github.com/gotomicro/ekit/pull/109) - [queue: 基于堆和切片的优先级队列](https://github.com/gotomicro/ekit/pull/110) +- [queue: 延时队列](https://github.com/gotomicro/ekit/pull/111) # v0.0.4 - [slice: 重构 index 和 contains 的方法,直接调用对应Func 版本](https://github.com/gotomicro/ekit/pull/87) diff --git a/constrain.go b/constrain.go index 7246d1f4..33faccff 100644 --- a/constrain.go +++ b/constrain.go @@ -25,11 +25,3 @@ type RealNumber interface { type Number interface { RealNumber | ~complex64 | ~complex128 } - -type Comparable[T any] interface { - // CompareTo 方法只能返回以下三个返回值: - // 1: dst 比较大 - // 0: 两者一样大小 - // -1: dst 比较小 - CompareTo(dst T) int -} diff --git a/queue/delay_queue.go b/queue/delay_queue.go index c25396f9..4f28b46c 100644 --- a/queue/delay_queue.go +++ b/queue/delay_queue.go @@ -19,119 +19,151 @@ import ( "sync" "time" - "github.com/gotomicro/ekit/internal/queue" + "github.com/gotomicro/ekit/list" - "github.com/gotomicro/ekit" + "github.com/gotomicro/ekit/internal/queue" ) -type DelayQueue[T Delayable[T]] struct { +type DelayQueue[T Delayable] struct { q queue.PriorityQueue[T] mutex sync.RWMutex - enqueueSignal chan struct{} - dequeueSignal chan struct{} + enqueueReqs *list.LinkedList[delayQueueReq] + dequeueReqs *list.LinkedList[delayQueueReq] } -func NewDelayQueue[T Delayable[T]](compare ekit.Comparator[T]) *DelayQueue[T] { +type delayQueueReq struct { + ch chan struct{} +} + +func NewDelayQueue[T Delayable](c int) *DelayQueue[T] { return &DelayQueue[T]{ - q: *queue.NewPriorityQueue[T](0, compare), - enqueueSignal: make(chan struct{}, 1), - dequeueSignal: make(chan struct{}, 1), + q: *queue.NewPriorityQueue[T](c, func(src T, dst T) int { + srcDelay := src.Delay() + dstDelay := dst.Delay() + if srcDelay > dstDelay { + return 1 + } + if srcDelay == dstDelay { + return 0 + } + return -1 + }), + enqueueReqs: list.NewLinkedList[delayQueueReq](), + dequeueReqs: list.NewLinkedList[delayQueueReq](), } } func (d *DelayQueue[T]) Enqueue(ctx context.Context, t T) error { + // 确保 ctx 没有过期 + if ctx.Err() != nil { + return ctx.Err() + } for { d.mutex.Lock() err := d.q.Enqueue(t) - d.mutex.Unlock() if err == queue.ErrOutOfCapacity { + ch := make(chan struct{}, 1) + _ = d.enqueueReqs.Append(delayQueueReq{ch: ch}) + d.mutex.Unlock() select { case <-ctx.Done(): return ctx.Err() - case <-d.dequeueSignal: - continue + case <-ch: } + continue } - if err == nil { // 这里使用写锁,是为了在 Dequeue 那边 // 当一开始的 Peek 返回 queue.ErrEmptyQueue 的时候不会错过这个入队信号 - d.mutex.Lock() - head, err := d.q.Peek() - if err != nil { - // 这种情况就是出现在入队成功之后,元素立刻被取走了 - // 这里 err 预期应该只有 queue.ErrEmptyQueue 一种可能 - d.mutex.Lock() + if d.dequeueReqs.Len() == 0 { + // 没人等。 + d.mutex.Unlock() return nil } - if t.CompareTo(head) == 0 { - select { - case d.enqueueSignal <- struct{}{}: - default: - } + req, err := d.dequeueReqs.Delete(0) + if err == nil { + // 唤醒出队的 + req.ch <- struct{}{} } - d.mutex.Lock() } + d.mutex.Unlock() return err } - } func (d *DelayQueue[T]) Dequeue(ctx context.Context) (T, error) { - ticker := time.NewTicker(0) + // 确保 ctx 没有过期 + if ctx.Err() != nil { + var t T + return t, ctx.Err() + } + ticker := time.NewTicker(time.Second) ticker.Stop() + defer func() { + ticker.Stop() + }() for { - d.mutex.RLock() + d.mutex.Lock() head, err := d.q.Peek() - d.mutex.RUnlock() + if err != nil && err != queue.ErrEmptyQueue { + var t T + return t, err + } if err == queue.ErrEmptyQueue { + ch := make(chan struct{}, 1) + _ = d.dequeueReqs.Append(delayQueueReq{ch: ch}) + d.mutex.Unlock() select { case <-ctx.Done(): var t T return t, ctx.Err() - case <-d.enqueueSignal: + case <-ch: } - } else { - ticker.Reset(head.Delay()) - select { - case <-ctx.Done(): - var t T - return t, ctx.Err() - case <-ticker.C: - var t T - d.mutex.Lock() - t, err = d.q.Dequeue() + continue + } + + delay := head.Delay() + // 已经到期了 + if delay <= 0 { + // 拿着锁,所以不然不可能返回 error + t, _ := d.q.Dequeue() + d.wakeEnqueue() + d.mutex.Unlock() + return t, nil + } + + // 在进入 select 之前必须要释放锁 + d.mutex.Unlock() + ticker.Reset(delay) + select { + case <-ctx.Done(): + var t T + return t, ctx.Err() + case <-ticker.C: + var t T + d.mutex.Lock() + t, err = d.q.Dequeue() + // 被人抢走了,理论上是不会出现这个可能的 + if err != nil { d.mutex.Unlock() - // 被人抢走了,理论上是不会出现这个可能的 - if err == queue.ErrEmptyQueue { - continue - } - select { - case d.dequeueSignal <- struct{}{}: - default: - } - return t, nil - case <-d.enqueueSignal: + continue } + d.wakeEnqueue() + d.mutex.Unlock() + return t, nil } } } -type Delayable[T any] interface { - Delay() time.Duration - ekit.Comparable[T] -} - -type user struct { -} - -func (u user) Delay() time.Duration { - //TODO implement me - panic("implement me") +func (d *DelayQueue[T]) wakeEnqueue() { + req, err := d.enqueueReqs.Delete(0) + if err == nil { + // 唤醒等待入队的 + req.ch <- struct{}{} + } } -func (u user) CompareTo(dst user) int { - //TODO implement me - panic("implement me") +type Delayable interface { + Delay() time.Duration } diff --git a/queue/delay_queue_test.go b/queue/delay_queue_test.go index 45d00516..40a2181e 100644 --- a/queue/delay_queue_test.go +++ b/queue/delay_queue_test.go @@ -15,11 +15,248 @@ package queue import ( - "fmt" + "context" "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestNewDelayQueue(t *testing.T) { - q := NewDelayQueue() - fmt.Println(q) +func TestDelayQueue_Dequeue(t *testing.T) { + t.Parallel() + now := time.Now() + testCases := []struct { + name string + q *DelayQueue[delayElem] + timeout time.Duration + wantVal int + wantErr error + }{ + { + name: "dequeued", + q: newDelayQueue(t, delayElem{ + deadline: now.Add(time.Millisecond * 10), + val: 11, + }), + timeout: time.Second, + wantVal: 11, + }, + { + // 元素本身就已经过期了 + name: "already deadline", + q: newDelayQueue(t, delayElem{ + deadline: now.Add(-time.Millisecond * 10), + val: 11, + }), + timeout: time.Second, + wantVal: 11, + }, + { + // 已经超时了的 context 设置 + name: "invalid context", + q: newDelayQueue(t, delayElem{ + deadline: now.Add(time.Millisecond * 10), + val: 11, + }), + timeout: -time.Second, + wantErr: context.DeadlineExceeded, + }, + { + name: "empty and timeout", + q: NewDelayQueue[delayElem](10), + timeout: time.Second, + wantErr: context.DeadlineExceeded, + }, + { + name: "not empty but timeout", + q: newDelayQueue(t, delayElem{ + deadline: now.Add(time.Second * 10), + val: 11, + }), + timeout: time.Second, + wantErr: context.DeadlineExceeded, + }, + } + + for _, tt := range testCases { + tc := tt + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) + defer cancel() + ele, err := tc.q.Dequeue(ctx) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantVal, ele.val) + }) + } + + // 最开始没有元素,然后进去了一个元素 + t.Run("dequeue while enqueue", func(t *testing.T) { + q := NewDelayQueue[delayElem](3) + go func() { + time.Sleep(time.Millisecond * 500) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, delayElem{ + val: 123, + deadline: time.Now().Add(time.Millisecond * 100), + }) + require.NoError(t, err) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ele, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, ele.val) + }) + + // 进去了一个更加短超时时间的元素 + // 于是后面两个都会拿出来,但是时间短的会先拿出来 + t.Run("enqueue short ele", func(t *testing.T) { + q := NewDelayQueue[delayElem](3) + // 长时间过期的元素 + err := q.Enqueue(context.Background(), delayElem{ + val: 234, + deadline: time.Now().Add(time.Second), + }) + require.NoError(t, err) + + go func() { + time.Sleep(time.Millisecond * 200) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + err := q.Enqueue(ctx, delayElem{ + val: 123, + deadline: time.Now().Add(time.Millisecond * 300), + }) + require.NoError(t, err) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + // 先拿出短时间的 + ele, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, ele.val) + // 再拿出长时间的 + ele, err = q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 234, ele.val) + + // 没有元素了,会超时 + _, err = q.Dequeue(ctx) + require.Equal(t, context.DeadlineExceeded, err) + }) +} + +func TestDelayQueue_Enqueue(t *testing.T) { + t.Parallel() + now := time.Now() + testCases := []struct { + name string + q *DelayQueue[delayElem] + timeout time.Duration + val delayElem + wantErr error + }{ + { + name: "enqueued", + q: NewDelayQueue[delayElem](3), + timeout: time.Second, + val: delayElem{val: 123, deadline: now.Add(time.Minute)}, + }, + { + // context 本身已经过期了 + name: "invalid context", + q: NewDelayQueue[delayElem](3), + timeout: -time.Second, + val: delayElem{val: 123, deadline: now.Add(time.Minute)}, + wantErr: context.DeadlineExceeded, + }, + { + // enqueue 的时候阻塞住了,直到超时 + name: "enqueue timeout", + q: newDelayQueue(t, delayElem{val: 123, deadline: now.Add(time.Minute)}), + timeout: time.Millisecond * 100, + val: delayElem{val: 234, deadline: now.Add(time.Minute)}, + wantErr: context.DeadlineExceeded, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) + defer cancel() + err := tc.q.Enqueue(ctx, tc.val) + assert.Equal(t, tc.wantErr, err) + }) + } + + // 队列满了,这时候入队。 + // 在等待一段时间之后,队列元素被取走一个 + t.Run("enqueue while dequeue", func(t *testing.T) { + t.Parallel() + q := newDelayQueue(t, delayElem{val: 123, deadline: time.Now().Add(time.Second)}) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + ele, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, ele.val) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + err := q.Enqueue(ctx, delayElem{val: 345, deadline: time.Now().Add(time.Millisecond * 1500)}) + require.NoError(t, err) + }) + + // 入队相同过期时间的元素 + // 但是因为我们在入队的时候是分别计算 Delay 的 + // 那么就会导致虽然过期时间是相同的,但是因为调用 Delay 有先后之分 + // 所以会造成 dstDelay 就是要比 srcDelay 小一点 + t.Run("enqueue with same deadline", func(t *testing.T) { + t.Parallel() + q := NewDelayQueue[delayElem](3) + deadline := time.Now().Add(time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + err := q.Enqueue(ctx, delayElem{val: 123, deadline: deadline}) + require.NoError(t, err) + err = q.Enqueue(ctx, delayElem{val: 456, deadline: deadline}) + require.NoError(t, err) + err = q.Enqueue(ctx, delayElem{val: 789, deadline: deadline}) + require.NoError(t, err) + + ele, err := q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 123, ele.val) + + ele, err = q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 789, ele.val) + + ele, err = q.Dequeue(ctx) + require.NoError(t, err) + require.Equal(t, 456, ele.val) + }) +} + +func newDelayQueue(t *testing.T, eles ...delayElem) *DelayQueue[delayElem] { + q := NewDelayQueue[delayElem](len(eles)) + for _, ele := range eles { + err := q.Enqueue(context.Background(), ele) + require.NoError(t, err) + } + return q +} + +type delayElem struct { + deadline time.Time + val int +} + +func (d delayElem) Delay() time.Duration { + return time.Until(d.deadline) }