diff --git a/cmd/query/app/server.go b/cmd/query/app/server.go index f1f7b4284b7..eee99ba75a5 100644 --- a/cmd/query/app/server.go +++ b/cmd/query/app/server.go @@ -176,13 +176,23 @@ func createGRPCServerOTEL( telset telemetery.Setting, ) (*grpc.Server, error) { var grpcOpts []configgrpc.ToServerOption + unaryInterceptors := []grpc.UnaryServerInterceptor{ + bearertoken.NewUnaryServerInterceptor(), + } + streamInterceptors := []grpc.StreamServerInterceptor{ + bearertoken.NewStreamServerInterceptor(), + } + + //nolint:contextcheck if tm.Enabled { - //nolint:contextcheck - grpcOpts = append(grpcOpts, - configgrpc.WithGrpcServerOption(grpc.StreamInterceptor(tenancy.NewGuardingStreamInterceptor(tm))), - configgrpc.WithGrpcServerOption(grpc.UnaryInterceptor(tenancy.NewGuardingUnaryInterceptor(tm))), - ) + unaryInterceptors = append(unaryInterceptors, tenancy.NewGuardingUnaryInterceptor(tm)) + streamInterceptors = append(streamInterceptors, tenancy.NewGuardingStreamInterceptor(tm)) } + + grpcOpts = append(grpcOpts, + configgrpc.WithGrpcServerOption(grpc.ChainUnaryInterceptor(unaryInterceptors...)), + configgrpc.WithGrpcServerOption(grpc.ChainStreamInterceptor(streamInterceptors...)), + ) return options.GRPC.ToServer( ctx, telset.Host, diff --git a/plugin/metrics/prometheus/metricsstore/reader.go b/plugin/metrics/prometheus/metricsstore/reader.go index 86cb6922e3d..df9fddd6f23 100644 --- a/plugin/metrics/prometheus/metricsstore/reader.go +++ b/plugin/metrics/prometheus/metricsstore/reader.go @@ -342,7 +342,7 @@ func getHTTPRoundTripper(c *config.Configuration, logger *zap.Logger) (rt http.R } token = tokenFromFile } - return bearertoken.RoundTripper{ + return &bearertoken.RoundTripper{ Transport: httpTransport, OverrideFromCtx: c.TokenOverrideFromContext, StaticToken: token, diff --git a/plugin/storage/grpc/factory.go b/plugin/storage/grpc/factory.go index 0f08667e194..b5ceac8a3aa 100644 --- a/plugin/storage/grpc/factory.go +++ b/plugin/storage/grpc/factory.go @@ -123,13 +123,13 @@ func (f *Factory) newRemoteStorage( if c.Auth != nil { return nil, fmt.Errorf("authenticator is not supported") } - tenancyMgr := tenancy.NewManager(&c.Tenancy) unaryInterceptors := []grpc.UnaryClientInterceptor{ bearertoken.NewUnaryClientInterceptor(), } streamInterceptors := []grpc.StreamClientInterceptor{ - tenancy.NewClientStreamInterceptor(tenancyMgr), + bearertoken.NewStreamClientInterceptor(), } + tenancyMgr := tenancy.NewManager(&c.Tenancy) if tenancyMgr.Enabled { unaryInterceptors = append(unaryInterceptors, tenancy.NewClientUnaryInterceptor(tenancyMgr)) streamInterceptors = append(streamInterceptors, tenancy.NewClientStreamInterceptor(tenancyMgr)) diff --git a/plugin/storage/grpc/shared/archive.go b/plugin/storage/grpc/shared/archive.go index cff3aacd635..2bd14685d27 100644 --- a/plugin/storage/grpc/shared/archive.go +++ b/plugin/storage/grpc/shared/archive.go @@ -33,7 +33,7 @@ type archiveWriter struct { // GetTrace takes a traceID and returns a Trace associated with that traceID from Archive Storage func (r *archiveReader) GetTrace(ctx context.Context, traceID model.TraceID) (*model.Trace, error) { - stream, err := r.client.GetArchiveTrace(upgradeContext(ctx), &storage_v1.GetTraceRequest{ + stream, err := r.client.GetArchiveTrace(ctx, &storage_v1.GetTraceRequest{ TraceID: traceID, }) if status.Code(err) == codes.NotFound { diff --git a/plugin/storage/grpc/shared/grpc_client.go b/plugin/storage/grpc/shared/grpc_client.go index 99f42bde006..4469e9ce2b7 100644 --- a/plugin/storage/grpc/shared/grpc_client.go +++ b/plugin/storage/grpc/shared/grpc_client.go @@ -12,11 +12,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "github.com/jaegertracing/jaeger/model" - "github.com/jaegertracing/jaeger/pkg/bearertoken" _ "github.com/jaegertracing/jaeger/pkg/gogocodec" // force gogo codec registration "github.com/jaegertracing/jaeger/proto-gen/storage_v1" "github.com/jaegertracing/jaeger/storage/dependencystore" @@ -30,9 +28,6 @@ var ( _ StoragePlugin = (*GRPCClient)(nil) _ ArchiveStoragePlugin = (*GRPCClient)(nil) _ PluginCapabilities = (*GRPCClient)(nil) - - // upgradeContext composites several steps of upgrading context - upgradeContext = composeContextUpgradeFuncs(upgradeContextWithBearerToken) ) // GRPCClient implements shared.StoragePlugin and reads/writes spans and dependencies @@ -58,36 +53,6 @@ func NewGRPCClient(tracedConn *grpc.ClientConn, untracedConn *grpc.ClientConn) * } } -// ContextUpgradeFunc is a functional type that can be composed to upgrade context -type ContextUpgradeFunc func(ctx context.Context) context.Context - -// composeContextUpgradeFuncs composes ContextUpgradeFunc and returns a composed function -// to run the given func in strict order. -func composeContextUpgradeFuncs(funcs ...ContextUpgradeFunc) ContextUpgradeFunc { - return func(ctx context.Context) context.Context { - for _, fun := range funcs { - ctx = fun(ctx) - } - return ctx - } -} - -// upgradeContextWithBearerToken turns the context into a gRPC outgoing context with bearer token -// in the request metadata, if the original context has bearer token attached. -// Otherwise returns original context. -func upgradeContextWithBearerToken(ctx context.Context) context.Context { - bearerToken, hasToken := bearertoken.GetBearerToken(ctx) - if hasToken { - md, ok := metadata.FromOutgoingContext(ctx) - if !ok { - md = metadata.New(nil) - } - md.Set(BearerTokenKey, bearerToken) - return metadata.NewOutgoingContext(ctx, md) - } - return ctx -} - // DependencyReader implements shared.StoragePlugin. func (c *GRPCClient) DependencyReader() dependencystore.Reader { return c @@ -117,7 +82,7 @@ func (c *GRPCClient) ArchiveSpanWriter() spanstore.Writer { // GetTrace takes a traceID and returns a Trace associated with that traceID func (c *GRPCClient) GetTrace(ctx context.Context, traceID model.TraceID) (*model.Trace, error) { - stream, err := c.readerClient.GetTrace(upgradeContext(ctx), &storage_v1.GetTraceRequest{ + stream, err := c.readerClient.GetTrace(ctx, &storage_v1.GetTraceRequest{ TraceID: traceID, }) if status.Code(err) == codes.NotFound { @@ -132,7 +97,7 @@ func (c *GRPCClient) GetTrace(ctx context.Context, traceID model.TraceID) (*mode // GetServices returns a list of all known services func (c *GRPCClient) GetServices(ctx context.Context) ([]string, error) { - resp, err := c.readerClient.GetServices(upgradeContext(ctx), &storage_v1.GetServicesRequest{}) + resp, err := c.readerClient.GetServices(ctx, &storage_v1.GetServicesRequest{}) if err != nil { return nil, fmt.Errorf("plugin error: %w", err) } @@ -145,7 +110,7 @@ func (c *GRPCClient) GetOperations( ctx context.Context, query spanstore.OperationQueryParameters, ) ([]spanstore.Operation, error) { - resp, err := c.readerClient.GetOperations(upgradeContext(ctx), &storage_v1.GetOperationsRequest{ + resp, err := c.readerClient.GetOperations(ctx, &storage_v1.GetOperationsRequest{ Service: query.ServiceName, SpanKind: query.SpanKind, }) @@ -173,7 +138,7 @@ func (c *GRPCClient) GetOperations( // FindTraces retrieves traces that match the traceQuery func (c *GRPCClient) FindTraces(ctx context.Context, query *spanstore.TraceQueryParameters) ([]*model.Trace, error) { - stream, err := c.readerClient.FindTraces(upgradeContext(ctx), &storage_v1.FindTracesRequest{ + stream, err := c.readerClient.FindTraces(ctx, &storage_v1.FindTracesRequest{ Query: &storage_v1.TraceQueryParameters{ ServiceName: query.ServiceName, OperationName: query.OperationName, @@ -212,7 +177,7 @@ func (c *GRPCClient) FindTraces(ctx context.Context, query *spanstore.TraceQuery // FindTraceIDs retrieves traceIDs that match the traceQuery func (c *GRPCClient) FindTraceIDs(ctx context.Context, query *spanstore.TraceQueryParameters) ([]model.TraceID, error) { - resp, err := c.readerClient.FindTraceIDs(upgradeContext(ctx), &storage_v1.FindTraceIDsRequest{ + resp, err := c.readerClient.FindTraceIDs(ctx, &storage_v1.FindTraceIDsRequest{ Query: &storage_v1.TraceQueryParameters{ ServiceName: query.ServiceName, OperationName: query.OperationName, diff --git a/plugin/storage/grpc/shared/grpc_client_test.go b/plugin/storage/grpc/shared/grpc_client_test.go index da8d82ce4a4..a8bb88fcf34 100644 --- a/plugin/storage/grpc/shared/grpc_client_test.go +++ b/plugin/storage/grpc/shared/grpc_client_test.go @@ -15,11 +15,9 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "github.com/jaegertracing/jaeger/model" - "github.com/jaegertracing/jaeger/pkg/bearertoken" "github.com/jaegertracing/jaeger/proto-gen/storage_v1" grpcMocks "github.com/jaegertracing/jaeger/proto-gen/storage_v1/mocks" "github.com/jaegertracing/jaeger/storage/spanstore" @@ -116,22 +114,6 @@ func TestNewGRPCClient(t *testing.T) { assert.Implements(t, (*storage_v1.StreamingSpanWriterPluginClient)(nil), client.streamWriterClient) } -func TestContextUpgradeWithToken(t *testing.T) { - testBearerToken := "test-bearer-token" - ctx := bearertoken.ContextWithBearerToken(context.Background(), testBearerToken) - upgradedToken := upgradeContextWithBearerToken(ctx) - md, ok := metadata.FromOutgoingContext(upgradedToken) - assert.Truef(t, ok, "Expected metadata in context") - bearerTokenFromMetadata := md.Get(BearerTokenKey) - assert.Equal(t, []string{testBearerToken}, bearerTokenFromMetadata) -} - -func TestContextUpgradeWithoutToken(t *testing.T) { - upgradedToken := upgradeContextWithBearerToken(context.Background()) - _, ok := metadata.FromOutgoingContext(upgradedToken) - assert.Falsef(t, ok, "Expected no metadata in context") -} - func TestGRPCClientGetServices(t *testing.T) { withGRPCClient(func(r *grpcClientTest) { r.spanReader.On("GetServices", mock.Anything, &storage_v1.GetServicesRequest{}).