Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate unnest optimization from composer to capture type info #1138

Merged
merged 2 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions policy/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ func TestRuleComposerUnnest(t *testing.T) {
if normalize(unparsed) != normalize(tc.composed) {
t.Errorf("cel.AstToString() got %s, wanted %s", unparsed, tc.composed)
}
if !ast.OutputType().IsEquivalentType(tc.outputType) {
t.Errorf("ast.OutputType() got %v, wanted %v", ast.OutputType(), tc.outputType)
}
r.setup(t, env, ast)
r.run(t)
})
Expand Down
150 changes: 102 additions & 48 deletions policy/composer.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,21 @@ type RuleComposer struct {
// Compose stitches together a set of expressions within a CompiledRule into a single CEL ast.
func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) {
ruleRoot, _ := c.env.Compile("true")
opt := cel.NewStaticOptimizer(
&ruleComposerImpl{
rule: r,
varIndices: []varIndex{},
exprUnnestHeight: c.exprUnnestHeight,
})
return opt.Optimize(c.env, ruleRoot)
composer := &ruleComposerImpl{
rule: r,
varIndices: []varIndex{},
}
opt := cel.NewStaticOptimizer(composer)
ast, iss := opt.Optimize(c.env, ruleRoot)
if iss.Err() != nil {
return nil, iss
}
unnester := &ruleUnnesterImpl{
varIndices: []varIndex{},
exprUnnestHeight: c.exprUnnestHeight,
}
opt = cel.NewStaticOptimizer(unnester)
return opt.Optimize(c.env, ast)
}

type varIndex struct {
Expand All @@ -93,8 +101,6 @@ type ruleComposerImpl struct {
rule *CompiledRule
nextVarIndex int
varIndices []varIndex

exprUnnestHeight int
}

// Optimize implements an AST optimizer for CEL which composes an expression graph into a single
Expand All @@ -103,21 +109,16 @@ func (opt *ruleComposerImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *as
// The input to optimize is a dummy expression which is completely replaced according
// to the configuration of the rule composition graph.
ruleExpr := opt.optimizeRule(ctx, opt.rule)
// If the rule is deeply nested, it may need to be unnested. This process may generate
// additional variables that are included in the `sortedVariables` list.
ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr)

// Collect all variables associated with the rule expression.
allVars := opt.sortedVariables()
// If there were no variables, return the expression.
if len(allVars) == 0 {
if len(opt.varIndices) == 0 {
return ctx.NewAST(ruleExpr)
}

// Otherwise populate the cel.@block with the variable declarations and wrap the expression
// in the block.
varExprs := make([]ast.Expr, len(allVars))
for i, vi := range allVars {
varExprs := make([]ast.Expr, len(opt.varIndices))
for i, vi := range opt.varIndices {
varExprs[i] = vi.expr
err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType))
if err != nil {
Expand Down Expand Up @@ -197,15 +198,90 @@ func (opt *ruleComposerImpl) rewriteVariableName(ctx *cel.OptimizerContext) ast.
})
}

func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.Expr) ast.Expr {
// Split the expr into local variables based on expression height
ruleAST := ctx.NewAST(ruleExpr)
ruleNav := ast.NavigateAST(ruleAST)
// registerVariable creates an entry for a variable name within the cel.@block used to enumerate
// variables within composed policy expression.
func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) {
varName := fmt.Sprintf("variables.%s", v.Name())
indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex)
varExpr := ctx.CopyASTAndMetadata(v.Expr().NativeRep())
ast.PostOrderVisit(varExpr, opt.rewriteVariableName(ctx))
vi := varIndex{
index: opt.nextVarIndex,
indexVar: indexVar,
localVar: varName,
expr: varExpr,
celType: v.Declaration().Type()}
opt.varIndices = append(opt.varIndices, vi)
opt.nextVarIndex++
}

type ruleUnnesterImpl struct {
nextVarIndex int
varIndices []varIndex
exprUnnestHeight int
}

func (opt *ruleUnnesterImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
// Since the optimizer is based on the original environment provided to the composer,
// a second pass on the `cel.@block` will require a rebuilding of the cel environment
ruleExpr := ast.NavigateAST(a)
var varExprs []ast.Expr
var varDecls []cel.EnvOption
if ruleExpr.Kind() == ast.CallKind && ruleExpr.AsCall().FunctionName() == "cel.@block" {
// Extract the expr from the cel.@block, args[1], as a navigable expr value.
// Also extract the variable declarations and all associated types from the cel.@block as
// varIndex values, but without doing any rewrites as the types are all correct already.
block := ruleExpr.AsCall()
ruleExpr = block.Args()[1].(ast.NavigableExpr)

// Collect the list of variables associated with the block
blockList := block.Args()[0].(ast.NavigableExpr)
vars := blockList.AsList()
varExprs = make([]ast.Expr, vars.Size())
varDecls = make([]cel.EnvOption, vars.Size())
copy(varExprs, vars.Elements())
for i, v := range varExprs {
// Track the variable he varDecls set.
indexVar := fmt.Sprintf("@index%d", i)
celType := a.GetType(v.ID())
varDecls[i] = cel.Variable(indexVar, celType)
opt.nextVarIndex++
}
}
if len(varDecls) != 0 {
err := ctx.ExtendEnv(varDecls...)
if err != nil {
ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error())
}
}

// Attempt to unnest the rule.
ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr)
// If there were no variables, return the expression.
if len(opt.varIndices) == 0 {
return a
}

// Otherwise populate the cel.@block with the variable declarations and wrap the expression
// in the block.
for i := 0; i < len(opt.varIndices); i++ {
vi := opt.varIndices[i]
varExprs = append(varExprs, vi.expr)
err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType))
if err != nil {
ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error())
}
}
blockExpr := ctx.NewCall("cel.@block", ctx.NewList(varExprs, []int32{}), ruleExpr)
return ctx.NewAST(blockExpr)
}

func (opt *ruleUnnesterImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.NavigableExpr) ast.NavigableExpr {
// Unnest expressions are ordered from leaf to root via the ast.MatchDescendants call.
heights := ast.Heights(ruleAST)
heights := ast.Heights(ast.NewAST(ruleExpr, nil))
unnestMap := map[int64]bool{}
unnestExprs := []ast.NavigableExpr{}
ast.MatchDescendants(ruleNav, func(e ast.NavigableExpr) bool {
ast.MatchDescendants(ruleExpr, func(e ast.NavigableExpr) bool {
// If the expression is a comprehension, then all unnest candidates captured previously that relate
// to the comprehension body should be removed from the list of candidate branches for unnesting.
if e.Kind() == ast.ComprehensionKind {
Expand Down Expand Up @@ -243,31 +319,14 @@ func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr
continue
}
reduceHeight(heights, e, opt.exprUnnestHeight)
opt.registerBranchVariable(ctx, e)
opt.registerUnnestVariable(ctx, e)
}
return ruleExpr
}

// registerVariable creates an entry for a variable name within the cel.@block used to enumerate
// variables within composed policy expression.
func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) {
varName := fmt.Sprintf("variables.%s", v.Name())
indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex)
varExpr := ctx.CopyASTAndMetadata(v.Expr().NativeRep())
ast.PostOrderVisit(varExpr, opt.rewriteVariableName(ctx))
vi := varIndex{
index: opt.nextVarIndex,
indexVar: indexVar,
localVar: varName,
expr: varExpr,
celType: v.Declaration().Type()}
opt.varIndices = append(opt.varIndices, vi)
opt.nextVarIndex++
}

// registerBranchVariable creates an entry for a variable name within the cel.@block used to unnest
// registerUnnestVariable creates an entry for a variable name within the cel.@block used to unnest
// a deeply nested logical branch or logical operator.
func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) {
func (opt *ruleUnnesterImpl) registerUnnestVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) {
indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex)
varExprCopy := ctx.CopyASTAndMetadata(ctx.NewAST(varExpr))
vi := varIndex{
Expand All @@ -281,11 +340,6 @@ func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, v
opt.nextVarIndex++
}

// sortedVariables returns the variables ordered by their declaration index.
func (opt *ruleComposerImpl) sortedVariables() []varIndex {
return opt.varIndices
}

// compositionStep interface represents an intermediate stage of rule and match expression composition
//
// The CompiledRule and CompiledMatch types are meant to represent standalone tuples of condition
Expand Down
9 changes: 9 additions & 0 deletions policy/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ var (
expr string
composed string
composerOpts []ComposerOption
outputType *cel.Type
}{
{
name: "unnest",
Expand All @@ -233,6 +234,7 @@ var (
: @index3],
@index2 ? optional.of("some divisible by 2") : @index4)
`,
outputType: cel.OptionalType(cel.StringType),
},
{
name: "required_labels",
Expand All @@ -248,6 +250,7 @@ var (
"invalid values provided on one or more labels: %s".format([@index2])],
@index3 ? optional.of(@index4) : (@index5 ? optional.of(@index6) : optional.none()))
`,
outputType: cel.OptionalType(cel.StringType),
},
{
name: "required_labels",
Expand All @@ -264,6 +267,7 @@ var (
(@index1.size() > 0)
? optional.of("missing one or more required labels: %s".format([@index1]))
: @index3)`,
outputType: cel.OptionalType(cel.StringType),
},
{
name: "nested_rule2",
Expand All @@ -277,6 +281,7 @@ var (
resource.?user.orValue("").startsWith("bad")
? (@index2 ? {"banned": "restricted_region"} : {"banned": "bad_actor"})
: @index3)`,
outputType: cel.MapType(cel.StringType, cel.StringType),
},
{
name: "nested_rule2",
Expand All @@ -293,6 +298,7 @@ var (
: (!(resource.origin in @index0)
? {"banned": "unconfigured_region"}
: {}))`,
outputType: cel.MapType(cel.StringType, cel.StringType),
},
{
name: "limits",
Expand All @@ -310,6 +316,7 @@ var (
? ((now.getHours() < 21) ? optional.of(@index4 + "!") :
((now.getHours() < 22) ? optional.of(@index4 + "!!") : @index5))
: @index6)`,
outputType: cel.OptionalType(cel.StringType),
},
{
name: "limits",
Expand All @@ -327,6 +334,7 @@ var (
? ((now.getHours() < 21) ? optional.of(@index4 + "!") : @index5)
: optional.of(@index3.format([@index0, @index2])))
`,
outputType: cel.OptionalType(cel.StringType),
},
{
name: "limits",
Expand All @@ -342,6 +350,7 @@ var (
((now.getHours() < 22) ? optional.of(@index4 + "!!") :
((now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none()))],
(now.getHours() >= 20) ? @index5 : optional.of(@index3.format([@index0, @index2])))`,
outputType: cel.OptionalType(cel.StringType),
},
}

Expand Down