Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
Signed-off-by: chahatsagarmain <chahatsagar2003@gmail.com>
  • Loading branch information
chahatsagarmain committed Oct 24, 2024
1 parent c813cc6 commit b73817e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 167 deletions.
69 changes: 25 additions & 44 deletions pkg/bearertoken/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,8 @@ func (tss *tokenatedServerStream) Context() context.Context {
return tss.context
}

// getValidBearerToken attempts to retrieve the bearer token from the context.
// It does not return an error if the token is missing.
func getValidBearerToken(ctx context.Context, bearerHeader string) (string, error) {
bearerToken, ok := GetBearerToken(ctx)
if ok && bearerToken != "" {
return bearerToken, nil
}

// extract bearer token from the metadata
func ValidTokenFromGRPCMetadata(ctx context.Context, bearerHeader string) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", nil
Expand All @@ -47,38 +41,35 @@ func getValidBearerToken(ctx context.Context, bearerHeader string) (string, erro

// NewStreamServerInterceptor creates a new stream interceptor that injects the bearer token into the context if available.
func NewStreamServerInterceptor() grpc.StreamServerInterceptor {
return streamServerInterceptor
}
return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if token, _ := GetBearerToken(ss.Context()); token != "" {
return handler(srv, ss)
}

func streamServerInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
bearerToken, err := getValidBearerToken(ss.Context(), Key)
if err != nil {
return err
}
// Upgrade the bearer token to be part of the context.
bearerToken, err := ValidTokenFromGRPCMetadata(ss.Context(), Key)
if err != nil {
return err
}

if token, _ := GetBearerToken(ss.Context()); token != "" {
return handler(srv, ss)
return handler(srv, &tokenatedServerStream{
ServerStream: ss,
context: ContextWithBearerToken(ss.Context(), bearerToken),
})
}

return handler(srv, &tokenatedServerStream{
ServerStream: ss,
context: ContextWithBearerToken(ss.Context(), bearerToken),
})
}

// NewUnaryServerInterceptor creates a new unary interceptor that injects the bearer token into the context if available.
func NewUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
bearerToken, err := getValidBearerToken(ctx, Key)
if err != nil {
return nil, err
}

if token, _ := GetBearerToken(ctx); token != "" {
return handler(ctx, req)
}

bearerToken, err := ValidTokenFromGRPCMetadata(ctx, Key)
if err != nil {
return nil, err
}

return handler(ContextWithBearerToken(ctx, bearerToken), req)
}
}
Expand All @@ -94,14 +85,9 @@ func NewUnaryClientInterceptor() grpc.UnaryClientInterceptor {
opts ...grpc.CallOption,
) error {
var token string
if md, ok := metadata.FromIncomingContext(ctx); ok {
tokens := md.Get(Key)
if len(tokens) > 1 {
return fmt.Errorf("malformed token: multiple tokens found")
}
if len(tokens) == 1 && tokens[0] != "" {
token = tokens[0]
}
token, err := ValidTokenFromGRPCMetadata(ctx, Key)
if err != nil {
return err
}

if token == "" {
Expand Down Expand Up @@ -129,14 +115,9 @@ func NewStreamClientInterceptor() grpc.StreamClientInterceptor {
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
var token string
if md, ok := metadata.FromIncomingContext(ctx); ok {
tokens := md.Get(Key)
if len(tokens) > 1 {
return nil, fmt.Errorf("malformed token: multiple tokens found")
}
if len(tokens) == 1 && tokens[0] != "" {
token = tokens[0]
}
token, err := ValidTokenFromGRPCMetadata(ctx, Key)
if err != nil {
return nil, err
}

if token == "" {
Expand Down
145 changes: 22 additions & 123 deletions pkg/bearertoken/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (s *mockServerStream) Context() context.Context {
return s.ctx
}

func TestBearerTokenInterceptors(t *testing.T) {
func TestClientInterceptors(t *testing.T) {
tests := []struct {
name string
ctx context.Context
Expand Down Expand Up @@ -57,34 +57,33 @@ func TestBearerTokenInterceptors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Unary interceptor test
unaryInterceptor := NewUnaryClientInterceptor()
unaryInvoker := func(ctx context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
verifyMetadata := func(ctx context.Context) error {
md, ok := metadata.FromOutgoingContext(ctx)
if test.expectedMD == nil {
require.False(t, ok) // There should be no metadata in this case
require.False(t, ok, "metadata should not be present")
} else {
require.True(t, ok)
require.True(t, ok, "metadata should be present")
assert.Equal(t, test.expectedMD, md)
}
return nil
}
unaryInterceptor := NewUnaryClientInterceptor()
unaryInvoker := func(ctx context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
return verifyMetadata(ctx)
}
err := unaryInterceptor(test.ctx, "method", nil, nil, nil, unaryInvoker)
if test.expectedErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expectedErr)
require.ErrorContains(t, err, test.expectedErr)
}

// Stream interceptor test
streamInterceptor := NewStreamClientInterceptor()
streamInvoker := func(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) {
md, ok := metadata.FromOutgoingContext(ctx)
if test.expectedMD == nil {
require.False(t, ok) // There should be no metadata in this case
} else {
require.True(t, ok)
assert.Equal(t, test.expectedMD, md)
if err := verifyMetadata(ctx); err != nil {
return nil, err
}
return nil, nil
}
Expand All @@ -93,7 +92,7 @@ func TestBearerTokenInterceptors(t *testing.T) {
require.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expectedErr)
assert.ErrorContains(t, err, test.expectedErr)
}
})
}
Expand Down Expand Up @@ -138,148 +137,48 @@ func TestServerInterceptors(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Test unary server interceptor
unaryInterceptor := NewUnaryServerInterceptor()
unaryHandler := func(ctx context.Context, _ any) (any, error) {
verifyToken := func(ctx context.Context) error {
token, ok := GetBearerToken(ctx)
if test.wantToken == "" {
assert.False(t, ok, "expected no token")
} else {
assert.True(t, ok, "expected token to be present")
assert.Equal(t, test.wantToken, token)
}
return nil, nil
return nil
}
// Test unary server interceptor
unaryInterceptor := NewUnaryServerInterceptor()
unaryHandler := func(ctx context.Context, _ any) (any, error) {
return nil, verifyToken(ctx)
}

_, err := unaryInterceptor(test.ctx, nil, &grpc.UnaryServerInfo{}, unaryHandler)
if test.expectedErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expectedErr)
require.ErrorContains(t, err, test.expectedErr)
}

// Test stream server interceptor
streamInterceptor := NewStreamServerInterceptor()
mockStream := &mockServerStream{ctx: test.ctx}
streamHandler := func(_ any, stream grpc.ServerStream) error {
token, ok := GetBearerToken(stream.Context())
if test.wantToken == "" {
assert.False(t, ok, "expected no token")
} else {
assert.True(t, ok, "expected token to be present")
assert.Equal(t, test.wantToken, token)
}
return nil
return verifyToken(stream.Context())
}

err = streamInterceptor(nil, mockStream, &grpc.StreamServerInfo{}, streamHandler)
if test.expectedErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expectedErr)
assert.ErrorContains(t, err, test.expectedErr)
}
})
}
}

func TestClientUnaryInterceptorWithBearerToken(t *testing.T) {
interceptor := NewUnaryClientInterceptor()

// Mock invoker
invoker := func(ctx context.Context, _ string, _ any, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok)
assert.Equal(t, "test-token", md[Key][0])
return nil
}

// Context with token
ctx := ContextWithBearerToken(context.Background(), "test-token")

err := interceptor(ctx, "method", nil, nil, nil, invoker)
require.NoError(t, err)
}

func TestClientStreamInterceptorWithBearerToken(t *testing.T) {
interceptor := NewStreamClientInterceptor()

// Mock streamer
streamer := func(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) {
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok)
assert.Equal(t, "test-token", md[Key][0])
return nil, nil
}

// Context with token
ctx := ContextWithBearerToken(context.Background(), "test-token")

_, err := interceptor(ctx, &grpc.StreamDesc{}, nil, "method", streamer)
require.NoError(t, err)
}

func TestServerUnaryInterceptorWithBearerToken(t *testing.T) {
interceptor := NewUnaryServerInterceptor()
testToken := "test-token"

// Test with token in context
handler := func(ctx context.Context, _ any) (any, error) {
token, ok := GetBearerToken(ctx)
require.True(t, ok)
assert.Equal(t, testToken, token)
return nil, nil
}

ctx := ContextWithBearerToken(context.Background(), testToken)
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{}, handler)
require.NoError(t, err)
}

func TestServerStreamInterceptorWithBearerToken(t *testing.T) {
interceptor := NewStreamServerInterceptor()
testToken := "test-token"

// Test with token in context
handler := func(_ any, stream grpc.ServerStream) error {
token, ok := GetBearerToken(stream.Context())
require.True(t, ok)
assert.Equal(t, testToken, token)
return nil
}

ctx := ContextWithBearerToken(context.Background(), testToken)
mockStream := &mockServerStream{ctx: ctx}
err := interceptor(nil, mockStream, &grpc.StreamServerInfo{}, handler)
require.NoError(t, err)
}

func TestMalformedToken(t *testing.T) {
// Context with multiple tokens
ctx := metadata.NewIncomingContext(context.Background(), metadata.MD{
Key: []string{"token1", "token2"},
})

// Unary interceptor
unaryInterceptor := NewUnaryClientInterceptor()
unaryInvoker := func(_ context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
return nil
}
err := unaryInterceptor(ctx, "method", nil, nil, nil, unaryInvoker)
require.Error(t, err)
assert.Contains(t, err.Error(), "malformed token: multiple tokens found")

// Stream interceptor
streamInterceptor := NewStreamClientInterceptor()
streamInvoker := func(_ context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) {
return nil, nil
}
_, err = streamInterceptor(ctx, &grpc.StreamDesc{}, nil, "method", streamInvoker)
require.Error(t, err)
assert.Contains(t, err.Error(), "malformed token: multiple tokens found")
}

func TestTokenatedServerStream(t *testing.T) {
originalCtx := context.Background()
testToken := "test-token"
Expand Down

0 comments on commit b73817e

Please sign in to comment.