diff --git a/interceptors/retry/retry.go b/interceptors/retry/retry.go index f89d687c..f0254f91 100644 --- a/interceptors/retry/retry.go +++ b/interceptors/retry/retry.go @@ -100,14 +100,15 @@ func StreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientIntercepto callOpts.onRetryCallback(parentCtx, attempt, lastErr) } var newStreamer grpc.ClientStream - newStreamer, lastErr = streamer(parentCtx, desc, cc, method, grpcOpts...) + newStreamer, lastErr = streamer(perStreamContext(parentCtx, callOpts, attempt), desc, cc, method, grpcOpts...) if lastErr == nil { retryingStreamer := &serverStreamingRetryingStream{ ClientStream: newStreamer, callOpts: callOpts, parentCtx: parentCtx, streamerCall: func(ctx context.Context) (grpc.ClientStream, error) { - return streamer(ctx, desc, cc, method, grpcOpts...) + attempt++ + return streamer(perStreamContext(ctx, callOpts, attempt), desc, cc, method, grpcOpts...) }, } return retryingStreamer, nil @@ -296,6 +297,15 @@ func perCallContext(parentCtx context.Context, callOpts *options, attempt uint) return ctx, cancel } +func perStreamContext(parentCtx context.Context, callOpts *options, attempt uint) context.Context { + ctx := parentCtx + if attempt > 0 && callOpts.includeHeader { + mdClone := metadata.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, fmt.Sprintf("%d", attempt)) + ctx = mdClone.ToOutgoing(ctx) + } + return ctx +} + func contextErrToGrpcErr(err error) error { switch err { case context.DeadlineExceeded: diff --git a/interceptors/retry/retry_test.go b/interceptors/retry/retry_test.go index 5d67a270..59455ef6 100644 --- a/interceptors/retry/retry_test.go +++ b/interceptors/retry/retry_test.go @@ -6,6 +6,7 @@ package retry import ( "context" "io" + "strconv" "strings" "sync" "testing" @@ -17,6 +18,7 @@ import ( "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -432,3 +434,73 @@ func TestJitterUp(t *testing.T) { assert.True(t, highCount != 0, "at least one sample should reach to >%s", high) assert.True(t, lowCount != 0, "at least one sample should to <%s", low) } + +type failingClientStream struct { + RecvMsgErr error +} + +func (s *failingClientStream) Header() (metadata.MD, error) { + return nil, nil +} + +func (s *failingClientStream) Trailer() metadata.MD { + return nil +} + +func (s *failingClientStream) CloseSend() error { + return nil +} + +func (s *failingClientStream) Context() context.Context { + return context.Background() +} + +func (s *failingClientStream) SendMsg(m any) error { + return nil +} + +func (s *failingClientStream) RecvMsg(m any) error { + return s.RecvMsgErr +} + +func TestStreamClientInterceptorAttemptMetadata(t *testing.T) { + retryCount := 5 + attempt := 0 + recvMsgErr := status.Error(codes.Unavailable, "unavailable") + + var testStreamer grpc.Streamer = func( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, + opts ...grpc.CallOption, + ) (grpc.ClientStream, error) { + if attempt > 0 { + md, ok := metadata.FromOutgoingContext(ctx) + require.True(t, ok) + + raw := md.Get(AttemptMetadataKey) + require.Len(t, raw, 1) + + attemptMetadataValue, err := strconv.Atoi(raw[0]) + require.NoError(t, err) + + require.Equal(t, attempt, attemptMetadataValue) + } + + attempt++ + + return &failingClientStream{ + RecvMsgErr: recvMsgErr, + }, nil + } + + streamClientInterceptor := StreamClientInterceptor(WithCodes(codes.Unavailable), WithMax(uint(retryCount))) + clientStream, err := streamClientInterceptor(context.Background(), &grpc.StreamDesc{}, nil, "some_method", testStreamer) + require.NoError(t, err) + + err = clientStream.RecvMsg(nil) + require.ErrorIs(t, err, recvMsgErr) + + require.Equal(t, retryCount, attempt) +}