diff --git a/pkg/isclosed/isclosed.go b/pkg/isclosed/isclosed.go new file mode 100644 index 0000000..a3992ab --- /dev/null +++ b/pkg/isclosed/isclosed.go @@ -0,0 +1,37 @@ +package isclosed + +import ( + "context" + "sync" +) + +// All returns a channel that is closed either when each of the channels is read from or the passed +// context is canceled. All is useful for implementing graceful shutdown of a number of Sends or +// other running goroutines that indicate their state via a returned read only channel. The graceful +// shutdown can be circumvented via the context passed to All to ensure shutdowns will not deadlock. +func All(ctx context.Context, done ...<-chan struct{}) <-chan struct{} { + var ( + shutdown = make(chan struct{}) + wg sync.WaitGroup + ) + + wg.Add(len(done)) + for _, ch := range done { + go func() { + defer wg.Done() + + select { + case <-ctx.Done(): + case <-ch: + } + }() + } + + go func() { + defer close(shutdown) + + wg.Wait() + }() + + return shutdown +} diff --git a/pkg/isclosed/isclosed_test.go b/pkg/isclosed/isclosed_test.go new file mode 100644 index 0000000..04acf30 --- /dev/null +++ b/pkg/isclosed/isclosed_test.go @@ -0,0 +1,96 @@ +package isclosed + +import ( + "context" + "testing" + "time" +) + +var maxTestTimeout = 3 * time.Second + +func TestAll_DoneAfterAllClose(t *testing.T) { + var ( + ctx = context.Background() + a, b, c = make(chan struct{}), make(chan struct{}), make(chan struct{}) + done = All(ctx, a, b, c) + ) + + close(a) + close(b) + close(c) // close all channels + eventually(t, done, maxTestTimeout) +} + +func TestAll_DoneAfterCtxCancel(t *testing.T) { + var ( + ctx, cancel = context.WithCancel(context.Background()) + a, b, c = make(chan struct{}), make(chan struct{}), make(chan struct{}) + done = All(ctx, a, b, c) + ) + + close(a) + close(b) + cancel() // c is never closed, but context is canceled + eventually(t, done, maxTestTimeout) +} + +func TestAll_DoneAfterCtxCancelWithNilChannels(t *testing.T) { + var ( + ctx, cancel = context.WithCancel(context.Background()) + done = All(ctx, nil, nil, nil) + ) + + cancel() + eventually(t, done, maxTestTimeout) +} + +// TestAll_DoneNonBlocking verifies that if all the input channels close, the All function's returned +// channel should also close regardless of the order of the input channel arguments. +func TestAll_DoneNonBlocking(t *testing.T) { + var ( + ctx = context.Background() + a, b, c = make(chan struct{}), make(chan struct{}), make(chan struct{}) + start = make(chan struct{}) + + // done is only closed once all three input channels are closed as the context is never + // cancelled. + done = All(ctx, c, b, a) + ) + + // By default channel a closes with no dependencies. + go func() { + <-start + close(a) + }() + + // Only close channel b after channel a is closed. + go func() { + <-a + close(b) + }() + + // Only close the c channel after both channel b and channel a are closed. + go func() { + <-a + <-b + close(c) + }() + + // Start the closing of the channels once all waiting routines are running. + close(start) + + // Require that the done channel is eventually closed without context cancellation even with + // dependencies on closing between various channels as long as there is no deadlock state. + eventually(t, done, maxTestTimeout) +} + +// eventually blocks until done is closed or d time duration passes. +func eventually(t *testing.T, done <-chan struct{}, d time.Duration) { + t.Helper() + + select { + case <-done: + case <-time.After(d): + t.Fatal("timed out waiting for done to close") + } +} diff --git a/pkg/timelock/timelock.go b/pkg/timelock/timelock.go index 454688a..ac1255a 100644 --- a/pkg/timelock/timelock.go +++ b/pkg/timelock/timelock.go @@ -18,6 +18,8 @@ import ( "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/rpc" "github.com/rs/zerolog" + + "github.com/smartcontractkit/timelock-worker/pkg/isclosed" "github.com/smartcontractkit/timelock-worker/pkg/timelock/contract" ) @@ -122,7 +124,6 @@ func NewTimelockWorker(nodeURL, timelockAddress, callProxyAddress, privateKey st // It handles the retrieval of old and new events, contexts and cancellations. func (tw *Worker) Listen() error { ctxwc, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer cancel() // Log timelock-worker configuration. tw.startLog() @@ -154,6 +155,7 @@ func (tw *Worker) Listen() error { select { case <-ctxwc.Done(): case <-processingDone: + cancel() } tw.logger.Info().Msg("shutting down timelock-worker") @@ -161,9 +163,10 @@ func (tw *Worker) Listen() error { tw.dumpOperationStore(time.Now) // Wait for all goroutines to finish. - <-historyDone - <-newDone - <-schedulingDone + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + <-isclosed.All(shutdownCtx, schedulingDone, historyDone, newDone, processingDone) return nil } @@ -273,7 +276,16 @@ func (tw *Worker) retrieveHistoricalLogs(ctx context.Context) (<-chan struct{}, // processLogs is implemented as a fan-in for all the logs channels, merging all the data and handling logs sequentially. // This function is thread safe. func (tw *Worker) processLogs(ctx context.Context, oldLog, newLog <-chan types.Log) <-chan struct{} { - done := make(chan struct{}) + var ( + done, newDone, oldDone = make(chan struct{}), make(chan struct{}), make(chan struct{}) + ctxwc, cancel = context.WithCancel(ctx) + ) + + // Cancel the context and shutdown the processing routine if no more logs are available. + go func() { + defer cancel() + <-isclosed.All(ctxwc, oldDone, newDone) + }() // This is the goroutine watching over the subscribed and historical logs. go func() { @@ -281,18 +293,30 @@ func (tw *Worker) processLogs(ctx context.Context, oldLog, newLog <-chan types.L for { select { - case log := <-newLog: - if err := tw.handleLog(ctx, log); err != nil { + case log, open := <-newLog: + if !open { + close(newDone) + newLog = nil + continue + } + + if err := tw.handleLog(ctxwc, log); err != nil { tw.logger.Error().Msgf("error processing new log: %v\n", log) } - case log := <-oldLog: - if err := tw.handleLog(ctx, log); err != nil { + case log, open := <-oldLog: + if !open { + close(oldDone) + oldLog = nil + continue + } + + if err := tw.handleLog(ctxwc, log); err != nil { tw.logger.Error().Msgf("error processing historical log: %v\n", log) } - case <-ctx.Done(): - tw.logger.Info().Msgf("received OS signal") + case <-ctxwc.Done(): + tw.logger.Info().Msgf("cancelled processing logs") SetReadyStatus(HealthStatusError) return }