diff --git a/mux.go b/mux.go index b0e0203f..4645a18f 100644 --- a/mux.go +++ b/mux.go @@ -99,8 +99,9 @@ func newMux(dst string, option *ClientOption, init, dead wire, wireFn wireFn, wi for i := 0; i < len(m.wire); i++ { m.wire[i].Store(init) } - m.dpool = newPool(option.BlockingPoolSize, dead, wireFn) - m.spool = newPool(option.BlockingPoolSize, dead, wireNoBgFn) + + m.dpool = newPool(option.BlockingPoolSize, dead, option.IdleConnTTL, option.BlockingPoolMinSize, wireFn) + m.spool = newPool(option.BlockingPoolSize, dead, option.IdleConnTTL, option.BlockingPoolMinSize, wireNoBgFn) return m } diff --git a/pipe_test.go b/pipe_test.go index 3d6c14f6..30bfd5c7 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -1005,7 +1005,7 @@ func TestDoStreamRecycle(t *testing.T) { go func() { mock.Expect("PING").ReplyString("OK") }() - conns := newPool(1, nil, nil) + conns := newPool(1, nil, 0, 0, nil) s := p.DoStream(context.Background(), conns, cmds.NewCompleted([]string{"PING"})) buf := bytes.NewBuffer(nil) if err := s.Error(); err != nil { @@ -1058,7 +1058,7 @@ func TestDoStreamRecycleDestinationFull(t *testing.T) { go func() { mock.Expect("PING").ReplyBlobString("OK") }() - conns := newPool(1, nil, nil) + conns := newPool(1, nil, 0, 0, nil) s := p.DoStream(context.Background(), conns, cmds.NewCompleted([]string{"PING"})) buf := &limitedbuffer{buf: make([]byte, 1)} if err := s.Error(); err != nil { @@ -1091,7 +1091,7 @@ func TestDoMultiStreamRecycle(t *testing.T) { go func() { mock.Expect("PING").Expect("PING").ReplyString("OK").ReplyString("OK") }() - conns := newPool(1, nil, nil) + conns := newPool(1, nil, 0, 0, nil) s := p.DoMultiStream(context.Background(), conns, cmds.NewCompleted([]string{"PING"}), cmds.NewCompleted([]string{"PING"})) buf := bytes.NewBuffer(nil) if err := s.Error(); err != nil { @@ -1124,7 +1124,7 @@ func TestDoMultiStreamRecycleDestinationFull(t *testing.T) { go func() { mock.Expect("PING").Expect("PING").ReplyBlobString("OK").ReplyBlobString("OK") }() - conns := newPool(1, nil, nil) + conns := newPool(1, nil, 0, 0, nil) s := p.DoMultiStream(context.Background(), conns, cmds.NewCompleted([]string{"PING"}), cmds.NewCompleted([]string{"PING"})) buf := &limitedbuffer{buf: make([]byte, 1)} if err := s.Error(); err != nil { @@ -3569,7 +3569,7 @@ func TestAlreadyCanceledContext(t *testing.T) { t.Fatalf("unexpected err %v", err) } - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) if s := p.DoStream(ctx, cp, cmds.NewCompleted([]string{"GET", "a"})); !errors.Is(s.Error(), context.Canceled) { t.Fatalf("unexpected err %v", s.Error()) } @@ -3614,7 +3614,7 @@ func TestCancelContext_DoStream(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) defer cancel() - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) s := p.DoStream(ctx, cp, cmds.NewCompleted([]string{"GET", "a"})) if err := s.Error(); err != io.EOF && !strings.Contains(err.Error(), "i/o") { t.Fatalf("unexpected err %v", err) @@ -3631,7 +3631,7 @@ func TestWriteDeadlineIsShorterThanContextDeadline_DoStream(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) startTime := time.Now() s := p.DoStream(ctx, cp, cmds.NewCompleted([]string{"GET", "a"})) if err := s.Error(); err != io.EOF && !strings.Contains(err.Error(), "i/o") { @@ -3652,7 +3652,7 @@ func TestWriteDeadlineIsNoShorterThanContextDeadline_DoStreamBlocked(t *testing. ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) startTime := time.Now() s := p.DoStream(ctx, cp, cmds.NewBlockingCompleted([]string{"BLPOP", "a"})) if err := s.Error(); err != io.EOF && !strings.Contains(err.Error(), "i/o") { @@ -3727,7 +3727,7 @@ func TestCancelContext_DoMultiStream(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) defer cancel() - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) s := p.DoMultiStream(ctx, cp, cmds.NewCompleted([]string{"GET", "a"})) if err := s.Error(); err != io.EOF && !strings.Contains(err.Error(), "i/o") { t.Fatalf("unexpected err %v", err) @@ -3744,7 +3744,7 @@ func TestWriteDeadlineIsShorterThanContextDeadline_DoMultiStream(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) startTime := time.Now() s := p.DoMultiStream(ctx, cp, cmds.NewCompleted([]string{"GET", "a"})) if err := s.Error(); err != io.EOF && !strings.Contains(err.Error(), "i/o") { @@ -3765,7 +3765,7 @@ func TestWriteDeadlineIsNoShorterThanContextDeadline_DoMultiStreamBlocked(t *tes ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) startTime := time.Now() s := p.DoMultiStream(ctx, cp, cmds.NewBlockingCompleted([]string{"BLPOP", "a"})) if err := s.Error(); err != io.EOF && !strings.Contains(err.Error(), "i/o") { @@ -3797,7 +3797,7 @@ func TestTimeout_DoStream(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) p, _, _, _ := setup(t, ClientOption{ConnWriteTimeout: time.Millisecond * 30}) - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) s := p.DoStream(context.Background(), cp, cmds.NewCompleted([]string{"GET", "a"})) if err := s.Error(); err != io.EOF && !strings.Contains(err.Error(), "i/o") { @@ -3817,7 +3817,7 @@ func TestForceClose_DoStream_Block(t *testing.T) { p.Close() }() - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) s := p.DoStream(context.Background(), cp, cmds.NewBlockingCompleted([]string{"GET", "a"})) if s.Error() != nil { @@ -3874,7 +3874,7 @@ func TestTimeout_DoMultiStream(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) p, _, _, _ := setup(t, ClientOption{ConnWriteTimeout: time.Millisecond * 30}) - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) s := p.DoMultiStream(context.Background(), cp, cmds.NewCompleted([]string{"GET", "a"})) if err := s.Error(); err != io.EOF && !strings.Contains(err.Error(), "i/o") { @@ -3894,7 +3894,7 @@ func TestForceClose_DoMultiStream_Block(t *testing.T) { p.Close() }() - cp := newPool(1, nil, nil) + cp := newPool(1, nil, 0, 0, nil) s := p.DoMultiStream(context.Background(), cp, cmds.NewBlockingCompleted([]string{"GET", "a"})) if s.Error() != nil { diff --git a/pool.go b/pool.go index f0a69c63..a698e2f8 100644 --- a/pool.go +++ b/pool.go @@ -1,30 +1,39 @@ package rueidis -import "sync" +import ( + "sync" + "time" +) -func newPool(cap int, dead wire, makeFn func() wire) *pool { +func newPool(cap int, dead wire, idleConnTTL time.Duration, minSize int, makeFn func() wire) *pool { if cap <= 0 { cap = DefaultPoolSize } return &pool{ - size: 0, - cap: cap, - dead: dead, - make: makeFn, - list: make([]wire, 0, 4), - cond: sync.NewCond(&sync.Mutex{}), + size: 0, + minSize: minSize, + cap: cap, + dead: dead, + make: makeFn, + list: make([]wire, 0, 4), + cond: sync.NewCond(&sync.Mutex{}), + idleConnTTL: idleConnTTL, } } type pool struct { - dead wire - cond *sync.Cond - make func() wire - list []wire - size int - cap int - down bool + dead wire + cond *sync.Cond + make func() wire + list []wire + size int + minSize int + cap int + down bool + idleConnTTL time.Duration + timer *time.Timer + timerIsActive bool } func (p *pool) Acquire() (v wire) { @@ -50,6 +59,7 @@ func (p *pool) Store(v wire) { p.cond.L.Lock() if !p.down && v.Error() == nil { p.list = append(p.list, v) + p.startTimerIfNeeded() } else { p.size-- v.Close() @@ -61,9 +71,49 @@ func (p *pool) Store(v wire) { func (p *pool) Close() { p.cond.L.Lock() p.down = true + p.stopTimer() for _, w := range p.list { w.Close() } p.cond.L.Unlock() p.cond.Broadcast() } + +func (p *pool) startTimerIfNeeded() { + if p.idleConnTTL == 0 || p.timerIsActive || len(p.list) <= p.minSize { + return + } + + p.timerIsActive = true + if p.timer == nil { + p.timer = time.AfterFunc(p.idleConnTTL, p.removeIdleConns) + } else { + p.timer.Reset(p.idleConnTTL) + } +} + +func (p *pool) removeIdleConns() { + p.cond.L.Lock() + defer p.cond.L.Unlock() + + if p.down || len(p.list) <= p.minSize { + return + } + + newLen := min(p.minSize, len(p.list)) + for i, w := range p.list[newLen:] { + w.Close() + p.list[newLen+i] = nil + p.size-- + } + + p.list = p.list[:newLen] + p.timerIsActive = false +} + +func (p *pool) stopTimer() { + p.timerIsActive = false + if p.timer != nil { + p.timer.Stop() + } +} diff --git a/pool_test.go b/pool_test.go index 32fef647..16d58d54 100644 --- a/pool_test.go +++ b/pool_test.go @@ -5,6 +5,7 @@ import ( "runtime" "sync/atomic" "testing" + "time" ) var dead = deadFn() @@ -14,7 +15,7 @@ func TestPool(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) setup := func(size int) (*pool, *int32) { var count int32 - return newPool(size, dead, func() wire { + return newPool(size, dead, 0, 0, func() wire { atomic.AddInt32(&count, 1) closed := false return &mockWire{ @@ -32,7 +33,7 @@ func TestPool(t *testing.T) { } t.Run("DefaultPoolSize", func(t *testing.T) { - p := newPool(0, dead, func() wire { return nil }) + p := newPool(0, dead, 0, 0, func() wire { return nil }) if cap(p.list) == 0 { t.Fatalf("DefaultPoolSize is not applied") } @@ -180,7 +181,7 @@ func TestPoolError(t *testing.T) { defer ShouldNotLeaked(SetupLeakDetection()) setup := func(size int) (*pool, *int32) { var count int32 - return newPool(size, dead, func() wire { + return newPool(size, dead, 0, 0, func() wire { w := &pipe{} w.pshks.Store(emptypshks) c := atomic.AddInt32(&count, 1) @@ -211,3 +212,92 @@ func TestPoolError(t *testing.T) { } }) } + +func TestPoolWithIdleTTL(t *testing.T) { + defer ShouldNotLeaked(SetupLeakDetection()) + setup := func(size int, ttl time.Duration, minSize int) *pool { + return newPool(size, dead, ttl, minSize, func() wire { + closed := false + return &mockWire{ + CloseFn: func() { + closed = true + }, + ErrorFn: func() error { + if closed { + return ErrClosing + } + return nil + }, + } + }) + } + + t.Run("Removing idle conns. Min size is not 0", func(t *testing.T) { + minSize := 3 + p := setup(0, time.Millisecond*50, minSize) + conns := make([]wire, 10) + + for i := 0; i < 2; i++ { + for i := range conns { + w := p.Acquire() + conns[i] = w + } + + for _, w := range conns { + p.Store(w) + } + + time.Sleep(time.Millisecond * 60) + p.cond.Broadcast() + time.Sleep(time.Millisecond * 40) + + p.cond.L.Lock() + if p.size != minSize { + defer p.cond.L.Unlock() + t.Fatalf("size must be equal to %d, actual: %d", minSize, p.size) + } + + if len(p.list) != minSize { + defer p.cond.L.Unlock() + t.Fatalf("pool len must equal to %d, actual: %d", minSize, len(p.list)) + } + p.cond.L.Unlock() + } + + p.Close() + }) + + t.Run("Removing idle conns. Min size is 0", func(t *testing.T) { + p := setup(0, time.Millisecond*50, 0) + conns := make([]wire, 10) + + for i := 0; i < 2; i++ { + for i := range conns { + w := p.Acquire() + conns[i] = w + } + + for _, w := range conns { + p.Store(w) + } + + time.Sleep(time.Millisecond * 60) + p.cond.Broadcast() + time.Sleep(time.Millisecond * 40) + + p.cond.L.Lock() + if p.size != 0 { + defer p.cond.L.Unlock() + t.Fatalf("size must be equal to 0, actual: %d", p.size) + } + + if len(p.list) != 0 { + defer p.cond.L.Unlock() + t.Fatalf("pool len must equal to 0, actual: %d", len(p.list)) + } + p.cond.L.Unlock() + } + + p.Close() + }) +} diff --git a/rueidis.go b/rueidis.go index d0919a1d..6d9591f8 100644 --- a/rueidis.go +++ b/rueidis.go @@ -133,6 +133,15 @@ type ClientOption struct { // WriteBufferEachConn is the size of the bufio.NewWriterSize for each connection, default to DefaultWriteBuffer (0.5 MiB). WriteBufferEachConn int + // IdleConnTTL is the duration for which a connection will be closed if it is idle. + // If IdleConnTTL is 0, then idle connections will not be closed. + IdleConnTTL time.Duration + // BlockingPoolMinSize is the minimum size of the connection pool + // shared by blocking commands (ex BLPOP, XREAD with BLOCK). + // Only relevant if IdleConnTTL is not 0. This parameter limits + // the number of idle connections that can be removed by TTL. + BlockingPoolMinSize int + // BlockingPoolSize is the size of the connection pool shared by blocking commands (ex BLPOP, XREAD with BLOCK). // The default is DefaultPoolSize. BlockingPoolSize int