diff --git a/flagd/cmd/start.go b/flagd/cmd/start.go index eaa1d8f0d..7cc16141a 100644 --- a/flagd/cmd/start.go +++ b/flagd/cmd/start.go @@ -79,7 +79,7 @@ func init() { flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate "+ "from disk") flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+ - "to the flag value evaluation context") + "to the flag evaluation context") _ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName)) _ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName)) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator.go b/flagd/pkg/service/flag-evaluation/flag_evaluator.go index 911ac10f7..7d4ab71ca 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator.go @@ -181,22 +181,16 @@ func (s *OldFlagEvaluationService) ResolveBoolean( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{}) - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } err := resolve[bool]( sCtx, s.logger, s.eval.ResolveBooleanValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &booleanResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -214,23 +208,16 @@ func (s *OldFlagEvaluationService) ResolveString( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveString", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } - res := connect.NewResponse(&schemaV1.ResolveStringResponse{}) err := resolve[string]( sCtx, s.logger, s.eval.ResolveStringValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &stringResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -248,23 +235,16 @@ func (s *OldFlagEvaluationService) ResolveInt( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveInt", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } - res := connect.NewResponse(&schemaV1.ResolveIntResponse{}) err := resolve[int64]( sCtx, s.logger, s.eval.ResolveIntValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &intResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -282,23 +262,16 @@ func (s *OldFlagEvaluationService) ResolveFloat( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveFloat", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } - res := connect.NewResponse(&schemaV1.ResolveFloatResponse{}) err := resolve[float64]( sCtx, s.logger, s.eval.ResolveFloatValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &floatResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -316,23 +289,16 @@ func (s *OldFlagEvaluationService) ResolveObject( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveObject", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } - res := connect.NewResponse(&schemaV1.ResolveObjectResponse{}) err := resolve[map[string]any]( sCtx, s.logger, s.eval.ResolveObjectValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &objectResponse{schemaV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -342,21 +308,36 @@ func (s *OldFlagEvaluationService) ResolveObject( return res, err } +// mergeContexts combines values from the request context with the values from the config --context-values flag. +// Request context values have a higher priority. +func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any { + merged := make(map[string]any) + for k, v := range configFlagsCtx { + merged[k] = v + } + for k, v := range reqCtx { + merged[k] = v + } + return merged +} + // resolve is a generic flag resolver func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string, - evaluationContext map[string]any, resp response[T], metrics telemetry.IMetricsRecorder, + evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder, + configContextValues map[string]any, ) error { reqID := xid.New().String() defer logger.ClearFields(reqID) + mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues) logger.WriteFields( reqID, zap.String("flag-key", flagKey), - zap.Strings("context-keys", formatContextKeys(evaluationContext)), + zap.Strings("context-keys", formatContextKeys(mergedContext)), ) var evalErrFormatted error - result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, evaluationContext) + result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, mergedContext) if evalErr != nil { logger.WarnWithID(reqID, fmt.Sprintf("returning error response, reason: %v", evalErr)) reason = model.ErrorReason diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go index f66dd3d25..c0413f2fa 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go @@ -174,23 +174,16 @@ func (s *FlagEvaluationService) ResolveBoolean( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } - res := connect.NewResponse(&evalV1.ResolveBooleanResponse{}) err := resolve( sCtx, s.logger, s.eval.ResolveBooleanValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &booleanResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -207,23 +200,16 @@ func (s *FlagEvaluationService) ResolveString( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveString", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } - res := connect.NewResponse(&evalV1.ResolveStringResponse{}) err := resolve( sCtx, s.logger, s.eval.ResolveStringValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &stringResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -240,23 +226,16 @@ func (s *FlagEvaluationService) ResolveInt( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveInt", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } - res := connect.NewResponse(&evalV1.ResolveIntResponse{}) err := resolve( sCtx, s.logger, s.eval.ResolveIntValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &intResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -273,23 +252,16 @@ func (s *FlagEvaluationService) ResolveFloat( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveFloat", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } - res := connect.NewResponse(&evalV1.ResolveFloatResponse{}) err := resolve( sCtx, s.logger, s.eval.ResolveFloatValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &floatResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) @@ -306,23 +278,16 @@ func (s *FlagEvaluationService) ResolveObject( sCtx, span := s.flagEvalTracer.Start(ctx, "resolveObject", trace.WithSpanKind(trace.SpanKindServer)) defer span.End() - evalCtx := map[string]any{} - if e := req.Msg.GetContext(); e != nil { - evalCtx = e.AsMap() - } - for k, v := range s.contextValues { - evalCtx[k] = v - } - res := connect.NewResponse(&evalV1.ResolveObjectResponse{}) err := resolve( sCtx, s.logger, s.eval.ResolveObjectValue, req.Msg.GetFlagKey(), - evalCtx, + req.Msg.GetContext(), &objectResponse{evalV1Resp: res}, s.metrics, + s.contextValues, ) if err != nil { span.RecordError(err) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go index 5305cd546..f99c8a335 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go @@ -3,6 +3,7 @@ package service import ( "context" "errors" + "reflect" "testing" evalV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/evaluation/v1" @@ -965,3 +966,34 @@ func TestFlag_EvaluationV2_ErrorCodes(t *testing.T) { } } } + +func Test_mergeContexts(t *testing.T) { + type args struct { + clientContext, configContext map[string]any + } + + tests := []struct { + name string + args args + want map[string]any + }{ + { + name: "merge contexts", + args: args{ + clientContext: map[string]any{"k1": "v1", "k2": "v2"}, + configContext: map[string]any{"k2": "v22", "k3": "v3"}, + }, + want: map[string]any{"k1": "v1", "k2": "v2", "k3": "v3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeContexts(tt.args.clientContext, tt.args.configContext) + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("\ngot: %+v\nwant: %+v", got, tt.want) + } + }) + } +} diff --git a/flagd/pkg/service/flag-evaluation/ofrep/handler.go b/flagd/pkg/service/flag-evaluation/ofrep/handler.go index 10217817c..b3dcc972f 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/handler.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/handler.go @@ -27,9 +27,9 @@ type handler struct { func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any) http.Handler { h := handler{ - logger, - evaluator, - contextValues, + Logger: logger, + evaluator: evaluator, + contextValues: contextValues, } router := mux.NewRouter() @@ -122,16 +122,18 @@ func extractOfrepRequest(req *http.Request) (ofrep.Request, error) { func flagdContext( log *logger.Logger, requestID string, request ofrep.Request, contextValues map[string]any, ) map[string]any { - context := map[string]any{} + context := make(map[string]any) + for k, v := range contextValues { + context[k] = v + } + if res, ok := request.Context.(map[string]any); ok { - context = res + for k, v := range res { + context[k] = v + } } else { log.WarnWithID(requestID, "provided context does not comply with flagd, continuing ignoring the context") } - for k, v := range contextValues { - context[k] = v - } - return context } diff --git a/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go b/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go index 0b38bcb04..8ff7f42d6 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go @@ -290,39 +290,3 @@ func TestWriteJSONResponse(t *testing.T) { } } -func Test_flagContext(t *testing.T) { - log := logger.NewLogger(nil, false) - - type args struct { - requestID string - req ofrep.Request - contextValues map[string]any - } - tests := []struct { - name string - args args - want map[string]any - }{ - { - name: "merge contexts", - args: args{ - requestID: "", - req: ofrep.Request{ - Context: map[string]any{"k1": "v1", "k2": "v2"}, - }, - contextValues: map[string]any{"k2": "v22", "k3": "v3"}, - }, - want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := flagdContext(log, tt.args.requestID, tt.args.req, tt.args.contextValues) - - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("\ngot: %+v\nwant: %+v", got, tt.want) - } - }) - } -}