Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use GRPC interceptors for bearer token #6063

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
20 changes: 15 additions & 5 deletions cmd/query/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,24 @@

grpcOpts = append(grpcOpts, grpc.Creds(creds))
}

unaryInterceptors := []grpc.UnaryServerInterceptor{
bearertoken.NewUnaryServerInterceptor(),
}
streamInterceptors := []grpc.StreamServerInterceptor{
bearertoken.NewStreamServerInterceptor(),
}
//nolint:contextcheck
if tm.Enabled {
//nolint:contextcheck
grpcOpts = append(grpcOpts,
grpc.StreamInterceptor(tenancy.NewGuardingStreamInterceptor(tm)),
grpc.UnaryInterceptor(tenancy.NewGuardingUnaryInterceptor(tm)),
)
unaryInterceptors = append(unaryInterceptors, tenancy.NewGuardingUnaryInterceptor(tm))
streamInterceptors = append(streamInterceptors, tenancy.NewGuardingStreamInterceptor(tm))

Check warning on line 137 in cmd/query/app/server.go

View check run for this annotation

Codecov / codecov/patch

cmd/query/app/server.go#L136-L137

Added lines #L136 - L137 were not covered by tests
}

grpcOpts = append(grpcOpts,
grpc.ChainUnaryInterceptor(unaryInterceptors...),
grpc.ChainStreamInterceptor(streamInterceptors...),
)

server := grpc.NewServer(grpcOpts...)
return server, nil
}
Expand Down
19 changes: 15 additions & 4 deletions cmd/remote-storage/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"google.golang.org/grpc/reflection"

"github.com/jaegertracing/jaeger/cmd/query/app/querysvc"
"github.com/jaegertracing/jaeger/pkg/bearertoken"
"github.com/jaegertracing/jaeger/pkg/telemetery"
"github.com/jaegertracing/jaeger/pkg/tenancy"
"github.com/jaegertracing/jaeger/plugin/storage/grpc/shared"
Expand Down Expand Up @@ -96,13 +97,23 @@ func createGRPCServer(opts *Options, tm *tenancy.Manager, handler *shared.GRPCHa
creds := credentials.NewTLS(tlsCfg)
grpcOpts = append(grpcOpts, grpc.Creds(creds))
}

unaryInterceptors := []grpc.UnaryServerInterceptor{
bearertoken.NewUnaryServerInterceptor(),
}
streamInterceptors := []grpc.StreamServerInterceptor{
bearertoken.NewStreamServerInterceptor(),
}
if tm.Enabled {
grpcOpts = append(grpcOpts,
grpc.StreamInterceptor(tenancy.NewGuardingStreamInterceptor(tm)),
grpc.UnaryInterceptor(tenancy.NewGuardingUnaryInterceptor(tm)),
)
unaryInterceptors = append(unaryInterceptors, tenancy.NewGuardingUnaryInterceptor(tm))
streamInterceptors = append(streamInterceptors, tenancy.NewGuardingStreamInterceptor(tm))
}

grpcOpts = append(grpcOpts,
grpc.ChainUnaryInterceptor(unaryInterceptors...),
grpc.ChainStreamInterceptor(streamInterceptors...),
)

server := grpc.NewServer(grpcOpts...)
yurishkuro marked this conversation as resolved.
Show resolved Hide resolved
healthServer := health.NewServer()
reflection.Register(server)
Expand Down
135 changes: 135 additions & 0 deletions pkg/bearertoken/grpc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright (c) 2024 The Jaeger Authors.
// SPDX-License-Identifier: Apache-2.0

package bearertoken

import (
"context"
"fmt"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

const Key = "bearer.token"

type tokenatedServerStream struct {
grpc.ServerStream
context context.Context
yurishkuro marked this conversation as resolved.
Show resolved Hide resolved
}

func (tss *tokenatedServerStream) Context() context.Context {
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
return tss.context
}

// extract bearer token from the metadata
func ValidTokenFromGRPCMetadata(ctx context.Context, bearerHeader string) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", nil
}

tokens := md.Get(bearerHeader)
if len(tokens) < 1 {
return "", nil
}
if len(tokens) > 1 {
return "", fmt.Errorf("malformed token: multiple tokens found")
}
return tokens[0], nil
}

// NewStreamServerInterceptor creates a new stream interceptor that injects the bearer token into the context if available.
func NewStreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
if token, _ := GetBearerToken(ss.Context()); token != "" {
return handler(srv, ss)
}

bearerToken, err := ValidTokenFromGRPCMetadata(ss.Context(), Key)
if err != nil {
return err
}

return handler(srv, &tokenatedServerStream{
ServerStream: ss,
context: ContextWithBearerToken(ss.Context(), bearerToken),
})
}
}

// NewUnaryServerInterceptor creates a new unary interceptor that injects the bearer token into the context if available.
func NewUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if token, _ := GetBearerToken(ctx); token != "" {
return handler(ctx, req)
}
yurishkuro marked this conversation as resolved.
Show resolved Hide resolved

bearerToken, err := ValidTokenFromGRPCMetadata(ctx, Key)
if err != nil {
return nil, err
}

return handler(ContextWithBearerToken(ctx, bearerToken), req)
}
}

// NewUnaryClientInterceptor injects the bearer token header into gRPC request metadata.
func NewUnaryClientInterceptor() grpc.UnaryClientInterceptor {
return grpc.UnaryClientInterceptor(func(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary wrapping

Suggested change
return grpc.UnaryClientInterceptor(func(
return func(

ctx context.Context,
method string,
req, reply any,
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
var token string
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var token string

token, err := ValidTokenFromGRPCMetadata(ctx, Key)
if err != nil {
return err
}

if token == "" {
bearerToken, ok := GetBearerToken(ctx)
if ok && bearerToken != "" {
token = bearerToken
}
}

if token != "" {
ctx = metadata.AppendToOutgoingContext(ctx, Key, token)
}
return invoker(ctx, method, req, reply, cc, opts...)
})
}

// NewStreamClientInterceptor injects the bearer token header into gRPC request metadata.
func NewStreamClientInterceptor() grpc.StreamClientInterceptor {
return grpc.StreamClientInterceptor(func(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string,
streamer grpc.Streamer,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
var token string
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var token string

token, err := ValidTokenFromGRPCMetadata(ctx, Key)
if err != nil {
return nil, err
}

if token == "" {
bearerToken, ok := GetBearerToken(ctx)
if ok && bearerToken != "" {
token = bearerToken
}
}

if token != "" {
ctx = metadata.AppendToOutgoingContext(ctx, Key, token)
}
return streamer(ctx, desc, cc, method, opts...)
})
}
196 changes: 196 additions & 0 deletions pkg/bearertoken/grpc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright (c) 2024 The Jaeger Authors.
// SPDX-License-Identifier: Apache-2.0

package bearertoken

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

type mockServerStream struct {
ctx context.Context
grpc.ServerStream
}

func (s *mockServerStream) Context() context.Context {
return s.ctx
}

func TestClientInterceptors(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expectedErr string
expectedMD metadata.MD
}{
{
name: "no token in context",
ctx: context.Background(),
expectedErr: "",
expectedMD: nil, // Expecting no metadata
},
{
name: "token in context",
ctx: ContextWithBearerToken(context.Background(), "test-token"),
expectedErr: "",
expectedMD: metadata.MD{Key: []string{"test-token"}},
},
{
name: "multiple tokens in metadata",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{Key: []string{"token1", "token2"}}),
expectedErr: "malformed token: multiple tokens found",
},
{
name: "valid token in metadata",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{Key: []string{"valid-token"}}),
expectedErr: "",
expectedMD: metadata.MD{Key: []string{"valid-token"}}, // Valid token setup
},
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing no-context / no-metadata test


for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Unary interceptor test
verifyMetadata := func(ctx context.Context) error {
md, ok := metadata.FromOutgoingContext(ctx)
if test.expectedMD == nil {
require.False(t, ok, "metadata should not be present")
} else {
require.True(t, ok, "metadata should be present")
assert.Equal(t, test.expectedMD, md)
}
return nil
}
unaryInterceptor := NewUnaryClientInterceptor()
unaryInvoker := func(ctx context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
return verifyMetadata(ctx)
}
err := unaryInterceptor(test.ctx, "method", nil, nil, nil, unaryInvoker)
if test.expectedErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
require.ErrorContains(t, err, test.expectedErr)
}

// Stream interceptor test
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of comments you can wrap into a subtest

t.Run("stream", func...)

streamInterceptor := NewStreamClientInterceptor()
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
streamInvoker := func(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) {
if err := verifyMetadata(ctx); err != nil {
return nil, err
}
return nil, nil
}
_, err = streamInterceptor(test.ctx, &grpc.StreamDesc{}, nil, "method", streamInvoker)
if test.expectedErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
assert.ErrorContains(t, err, test.expectedErr)
}
})
}
}

func TestServerInterceptors(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expectedErr string
wantToken string
}{
{
name: "no token in context",
ctx: context.Background(),
expectedErr: "",
wantToken: "",
},
{
name: "token in context",
ctx: ContextWithBearerToken(context.Background(), "test-token"),
expectedErr: "",
wantToken: "test-token",
},
{
name: "multiple tokens in metadata",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
Key: []string{"token1", "token2"},
}),
expectedErr: "malformed token: multiple tokens found",
wantToken: "",
},
{
name: "valid token in metadata",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{
Key: []string{"valid-token"},
}),
expectedErr: "",
wantToken: "valid-token",
},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing no-context / no-metadata test

}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
verifyToken := func(ctx context.Context) error {
token, ok := GetBearerToken(ctx)
if test.wantToken == "" {
assert.False(t, ok, "expected no token")
} else {
assert.True(t, ok, "expected token to be present")
assert.Equal(t, test.wantToken, token)
}
return nil
}
// Test unary server interceptor
unaryInterceptor := NewUnaryServerInterceptor()
unaryHandler := func(ctx context.Context, _ any) (any, error) {
return nil, verifyToken(ctx)
}

_, err := unaryInterceptor(test.ctx, nil, &grpc.UnaryServerInfo{}, unaryHandler)
if test.expectedErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
require.ErrorContains(t, err, test.expectedErr)
}

// Test stream server interceptor
streamInterceptor := NewStreamServerInterceptor()
mockStream := &mockServerStream{ctx: test.ctx}
streamHandler := func(_ any, stream grpc.ServerStream) error {
return verifyToken(stream.Context())
}

err = streamInterceptor(nil, mockStream, &grpc.StreamServerInfo{}, streamHandler)
if test.expectedErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
assert.ErrorContains(t, err, test.expectedErr)
}
})
}
}

func TestTokenatedServerStream(t *testing.T) {
originalCtx := context.Background()
testToken := "test-token"
newCtx := ContextWithBearerToken(originalCtx, testToken)

stream := &tokenatedServerStream{
ServerStream: &mockServerStream{ctx: originalCtx},
context: newCtx,
}

// Verify that Context() returns the modified context
token, ok := GetBearerToken(stream.Context())
require.True(t, ok)
assert.Equal(t, testToken, token)
}
Loading
Loading