Skip to content

Commit

Permalink
feat: simplify code
Browse files Browse the repository at this point in the history
Signed-off-by: Aleksei Muratov <muratoff.alexey@gmail.com>
  • Loading branch information
alemrtv committed Nov 28, 2024
1 parent 326e384 commit 7725996
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 138 deletions.
2 changes: 1 addition & 1 deletion flagd/cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func init() {
flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate "+
"from disk")
flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+
"to the flag value evaluation context")
"to the flag evaluation context")

_ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName))
_ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName))
Expand Down
75 changes: 28 additions & 47 deletions flagd/pkg/service/flag-evaluation/flag_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,22 +181,16 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{})
evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

err := resolve[bool](
sCtx,
s.logger,
s.eval.ResolveBooleanValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&booleanResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -214,23 +208,16 @@ func (s *OldFlagEvaluationService) ResolveString(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveString", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

res := connect.NewResponse(&schemaV1.ResolveStringResponse{})
err := resolve[string](
sCtx,
s.logger,
s.eval.ResolveStringValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&stringResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -248,23 +235,16 @@ func (s *OldFlagEvaluationService) ResolveInt(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveInt", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

res := connect.NewResponse(&schemaV1.ResolveIntResponse{})
err := resolve[int64](
sCtx,
s.logger,
s.eval.ResolveIntValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&intResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -282,23 +262,16 @@ func (s *OldFlagEvaluationService) ResolveFloat(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveFloat", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

res := connect.NewResponse(&schemaV1.ResolveFloatResponse{})
err := resolve[float64](
sCtx,
s.logger,
s.eval.ResolveFloatValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&floatResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -316,23 +289,16 @@ func (s *OldFlagEvaluationService) ResolveObject(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveObject", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

res := connect.NewResponse(&schemaV1.ResolveObjectResponse{})
err := resolve[map[string]any](
sCtx,
s.logger,
s.eval.ResolveObjectValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&objectResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -342,21 +308,36 @@ func (s *OldFlagEvaluationService) ResolveObject(
return res, err
}

// mergeContexts combines values from the request context with the values from the config --context-values flag.
// Request context values have a higher priority.
func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
merged := make(map[string]any)
for k, v := range configFlagsCtx {
merged[k] = v
}
for k, v := range reqCtx {
merged[k] = v
}
return merged
}

// resolve is a generic flag resolver
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string,
evaluationContext map[string]any, resp response[T], metrics telemetry.IMetricsRecorder,
evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder,
configContextValues map[string]any,
) error {
reqID := xid.New().String()
defer logger.ClearFields(reqID)

mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues)
logger.WriteFields(
reqID,
zap.String("flag-key", flagKey),
zap.Strings("context-keys", formatContextKeys(evaluationContext)),
zap.Strings("context-keys", formatContextKeys(mergedContext)),
)

var evalErrFormatted error
result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, evaluationContext)
result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, mergedContext)
if evalErr != nil {
logger.WarnWithID(reqID, fmt.Sprintf("returning error response, reason: %v", evalErr))
reason = model.ErrorReason
Expand Down
55 changes: 10 additions & 45 deletions flagd/pkg/service/flag-evaluation/flag_evaluator_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,16 @@ func (s *FlagEvaluationService) ResolveBoolean(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

res := connect.NewResponse(&evalV1.ResolveBooleanResponse{})
err := resolve(
sCtx,
s.logger,
s.eval.ResolveBooleanValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&booleanResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -207,23 +200,16 @@ func (s *FlagEvaluationService) ResolveString(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveString", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

res := connect.NewResponse(&evalV1.ResolveStringResponse{})
err := resolve(
sCtx,
s.logger,
s.eval.ResolveStringValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&stringResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -240,23 +226,16 @@ func (s *FlagEvaluationService) ResolveInt(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveInt", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

res := connect.NewResponse(&evalV1.ResolveIntResponse{})
err := resolve(
sCtx,
s.logger,
s.eval.ResolveIntValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&intResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -273,23 +252,16 @@ func (s *FlagEvaluationService) ResolveFloat(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveFloat", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

res := connect.NewResponse(&evalV1.ResolveFloatResponse{})
err := resolve(
sCtx,
s.logger,
s.eval.ResolveFloatValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&floatResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -306,23 +278,16 @@ func (s *FlagEvaluationService) ResolveObject(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveObject", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
for k, v := range s.contextValues {
evalCtx[k] = v
}

res := connect.NewResponse(&evalV1.ResolveObjectResponse{})
err := resolve(
sCtx,
s.logger,
s.eval.ResolveObjectValue,
req.Msg.GetFlagKey(),
evalCtx,
req.Msg.GetContext(),
&objectResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand Down
32 changes: 32 additions & 0 deletions flagd/pkg/service/flag-evaluation/flag_evaluator_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package service
import (
"context"
"errors"
"reflect"
"testing"

evalV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/evaluation/v1"
Expand Down Expand Up @@ -965,3 +966,34 @@ func TestFlag_EvaluationV2_ErrorCodes(t *testing.T) {
}
}
}

func Test_mergeContexts(t *testing.T) {
type args struct {
clientContext, configContext map[string]any
}

tests := []struct {
name string
args args
want map[string]any
}{
{
name: "merge contexts",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
},
want: map[string]any{"k1": "v1", "k2": "v2", "k3": "v3"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := mergeContexts(tt.args.clientContext, tt.args.configContext)

if !reflect.DeepEqual(got, tt.want) {
t.Errorf("\ngot: %+v\nwant: %+v", got, tt.want)
}
})
}
}
20 changes: 11 additions & 9 deletions flagd/pkg/service/flag-evaluation/ofrep/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ type handler struct {

func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any) http.Handler {
h := handler{
logger,
evaluator,
contextValues,
Logger: logger,
evaluator: evaluator,
contextValues: contextValues,
}

router := mux.NewRouter()
Expand Down Expand Up @@ -122,16 +122,18 @@ func extractOfrepRequest(req *http.Request) (ofrep.Request, error) {
func flagdContext(
log *logger.Logger, requestID string, request ofrep.Request, contextValues map[string]any,
) map[string]any {
context := map[string]any{}
context := make(map[string]any)
for k, v := range contextValues {
context[k] = v
}

if res, ok := request.Context.(map[string]any); ok {
context = res
for k, v := range res {
context[k] = v
}
} else {
log.WarnWithID(requestID, "provided context does not comply with flagd, continuing ignoring the context")
}

for k, v := range contextValues {
context[k] = v
}

return context
}
Loading

0 comments on commit 7725996

Please sign in to comment.