From 981f24881e89a61c5cf907da452727cafa63d9c8 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Tue, 25 Feb 2025 18:19:42 -0800 Subject: [PATCH 1/2] Separate unnest optimization from composer to capture type info --- policy/compiler_test.go | 3 + policy/composer.go | 154 +++++++++++++++++++++++++++------------- policy/helper_test.go | 9 +++ 3 files changed, 118 insertions(+), 48 deletions(-) diff --git a/policy/compiler_test.go b/policy/compiler_test.go index b318d2d6..545b01f8 100644 --- a/policy/compiler_test.go +++ b/policy/compiler_test.go @@ -79,6 +79,9 @@ func TestRuleComposerUnnest(t *testing.T) { if normalize(unparsed) != normalize(tc.composed) { t.Errorf("cel.AstToString() got %s, wanted %s", unparsed, tc.composed) } + if !ast.OutputType().IsEquivalentType(tc.outputType) { + t.Errorf("ast.OutputType() got %v, wanted %v", ast.OutputType(), tc.outputType) + } r.setup(t, env, ast) r.run(t) }) diff --git a/policy/composer.go b/policy/composer.go index 0b9be2a5..80263031 100644 --- a/policy/composer.go +++ b/policy/composer.go @@ -72,13 +72,21 @@ type RuleComposer struct { // Compose stitches together a set of expressions within a CompiledRule into a single CEL ast. func (c *RuleComposer) Compose(r *CompiledRule) (*cel.Ast, *cel.Issues) { ruleRoot, _ := c.env.Compile("true") - opt := cel.NewStaticOptimizer( - &ruleComposerImpl{ - rule: r, - varIndices: []varIndex{}, - exprUnnestHeight: c.exprUnnestHeight, - }) - return opt.Optimize(c.env, ruleRoot) + composer := &ruleComposerImpl{ + rule: r, + varIndices: []varIndex{}, + } + opt := cel.NewStaticOptimizer(composer) + ast, iss := opt.Optimize(c.env, ruleRoot) + if iss.Err() != nil { + return nil, iss + } + unnester := &ruleUnnesterImpl{ + varIndices: []varIndex{}, + exprUnnestHeight: c.exprUnnestHeight, + } + opt = cel.NewStaticOptimizer(unnester) + return opt.Optimize(c.env, ast) } type varIndex struct { @@ -93,8 +101,6 @@ type ruleComposerImpl struct { rule *CompiledRule nextVarIndex int varIndices []varIndex - - exprUnnestHeight int } // Optimize implements an AST optimizer for CEL which composes an expression graph into a single @@ -103,21 +109,16 @@ func (opt *ruleComposerImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *as // The input to optimize is a dummy expression which is completely replaced according // to the configuration of the rule composition graph. ruleExpr := opt.optimizeRule(ctx, opt.rule) - // If the rule is deeply nested, it may need to be unnested. This process may generate - // additional variables that are included in the `sortedVariables` list. - ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr) - // Collect all variables associated with the rule expression. - allVars := opt.sortedVariables() // If there were no variables, return the expression. - if len(allVars) == 0 { + if len(opt.varIndices) == 0 { return ctx.NewAST(ruleExpr) } // Otherwise populate the cel.@block with the variable declarations and wrap the expression // in the block. - varExprs := make([]ast.Expr, len(allVars)) - for i, vi := range allVars { + varExprs := make([]ast.Expr, len(opt.varIndices)) + for i, vi := range opt.varIndices { varExprs[i] = vi.expr err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType)) if err != nil { @@ -197,15 +198,94 @@ func (opt *ruleComposerImpl) rewriteVariableName(ctx *cel.OptimizerContext) ast. }) } -func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.Expr) ast.Expr { - // Split the expr into local variables based on expression height - ruleAST := ctx.NewAST(ruleExpr) - ruleNav := ast.NavigateAST(ruleAST) +// registerVariable creates an entry for a variable name within the cel.@block used to enumerate +// variables within composed policy expression. +func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) { + varName := fmt.Sprintf("variables.%s", v.Name()) + indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex) + varExpr := ctx.CopyASTAndMetadata(v.Expr().NativeRep()) + ast.PostOrderVisit(varExpr, opt.rewriteVariableName(ctx)) + vi := varIndex{ + index: opt.nextVarIndex, + indexVar: indexVar, + localVar: varName, + expr: varExpr, + celType: v.Declaration().Type()} + opt.varIndices = append(opt.varIndices, vi) + opt.nextVarIndex++ +} + +type ruleUnnesterImpl struct { + nextVarIndex int + varIndices []varIndex + exprUnnestHeight int +} + +func (opt *ruleUnnesterImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST { + // If the input AST does not start with a cel.@block, return + ruleExpr := ast.NavigateAST(a) + var varExprs []ast.Expr + if ruleExpr.Kind() == ast.CallKind && ruleExpr.AsCall().FunctionName() == "cel.@block" { + // Extract the expr from the cel.@block, args[1], as a navigable expr value. + // Also extract the variable declarations and all associated types from the cel.@block as + // varIndex values, but without doing any rewrites as the types are all correct already. + block := ruleExpr.AsCall() + ruleExpr = block.Args()[1].(ast.NavigableExpr) + + // Collect the list of variables associated with the block + varSet := block.Args()[0].(ast.NavigableExpr) + vars := varSet.AsList() + varExprs = make([]ast.Expr, vars.Size()) + copy(varExprs, vars.Elements()) + for i := 0; i < vars.Size(); i++ { + v := varExprs[i] + indexVar := fmt.Sprintf("@index%d", i) + t := a.GetType(v.ID()) + vi := varIndex{ + index: i, + indexVar: indexVar, + localVar: indexVar, + expr: v, + celType: t, + } + opt.varIndices = append(opt.varIndices, vi) + opt.nextVarIndex++ + + err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType)) + if err != nil { + ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error()) + } + } + } + + // Attempt to unnest the rule. + blockVarCount := len(varExprs) + ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr) + // If there were no variables, return the expression. + if len(opt.varIndices) == blockVarCount { + return a + } + + // Otherwise populate the cel.@block with the variable declarations and wrap the expression + // in the block. + for i := blockVarCount; i < len(opt.varIndices); i++ { + vi := opt.varIndices[i] + varExprs = append(varExprs, vi.expr) + err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType)) + if err != nil { + ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error()) + } + } + blockExpr := ctx.NewCall("cel.@block", ctx.NewList(varExprs, []int32{}), ruleExpr) + return ctx.NewAST(blockExpr) +} + +func (opt *ruleUnnesterImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr ast.NavigableExpr) ast.NavigableExpr { // Unnest expressions are ordered from leaf to root via the ast.MatchDescendants call. - heights := ast.Heights(ruleAST) + heights := ast.Heights(ast.NewAST(ruleExpr, nil)) unnestMap := map[int64]bool{} unnestExprs := []ast.NavigableExpr{} - ast.MatchDescendants(ruleNav, func(e ast.NavigableExpr) bool { + ast.MatchDescendants(ruleExpr, func(e ast.NavigableExpr) bool { // If the expression is a comprehension, then all unnest candidates captured previously that relate // to the comprehension body should be removed from the list of candidate branches for unnesting. if e.Kind() == ast.ComprehensionKind { @@ -243,31 +323,14 @@ func (opt *ruleComposerImpl) maybeUnnestRule(ctx *cel.OptimizerContext, ruleExpr continue } reduceHeight(heights, e, opt.exprUnnestHeight) - opt.registerBranchVariable(ctx, e) + opt.registerUnnestVariable(ctx, e) } return ruleExpr } -// registerVariable creates an entry for a variable name within the cel.@block used to enumerate -// variables within composed policy expression. -func (opt *ruleComposerImpl) registerVariable(ctx *cel.OptimizerContext, v *CompiledVariable) { - varName := fmt.Sprintf("variables.%s", v.Name()) - indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex) - varExpr := ctx.CopyASTAndMetadata(v.Expr().NativeRep()) - ast.PostOrderVisit(varExpr, opt.rewriteVariableName(ctx)) - vi := varIndex{ - index: opt.nextVarIndex, - indexVar: indexVar, - localVar: varName, - expr: varExpr, - celType: v.Declaration().Type()} - opt.varIndices = append(opt.varIndices, vi) - opt.nextVarIndex++ -} - -// registerBranchVariable creates an entry for a variable name within the cel.@block used to unnest +// registerUnnestVariable creates an entry for a variable name within the cel.@block used to unnest // a deeply nested logical branch or logical operator. -func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) { +func (opt *ruleUnnesterImpl) registerUnnestVariable(ctx *cel.OptimizerContext, varExpr ast.NavigableExpr) { indexVar := fmt.Sprintf("@index%d", opt.nextVarIndex) varExprCopy := ctx.CopyASTAndMetadata(ctx.NewAST(varExpr)) vi := varIndex{ @@ -281,11 +344,6 @@ func (opt *ruleComposerImpl) registerBranchVariable(ctx *cel.OptimizerContext, v opt.nextVarIndex++ } -// sortedVariables returns the variables ordered by their declaration index. -func (opt *ruleComposerImpl) sortedVariables() []varIndex { - return opt.varIndices -} - // compositionStep interface represents an intermediate stage of rule and match expression composition // // The CompiledRule and CompiledMatch types are meant to represent standalone tuples of condition diff --git a/policy/helper_test.go b/policy/helper_test.go index fbe3183a..8e117331 100644 --- a/policy/helper_test.go +++ b/policy/helper_test.go @@ -216,6 +216,7 @@ var ( expr string composed string composerOpts []ComposerOption + outputType *cel.Type }{ { name: "unnest", @@ -233,6 +234,7 @@ var ( : @index3], @index2 ? optional.of("some divisible by 2") : @index4) `, + outputType: cel.OptionalType(cel.StringType), }, { name: "required_labels", @@ -248,6 +250,7 @@ var ( "invalid values provided on one or more labels: %s".format([@index2])], @index3 ? optional.of(@index4) : (@index5 ? optional.of(@index6) : optional.none())) `, + outputType: cel.OptionalType(cel.StringType), }, { name: "required_labels", @@ -264,6 +267,7 @@ var ( (@index1.size() > 0) ? optional.of("missing one or more required labels: %s".format([@index1])) : @index3)`, + outputType: cel.OptionalType(cel.StringType), }, { name: "nested_rule2", @@ -277,6 +281,7 @@ var ( resource.?user.orValue("").startsWith("bad") ? (@index2 ? {"banned": "restricted_region"} : {"banned": "bad_actor"}) : @index3)`, + outputType: cel.MapType(cel.StringType, cel.StringType), }, { name: "nested_rule2", @@ -293,6 +298,7 @@ var ( : (!(resource.origin in @index0) ? {"banned": "unconfigured_region"} : {}))`, + outputType: cel.MapType(cel.StringType, cel.StringType), }, { name: "limits", @@ -310,6 +316,7 @@ var ( ? ((now.getHours() < 21) ? optional.of(@index4 + "!") : ((now.getHours() < 22) ? optional.of(@index4 + "!!") : @index5)) : @index6)`, + outputType: cel.OptionalType(cel.StringType), }, { name: "limits", @@ -327,6 +334,7 @@ var ( ? ((now.getHours() < 21) ? optional.of(@index4 + "!") : @index5) : optional.of(@index3.format([@index0, @index2]))) `, + outputType: cel.OptionalType(cel.StringType), }, { name: "limits", @@ -342,6 +350,7 @@ var ( ((now.getHours() < 22) ? optional.of(@index4 + "!!") : ((now.getHours() < 24) ? optional.of(@index4 + "!!!") : optional.none()))], (now.getHours() >= 20) ? @index5 : optional.of(@index3.format([@index0, @index2])))`, + outputType: cel.OptionalType(cel.StringType), }, } From 610601ed75147a7cd24a6ca13fdbee083b92bdd1 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Wed, 26 Feb 2025 12:17:27 -0800 Subject: [PATCH 2/2] Simplify the variable tracking during unnest --- policy/composer.go | 40 ++++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/policy/composer.go b/policy/composer.go index 80263031..76247248 100644 --- a/policy/composer.go +++ b/policy/composer.go @@ -222,9 +222,11 @@ type ruleUnnesterImpl struct { } func (opt *ruleUnnesterImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST { - // If the input AST does not start with a cel.@block, return + // Since the optimizer is based on the original environment provided to the composer, + // a second pass on the `cel.@block` will require a rebuilding of the cel environment ruleExpr := ast.NavigateAST(a) var varExprs []ast.Expr + var varDecls []cel.EnvOption if ruleExpr.Kind() == ast.CallKind && ruleExpr.AsCall().FunctionName() == "cel.@block" { // Extract the expr from the cel.@block, args[1], as a navigable expr value. // Also extract the variable declarations and all associated types from the cel.@block as @@ -233,42 +235,36 @@ func (opt *ruleUnnesterImpl) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *as ruleExpr = block.Args()[1].(ast.NavigableExpr) // Collect the list of variables associated with the block - varSet := block.Args()[0].(ast.NavigableExpr) - vars := varSet.AsList() + blockList := block.Args()[0].(ast.NavigableExpr) + vars := blockList.AsList() varExprs = make([]ast.Expr, vars.Size()) + varDecls = make([]cel.EnvOption, vars.Size()) copy(varExprs, vars.Elements()) - for i := 0; i < vars.Size(); i++ { - v := varExprs[i] + for i, v := range varExprs { + // Track the variable he varDecls set. indexVar := fmt.Sprintf("@index%d", i) - t := a.GetType(v.ID()) - vi := varIndex{ - index: i, - indexVar: indexVar, - localVar: indexVar, - expr: v, - celType: t, - } - opt.varIndices = append(opt.varIndices, vi) + celType := a.GetType(v.ID()) + varDecls[i] = cel.Variable(indexVar, celType) opt.nextVarIndex++ - - err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType)) - if err != nil { - ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error()) - } + } + } + if len(varDecls) != 0 { + err := ctx.ExtendEnv(varDecls...) + if err != nil { + ctx.ReportErrorAtID(ruleExpr.ID(), "%s", err.Error()) } } // Attempt to unnest the rule. - blockVarCount := len(varExprs) ruleExpr = opt.maybeUnnestRule(ctx, ruleExpr) // If there were no variables, return the expression. - if len(opt.varIndices) == blockVarCount { + if len(opt.varIndices) == 0 { return a } // Otherwise populate the cel.@block with the variable declarations and wrap the expression // in the block. - for i := blockVarCount; i < len(opt.varIndices); i++ { + for i := 0; i < len(opt.varIndices); i++ { vi := opt.varIndices[i] varExprs = append(varExprs, vi.expr) err := ctx.ExtendEnv(cel.Variable(vi.indexVar, vi.celType))