-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
package logp | ||
|
||
import ( | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
) | ||
|
||
// RateLimitedLogger is a logger that limits log messages to at most once within | ||
// a specified period. | ||
// It is intended for logging events that occur frequently, providing a summary | ||
// with the number of occurrences within the given period. | ||
// | ||
// RateLimitedLogger takes a logger function, logFn, which is called every time | ||
// the specified period has elapsed to log the summary. | ||
type RateLimitedLogger struct { | ||
count atomic.Uint64 | ||
|
||
period time.Duration | ||
|
||
// logFn is called for logging, which receives the count of events and the | ||
// duration since the last call. | ||
logFn func(count uint64, d time.Duration) | ||
lastLog time.Time | ||
done chan struct{} | ||
|
||
// nowFn is used to acquire the current time instead of time.Now so it can | ||
// be mocked for tests. | ||
nowFn func() time.Time | ||
// newTickerFn is used to acquire a *time.Ticker instead of time.NewTicker | ||
// so it can be mocked for tests. | ||
newTickerFn func(duration time.Duration) *time.Ticker | ||
|
||
started atomic.Bool | ||
wg sync.WaitGroup | ||
ticker *time.Ticker | ||
} | ||
|
||
// NewRateLimited returns a new RateLimitedLogger. It takes a logFn, which is | ||
// called with the count of events and the period between each call, | ||
// and a period determining how often the log function should be called. | ||
func NewRateLimited( | ||
logFn func(count uint64, d time.Duration), period time.Duration) *RateLimitedLogger { | ||
return &RateLimitedLogger{ | ||
period: period, | ||
logFn: logFn, | ||
|
||
nowFn: time.Now, | ||
newTickerFn: time.NewTicker, | ||
} | ||
} | ||
|
||
func (r *RateLimitedLogger) Add() { | ||
r.count.Add(1) | ||
} | ||
|
||
func (r *RateLimitedLogger) AddN(n uint64) { | ||
r.count.Add(n) | ||
} | ||
|
||
func (r *RateLimitedLogger) Start() { | ||
if r.started.Load() { | ||
return | ||
} | ||
|
||
r.done = make(chan struct{}) | ||
r.started.Store(true) | ||
r.lastLog = r.nowFn() | ||
r.ticker = r.newTickerFn(r.period) | ||
|
||
r.wg.Add(1) | ||
go func() { | ||
defer r.wg.Done() | ||
|
||
defer r.ticker.Stop() | ||
|
||
for { | ||
select { | ||
case <-r.ticker.C: | ||
r.log() | ||
case <-r.done: | ||
r.log() | ||
return | ||
} | ||
} | ||
}() | ||
} | ||
|
||
func (r *RateLimitedLogger) Stop() { | ||
if !r.started.Load() { | ||
return | ||
} | ||
|
||
close(r.done) | ||
r.wg.Wait() | ||
r.started.Store(false) | ||
} | ||
|
||
func (r *RateLimitedLogger) log() { | ||
count := r.count.Swap(0) | ||
if count > 0 { | ||
r.lastLog = r.nowFn() | ||
r.logFn(count, r.period) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
package logp | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"io" | ||
"math" | ||
"strings" | ||
"sync" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
type syncBuffer struct { | ||
buff bytes.Buffer | ||
mu sync.Mutex | ||
} | ||
|
||
func (s *syncBuffer) Read(p []byte) (n int, err error) { | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
|
||
return s.buff.Read(p) | ||
} | ||
|
||
func (s *syncBuffer) Write(p []byte) (n int, err error) { | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
|
||
return fmt.Fprintf(&s.buff, "%s", p) | ||
} | ||
|
||
func TestRateLimitedLogger(t *testing.T) { | ||
pattern := "%d occurrences in the last %s" | ||
|
||
newLogger := func() (io.Reader, func(count uint64, d time.Duration)) { | ||
sbuff := &syncBuffer{} | ||
|
||
logFn := func(count uint64, d time.Duration) { | ||
fmt.Fprintf(sbuff, pattern, count, d) | ||
} | ||
return sbuff, logFn | ||
} | ||
|
||
now := time.Now() | ||
|
||
t.Run("Start", func(t *testing.T) { | ||
r := NewRateLimited(func(count uint64, d time.Duration) {}, math.MaxInt64) | ||
defer r.Stop() | ||
r.nowFn = func() time.Time { return now } | ||
|
||
r.Start() | ||
|
||
assert.True(t, r.started.Load(), | ||
"Start() was called, thus 'started' should be true") | ||
assert.NotEmpty(t, r.lastLog, "lastLog should have been set") | ||
}) | ||
|
||
t.Run("Start twice", func(t *testing.T) { | ||
r := NewRateLimited(func(count uint64, d time.Duration) {}, math.MaxInt64) | ||
defer r.Stop() | ||
|
||
r.nowFn = func() time.Time { return now } | ||
|
||
r.Start() | ||
r.nowFn = func() time.Time { return now.Add(time.Minute) } | ||
r.Start() | ||
|
||
assert.Equal(t, now, r.lastLog, "lastLog should have been set a second time") | ||
}) | ||
|
||
t.Run("Stop", func(t *testing.T) { | ||
tcs := []struct { | ||
name string | ||
count int | ||
}{ | ||
{name: "once", count: 1}, | ||
{name: "twice", count: 2}, | ||
} | ||
|
||
for _, tc := range tcs { | ||
t.Run(tc.name, func(t *testing.T) { | ||
buff, logFn := newLogger() | ||
r := NewRateLimited(logFn, 42*time.Second) | ||
r.nowFn = func() time.Time { return now } | ||
|
||
tch := make(chan time.Time) | ||
r.newTickerFn = func(duration time.Duration) *time.Ticker { | ||
return &time.Ticker{C: tch} | ||
} | ||
|
||
r.Start() | ||
|
||
r.nowFn = func() time.Time { return now.Add(42 * time.Second) } | ||
|
||
r.count.Add(1) | ||
for i := 0; i < tc.count; i++ { | ||
r.Stop() | ||
} | ||
|
||
bs, err := io.ReadAll(buff) | ||
require.NoError(t, err, "failed reading logs") | ||
logs := string(bs) | ||
got := strings.TrimSpace(logs) | ||
|
||
assert.False(t, r.started.Load(), | ||
"Stop() was called, thus 'started' should be false") | ||
assert.Len(t, strings.Split(got, "\n"), 1) | ||
assert.Contains(t, logs, fmt.Sprintf(pattern, 1, 42*time.Second)) | ||
|
||
}) | ||
} | ||
}) | ||
|
||
t.Run("Add", func(t *testing.T) { | ||
buff, logFn := newLogger() | ||
r := NewRateLimited(logFn, 42*time.Second) | ||
defer r.Stop() | ||
|
||
r.nowFn = func() time.Time { return now } | ||
|
||
tch := make(chan time.Time) | ||
r.newTickerFn = func(duration time.Duration) *time.Ticker { | ||
return &time.Ticker{C: tch} | ||
} | ||
|
||
r.Start() | ||
r.Add() | ||
|
||
r.nowFn = func() time.Time { return now.Add(42 * time.Second) } | ||
tch <- now.Add(42 * time.Second) | ||
|
||
var logs string | ||
assert.Eventually(t, func() bool { | ||
bs, err := io.ReadAll(buff) | ||
require.NoError(t, err, "failed reading logs") | ||
logs = strings.TrimSpace(string(bs)) | ||
|
||
return len(strings.Split(logs, "\n")) == 1 | ||
}, time.Second, 100*time.Millisecond, "should have found 1 log") | ||
|
||
assert.Contains(t, logs, fmt.Sprintf(pattern, 1, 42*time.Second)) | ||
}) | ||
|
||
t.Run("AddN", func(t *testing.T) { | ||
buff, logFn := newLogger() | ||
r := NewRateLimited(logFn, 42*time.Second) | ||
defer r.Stop() | ||
|
||
r.nowFn = func() time.Time { return now } | ||
|
||
tch := make(chan time.Time) | ||
r.newTickerFn = func(duration time.Duration) *time.Ticker { | ||
return &time.Ticker{C: tch} | ||
} | ||
|
||
r.Start() | ||
r.AddN(42) | ||
|
||
r.nowFn = func() time.Time { return now.Add(42 * time.Second) } | ||
tch <- now.Add(42 * time.Second) | ||
|
||
var logs string | ||
assert.Eventually(t, func() bool { | ||
bs, err := io.ReadAll(buff) | ||
require.NoError(t, err, "failed reading logs") | ||
logs = strings.TrimSpace(string(bs)) | ||
|
||
return len(strings.Split(logs, "\n")) == 1 | ||
}, time.Second, 100*time.Millisecond, "should have found 1 log") | ||
|
||
assert.Contains(t, logs, fmt.Sprintf(pattern, 42, 42*time.Second)) | ||
}) | ||
} |