diff --git a/cmd/query/app/grpc_handler.go b/cmd/query/app/grpc_handler.go index ec6992147f9..441d1b4d3c3 100644 --- a/cmd/query/app/grpc_handler.go +++ b/cmd/query/app/grpc_handler.go @@ -43,6 +43,7 @@ const ( var ( errGRPCMetricsQueryDisabled = status.Error(codes.Unimplemented, "metrics querying is currently disabled") errNilRequest = status.Error(codes.InvalidArgument, "a nil argument is not allowed") + errUninitializedTraceID = status.Error(codes.InvalidArgument, "uninitialized TraceID is not allowed") errMissingServiceNames = status.Error(codes.InvalidArgument, "please provide at least one service name") errMissingQuantile = status.Error(codes.InvalidArgument, "please provide a quantile between (0, 1]") ) @@ -60,6 +61,12 @@ var _ api_v2.QueryServiceServer = (*GRPCHandler)(nil) // GetTrace is the gRPC handler to fetch traces based on trace-id. func (g *GRPCHandler) GetTrace(r *api_v2.GetTraceRequest, stream api_v2.QueryService_GetTraceServer) error { + if r == nil { + return errNilRequest + } + if r.TraceID == (model.TraceID{}) { + return errUninitializedTraceID + } trace, err := g.queryService.GetTrace(stream.Context(), r.TraceID) if err == spanstore.ErrTraceNotFound { g.logger.Error(msgTraceNotFound, zap.Error(err)) @@ -74,6 +81,12 @@ func (g *GRPCHandler) GetTrace(r *api_v2.GetTraceRequest, stream api_v2.QuerySer // ArchiveTrace is the gRPC handler to archive traces. func (g *GRPCHandler) ArchiveTrace(ctx context.Context, r *api_v2.ArchiveTraceRequest) (*api_v2.ArchiveTraceResponse, error) { + if r == nil { + return nil, errNilRequest + } + if r.TraceID == (model.TraceID{}) { + return nil, errUninitializedTraceID + } err := g.queryService.ArchiveTrace(ctx, r.TraceID) if err == spanstore.ErrTraceNotFound { g.logger.Error("trace not found", zap.Error(err)) @@ -89,6 +102,9 @@ func (g *GRPCHandler) ArchiveTrace(ctx context.Context, r *api_v2.ArchiveTraceRe // FindTraces is the gRPC handler to fetch traces based on TraceQueryParameters. func (g *GRPCHandler) FindTraces(r *api_v2.FindTracesRequest, stream api_v2.QueryService_FindTracesServer) error { + if r == nil { + return errNilRequest + } query := r.GetQuery() if query == nil { return status.Errorf(codes.InvalidArgument, "missing query") @@ -147,6 +163,9 @@ func (g *GRPCHandler) GetOperations( ctx context.Context, r *api_v2.GetOperationsRequest, ) (*api_v2.GetOperationsResponse, error) { + if r == nil { + return nil, errNilRequest + } operations, err := g.queryService.GetOperations(ctx, spanstore.OperationQueryParameters{ ServiceName: r.Service, SpanKind: r.SpanKind, @@ -172,8 +191,16 @@ func (g *GRPCHandler) GetOperations( // GetDependencies is the gRPC handler to fetch dependencies. func (g *GRPCHandler) GetDependencies(ctx context.Context, r *api_v2.GetDependenciesRequest) (*api_v2.GetDependenciesResponse, error) { + if r == nil { + return nil, errNilRequest + } + startTime := r.StartTime endTime := r.EndTime + if startTime == (time.Time{}) || endTime == (time.Time{}) { + return nil, status.Errorf(codes.InvalidArgument, "StartTime and EndTime must be initialized.") + } + dependencies, err := g.queryService.GetDependencies(ctx, startTime, endTime.Sub(startTime)) if err != nil { g.logger.Error("failed to fetch dependencies", zap.Error(err)) diff --git a/cmd/query/app/grpc_handler_test.go b/cmd/query/app/grpc_handler_test.go index 9e2d3e2900e..eee1ee0d243 100644 --- a/cmd/query/app/grpc_handler_test.go +++ b/cmd/query/app/grpc_handler_test.go @@ -264,6 +264,23 @@ func assertGRPCError(t *testing.T, err error, code codes.Code, msg string) { assert.Equal(t, code, s.Code()) assert.Contains(t, s.Message(), msg) } +func TestGetTraceEmptyTraceIDFailure_GRPC(t *testing.T) { + withServerAndClient(t, func(server *grpcServer, client *grpcClient) { + + server.spanReader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("model.TraceID")). + Return(mockTrace, nil).Once() + + res, err := client.GetTrace(context.Background(), &api_v2.GetTraceRequest{ + TraceID: model.TraceID{}, + }) + + assert.NoError(t, err) + + spanResChunk, err := res.Recv() + assert.ErrorIs(t, err, errUninitializedTraceID) + assert.Nil(t, spanResChunk) + }) +} func TestGetTraceDBFailureGRPC(t *testing.T) { withServerAndClient(t, func(server *grpcServer, client *grpcClient) { @@ -302,6 +319,13 @@ func TestGetTraceNotFoundGRPC(t *testing.T) { }) } +// test from GRPCHandler and not grpcClient as Generated Go client panics with `nil` request +func TestGetTraceNilRequestOnHandlerGRPC(t *testing.T) { + grpcHandler := &GRPCHandler{} + err := grpcHandler.GetTrace(nil, nil) + assert.EqualError(t, err, errNilRequest.Error()) +} + func TestArchiveTraceSuccessGRPC(t *testing.T) { withServerAndClient(t, func(server *grpcServer, client *grpcClient) { server.spanReader.On("GetTrace", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("model.TraceID")). @@ -332,6 +356,24 @@ func TestArchiveTraceNotFoundGRPC(t *testing.T) { }) } +func TestArchiveTraceEmptyTraceFailureGRPC(t *testing.T) { + withServerAndClient(t, func(server *grpcServer, client *grpcClient) { + + _, err := client.ArchiveTrace(context.Background(), &api_v2.ArchiveTraceRequest{ + TraceID: model.TraceID{}, + }) + assert.ErrorIs(t, err, errUninitializedTraceID) + }) + +} + +// test from GRPCHandler and not grpcClient as Generated Go client panics with `nil` request +func TestArchiveTraceNilRequestOnHandlerGRPC(t *testing.T) { + grpcHandler := &GRPCHandler{} + _, err := grpcHandler.ArchiveTrace(context.Background(), nil) + assert.EqualError(t, err, errNilRequest.Error()) +} + func TestArchiveTraceFailureGRPC(t *testing.T) { withServerAndClient(t, func(server *grpcServer, client *grpcClient) { @@ -348,7 +390,7 @@ func TestArchiveTraceFailureGRPC(t *testing.T) { }) } -func TestSearchSuccessGRPC(t *testing.T) { +func TestFindTracesSuccessGRPC(t *testing.T) { withServerAndClient(t, func(server *grpcServer, client *grpcClient) { server.spanReader.On("FindTraces", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("*spanstore.TraceQueryParameters")). Return([]*model.Trace{mockTraceGRPC}, nil).Once() @@ -376,7 +418,7 @@ func TestSearchSuccessGRPC(t *testing.T) { }) } -func TestSearchSuccess_SpanStreamingGRPC(t *testing.T) { +func TestFindTracesSuccess_SpanStreamingGRPC(t *testing.T) { withServerAndClient(t, func(server *grpcServer, client *grpcClient) { server.spanReader.On("FindTraces", mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("*spanstore.TraceQueryParameters")). @@ -405,7 +447,7 @@ func TestSearchSuccess_SpanStreamingGRPC(t *testing.T) { }) } -func TestSearchInvalid_GRPC(t *testing.T) { +func TestFindTracesMissingQuery_GRPC(t *testing.T) { withServerAndClient(t, func(server *grpcServer, client *grpcClient) { res, err := client.FindTraces(context.Background(), &api_v2.FindTracesRequest{ Query: nil, @@ -418,7 +460,7 @@ func TestSearchInvalid_GRPC(t *testing.T) { }) } -func TestSearchFailure_GRPC(t *testing.T) { +func TestFindTracesFailure_GRPC(t *testing.T) { withServerAndClient(t, func(server *grpcServer, client *grpcClient) { mockErrorGRPC := fmt.Errorf("whatsamattayou") @@ -444,6 +486,13 @@ func TestSearchFailure_GRPC(t *testing.T) { }) } +// test from GRPCHandler and not grpcClient as Generated Go client panics with `nil` request +func TestFindTracesNilRequestOnHandlerGRPC(t *testing.T) { + grpcHandler := &GRPCHandler{} + err := grpcHandler.FindTraces(nil, nil) + assert.EqualError(t, err, errNilRequest.Error()) +} + func TestGetServicesSuccessGRPC(t *testing.T) { withServerAndClient(t, func(server *grpcServer, client *grpcClient) { expectedServices := []string{"trifle", "bling"} @@ -507,6 +556,13 @@ func TestGetOperationsFailureGRPC(t *testing.T) { }) } +// test from GRPCHandler and not grpcClient as Generated Go client panics with `nil` request +func TestGetOperationsNilRequestOnHandlerGRPC(t *testing.T) { + grpcHandler := &GRPCHandler{} + _, err := grpcHandler.GetOperations(context.Background(), nil) + assert.EqualError(t, err, errNilRequest.Error()) +} + func TestGetDependenciesSuccessGRPC(t *testing.T) { withServerAndClient(t, func(server *grpcServer, client *grpcClient) { expectedDependencies := []model.DependencyLink{{Parent: "killer", Child: "queen", CallCount: 12}} @@ -539,6 +595,37 @@ func TestGetDependenciesFailureGRPC(t *testing.T) { }) } +func TestGetDependenciesFailureUninitializedTimeGRPC(t *testing.T) { + + timeInputs := []struct { + startTime time.Time + endTime time.Time + }{ + {time.Time{}, time.Time{}}, + {time.Now(), time.Time{}}, + {time.Time{}, time.Now()}, + } + + for _, input := range timeInputs { + withServerAndClient(t, func(server *grpcServer, client *grpcClient) { + + _, err := client.GetDependencies(context.Background(), &api_v2.GetDependenciesRequest{ + StartTime: input.startTime, + EndTime: input.endTime, + }) + + assert.Error(t, err) + }) + } +} + +// test from GRPCHandler and not grpcClient as Generated Go client panics with `nil` request +func TestGetDependenciesNilRequestOnHandlerGRPC(t *testing.T) { + grpcHandler := &GRPCHandler{} + _, err := grpcHandler.GetDependencies(context.Background(), nil) + assert.EqualError(t, err, errNilRequest.Error()) +} + func TestSendSpanChunksError(t *testing.T) { g := &GRPCHandler{ logger: zap.NewNop(),