diff --git a/pkg/util/stop/stopper.go b/pkg/util/stop/stopper.go index 8f713487e571..3f58806325ad 100644 --- a/pkg/util/stop/stopper.go +++ b/pkg/util/stop/stopper.go @@ -207,7 +207,13 @@ func (s *Stopper) RunWorker(ctx context.Context, f func(context.Context)) { func (s *Stopper) AddCloser(c Closer) { s.mu.Lock() defer s.mu.Unlock() - s.mu.closers = append(s.mu.closers, c) + select { + case <-s.stopper: + // Close immediately. + c.Close() + default: + s.mu.closers = append(s.mu.closers, c) + } } // WithCancelOnQuiesce returns a child context which is canceled when the @@ -217,7 +223,7 @@ func (s *Stopper) AddCloser(c Closer) { // Canceling this context releases resources associated with it, so code should // call cancel as soon as the operations running in this Context complete. func (s *Stopper) WithCancelOnQuiesce(ctx context.Context) (context.Context, func()) { - return s.withCancel(ctx, s.mu.qCancels) + return s.withCancel(ctx, s.mu.qCancels, s.quiescer) } // WithCancelOnStop returns a child context which is canceled when the @@ -227,24 +233,31 @@ func (s *Stopper) WithCancelOnQuiesce(ctx context.Context) (context.Context, fun // Canceling this context releases resources associated with it, so code should // call cancel as soon as the operations running in this Context complete. func (s *Stopper) WithCancelOnStop(ctx context.Context) (context.Context, func()) { - return s.withCancel(ctx, s.mu.sCancels) + return s.withCancel(ctx, s.mu.sCancels, s.stopper) } func (s *Stopper) withCancel( - ctx context.Context, cancels map[int]func(), + ctx context.Context, cancels map[int]func(), cancelCh chan struct{}, ) (context.Context, func()) { var cancel func() ctx, cancel = context.WithCancel(ctx) s.mu.Lock() defer s.mu.Unlock() - id := s.mu.idAlloc - s.mu.idAlloc++ - cancels[id] = cancel - return ctx, func() { + select { + case <-cancelCh: + // Cancel immediately. cancel() - s.mu.Lock() - defer s.mu.Unlock() - delete(cancels, id) + return ctx, func() {} + default: + id := s.mu.idAlloc + s.mu.idAlloc++ + cancels[id] = cancel + return ctx, func() { + cancel() + s.mu.Lock() + defer s.mu.Unlock() + delete(cancels, id) + } } } @@ -461,10 +474,13 @@ func (s *Stopper) Stop(ctx context.Context) { } s.Quiesce(ctx) + s.mu.Lock() for _, cancel := range s.mu.sCancels { cancel() } close(s.stopper) + s.mu.Unlock() + s.stop.Wait() s.mu.Lock() defer s.mu.Unlock() diff --git a/pkg/util/stop/stopper_test.go b/pkg/util/stop/stopper_test.go index daafcefee7ba..95739177f6c6 100644 --- a/pkg/util/stop/stopper_test.go +++ b/pkg/util/stop/stopper_test.go @@ -17,6 +17,7 @@ package stop_test import ( "context" "fmt" + "runtime" "sync" "sync/atomic" "testing" @@ -253,6 +254,35 @@ func TestStopperClosers(t *testing.T) { } } +func TestStopperCloserConcurrent(t *testing.T) { + defer leaktest.AfterTest(t)() + const trials = 10 + for i := 0; i < trials; i++ { + s := stop.NewStopper() + var tc1 testCloser + + // Add Closer and Stop concurrently. There should be + // no circumstance where the Closer is not called. + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + runtime.Gosched() + s.AddCloser(&tc1) + }() + go func() { + defer wg.Done() + runtime.Gosched() + s.Stop(context.Background()) + }() + wg.Wait() + + if !tc1 { + t.Errorf("expected true; got %t", tc1) + } + } +} + func TestStopperNumTasks(t *testing.T) { defer leaktest.AfterTest(t)() s := stop.NewStopper() @@ -397,6 +427,44 @@ func TestStopperWithCancel(t *testing.T) { } } +func TestStopperWithCancelConcurrent(t *testing.T) { + defer leaktest.AfterTest(t)() + const trials = 10 + for i := 0; i < trials; i++ { + s := stop.NewStopper() + ctx := context.Background() + var ctx1, ctx2 context.Context + + // Tie two contexts to the Stopper and Stop concurrently. There should + // be no circumstance where either Context is not canceled. + var wg sync.WaitGroup + wg.Add(3) + go func() { + defer wg.Done() + runtime.Gosched() + ctx1, _ = s.WithCancelOnQuiesce(ctx) + }() + go func() { + defer wg.Done() + runtime.Gosched() + ctx2, _ = s.WithCancelOnStop(ctx) + }() + go func() { + defer wg.Done() + runtime.Gosched() + s.Stop(ctx) + }() + wg.Wait() + + if err := ctx1.Err(); err != context.Canceled { + t.Errorf("should be canceled: %v", err) + } + if err := ctx2.Err(); err != context.Canceled { + t.Errorf("should be canceled: %v", err) + } + } +} + func TestStopperShouldQuiesce(t *testing.T) { defer leaktest.AfterTest(t)() s := stop.NewStopper()