From b73817e979b005918272238e3dc9d43da58617b8 Mon Sep 17 00:00:00 2001 From: chahatsagarmain Date: Fri, 25 Oct 2024 03:10:36 +0530 Subject: [PATCH] minor changes Signed-off-by: chahatsagarmain --- pkg/bearertoken/grpc.go | 69 ++++++----------- pkg/bearertoken/grpc_test.go | 145 ++++++----------------------------- 2 files changed, 47 insertions(+), 167 deletions(-) diff --git a/pkg/bearertoken/grpc.go b/pkg/bearertoken/grpc.go index c860764d75b..2be9c7c3860 100644 --- a/pkg/bearertoken/grpc.go +++ b/pkg/bearertoken/grpc.go @@ -22,14 +22,8 @@ func (tss *tokenatedServerStream) Context() context.Context { return tss.context } -// getValidBearerToken attempts to retrieve the bearer token from the context. -// It does not return an error if the token is missing. -func getValidBearerToken(ctx context.Context, bearerHeader string) (string, error) { - bearerToken, ok := GetBearerToken(ctx) - if ok && bearerToken != "" { - return bearerToken, nil - } - +// extract bearer token from the metadata +func ValidTokenFromGRPCMetadata(ctx context.Context, bearerHeader string) (string, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { return "", nil @@ -47,38 +41,35 @@ func getValidBearerToken(ctx context.Context, bearerHeader string) (string, erro // NewStreamServerInterceptor creates a new stream interceptor that injects the bearer token into the context if available. func NewStreamServerInterceptor() grpc.StreamServerInterceptor { - return streamServerInterceptor -} + return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if token, _ := GetBearerToken(ss.Context()); token != "" { + return handler(srv, ss) + } -func streamServerInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - bearerToken, err := getValidBearerToken(ss.Context(), Key) - if err != nil { - return err - } - // Upgrade the bearer token to be part of the context. + bearerToken, err := ValidTokenFromGRPCMetadata(ss.Context(), Key) + if err != nil { + return err + } - if token, _ := GetBearerToken(ss.Context()); token != "" { - return handler(srv, ss) + return handler(srv, &tokenatedServerStream{ + ServerStream: ss, + context: ContextWithBearerToken(ss.Context(), bearerToken), + }) } - - return handler(srv, &tokenatedServerStream{ - ServerStream: ss, - context: ContextWithBearerToken(ss.Context(), bearerToken), - }) } // NewUnaryServerInterceptor creates a new unary interceptor that injects the bearer token into the context if available. func NewUnaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - bearerToken, err := getValidBearerToken(ctx, Key) - if err != nil { - return nil, err - } - if token, _ := GetBearerToken(ctx); token != "" { return handler(ctx, req) } + bearerToken, err := ValidTokenFromGRPCMetadata(ctx, Key) + if err != nil { + return nil, err + } + return handler(ContextWithBearerToken(ctx, bearerToken), req) } } @@ -94,14 +85,9 @@ func NewUnaryClientInterceptor() grpc.UnaryClientInterceptor { opts ...grpc.CallOption, ) error { var token string - if md, ok := metadata.FromIncomingContext(ctx); ok { - tokens := md.Get(Key) - if len(tokens) > 1 { - return fmt.Errorf("malformed token: multiple tokens found") - } - if len(tokens) == 1 && tokens[0] != "" { - token = tokens[0] - } + token, err := ValidTokenFromGRPCMetadata(ctx, Key) + if err != nil { + return err } if token == "" { @@ -129,14 +115,9 @@ func NewStreamClientInterceptor() grpc.StreamClientInterceptor { opts ...grpc.CallOption, ) (grpc.ClientStream, error) { var token string - if md, ok := metadata.FromIncomingContext(ctx); ok { - tokens := md.Get(Key) - if len(tokens) > 1 { - return nil, fmt.Errorf("malformed token: multiple tokens found") - } - if len(tokens) == 1 && tokens[0] != "" { - token = tokens[0] - } + token, err := ValidTokenFromGRPCMetadata(ctx, Key) + if err != nil { + return nil, err } if token == "" { diff --git a/pkg/bearertoken/grpc_test.go b/pkg/bearertoken/grpc_test.go index d4acfab7917..9b4a87ecbaf 100644 --- a/pkg/bearertoken/grpc_test.go +++ b/pkg/bearertoken/grpc_test.go @@ -22,7 +22,7 @@ func (s *mockServerStream) Context() context.Context { return s.ctx } -func TestBearerTokenInterceptors(t *testing.T) { +func TestClientInterceptors(t *testing.T) { tests := []struct { name string ctx context.Context @@ -57,34 +57,33 @@ func TestBearerTokenInterceptors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Unary interceptor test - unaryInterceptor := NewUnaryClientInterceptor() - unaryInvoker := func(ctx context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error { + verifyMetadata := func(ctx context.Context) error { md, ok := metadata.FromOutgoingContext(ctx) if test.expectedMD == nil { - require.False(t, ok) // There should be no metadata in this case + require.False(t, ok, "metadata should not be present") } else { - require.True(t, ok) + require.True(t, ok, "metadata should be present") assert.Equal(t, test.expectedMD, md) } return nil } + unaryInterceptor := NewUnaryClientInterceptor() + unaryInvoker := func(ctx context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error { + return verifyMetadata(ctx) + } err := unaryInterceptor(test.ctx, "method", nil, nil, nil, unaryInvoker) if test.expectedErr == "" { require.NoError(t, err) } else { require.Error(t, err) - assert.Contains(t, err.Error(), test.expectedErr) + require.ErrorContains(t, err, test.expectedErr) } // Stream interceptor test streamInterceptor := NewStreamClientInterceptor() streamInvoker := func(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) { - md, ok := metadata.FromOutgoingContext(ctx) - if test.expectedMD == nil { - require.False(t, ok) // There should be no metadata in this case - } else { - require.True(t, ok) - assert.Equal(t, test.expectedMD, md) + if err := verifyMetadata(ctx); err != nil { + return nil, err } return nil, nil } @@ -93,7 +92,7 @@ func TestBearerTokenInterceptors(t *testing.T) { require.NoError(t, err) } else { require.Error(t, err) - assert.Contains(t, err.Error(), test.expectedErr) + assert.ErrorContains(t, err, test.expectedErr) } }) } @@ -138,9 +137,7 @@ func TestServerInterceptors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - // Test unary server interceptor - unaryInterceptor := NewUnaryServerInterceptor() - unaryHandler := func(ctx context.Context, _ any) (any, error) { + verifyToken := func(ctx context.Context) error { token, ok := GetBearerToken(ctx) if test.wantToken == "" { assert.False(t, ok, "expected no token") @@ -148,7 +145,12 @@ func TestServerInterceptors(t *testing.T) { assert.True(t, ok, "expected token to be present") assert.Equal(t, test.wantToken, token) } - return nil, nil + return nil + } + // Test unary server interceptor + unaryInterceptor := NewUnaryServerInterceptor() + unaryHandler := func(ctx context.Context, _ any) (any, error) { + return nil, verifyToken(ctx) } _, err := unaryInterceptor(test.ctx, nil, &grpc.UnaryServerInfo{}, unaryHandler) @@ -156,21 +158,14 @@ func TestServerInterceptors(t *testing.T) { require.NoError(t, err) } else { require.Error(t, err) - assert.Contains(t, err.Error(), test.expectedErr) + require.ErrorContains(t, err, test.expectedErr) } // Test stream server interceptor streamInterceptor := NewStreamServerInterceptor() mockStream := &mockServerStream{ctx: test.ctx} streamHandler := func(_ any, stream grpc.ServerStream) error { - token, ok := GetBearerToken(stream.Context()) - if test.wantToken == "" { - assert.False(t, ok, "expected no token") - } else { - assert.True(t, ok, "expected token to be present") - assert.Equal(t, test.wantToken, token) - } - return nil + return verifyToken(stream.Context()) } err = streamInterceptor(nil, mockStream, &grpc.StreamServerInfo{}, streamHandler) @@ -178,108 +173,12 @@ func TestServerInterceptors(t *testing.T) { require.NoError(t, err) } else { require.Error(t, err) - assert.Contains(t, err.Error(), test.expectedErr) + assert.ErrorContains(t, err, test.expectedErr) } }) } } -func TestClientUnaryInterceptorWithBearerToken(t *testing.T) { - interceptor := NewUnaryClientInterceptor() - - // Mock invoker - invoker := func(ctx context.Context, _ string, _ any, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error { - md, ok := metadata.FromOutgoingContext(ctx) - require.True(t, ok) - assert.Equal(t, "test-token", md[Key][0]) - return nil - } - - // Context with token - ctx := ContextWithBearerToken(context.Background(), "test-token") - - err := interceptor(ctx, "method", nil, nil, nil, invoker) - require.NoError(t, err) -} - -func TestClientStreamInterceptorWithBearerToken(t *testing.T) { - interceptor := NewStreamClientInterceptor() - - // Mock streamer - streamer := func(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) { - md, ok := metadata.FromOutgoingContext(ctx) - require.True(t, ok) - assert.Equal(t, "test-token", md[Key][0]) - return nil, nil - } - - // Context with token - ctx := ContextWithBearerToken(context.Background(), "test-token") - - _, err := interceptor(ctx, &grpc.StreamDesc{}, nil, "method", streamer) - require.NoError(t, err) -} - -func TestServerUnaryInterceptorWithBearerToken(t *testing.T) { - interceptor := NewUnaryServerInterceptor() - testToken := "test-token" - - // Test with token in context - handler := func(ctx context.Context, _ any) (any, error) { - token, ok := GetBearerToken(ctx) - require.True(t, ok) - assert.Equal(t, testToken, token) - return nil, nil - } - - ctx := ContextWithBearerToken(context.Background(), testToken) - _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{}, handler) - require.NoError(t, err) -} - -func TestServerStreamInterceptorWithBearerToken(t *testing.T) { - interceptor := NewStreamServerInterceptor() - testToken := "test-token" - - // Test with token in context - handler := func(_ any, stream grpc.ServerStream) error { - token, ok := GetBearerToken(stream.Context()) - require.True(t, ok) - assert.Equal(t, testToken, token) - return nil - } - - ctx := ContextWithBearerToken(context.Background(), testToken) - mockStream := &mockServerStream{ctx: ctx} - err := interceptor(nil, mockStream, &grpc.StreamServerInfo{}, handler) - require.NoError(t, err) -} - -func TestMalformedToken(t *testing.T) { - // Context with multiple tokens - ctx := metadata.NewIncomingContext(context.Background(), metadata.MD{ - Key: []string{"token1", "token2"}, - }) - - // Unary interceptor - unaryInterceptor := NewUnaryClientInterceptor() - unaryInvoker := func(_ context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error { - return nil - } - err := unaryInterceptor(ctx, "method", nil, nil, nil, unaryInvoker) - require.Error(t, err) - assert.Contains(t, err.Error(), "malformed token: multiple tokens found") - - // Stream interceptor - streamInterceptor := NewStreamClientInterceptor() - streamInvoker := func(_ context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) { - return nil, nil - } - _, err = streamInterceptor(ctx, &grpc.StreamDesc{}, nil, "method", streamInvoker) - require.Error(t, err) - assert.Contains(t, err.Error(), "malformed token: multiple tokens found") -} - func TestTokenatedServerStream(t *testing.T) { originalCtx := context.Background() testToken := "test-token"