Skip to content

Commit

Permalink
wip: rate limited logger
Browse files Browse the repository at this point in the history
  • Loading branch information
AndersonQ committed Aug 6, 2024
1 parent 1160ccd commit b4dcc44
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 0 deletions.
105 changes: 105 additions & 0 deletions logp/ratelimited.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package logp

import (
"sync"
"sync/atomic"
"time"
)

// RateLimitedLogger is a logger that limits log messages to at most once within
// a specified period.
// It is intended for logging events that occur frequently, providing a summary
// with the number of occurrences within the given period.
//
// RateLimitedLogger takes a logger function, logFn, which is called every time
// the specified period has elapsed to log the summary.
type RateLimitedLogger struct {
count atomic.Uint64

period time.Duration

// logFn is called for logging, which receives the count of events and the
// duration since the last call.
logFn func(count uint64, d time.Duration)
lastLog time.Time
done chan struct{}

// nowFn is used to acquire the current time instead of time.Now so it can
// be mocked for tests.
nowFn func() time.Time
// newTickerFn is used to acquire a *time.Ticker instead of time.NewTicker
// so it can be mocked for tests.
newTickerFn func(duration time.Duration) *time.Ticker

started atomic.Bool
wg sync.WaitGroup
ticker *time.Ticker
}

// NewRateLimited returns a new RateLimitedLogger. It takes a logFn, which is
// called with the count of events and the period between each call,
// and a period determining how often the log function should be called.
func NewRateLimited(
logFn func(count uint64, d time.Duration), period time.Duration) *RateLimitedLogger {
return &RateLimitedLogger{
period: period,
logFn: logFn,

nowFn: time.Now,
newTickerFn: time.NewTicker,
}
}

func (r *RateLimitedLogger) Add() {
r.count.Add(1)
}

func (r *RateLimitedLogger) AddN(n uint64) {
r.count.Add(n)
}

func (r *RateLimitedLogger) Start() {
if r.started.Load() {
return
}

r.done = make(chan struct{})
r.started.Store(true)
r.lastLog = r.nowFn()
r.ticker = r.newTickerFn(r.period)

r.wg.Add(1)
go func() {
defer r.wg.Done()

defer r.ticker.Stop()

for {
select {
case <-r.ticker.C:
r.log()
case <-r.done:
r.log()
return
}
}
}()
}

func (r *RateLimitedLogger) Stop() {
if !r.started.Load() {
return
}

close(r.done)
r.wg.Wait()
r.started.Store(false)
}

func (r *RateLimitedLogger) log() {
count := r.count.Swap(0)
if count > 0 {
r.lastLog = r.nowFn()
r.logFn(count, r.period)
}
}
177 changes: 177 additions & 0 deletions logp/ratelimited_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package logp

import (
"bytes"
"fmt"
"io"
"math"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type syncBuffer struct {
buff bytes.Buffer
mu sync.Mutex
}

func (s *syncBuffer) Read(p []byte) (n int, err error) {
s.mu.Lock()
defer s.mu.Unlock()

return s.buff.Read(p)
}

func (s *syncBuffer) Write(p []byte) (n int, err error) {
s.mu.Lock()
defer s.mu.Unlock()

return fmt.Fprintf(&s.buff, "%s", p)
}

func TestRateLimitedLogger(t *testing.T) {
pattern := "%d occurrences in the last %s"

newLogger := func() (io.Reader, func(count uint64, d time.Duration)) {
sbuff := &syncBuffer{}

logFn := func(count uint64, d time.Duration) {
fmt.Fprintf(sbuff, pattern, count, d)
}
return sbuff, logFn
}

now := time.Now()

t.Run("Start", func(t *testing.T) {
r := NewRateLimited(func(count uint64, d time.Duration) {}, math.MaxInt64)
defer r.Stop()
r.nowFn = func() time.Time { return now }

r.Start()

assert.True(t, r.started.Load(),
"Start() was called, thus 'started' should be true")
assert.NotEmpty(t, r.lastLog, "lastLog should have been set")
})

t.Run("Start twice", func(t *testing.T) {
r := NewRateLimited(func(count uint64, d time.Duration) {}, math.MaxInt64)
defer r.Stop()

r.nowFn = func() time.Time { return now }

r.Start()
r.nowFn = func() time.Time { return now.Add(time.Minute) }
r.Start()

assert.Equal(t, now, r.lastLog, "lastLog should have been set a second time")
})

t.Run("Stop", func(t *testing.T) {
tcs := []struct {
name string
count int
}{
{name: "once", count: 1},
{name: "twice", count: 2},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
buff, logFn := newLogger()
r := NewRateLimited(logFn, 42*time.Second)
r.nowFn = func() time.Time { return now }

tch := make(chan time.Time)
r.newTickerFn = func(duration time.Duration) *time.Ticker {
return &time.Ticker{C: tch}
}

r.Start()

r.nowFn = func() time.Time { return now.Add(42 * time.Second) }

r.count.Add(1)
for i := 0; i < tc.count; i++ {
r.Stop()
}

bs, err := io.ReadAll(buff)
require.NoError(t, err, "failed reading logs")
logs := string(bs)
got := strings.TrimSpace(logs)

assert.False(t, r.started.Load(),
"Stop() was called, thus 'started' should be false")
assert.Len(t, strings.Split(got, "\n"), 1)
assert.Contains(t, logs, fmt.Sprintf(pattern, 1, 42*time.Second))

})
}
})

t.Run("Add", func(t *testing.T) {
buff, logFn := newLogger()
r := NewRateLimited(logFn, 42*time.Second)
defer r.Stop()

r.nowFn = func() time.Time { return now }

tch := make(chan time.Time)
r.newTickerFn = func(duration time.Duration) *time.Ticker {
return &time.Ticker{C: tch}
}

r.Start()
r.Add()

r.nowFn = func() time.Time { return now.Add(42 * time.Second) }
tch <- now.Add(42 * time.Second)

var logs string
assert.Eventually(t, func() bool {
bs, err := io.ReadAll(buff)
require.NoError(t, err, "failed reading logs")
logs = strings.TrimSpace(string(bs))

return len(strings.Split(logs, "\n")) == 1
}, time.Second, 100*time.Millisecond, "should have found 1 log")

assert.Contains(t, logs, fmt.Sprintf(pattern, 1, 42*time.Second))
})

t.Run("AddN", func(t *testing.T) {
buff, logFn := newLogger()
r := NewRateLimited(logFn, 42*time.Second)
defer r.Stop()

r.nowFn = func() time.Time { return now }

tch := make(chan time.Time)
r.newTickerFn = func(duration time.Duration) *time.Ticker {
return &time.Ticker{C: tch}
}

r.Start()
r.AddN(42)

r.nowFn = func() time.Time { return now.Add(42 * time.Second) }
tch <- now.Add(42 * time.Second)

var logs string
assert.Eventually(t, func() bool {
bs, err := io.ReadAll(buff)
require.NoError(t, err, "failed reading logs")
logs = strings.TrimSpace(string(bs))

return len(strings.Split(logs, "\n")) == 1
}, time.Second, 100*time.Millisecond, "should have found 1 log")

assert.Contains(t, logs, fmt.Sprintf(pattern, 42, 42*time.Second))
})
}

0 comments on commit b4dcc44

Please sign in to comment.