Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add context value flag #1448

Merged
merged 21 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/pkg/service/iservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Configuration struct {
SocketPath string
CORS []string
Options []connect.HandlerOption
ContextValues map[string]any
}

/*
Expand Down
3 changes: 3 additions & 0 deletions docs/reference/flag-definitions.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ For example, when accessing flagd via HTTP, the POST body may look like this:

The evaluation context can be accessed in targeting rules using the `var` operation followed by the evaluation context property name.

The evaluation context can be appended by arbitrary key value pairs
via the `-X` command line flag.

| Description | Example |
| -------------------------------------------------------------- | ---------------------------------------------------- |
| Retrieve property from the evaluation context | `#!json { "var": "email" }` |
Expand Down
1 change: 1 addition & 0 deletions docs/reference/flagd-cli/flagd_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ flagd start [flags]
### Options

```
-X, --context-value stringToString add arbitrary key value pairs to the flag value evaluation context (default [])
-C, --cors-origin strings CORS allowed origins, * will allow all origins
-h, --help help for start
-z, --log-format string Set the logging format, e.g. console or json (default "console")
Expand Down
11 changes: 10 additions & 1 deletion flagd/cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ const (
sourcesFlagName = "sources"
syncPortFlagName = "sync-port"
uriFlagName = "uri"
contextValueFlagName = "context-value"
)

func init() {
flags := startCmd.Flags()

// allows environment variables to use _ instead of -
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) // sync-provider-args becomes SYNC_PROVIDER_ARGS
viper.SetEnvPrefix("FLAGD") // port becomes FLAGD_PORT
Expand Down Expand Up @@ -78,6 +78,8 @@ func init() {
flags.StringP(otelCAPathFlagName, "A", "", "tls certificate authority path to use with OpenTelemetry collector")
flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate "+
"from disk")
flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+
"to the flag evaluation context")

_ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName))
_ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName))
Expand All @@ -95,6 +97,7 @@ func init() {
_ = viper.BindPFlag(uriFlagName, flags.Lookup(uriFlagName))
_ = viper.BindPFlag(syncPortFlagName, flags.Lookup(syncPortFlagName))
_ = viper.BindPFlag(ofrepPortFlagName, flags.Lookup(ofrepPortFlagName))
_ = viper.BindPFlag(contextValueFlagName, flags.Lookup(contextValueFlagName))
}

// startCmd represents the start command
Expand Down Expand Up @@ -139,6 +142,11 @@ var startCmd = &cobra.Command{
}
syncProviders = append(syncProviders, syncProvidersFromConfig...)

contextValuesToMap := make(map[string]any)
for k, v := range viper.GetStringMapString(contextValueFlagName) {
contextValuesToMap[k] = v
}

// Build Runtime -----------------------------------------------------------
rt, err := runtime.FromConfig(logger, Version, runtime.Config{
CORS: viper.GetStringSlice(corsFlagName),
Expand All @@ -156,6 +164,7 @@ var startCmd = &cobra.Command{
ServiceSocketPath: viper.GetString(socketPathFlagName),
SyncServicePort: viper.GetUint16(syncPortFlagName),
SyncProviders: syncProviders,
ContextValues: contextValuesToMap,
})
if err != nil {
rtLogger.Fatal(err.Error())
Expand Down
16 changes: 11 additions & 5 deletions flagd/pkg/runtime/from_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type Config struct {

SyncProviders []sync.SourceConfig
CORS []string

ContextValues map[string]any
}

// FromConfig builds a runtime from startup configurations
Expand Down Expand Up @@ -101,17 +103,20 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
ofrepService, err := ofrep.NewOfrepService(jsonEvaluator, config.CORS, ofrep.SvcConfiguration{
Logger: logger.WithFields(zap.String("component", "OFREPService")),
Port: config.OfrepServicePort,
})
},
config.ContextValues,
)
if err != nil {
return nil, fmt.Errorf("error creating ofrep service")
}

// flag sync service
flagSyncService, err := flagsync.NewSyncService(flagsync.SvcConfigurations{
Logger: logger.WithFields(zap.String("component", "FlagSyncService")),
Port: config.SyncServicePort,
Sources: sources,
Store: s,
Logger: logger.WithFields(zap.String("component", "FlagSyncService")),
Port: config.SyncServicePort,
Sources: sources,
Store: s,
ContextValues: config.ContextValues,
})
if err != nil {
return nil, fmt.Errorf("error creating sync service: %w", err)
Expand Down Expand Up @@ -145,6 +150,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
SocketPath: config.ServiceSocketPath,
CORS: config.CORS,
Options: options,
ContextValues: config.ContextValues,
},
SyncImpl: iSyncs,
}, nil
Expand Down
2 changes: 2 additions & 0 deletions flagd/pkg/service/flag-evaluation/connect_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eval,
s.eventingConfiguration,
s.metrics,
svcConf.ContextValues,
)

marshalOpts := WithJSON(
Expand All @@ -170,6 +171,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eval,
s.eventingConfiguration,
s.metrics,
svcConf.ContextValues,
)

_, newHandler := evaluationV1.NewServiceHandler(newFes, append(svcConf.Options, marshalOpts)...)
Expand Down
45 changes: 34 additions & 11 deletions flagd/pkg/service/flag-evaluation/flag_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,24 @@ type OldFlagEvaluationService struct {
metrics telemetry.IMetricsRecorder
eventingConfiguration IEvents
flagEvalTracer trace.Tracer
contextValues map[string]any
}

// NewOldFlagEvaluationService creates a OldFlagEvaluationService with provided parameters
func NewOldFlagEvaluationService(log *logger.Logger,
eval evaluator.IEvaluator, eventingCfg IEvents, metricsRecorder telemetry.IMetricsRecorder,
func NewOldFlagEvaluationService(
log *logger.Logger,
eval evaluator.IEvaluator,
eventingCfg IEvents,
metricsRecorder telemetry.IMetricsRecorder,
contextValues map[string]any,
) *OldFlagEvaluationService {
svc := &OldFlagEvaluationService{
logger: log,
eval: eval,
metrics: &telemetry.NoopMetricsRecorder{},
eventingConfiguration: eventingCfg,
flagEvalTracer: otel.Tracer("flagEvaluationService"),
contextValues: contextValues,
}

if metricsRecorder != nil {
Expand All @@ -65,12 +71,8 @@ func (s *OldFlagEvaluationService) ResolveAll(
res := &schemaV1.ResolveAllResponse{
Flags: make(map[string]*schemaV1.AnyFlag),
}
evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}

values, err := s.eval.ResolveAllValues(sCtx, reqID, evalCtx)
values, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues))
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
Expand Down Expand Up @@ -172,6 +174,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{})

err := resolve[bool](
sCtx,
s.logger,
Expand All @@ -180,6 +183,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
req.Msg.GetContext(),
&booleanResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -206,6 +210,7 @@ func (s *OldFlagEvaluationService) ResolveString(
req.Msg.GetContext(),
&stringResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -232,6 +237,7 @@ func (s *OldFlagEvaluationService) ResolveInt(
req.Msg.GetContext(),
&intResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -258,6 +264,7 @@ func (s *OldFlagEvaluationService) ResolveFloat(
req.Msg.GetContext(),
&floatResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -284,6 +291,7 @@ func (s *OldFlagEvaluationService) ResolveObject(
req.Msg.GetContext(),
&objectResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
Expand All @@ -293,21 +301,36 @@ func (s *OldFlagEvaluationService) ResolveObject(
return res, err
}

// mergeContexts combines values from the request context with the values from the config --context-values flag.
// Request context values have a higher priority.
func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
merged := make(map[string]any)
for k, v := range configFlagsCtx {
merged[k] = v
}
for k, v := range reqCtx {
merged[k] = v
}
return merged
}

// resolve is a generic flag resolver
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string,
evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder,
configContextValues map[string]any,
beeme1mr marked this conversation as resolved.
Show resolved Hide resolved
) error {
reqID := xid.New().String()
defer logger.ClearFields(reqID)

mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues)
logger.WriteFields(
reqID,
zap.String("flag-key", flagKey),
zap.Strings("context-keys", formatContextKeys(evaluationContext)),
zap.Strings("context-keys", formatContextKeys(mergedContext)),
)

var evalErrFormatted error
result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, evaluationContext.AsMap())
result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, mergedContext)
if evalErr != nil {
logger.WarnWithID(reqID, fmt.Sprintf("returning error response, reason: %v", evalErr))
reason = model.ErrorReason
Expand All @@ -329,9 +352,9 @@ func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver
return evalErrFormatted
}

func formatContextKeys(context *structpb.Struct) []string {
func formatContextKeys(context map[string]any) []string {
res := []string{}
for k := range context.AsMap() {
for k := range context {
res = append(res, k)
}
return res
Expand Down
11 changes: 11 additions & 0 deletions flagd/pkg/service/flag-evaluation/flag_evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ func TestConnectService_ResolveAll(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req))
if err != nil && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -235,6 +236,7 @@ func TestFlag_Evaluation_ResolveBoolean(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -290,6 +292,7 @@ func BenchmarkFlag_Evaluation_ResolveBoolean(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -388,6 +391,7 @@ func TestFlag_Evaluation_ResolveString(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -443,6 +447,7 @@ func BenchmarkFlag_Evaluation_ResolveString(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -540,6 +545,7 @@ func TestFlag_Evaluation_ResolveFloat(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -595,6 +601,7 @@ func BenchmarkFlag_Evaluation_ResolveFloat(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -692,6 +699,7 @@ func TestFlag_Evaluation_ResolveInt(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
Expand Down Expand Up @@ -747,6 +755,7 @@ func BenchmarkFlag_Evaluation_ResolveInt(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -847,6 +856,7 @@ func TestFlag_Evaluation_ResolveObject(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)

outParsed, err := structpb.NewStruct(tt.evalFields.result)
Expand Down Expand Up @@ -910,6 +920,7 @@ func BenchmarkFlag_Evaluation_ResolveObject(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
if name != "eval returns error" {
outParsed, err := structpb.NewStruct(tt.evalFields.result)
Expand Down
Loading
Loading