Skip to content

Commit

Permalink
Fix predefined rule value caching behavior (#188)
Browse files Browse the repository at this point in the history
At least in some cases, the value of `rule` can get inappropriately
cached. This PR attempts to fix the issue and refactor the code to make
it a bit less likely to appear in the future.

This should allow the `rule` value for predefined rules to still get
resolved during compile time, but will hopefully ensure that the
environment can't get polluted by the cache.

Closes #187.
  • Loading branch information
jchadwick-buf authored Feb 12, 2025
1 parent b8e35a2 commit 1f18b86
Show file tree
Hide file tree
Showing 9 changed files with 408 additions and 115 deletions.
69 changes: 38 additions & 31 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,23 @@ package protovalidate

import (
"fmt"
"slices"

"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
pvcel "github.com/bufbuild/protovalidate-go/cel"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/interpreter"
"google.golang.org/protobuf/reflect/protoreflect"
)

// astSet represents a collection of compiledAST and their associated cel.Env.
type astSet struct {
env *cel.Env
asts []compiledAST
}
type astSet []compiledAST

// Merge combines a set with another, producing a new ASTSet.
func (set astSet) Merge(other astSet) astSet {
out := astSet{
env: set.env,
asts: make([]compiledAST, 0, len(set.asts)+len(other.asts)),
}
if out.env == nil {
out.env = other.env
}
out.asts = append(out.asts, set.asts...)
out.asts = append(out.asts, other.asts...)
out := make([]compiledAST, 0, len(set)+len(other))
out = append(out, set...)
out = append(out, other...)
return out
}

Expand All @@ -49,7 +42,7 @@ func (set astSet) Merge(other astSet) astSet {
// generated for it. The main usage of this is to elide tautological expressions
// from the final result.
func (set astSet) ReduceResiduals(opts ...cel.ProgramOption) (programSet, error) {
residuals := make([]compiledAST, 0, len(set.asts))
residuals := make(astSet, 0, len(set))
options := append([]cel.ProgramOption{
cel.EvalOptions(
cel.OptTrackState,
Expand All @@ -59,8 +52,12 @@ func (set astSet) ReduceResiduals(opts ...cel.ProgramOption) (programSet, error)
),
}, opts...)

for _, ast := range set.asts {
program, err := ast.toProgram(set.env, options...)
for _, ast := range set {
options := slices.Clone(options)
if ast.Value.IsValid() {
options = append(options, cel.Globals(&variable{Name: "rule", Val: ast.Value.Interface()}))
}
program, err := ast.toProgram(ast.Env, options...)
if err != nil {
residuals = append(residuals, ast)
continue
Expand All @@ -78,12 +75,13 @@ func (set astSet) ReduceResiduals(opts ...cel.ProgramOption) (programSet, error)
}
}
}
residual, err := set.env.ResidualAst(ast.AST, details)
residual, err := ast.Env.ResidualAst(ast.AST, details)
if err != nil {
residuals = append(residuals, ast)
} else {
residuals = append(residuals, compiledAST{
AST: residual,
Env: ast.Env,
Source: ast.Source,
Path: ast.Path,
Value: ast.Value,
Expand All @@ -92,20 +90,17 @@ func (set astSet) ReduceResiduals(opts ...cel.ProgramOption) (programSet, error)
}
}

return astSet{
env: set.env,
asts: residuals,
}.ToProgramSet(opts...)
return residuals.ToProgramSet(opts...)
}

// ToProgramSet generates a ProgramSet from the specified ASTs.
func (set astSet) ToProgramSet(opts ...cel.ProgramOption) (out programSet, err error) {
if l := len(set.asts); l == 0 {
if l := len(set); l == 0 {
return nil, nil
}
out = make(programSet, len(set.asts))
for i, ast := range set.asts {
out[i], err = ast.toProgram(set.env, opts...)
out = make(programSet, len(set))
for i, ast := range set {
out[i], err = ast.toProgram(ast.Env, opts...)
if err != nil {
return nil, err
}
Expand All @@ -114,19 +109,31 @@ func (set astSet) ToProgramSet(opts ...cel.ProgramOption) (out programSet, err e
}

// SetRuleValue sets the rule value for the programs in the ASTSet.
func (set *astSet) SetRuleValue(
func (set astSet) WithRuleValue(
ruleValue protoreflect.Value,
ruleDescriptor protoreflect.FieldDescriptor,
) {
set.asts = append([]compiledAST{}, set.asts...)
for i := range set.asts {
set.asts[i].Value = ruleValue
set.asts[i].Descriptor = ruleDescriptor
) (out astSet, err error) {
out = slices.Clone(set)
for i := range set {
out[i].Env, err = out[i].Env.Extend(
cel.Constant(
"rule",
pvcel.ProtoFieldToType(ruleDescriptor, true, false),
pvcel.ProtoFieldToValue(ruleDescriptor, ruleValue, false),
),
)
if err != nil {
return nil, err
}
out[i].Value = ruleValue
out[i].Descriptor = ruleDescriptor
}
return out, nil
}

type compiledAST struct {
AST *cel.Ast
Env *cel.Env
Source *validate.Constraint
Path []*validate.FieldPathElement
Value protoreflect.Value
Expand Down
29 changes: 11 additions & 18 deletions ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,20 @@ func TestASTSet_Merge(t *testing.T) {

var set astSet
other := astSet{
env: &cel.Env{},
asts: []compiledAST{
{AST: &cel.Ast{}},
{AST: &cel.Ast{}},
},
{AST: &cel.Ast{}},
{AST: &cel.Ast{}},
}
merged := set.Merge(other)
assert.Equal(t, other.env, merged.env)
assert.Equal(t, other.asts, merged.asts)
assert.Equal(t, other, merged)

another := astSet{
asts: []compiledAST{
{AST: &cel.Ast{}},
{AST: &cel.Ast{}},
{AST: &cel.Ast{}},
},
{AST: &cel.Ast{}},
{AST: &cel.Ast{}},
{AST: &cel.Ast{}},
}
merged = other.Merge(another)
assert.Equal(t, other.env, merged.env)
assert.Equal(t, other.asts, merged.asts[0:2])
assert.Equal(t, another.asts, merged.asts[2:])
assert.Equal(t, other, merged[0:2])
assert.Equal(t, another, merged[2:])
}

func TestASTSet_ToProgramSet(t *testing.T) {
Expand All @@ -69,11 +62,11 @@ func TestASTSet_ToProgramSet(t *testing.T) {
cel.Variable("foo", cel.BoolType),
)
require.NoError(t, err)
assert.Len(t, asts.asts, 1)
assert.Len(t, asts, 1)
set, err := asts.ToProgramSet()
require.NoError(t, err)
assert.Len(t, set, 1)
assert.Equal(t, asts.asts[0].Source, set[0].Source)
assert.Equal(t, asts[0].Source, set[0].Source)

empty := astSet{}
set, err = empty.ToProgramSet()
Expand All @@ -97,7 +90,7 @@ func TestASTSet_ReduceResiduals(t *testing.T) {
cel.Variable("foo", cel.BoolType),
)
require.NoError(t, err)
assert.Len(t, asts.asts, 1)
assert.Len(t, asts, 1)
set, err := asts.ReduceResiduals(cel.Globals(&variable{Name: "foo", Val: true}))
require.NoError(t, err)
assert.Empty(t, set)
Expand Down
12 changes: 6 additions & 6 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ func (c *cache) Build(
var asts astSet
constraints.Range(func(desc protoreflect.FieldDescriptor, rule protoreflect.Value) bool {
fieldEnv, compileErr := env.Extend(
cel.Constant(
"rule",
pvcel.ProtoFieldToType(desc, true, false),
pvcel.ProtoFieldToValue(desc, rule, false),
),
cel.Variable("rule", pvcel.ProtoFieldToType(desc, true, false)),
)
if compileErr != nil {
err = compileErr
Expand All @@ -88,7 +84,11 @@ func (c *cache) Build(
err = compileErr
return false
}
precomputedASTs.SetRuleValue(rule, desc)
precomputedASTs, compileErr = precomputedASTs.WithRuleValue(rule, desc)
if compileErr != nil {
err = compileErr
return false
}
asts = asts.Merge(precomputedASTs)
return true
})
Expand Down
2 changes: 1 addition & 1 deletion cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestCache_LoadOrCompileStandardConstraint(t *testing.T) {

asts, err := cache.loadOrCompileStandardConstraint(env, oneOfDesc, desc)
require.NoError(t, err)
assert.NotNil(t, asts)
assert.Nil(t, asts)

cached, ok := cache.cache[desc]
assert.True(t, ok)
Expand Down
22 changes: 11 additions & 11 deletions compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,21 @@ func compileASTs(
env *cel.Env,
envOpts ...cel.EnvOption,
) (set astSet, err error) {
set.env = env
if len(expressions.Constraints) == 0 {
return set, nil
}

if len(envOpts) > 0 {
set.env, err = env.Extend(envOpts...)
if err != nil {
return set, &CompilationError{cause: fmt.Errorf(
"failed to extend environment: %w", err)}
}
}

set.asts = make([]compiledAST, len(expressions.Constraints))
set = make([]compiledAST, len(expressions.Constraints))
for i, constraint := range expressions.Constraints {
set.asts[i], err = compileAST(set.env, constraint, expressions.RulePath)
set[i].Env = env
if len(envOpts) > 0 {
set[i].Env, err = env.Extend(envOpts...)
if err != nil {
return set, &CompilationError{cause: fmt.Errorf(
"failed to extend environment: %w", err)}
}
}
set[i], err = compileAST(set[i].Env, constraint, expressions.RulePath)
if err != nil {
return set, err
}
Expand All @@ -117,6 +116,7 @@ func compileAST(env *cel.Env, constraint *validate.Constraint, rulePath []*valid

return compiledAST{
AST: ast,
Env: env,
Source: constraint,
Path: rulePath,
}, nil
Expand Down
Loading

0 comments on commit 1f18b86

Please sign in to comment.