Skip to content

Commit

Permalink
[B-1592] Call shutdown on handler when ctx done (#233)
Browse files Browse the repository at this point in the history
Co-authored-by: Mirco Bordoni <mirco.bordoni@toolsforhumanity.com>
  • Loading branch information
carlomazzaferro and mircobordoni authored May 16, 2024
1 parent 877c637 commit d30fb87
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ loop:
select {
case <-ctx.Done():
zap.S().Info("closing jobs channel")
c.handler.Shutdown()
close(jobs)
break loop
default:
Expand Down
15 changes: 14 additions & 1 deletion consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type MsgHandler struct {
msgsReceivedCount int
expectedMsg TestMsg
expectedMsgAttributes interface{}
shutdownReceived bool
}

func TestConsume(t *testing.T) {
Expand Down Expand Up @@ -109,7 +110,8 @@ func TestConsume_GracefulShutdown(t *testing.T) {
BatchSize: batchSize,
ExtendEnabled: true,
}
consumer := NewConsumer(awsCfg, config, &MsgHandler{})
msgHandler := MsgHandler{}
consumer := NewConsumer(awsCfg, config, &msgHandler)
go func() {
time.Sleep(time.Second * 1)
// Cancel context to trigger graceful shutdown
Expand All @@ -122,6 +124,11 @@ func TestConsume_GracefulShutdown(t *testing.T) {
os.Exit(1)
}()
consumer.Consume(ctx)

assert.Eventually(t, func() bool {
// Check that shutdown was called
return assert.Equal(t, true, msgHandler.shutdownReceived)
}, time.Second*2, time.Millisecond*100)
}

func createQueue(t *testing.T, ctx context.Context, awsCfg aws.Config, queueName string) *string {
Expand Down Expand Up @@ -164,6 +171,12 @@ func (m *MsgHandler) Run(ctx context.Context, msg *Message) error {
return err
}

func (m *MsgHandler) Shutdown() {
zap.S().Info("Shutting down")
m.shutdownReceived = true
// Do nothing
}

func sendTestMsg(t *testing.T, ctx context.Context, consumer *Consumer, queueUrl *string, expectedMsg TestMsg) TestMsg {
messageBodyBytes, err := json.Marshal(expectedMsg)
_, err = consumer.sqs.SendMessage(ctx, &sqs.SendMessageInput{
Expand Down
1 change: 1 addition & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ import "context"

type Handler interface {
Run(ctx context.Context, msg *Message) error
Shutdown()
}

0 comments on commit d30fb87

Please sign in to comment.