diff --git a/go.mod b/go.mod index deb04c427bd..fa2bf6eea55 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/go-openapi/strfmt v0.21.7 github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.5.3 - github.com/google/cel-go v0.17.6 + github.com/google/cel-go v0.18.0 github.com/google/go-cmp v0.5.9 github.com/google/gops v0.3.28 github.com/google/uuid v1.3.1 @@ -175,7 +175,7 @@ require ( golang.org/x/tools v0.11.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20230803162519-f966b187b2e5 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20230726155614-23370e0ffb3e // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index 52862a250bf..411ee5a2b48 100644 --- a/go.sum +++ b/go.sum @@ -341,8 +341,8 @@ github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Z github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -github.com/google/cel-go v0.17.6 h1:QDvHTIJunIsbgN8yVukx0HGnsqVLSY6xGqo+17IjIyM= -github.com/google/cel-go v0.17.6/go.mod h1:HXZKzB0LXqer5lHHgfWAnlYwJaQBDKMjxjulNQzhwhY= +github.com/google/cel-go v0.18.0 h1:u74MPiEC8mejBrkXqrTWT102g5IFEUjxOngzQIijMzU= +github.com/google/cel-go v0.18.0/go.mod h1:PVAybmSnWkNMUZR/tEWFUiJ1Np4Hz0MHsZJcgC4zln4= github.com/google/gnostic v0.5.7-v3refs h1:FhTMOKj2VhjpouxvWJAV1TL304uMlb9zcDqkl6cEI54= github.com/google/gnostic v0.5.7-v3refs/go.mod h1:73MKFl6jIHelAJNaBGFzt3SPtZULs9dYrGFt8OiIsHQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -1309,8 +1309,8 @@ google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaE google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto v0.0.0-20230803162519-f966b187b2e5 h1:L6iMMGrtzgHsWofoFcihmDEMYeDR9KN/ThbPWGrh++g= google.golang.org/genproto v0.0.0-20230803162519-f966b187b2e5/go.mod h1:oH/ZOT02u4kWEp7oYBGYFFkCdKS/uYR9Z7+0/xuuFp8= -google.golang.org/genproto/googleapis/api v0.0.0-20230726155614-23370e0ffb3e h1:z3vDksarJxsAKM5dmEGv0GHwE2hKJ096wZra71Vs4sw= -google.golang.org/genproto/googleapis/api v0.0.0-20230726155614-23370e0ffb3e/go.mod h1:rsr7RhLuwsDKL7RmgDDCUc6yaGr1iqceVb5Wv6f6YvQ= +google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 h1:nIgk/EEq3/YlnmVVXVnm14rC2oxgs1o0ong4sD/rd44= +google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5/go.mod h1:5DZzOUPCLYL3mNkQ0ms0F3EuUNZ7py1Bqeq6sxzI7/Q= google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d h1:uvYuEyMHKNt+lT4K3bN6fGswmK8qSvcreM3BwjDh+y4= google.golang.org/genproto/googleapis/rpc v0.0.0-20230822172742-b8732ec3820d/go.mod h1:+Bk1OCOj40wS2hwAMA+aCW9ypzm63QTBBHp6lQ3p+9M= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= diff --git a/vendor/github.com/google/cel-go/cel/BUILD.bazel b/vendor/github.com/google/cel-go/cel/BUILD.bazel index 0905f635395..62b903c81b4 100644 --- a/vendor/github.com/google/cel-go/cel/BUILD.bazel +++ b/vendor/github.com/google/cel-go/cel/BUILD.bazel @@ -10,9 +10,11 @@ go_library( "cel.go", "decls.go", "env.go", + "folding.go", "io.go", "library.go", "macro.go", + "optimizer.go", "options.go", "program.go", "validator.go", @@ -56,7 +58,9 @@ go_test( "cel_test.go", "decls_test.go", "env_test.go", + "folding_test.go", "io_test.go", + "validator_test.go", ], data = [ "//cel/testdata:gen_test_fds", diff --git a/vendor/github.com/google/cel-go/cel/decls.go b/vendor/github.com/google/cel-go/cel/decls.go index 0f9501341b4..b59e3708de6 100644 --- a/vendor/github.com/google/cel-go/cel/decls.go +++ b/vendor/github.com/google/cel-go/cel/decls.go @@ -353,43 +353,3 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) { return nil, fmt.Errorf("unsupported decl: %v", d) } } - -func typeValueToKind(tv ref.Type) (Kind, error) { - switch tv { - case types.BoolType: - return BoolKind, nil - case types.DoubleType: - return DoubleKind, nil - case types.IntType: - return IntKind, nil - case types.UintType: - return UintKind, nil - case types.ListType: - return ListKind, nil - case types.MapType: - return MapKind, nil - case types.StringType: - return StringKind, nil - case types.BytesType: - return BytesKind, nil - case types.DurationType: - return DurationKind, nil - case types.TimestampType: - return TimestampKind, nil - case types.NullType: - return NullTypeKind, nil - case types.TypeType: - return TypeKind, nil - default: - switch tv.TypeName() { - case "dyn": - return DynKind, nil - case "google.protobuf.Any": - return AnyKind, nil - case "optional": - return OpaqueKind, nil - default: - return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName()) - } - } -} diff --git a/vendor/github.com/google/cel-go/cel/env.go b/vendor/github.com/google/cel-go/cel/env.go index 69ba34dbd59..786a13c4d51 100644 --- a/vendor/github.com/google/cel-go/cel/env.go +++ b/vendor/github.com/google/cel-go/cel/env.go @@ -38,26 +38,37 @@ type Source = common.Source // Ast representing the checked or unchecked expression, its source, and related metadata such as // source position information. type Ast struct { - expr *exprpb.Expr - info *exprpb.SourceInfo - source Source - refMap map[int64]*celast.ReferenceInfo - typeMap map[int64]*types.Type + source Source + impl *celast.AST } // Expr returns the proto serializable instance of the parsed/checked expression. +// +// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr() +// the result instead. func (ast *Ast) Expr() *exprpb.Expr { - return ast.expr + if ast == nil { + return nil + } + pbExpr, _ := celast.ExprToProto(ast.impl.Expr()) + return pbExpr } // IsChecked returns whether the Ast value has been successfully type-checked. func (ast *Ast) IsChecked() bool { - return ast.typeMap != nil && len(ast.typeMap) > 0 + if ast == nil { + return false + } + return ast.impl.IsChecked() } // SourceInfo returns character offset and newline position information about expression elements. func (ast *Ast) SourceInfo() *exprpb.SourceInfo { - return ast.info + if ast == nil { + return nil + } + pbInfo, _ := celast.SourceInfoToProto(ast.impl.SourceInfo()) + return pbInfo } // ResultType returns the output type of the expression if the Ast has been type-checked, else @@ -65,9 +76,6 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo { // // Deprecated: use OutputType func (ast *Ast) ResultType() *exprpb.Type { - if !ast.IsChecked() { - return chkdecls.Dyn - } out := ast.OutputType() t, err := TypeToExprType(out) if err != nil { @@ -79,16 +87,18 @@ func (ast *Ast) ResultType() *exprpb.Type { // OutputType returns the output type of the expression if the Ast has been type-checked, else // returns cel.DynType as the parse step cannot infer types. func (ast *Ast) OutputType() *Type { - t, found := ast.typeMap[ast.expr.GetId()] - if !found { - return DynType + if ast == nil { + return types.ErrorType } - return t + return ast.impl.GetType(ast.impl.Expr().ID()) } // Source returns a view of the input used to create the Ast. This source may be complete or // constructed from the SourceInfo. func (ast *Ast) Source() Source { + if ast == nil { + return nil + } return ast.source } @@ -196,29 +206,28 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) { // It is possible to have both non-nil Ast and Issues values returned from this call: however, // the mere presence of an Ast does not imply that it is valid for use. func (e *Env) Check(ast *Ast) (*Ast, *Issues) { - // Note, errors aren't currently possible on the Ast to ParsedExpr conversion. - pe, _ := AstToParsedExpr(ast) - // Construct the internal checker env, erroring if there is an issue adding the declarations. chk, err := e.initChecker() if err != nil { errs := common.NewErrors(ast.Source()) errs.ReportError(common.NoLocation, err.Error()) - return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo()) + return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo()) } - res, errs := checker.Check(pe, ast.Source(), chk) + checked, errs := checker.Check(ast.impl, ast.Source(), chk) if len(errs.GetErrors()) > 0 { - return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo()) + return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo()) } // Manually create the Ast to ensure that the Ast source information (which may be more // detailed than the information provided by Check), is returned to the caller. ast = &Ast{ - source: ast.Source(), - expr: res.Expr, - info: res.SourceInfo, - refMap: res.ReferenceMap, - typeMap: res.TypeMap} + source: ast.Source(), + impl: checked} + + // Avoid creating a validator config if it's not needed. + if len(e.validators) == 0 { + return ast, nil + } // Generate a validator configuration from the set of configured validators. vConfig := newValidatorConfig() @@ -228,9 +237,9 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { } } // Apply additional validators on the type-checked result. - iss := NewIssuesWithSourceInfo(errs, ast.SourceInfo()) + iss := NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo()) for _, v := range e.validators { - v.Validate(e, vConfig, res, iss) + v.Validate(e, vConfig, checked, iss) } if iss.Err() != nil { return nil, iss @@ -424,16 +433,11 @@ func (e *Env) Parse(txt string) (*Ast, *Issues) { // It is possible to have both non-nil Ast and Issues values returned from this call; however, // the mere presence of an Ast does not imply that it is valid for use. func (e *Env) ParseSource(src Source) (*Ast, *Issues) { - res, errs := e.prsr.Parse(src) + parsed, errs := e.prsr.Parse(src) if len(errs.GetErrors()) > 0 { return nil, &Issues{errs: errs} } - // Manually create the Ast to ensure that the text source information is propagated on - // subsequent calls to Check. - return &Ast{ - source: src, - expr: res.GetExpr(), - info: res.GetSourceInfo()}, nil + return &Ast{source: src, impl: parsed}, nil } // Program generates an evaluable instance of the Ast within the environment (Env). @@ -529,8 +533,9 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) { // TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an // Ast format and then Program again. func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { - pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State()) - expr, err := AstToString(ParsedExprToAst(pruned)) + pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State()) + newAST := &Ast{source: a.Source(), impl: pruned} + expr, err := AstToString(newAST) if err != nil { return nil, err } @@ -551,13 +556,7 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { // EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and // extension functions provided by estimator. func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) { - checked := &celast.CheckedAST{ - Expr: ast.Expr(), - SourceInfo: ast.SourceInfo(), - TypeMap: ast.typeMap, - ReferenceMap: ast.refMap, - } - return checker.Cost(checked, estimator, opts...) + return checker.Cost(ast.impl, estimator, opts...) } // configure applies a series of EnvOptions to the current environment. @@ -699,7 +698,7 @@ type Error = common.Error // Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct. type Issues struct { errs *common.Errors - info *exprpb.SourceInfo + info *celast.SourceInfo } // NewIssues returns an Issues struct from a common.Errors object. @@ -710,7 +709,7 @@ func NewIssues(errs *common.Errors) *Issues { // NewIssuesWithSourceInfo returns an Issues struct from a common.Errors object with SourceInfo metatata // which can be used with the `ReportErrorAtID` method for additional error reports within the context // information that's inferred from an expression id. -func NewIssuesWithSourceInfo(errs *common.Errors, info *exprpb.SourceInfo) *Issues { +func NewIssuesWithSourceInfo(errs *common.Errors, info *celast.SourceInfo) *Issues { return &Issues{ errs: errs, info: info, @@ -760,30 +759,7 @@ func (i *Issues) String() string { // The source metadata for the expression at `id`, if present, is attached to the error report. // To ensure that source metadata is attached to error reports, use NewIssuesWithSourceInfo. func (i *Issues) ReportErrorAtID(id int64, message string, args ...any) { - i.errs.ReportErrorAtID(id, locationByID(id, i.info), message, args...) -} - -// locationByID returns a common.Location given an expression id. -// -// TODO: move this functionality into the native SourceInfo and an overhaul of the common.Source -// as this implementation relies on the abstractions present in the protobuf SourceInfo object, -// and is replicated in the checker. -func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location { - positions := sourceInfo.GetPositions() - var line = 1 - if offset, found := positions[id]; found { - col := int(offset) - for _, lineOffset := range sourceInfo.GetLineOffsets() { - if lineOffset < offset { - line++ - col = int(offset - lineOffset) - } else { - break - } - } - return common.NewLocation(line, col) - } - return common.NoLocation + i.errs.ReportErrorAtID(id, i.info.GetStartLocation(id), message, args...) } // getStdEnv lazy initializes the CEL standard environment. @@ -814,6 +790,13 @@ func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, b return nil, false } +// FindStructFieldNames returns an empty set of field for the interop provider. +// +// To inspect the field names, migrate to a `types.Provider` implementation. +func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) { + return []string{}, false +} + // FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field // name, if one exists. // diff --git a/vendor/github.com/google/cel-go/cel/folding.go b/vendor/github.com/google/cel-go/cel/folding.go new file mode 100644 index 00000000000..025b2b04757 --- /dev/null +++ b/vendor/github.com/google/cel-go/cel/folding.go @@ -0,0 +1,558 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +// ConstantFoldingOption defines a functional option for configuring constant folding. +type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) + +// MaxConstantFoldIterations limits the number of times literals may be folding during optimization. +// +// Defaults to 100 if not set. +func MaxConstantFoldIterations(limit int) ConstantFoldingOption { + return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) { + opt.maxFoldIterations = limit + return opt, nil + } +} + +// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate +// literal values within function calls and select statements with their evaluated result. +func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) { + folder := &constantFoldingOptimizer{ + maxFoldIterations: defaultMaxConstantFoldIterations, + } + var err error + for _, o := range opts { + folder, err = o(folder) + if err != nil { + return nil, err + } + } + return folder, nil +} + +type constantFoldingOptimizer struct { + maxFoldIterations int +} + +// Optimize queries the expression graph for scalar and aggregate literal expressions within call and +// select statements and then evaluates them and replaces the call site with the literal result. +// +// Note: only values which can be represented as literals in CEL syntax are supported. +func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + + // Walk the list of foldable expression and continue to fold until there are no more folds left. + // All of the fold candidates returned by the constantExprMatcher should succeed unless there's + // a logic bug with the selection of expressions. + foldableExprs := ast.MatchDescendants(root, constantExprMatcher) + foldCount := 0 + for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations { + for _, fold := range foldableExprs { + // If the expression could be folded because it's a non-strict call, and the + // branches are pruned, continue to the next fold. + if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) { + continue + } + // Otherwise, assume all context is needed to evaluate the expression. + err := tryFold(ctx, a, fold) + if err != nil { + ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error()) + return a + } + } + foldCount++ + foldableExprs = ast.MatchDescendants(root, constantExprMatcher) + } + // Once all of the constants have been folded, try to run through the remaining comprehensions + // one last time. In this case, there's no guarantee they'll run, so we only update the + // target comprehension node with the literal value if the evaluation succeeds. + for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) { + tryFold(ctx, a, compre) + } + + // If the output is a list, map, or struct which contains optional entries, then prune it + // to make sure that the optionals, if resolved, do not surface in the output literal. + pruneOptionalElements(ctx, root) + + // Ensure that all intermediate values in the folded expression can be represented as valid + // CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents` + // to avoid extra allocations during this final pass through the AST. + ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() != ast.LiteralKind { + return + } + val := e.AsLiteral() + adapted, err := adaptLiteral(ctx, val) + if err != nil { + ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error()) + return + } + e.SetKindCase(adapted) + })) + + return a +} + +// tryFold attempts to evaluate a sub-expression to a literal. +// +// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise +// the method will return an error. +func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error { + // Assume all context is needed to evaluate the expression. + subAST := &Ast{ + impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()), + } + prg, err := ctx.Program(subAST) + if err != nil { + return err + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + return err + } + // Clear any macro metadata associated with the fold. + a.SourceInfo().ClearMacroCall(expr.ID()) + // Update the fold expression to be a literal. + expr.SetKindCase(ctx.NewLiteral(out)) + return nil +} + +// maybePruneBranches inspects the non-strict call expression to determine whether +// a branch can be removed. Evaluation will naturally prune logical and / or calls, +// but conditional will not be pruned cleanly, so this is one small area where the +// constant folding step reimplements a portion of the evaluator. +func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool { + call := expr.AsCall() + args := call.Args() + switch call.FunctionName() { + case operators.LogicalAnd, operators.LogicalOr: + return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr) + case operators.Conditional: + cond := args[0] + truthy := args[1] + falsy := args[2] + if cond.Kind() != ast.LiteralKind { + return false + } + if cond.AsLiteral() == types.True { + expr.SetKindCase(truthy) + } else { + expr.SetKindCase(falsy) + } + return true + case operators.In: + haystack := args[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + expr.SetKindCase(ctx.NewLiteral(types.False)) + return true + } + needle := args[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + expr.SetKindCase(ctx.NewLiteral(types.True)) + return true + } + } + } + } + return false +} + +func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool { + shortcircuit := types.False + skip := types.True + if function == operators.LogicalOr { + shortcircuit = types.True + skip = types.False + } + newArgs := []ast.Expr{} + for _, arg := range args { + if arg.Kind() != ast.LiteralKind { + newArgs = append(newArgs, arg) + continue + } + if arg.AsLiteral() == skip { + continue + } + if arg.AsLiteral() == shortcircuit { + expr.SetKindCase(arg) + return true + } + } + if len(newArgs) == 0 { + newArgs = append(newArgs, args[0]) + expr.SetKindCase(newArgs[0]) + return true + } + if len(newArgs) == 1 { + expr.SetKindCase(newArgs[0]) + return true + } + expr.SetKindCase(ctx.NewCall(function, newArgs...)) + return true +} + +// pruneOptionalElements works from the bottom up to resolve optional elements within +// aggregate literals. +// +// Note, many aggregate literals will be resolved as arguments to functions or select +// statements, so this method exists to handle the case where the literal could not be +// fully resolved or exists outside of a call, select, or comprehension context. +func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) { + aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher) + for _, lit := range aggregateLiterals { + switch lit.Kind() { + case ast.ListKind: + pruneOptionalListElements(ctx, lit) + case ast.MapKind: + pruneOptionalMapEntries(ctx, lit) + case ast.StructKind: + pruneOptionalStructFields(ctx, lit) + } + } +} + +func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) { + l := e.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() + if len(optIndices) == 0 { + return + } + updatedElems := []ast.Expr{} + updatedIndices := []int32{} + for i, e := range elems { + if !l.IsOptional(int32(i)) { + updatedElems = append(updatedElems, e) + continue + } + if e.Kind() != ast.LiteralKind { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + optElemVal, ok := e.AsLiteral().(*types.Optional) + if !ok { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + if !optElemVal.HasValue() { + continue + } + e.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedElems = append(updatedElems, e) + } + e.SetKindCase(ctx.NewList(updatedElems, updatedIndices)) +} + +func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) { + m := e.AsMap() + entries := m.Entries() + updatedEntries := []ast.EntryExpr{} + modified := false + for _, e := range entries { + entry := e.AsMapEntry() + key := entry.Key() + val := entry.Value() + // If the entry is not optional, or the value-side of the optional hasn't + // been resolved to a literal, then preserve the entry as-is. + if !entry.IsOptional() || val.Kind() != ast.LiteralKind { + updatedEntries = append(updatedEntries, e) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedEntries = append(updatedEntries, e) + continue + } + // When the key is not a literal, but the value is, then it needs to be + // restored to an optional value. + if key.Kind() != ast.LiteralKind { + undoOptVal, err := adaptLiteral(ctx, optElemVal) + if err != nil { + ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err) + } + val.SetKindCase(undoOptVal) + updatedEntries = append(updatedEntries, e) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedEntry := ctx.NewMapEntry(key, val, false) + updatedEntries = append(updatedEntries, updatedEntry) + } + if modified { + e.SetKindCase(ctx.NewMap(updatedEntries)) + } +} + +func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) { + s := e.AsStruct() + fields := s.Fields() + updatedFields := []ast.EntryExpr{} + modified := false + for _, f := range fields { + field := f.AsStructField() + val := field.Value() + if !field.IsOptional() || val.Kind() != ast.LiteralKind { + updatedFields = append(updatedFields, f) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedFields = append(updatedFields, f) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedField := ctx.NewStructField(field.Name(), val, false) + updatedFields = append(updatedFields, updatedField) + } + if modified { + e.SetKindCase(ctx.NewStruct(s.TypeName(), updatedFields)) + } +} + +// adaptLiteral converts a runtime CEL value to its equivalent literal expression. +// +// For strongly typed values, the type-provider will be used to reconstruct the fields +// which are present in the literal and their equivalent initialization values. +func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) { + switch t := val.Type().(type) { + case *types.Type: + switch t { + case types.BoolType, types.BytesType, types.DoubleType, types.IntType, + types.NullType, types.StringType, types.UintType: + return ctx.NewLiteral(val), nil + case types.DurationType: + return ctx.NewCall( + overloads.TypeConvertDuration, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.TimestampType: + return ctx.NewCall( + overloads.TypeConvertTimestamp, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.OptionalType: + opt := val.(*types.Optional) + if !opt.HasValue() { + return ctx.NewCall("optional.none"), nil + } + target, err := adaptLiteral(ctx, opt.GetValue()) + if err != nil { + return nil, err + } + return ctx.NewCall("optional.of", target), nil + case types.TypeType: + return ctx.NewIdent(val.(*types.Type).TypeName()), nil + case types.ListType: + l, ok := val.(traits.Lister) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + elems := make([]ast.Expr, l.Size().(types.Int)) + idx := 0 + it := l.Iterator() + for it.HasNext() == types.True { + elemVal := it.Next() + elemExpr, err := adaptLiteral(ctx, elemVal) + if err != nil { + return nil, err + } + elems[idx] = elemExpr + idx++ + } + return ctx.NewList(elems, []int32{}), nil + case types.MapType: + m, ok := val.(traits.Mapper) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + entries := make([]ast.EntryExpr, m.Size().(types.Int)) + idx := 0 + it := m.Iterator() + for it.HasNext() == types.True { + keyVal := it.Next() + keyExpr, err := adaptLiteral(ctx, keyVal) + if err != nil { + return nil, err + } + valVal := m.Get(keyVal) + valExpr, err := adaptLiteral(ctx, valVal) + if err != nil { + return nil, err + } + entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false) + idx++ + } + return ctx.NewMap(entries), nil + default: + provider := ctx.CELTypeProvider() + fields, found := provider.FindStructFieldNames(t.TypeName()) + if !found { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + tester := val.(traits.FieldTester) + indexer := val.(traits.Indexer) + fieldInits := []ast.EntryExpr{} + for _, f := range fields { + field := types.String(f) + if tester.IsSet(field) != types.True { + continue + } + fieldVal := indexer.Get(field) + fieldExpr, err := adaptLiteral(ctx, fieldVal) + if err != nil { + return nil, err + } + fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false)) + } + return ctx.NewStruct(t.TypeName(), fieldInits), nil + } + } + return nil, fmt.Errorf("failed to adapt %v to literal", val) +} + +// constantExprMatcher matches calls, select statements, and comprehensions whose arguments +// are all constant scalar or aggregate literal values. +// +// Only comprehensions which are not nested are included as possible constant folds, and only +// if all variables referenced in the comprehension stack exist are only iteration or +// accumulation variables. +func constantExprMatcher(e ast.NavigableExpr) bool { + switch e.Kind() { + case ast.CallKind: + return constantCallMatcher(e) + case ast.SelectKind: + sel := e.AsSelect() // guaranteed to be a navigable value + return constantMatcher(sel.Operand().(ast.NavigableExpr)) + case ast.ComprehensionKind: + if isNestedComprehension(e) { + return false + } + vars := map[string]bool{} + constantExprs := true + visitor := ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() == ast.ComprehensionKind { + nested := e.AsComprehension() + vars[nested.AccuVar()] = true + vars[nested.IterVar()] = true + } + if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] { + constantExprs = false + } + }) + ast.PreOrderVisit(e, visitor) + return constantExprs + default: + return false + } +} + +// constantCallMatcher identifies strict and non-strict calls which can be folded. +func constantCallMatcher(e ast.NavigableExpr) bool { + call := e.AsCall() + children := e.Children() + fnName := call.FunctionName() + if fnName == operators.LogicalAnd { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.LogicalOr { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.Conditional { + cond := children[0] + if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType { + return true + } + } + if fnName == operators.In { + haystack := children[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + return true + } + needle := children[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + return true + } + } + } + } + // convert all other calls with constant arguments + for _, child := range children { + if !constantMatcher(child) { + return false + } + } + return true +} + +func isNestedComprehension(e ast.NavigableExpr) bool { + parent, found := e.Parent() + for found { + if parent.Kind() == ast.ComprehensionKind { + return true + } + parent, found = parent.Parent() + } + return false +} + +func aggregateLiteralMatcher(e ast.NavigableExpr) bool { + return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind +} + +var ( + constantMatcher = ast.ConstantValueMatcher() +) + +const ( + defaultMaxConstantFoldIterations = 100 +) diff --git a/vendor/github.com/google/cel-go/cel/inlining.go b/vendor/github.com/google/cel-go/cel/inlining.go new file mode 100644 index 00000000000..9fc3be278c1 --- /dev/null +++ b/vendor/github.com/google/cel-go/cel/inlining.go @@ -0,0 +1,220 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/traits" +) + +// InlineVariable holds a variable name to be matched and an AST representing +// the expression graph which should be used to replace it. +type InlineVariable struct { + name string + alias string + def *ast.AST +} + +// Name returns the qualified variable or field selection to replace. +func (v *InlineVariable) Name() string { + return v.name +} + +// Alias returns the alias to use when performing cel.bind() calls during inlining. +func (v *InlineVariable) Alias() string { + return v.alias +} + +// Expr returns the inlined expression value. +func (v *InlineVariable) Expr() ast.Expr { + return v.def.Expr() +} + +// Type indicates the inlined expression type. +func (v *InlineVariable) Type() *Type { + return v.def.GetType(v.def.Expr().ID()) +} + +// NewInlineVariable declares a variable name to be replaced by a checked expression. +func NewInlineVariable(name string, definition *Ast) *InlineVariable { + return NewInlineVariableWithAlias(name, name, definition) +} + +// NewInlineVariableWithAlias declares a variable name to be replaced by a checked expression. +// If the variable occurs more than once, the provided alias will be used to replace the expressions +// where the variable name occurs. +func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable { + return &InlineVariable{name: name, alias: alias, def: definition.impl} +} + +// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions. +// +// If a variable occurs one time, the variable is replaced by the inline definition. If the +// variable occurs more than once, the variable occurences are replaced by a cel.bind() call. +func NewInliningOptimizer(inlineVars ...*InlineVariable) ASTOptimizer { + return &inliningOptimizer{variables: inlineVars} +} + +type inliningOptimizer struct { + variables []*InlineVariable +} + +func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + for _, inlineVar := range opt.variables { + matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.Name())) + // Skip cases where the variable isn't in the expression graph + if len(matches) == 0 { + continue + } + + // For a single match, do a direct replacement of the expression sub-graph. + if len(matches) == 1 { + opt.inlineExpr(ctx, matches[0], ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type()) + continue + } + + if !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) { + for _, match := range matches { + opt.inlineExpr(ctx, match, ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type()) + } + continue + } + // For multiple matches, find the least common ancestor (lca) and insert the + // variable as a cel.bind() macro. + var lca ast.NavigableExpr = nil + ancestors := map[int64]bool{} + for _, match := range matches { + // Update the identifier matches with the provided alias. + aliasExpr := ctx.NewIdent(inlineVar.Alias()) + opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type()) + parent, found := match, true + for found { + _, hasAncestor := ancestors[parent.ID()] + if hasAncestor && (lca == nil || lca.Depth() < parent.Depth()) { + lca = parent + } + ancestors[parent.ID()] = true + parent, found = parent.Parent() + } + } + + // Update the least common ancestor by inserting a cel.bind() call to the alias. + inlined := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), inlineVar.Expr(), lca) + opt.inlineExpr(ctx, lca, inlined, inlineVar.Type()) + } + return a +} + +// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining +// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is +// made to determine whether the inlined value can be presence or existence tested. +func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) { + switch prev.Kind() { + case ast.SelectKind: + sel := prev.AsSelect() + if !sel.IsTestOnly() { + prev.SetKindCase(inlined) + return + } + opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType) + default: + prev.SetKindCase(inlined) + } +} + +// rewritePresenceExpr converts the inlined expression, when it occurs within a has() macro, to type-safe +// expression appropriate for the inlined type, if possible. +// +// If the rewrite is not possible an error is reported at the inline expression site. +func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) { + // If the input inlined expression is not a select expression it won't work with the has() + // macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error. + ctx.sourceInfo.ClearMacroCall(prev.ID()) + if inlined.Kind() == ast.SelectKind { + inlinedSel := inlined.AsSelect() + prev.SetKindCase( + ctx.NewPresenceTest(prev.ID(), inlinedSel.Operand(), inlinedSel.FieldName())) + return + } + if inlinedType.IsAssignableType(NullType) { + prev.SetKindCase( + ctx.NewCall(operators.NotEquals, + inlined, + ctx.NewLiteral(types.NullValue), + )) + return + } + if inlinedType.HasTrait(traits.SizerType) { + prev.SetKindCase( + ctx.NewCall(operators.NotEquals, + ctx.NewMemberCall(overloads.Size, inlined), + ctx.NewLiteral(types.IntZero), + )) + return + } + ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType) +} + +// isBindable indicates whether the inlined type can be used within a cel.bind() if the expression +// being replaced occurs within a presence test. Value types with a size() method or field selection +// support can be bound. +// +// In future iterations, support may also be added for indexer types which can be rewritten as an `in` +// expression; however, this would imply a rewrite of the inlined expression that may not be necessary +// in most cases. +func isBindable(matches []ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) bool { + if inlinedType.IsAssignableType(NullType) || + inlinedType.HasTrait(traits.SizerType) || + inlinedType.HasTrait(traits.FieldTesterType) { + return true + } + for _, m := range matches { + if m.Kind() != ast.SelectKind { + continue + } + sel := m.AsSelect() + if sel.IsTestOnly() { + return false + } + } + return true +} + +// matchVariable matches simple identifiers, select expressions, and presence test expressions +// which match the (potentially) qualified variable name provided as input. +// +// Note, this function does not support inlining against select expressions which includes optional +// field selection. This may be a future refinement. +func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher { + return func(e ast.NavigableExpr) bool { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + return true + } + if e.Kind() == ast.SelectKind { + sel := e.AsSelect() + // While the `ToQualifiedName` call could take the select directly, this + // would skip presence tests from possible matches, which we would like + // to include. + qualName, found := containers.ToQualifiedName(sel.Operand()) + return found && qualName+"."+sel.FieldName() == varName + } + return false + } +} diff --git a/vendor/github.com/google/cel-go/cel/io.go b/vendor/github.com/google/cel-go/cel/io.go index 80f63140e30..3133fb9d7d6 100644 --- a/vendor/github.com/google/cel-go/cel/io.go +++ b/vendor/github.com/google/cel-go/cel/io.go @@ -47,17 +47,11 @@ func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast { // // Prefer CheckedExprToAst if loading expressions from storage. func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) { - checkedAST, err := ast.CheckedExprToCheckedAST(checkedExpr) + checked, err := ast.ToAST(checkedExpr) if err != nil { return nil, err } - return &Ast{ - expr: checkedAST.Expr, - info: checkedAST.SourceInfo, - source: src, - refMap: checkedAST.ReferenceMap, - typeMap: checkedAST.TypeMap, - }, nil + return &Ast{source: src, impl: checked}, nil } // AstToCheckedExpr converts an Ast to an protobuf CheckedExpr value. @@ -67,13 +61,7 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) { if !a.IsChecked() { return nil, fmt.Errorf("cannot convert unchecked ast") } - cAst := &ast.CheckedAST{ - Expr: a.expr, - SourceInfo: a.info, - ReferenceMap: a.refMap, - TypeMap: a.typeMap, - } - return ast.CheckedASTToCheckedExpr(cAst) + return ast.ToProto(a.impl) } // ParsedExprToAst converts a parsed expression proto message to an Ast. @@ -89,18 +77,12 @@ func ParsedExprToAst(parsedExpr *exprpb.ParsedExpr) *Ast { // // Prefer ParsedExprToAst if loading expressions from storage. func ParsedExprToAstWithSource(parsedExpr *exprpb.ParsedExpr, src Source) *Ast { - si := parsedExpr.GetSourceInfo() - if si == nil { - si = &exprpb.SourceInfo{} - } + info, _ := ast.ProtoToSourceInfo(parsedExpr.GetSourceInfo()) if src == nil { - src = common.NewInfoSource(si) - } - return &Ast{ - expr: parsedExpr.GetExpr(), - info: si, - source: src, + src = common.NewInfoSource(parsedExpr.GetSourceInfo()) } + e, _ := ast.ProtoToExpr(parsedExpr.GetExpr()) + return &Ast{source: src, impl: ast.NewAST(e, info)} } // AstToParsedExpr converts an Ast to an protobuf ParsedExpr value. @@ -116,9 +98,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) { // Note, the conversion may not be an exact replica of the original expression, but will produce // a string that is semantically equivalent and whose textual representation is stable. func AstToString(a *Ast) (string, error) { - expr := a.Expr() - info := a.SourceInfo() - return parser.Unparse(expr, info) + return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo()) } // RefValueToValue converts between ref.Val and api.expr.Value. diff --git a/vendor/github.com/google/cel-go/cel/library.go b/vendor/github.com/google/cel-go/cel/library.go index 4d232085c24..deddc14e59a 100644 --- a/vendor/github.com/google/cel-go/cel/library.go +++ b/vendor/github.com/google/cel-go/cel/library.go @@ -20,6 +20,7 @@ import ( "strings" "time" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/stdlib" @@ -28,8 +29,6 @@ import ( "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/interpreter" "github.com/google/cel-go/parser" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) const ( @@ -313,7 +312,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption { Types(types.OptionalType), // Configure the optMap and optFlatMap macros. - Macros(NewReceiverMacro(optMapMacro, 2, optMap)), + Macros(ReceiverMacro(optMapMacro, 2, optMap)), // Global and member functions for working with optional values. Function(optionalOfFunc, @@ -374,7 +373,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption { Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)), } if lib.version >= 1 { - opts = append(opts, Macros(NewReceiverMacro(optFlatMapMacro, 2, optFlatMap))) + opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap))) } return opts } @@ -386,57 +385,57 @@ func (lib *optionalLib) ProgramOptions() []ProgramOption { } } -func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { +func optMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) { varIdent := args[0] varName := "" - switch varIdent.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr: - varName = varIdent.GetIdentExpr().GetName() + switch varIdent.Kind() { + case ast.IdentKind: + varName = varIdent.AsIdent() default: - return nil, meh.NewError(varIdent.GetId(), "optMap() variable name must be a simple identifier") + return nil, meh.NewError(varIdent.ID(), "optMap() variable name must be a simple identifier") } mapExpr := args[1] - return meh.GlobalCall( + return meh.NewCall( operators.Conditional, - meh.ReceiverCall(hasValueFunc, target), - meh.GlobalCall(optionalOfFunc, - meh.Fold( - unusedIterVar, + meh.NewMemberCall(hasValueFunc, target), + meh.NewCall(optionalOfFunc, + meh.NewComprehension( meh.NewList(), + unusedIterVar, varName, - meh.ReceiverCall(valueFunc, target), - meh.LiteralBool(false), - meh.Ident(varName), + meh.NewMemberCall(valueFunc, target), + meh.NewLiteral(types.False), + meh.NewIdent(varName), mapExpr, ), ), - meh.GlobalCall(optionalNoneFunc), + meh.NewCall(optionalNoneFunc), ), nil } -func optFlatMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { +func optFlatMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) { varIdent := args[0] varName := "" - switch varIdent.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr: - varName = varIdent.GetIdentExpr().GetName() + switch varIdent.Kind() { + case ast.IdentKind: + varName = varIdent.AsIdent() default: - return nil, meh.NewError(varIdent.GetId(), "optFlatMap() variable name must be a simple identifier") + return nil, meh.NewError(varIdent.ID(), "optFlatMap() variable name must be a simple identifier") } mapExpr := args[1] - return meh.GlobalCall( + return meh.NewCall( operators.Conditional, - meh.ReceiverCall(hasValueFunc, target), - meh.Fold( - unusedIterVar, + meh.NewMemberCall(hasValueFunc, target), + meh.NewComprehension( meh.NewList(), + unusedIterVar, varName, - meh.ReceiverCall(valueFunc, target), - meh.LiteralBool(false), - meh.Ident(varName), + meh.NewMemberCall(valueFunc, target), + meh.NewLiteral(types.False), + meh.NewIdent(varName), mapExpr, ), - meh.GlobalCall(optionalNoneFunc), + meh.NewCall(optionalNoneFunc), ), nil } diff --git a/vendor/github.com/google/cel-go/cel/macro.go b/vendor/github.com/google/cel-go/cel/macro.go index 1eb414c8be4..4db1fd57a9c 100644 --- a/vendor/github.com/google/cel-go/cel/macro.go +++ b/vendor/github.com/google/cel-go/cel/macro.go @@ -15,6 +15,11 @@ package cel import ( + "fmt" + + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" "github.com/google/cel-go/parser" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -26,7 +31,14 @@ import ( // a Macro should be created per arg-count or as a var arg macro. type Macro = parser.Macro -// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree. +// MacroFactory defines an expansion function which converts a call and its arguments to a cel.Expr value. +type MacroFactory = parser.MacroExpander + +// MacroExprFactory assists with the creation of Expr values in a manner which is consistent +// the internal semantics and id generation behaviors of the parser and checker libraries. +type MacroExprFactory = parser.ExprHelper + +// MacroExpander converts a call and its associated arguments into a protobuf Expr representation. // // If the MacroExpander determines within the implementation that an expansion is not needed it may return // a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments @@ -36,48 +48,197 @@ type Macro = parser.Macro // and produces as output an Expr ast node. // // Note: when the Macro.IsReceiverStyle() method returns true, the target argument will be nil. -type MacroExpander = parser.MacroExpander +type MacroExpander func(eh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) // MacroExprHelper exposes helper methods for creating new expressions within a CEL abstract syntax tree. -type MacroExprHelper = parser.ExprHelper +// ExprHelper assists with the manipulation of proto-based Expr values in a manner which is +// consistent with the source position and expression id generation code leveraged by both +// the parser and type-checker. +type MacroExprHelper interface { + // Copy the input expression with a brand new set of identifiers. + Copy(*exprpb.Expr) *exprpb.Expr + + // LiteralBool creates an Expr value for a bool literal. + LiteralBool(value bool) *exprpb.Expr + + // LiteralBytes creates an Expr value for a byte literal. + LiteralBytes(value []byte) *exprpb.Expr + + // LiteralDouble creates an Expr value for double literal. + LiteralDouble(value float64) *exprpb.Expr + + // LiteralInt creates an Expr value for an int literal. + LiteralInt(value int64) *exprpb.Expr + + // LiteralString creates am Expr value for a string literal. + LiteralString(value string) *exprpb.Expr + + // LiteralUint creates an Expr value for a uint literal. + LiteralUint(value uint64) *exprpb.Expr + + // NewList creates a CreateList instruction where the list is comprised of the optional set + // of elements provided as arguments. + NewList(elems ...*exprpb.Expr) *exprpb.Expr + + // NewMap creates a CreateStruct instruction for a map where the map is comprised of the + // optional set of key, value entries. + NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + + // NewMapEntry creates a Map Entry for the key, value pair. + NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + + // NewObject creates a CreateStruct instruction for an object with a given type name and + // optional set of field initializers. + NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + + // NewObjectFieldInit creates a new Object field initializer from the field name and value. + NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + + // Fold creates a fold comprehension instruction. + // + // - iterVar is the iteration variable name. + // - iterRange represents the expression that resolves to a list or map where the elements or + // keys (respectively) will be iterated over. + // - accuVar is the accumulation variable name, typically parser.AccumulatorName. + // - accuInit is the initial expression whose value will be set for the accuVar prior to + // folding. + // - condition is the expression to test to determine whether to continue folding. + // - step is the expression to evaluation at the conclusion of a single fold iteration. + // - result is the computation to evaluate at the conclusion of the fold. + // + // The accuVar should not shadow variable names that you would like to reference within the + // environment in the step and condition expressions. Presently, the name __result__ is commonly + // used by built-in macros but this may change in the future. + Fold(iterVar string, + iterRange *exprpb.Expr, + accuVar string, + accuInit *exprpb.Expr, + condition *exprpb.Expr, + step *exprpb.Expr, + result *exprpb.Expr) *exprpb.Expr + + // Ident creates an identifier Expr value. + Ident(name string) *exprpb.Expr + + // AccuIdent returns an accumulator identifier for use with comprehension results. + AccuIdent() *exprpb.Expr + + // GlobalCall creates a function call Expr value for a global (free) function. + GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr + + // ReceiverCall creates a function call Expr value for a receiver-style function. + ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr + + // PresenceTest creates a Select TestOnly Expr value for modelling has() semantics. + PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr + + // Select create a field traversal Expr value. + Select(operand *exprpb.Expr, field string) *exprpb.Expr + + // OffsetLocation returns the Location of the expression identifier. + OffsetLocation(exprID int64) common.Location + + // NewError associates an error message with a given expression id. + NewError(exprID int64, message string) *Error +} + +// GlobalMacro creates a Macro for a global function with the specified arg count. +func GlobalMacro(function string, argCount int, factory MacroFactory) Macro { + return parser.NewGlobalMacro(function, argCount, factory) +} + +// ReceiverMacro creates a Macro for a receiver function matching the specified arg count. +func ReceiverMacro(function string, argCount int, factory MacroFactory) Macro { + return parser.NewReceiverMacro(function, argCount, factory) +} + +// GlobalVarArgMacro creates a Macro for a global function with a variable arg count. +func GlobalVarArgMacro(function string, factory MacroFactory) Macro { + return parser.NewGlobalVarArgMacro(function, factory) +} + +// ReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count. +func ReceiverVarArgMacro(function string, factory MacroFactory) Macro { + return parser.NewReceiverVarArgMacro(function, factory) +} // NewGlobalMacro creates a Macro for a global function with the specified arg count. +// +// Deprecated: use GlobalMacro func NewGlobalMacro(function string, argCount int, expander MacroExpander) Macro { - return parser.NewGlobalMacro(function, argCount, expander) + expand := adaptingExpander{expander} + return parser.NewGlobalMacro(function, argCount, expand.Expander) } // NewReceiverMacro creates a Macro for a receiver function matching the specified arg count. +// +// Deprecated: use ReceiverMacro func NewReceiverMacro(function string, argCount int, expander MacroExpander) Macro { - return parser.NewReceiverMacro(function, argCount, expander) + expand := adaptingExpander{expander} + return parser.NewReceiverMacro(function, argCount, expand.Expander) } // NewGlobalVarArgMacro creates a Macro for a global function with a variable arg count. +// +// Deprecated: use GlobalVarArgMacro func NewGlobalVarArgMacro(function string, expander MacroExpander) Macro { - return parser.NewGlobalVarArgMacro(function, expander) + expand := adaptingExpander{expander} + return parser.NewGlobalVarArgMacro(function, expand.Expander) } // NewReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count. +// +// Deprecated: use ReceiverVarArgMacro func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro { - return parser.NewReceiverVarArgMacro(function, expander) + expand := adaptingExpander{expander} + return parser.NewReceiverVarArgMacro(function, expand.Expander) } // HasMacroExpander expands the input call arguments into a presence test, e.g. has(.field) func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeHas(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + arg, err := adaptToExpr(args[0]) + if err != nil { + return nil, err + } + if arg.Kind() == ast.SelectKind { + s := arg.AsSelect() + return adaptToProto(ph.NewPresenceTest(s.Operand(), s.FieldName())) + } + return nil, ph.NewError(arg.ID(), "invalid argument to has() macro") } // ExistsMacroExpander expands the input call arguments into a comprehension that returns true if any of the // elements in the range match the predicate expressions: // .exists(, ) func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeExists(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeExists(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) } // ExistsOneMacroExpander expands the input call arguments into a comprehension that returns true if exactly // one of the elements in the range match the predicate expressions: // .exists_one(, ) func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeExistsOne(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeExistsOne(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) } // MapMacroExpander expands the input call arguments into a comprehension that transforms each element in the @@ -91,14 +252,30 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex // In the second form only iterVar values which return true when provided to the predicate expression // are transformed. func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeMap(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeMap(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) } // FilterMacroExpander expands the input call arguments into a comprehension which produces a list which contains // only elements which match the provided predicate expression: // .filter(, ) func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeFilter(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeFilter(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) } var ( @@ -142,3 +319,258 @@ var ( // NoMacros provides an alias to an empty list of macros NoMacros = []Macro{} ) + +type adaptingExpander struct { + legacyExpander MacroExpander +} + +func (adapt *adaptingExpander) Expander(eh parser.ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + var legacyTarget *exprpb.Expr = nil + var err *Error = nil + if target != nil { + legacyTarget, err = adaptToProto(target) + if err != nil { + return nil, err + } + } + legacyArgs := make([]*exprpb.Expr, len(args)) + for i, arg := range args { + legacyArgs[i], err = adaptToProto(arg) + if err != nil { + return nil, err + } + } + ah := &adaptingHelper{modernHelper: eh} + legacyExpr, err := adapt.legacyExpander(ah, legacyTarget, legacyArgs) + if err != nil { + return nil, err + } + ex, err := adaptToExpr(legacyExpr) + if err != nil { + return nil, err + } + return ex, nil +} + +func wrapErr(id int64, message string, err error) *common.Error { + return &common.Error{ + Location: common.NoLocation, + Message: fmt.Sprintf("%s: %v", message, err), + ExprID: id, + } +} + +type adaptingHelper struct { + modernHelper parser.ExprHelper +} + +// Copy the input expression with a brand new set of identifiers. +func (ah *adaptingHelper) Copy(e *exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.Copy(mustAdaptToExpr(e))) +} + +// LiteralBool creates an Expr value for a bool literal. +func (ah *adaptingHelper) LiteralBool(value bool) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bool(value))) +} + +// LiteralBytes creates an Expr value for a byte literal. +func (ah *adaptingHelper) LiteralBytes(value []byte) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bytes(value))) +} + +// LiteralDouble creates an Expr value for double literal. +func (ah *adaptingHelper) LiteralDouble(value float64) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Double(value))) +} + +// LiteralInt creates an Expr value for an int literal. +func (ah *adaptingHelper) LiteralInt(value int64) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Int(value))) +} + +// LiteralString creates am Expr value for a string literal. +func (ah *adaptingHelper) LiteralString(value string) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.String(value))) +} + +// LiteralUint creates an Expr value for a uint literal. +func (ah *adaptingHelper) LiteralUint(value uint64) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Uint(value))) +} + +// NewList creates a CreateList instruction where the list is comprised of the optional set +// of elements provided as arguments. +func (ah *adaptingHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewList(mustAdaptToExprs(elems)...)) +} + +// NewMap creates a CreateStruct instruction for a map where the map is comprised of the +// optional set of key, value entries. +func (ah *adaptingHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { + adaptedEntries := make([]ast.EntryExpr, len(entries)) + for i, e := range entries { + adaptedEntries[i] = mustAdaptToEntryExpr(e) + } + return mustAdaptToProto(ah.modernHelper.NewMap(adaptedEntries...)) +} + +// NewMapEntry creates a Map Entry for the key, value pair. +func (ah *adaptingHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { + return mustAdaptToProtoEntry( + ah.modernHelper.NewMapEntry(mustAdaptToExpr(key), mustAdaptToExpr(val), optional)) +} + +// NewObject creates a CreateStruct instruction for an object with a given type name and +// optional set of field initializers. +func (ah *adaptingHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { + adaptedEntries := make([]ast.EntryExpr, len(fieldInits)) + for i, e := range fieldInits { + adaptedEntries[i] = mustAdaptToEntryExpr(e) + } + return mustAdaptToProto(ah.modernHelper.NewStruct(typeName, adaptedEntries...)) +} + +// NewObjectFieldInit creates a new Object field initializer from the field name and value. +func (ah *adaptingHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { + return mustAdaptToProtoEntry( + ah.modernHelper.NewStructField(field, mustAdaptToExpr(init), optional)) +} + +// Fold creates a fold comprehension instruction. +// +// - iterVar is the iteration variable name. +// - iterRange represents the expression that resolves to a list or map where the elements or +// keys (respectively) will be iterated over. +// - accuVar is the accumulation variable name, typically parser.AccumulatorName. +// - accuInit is the initial expression whose value will be set for the accuVar prior to +// folding. +// - condition is the expression to test to determine whether to continue folding. +// - step is the expression to evaluation at the conclusion of a single fold iteration. +// - result is the computation to evaluate at the conclusion of the fold. +// +// The accuVar should not shadow variable names that you would like to reference within the +// environment in the step and condition expressions. Presently, the name __result__ is commonly +// used by built-in macros but this may change in the future. +func (ah *adaptingHelper) Fold(iterVar string, + iterRange *exprpb.Expr, + accuVar string, + accuInit *exprpb.Expr, + condition *exprpb.Expr, + step *exprpb.Expr, + result *exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto( + ah.modernHelper.NewComprehension( + mustAdaptToExpr(iterRange), + iterVar, + accuVar, + mustAdaptToExpr(accuInit), + mustAdaptToExpr(condition), + mustAdaptToExpr(step), + mustAdaptToExpr(result), + ), + ) +} + +// Ident creates an identifier Expr value. +func (ah *adaptingHelper) Ident(name string) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewIdent(name)) +} + +// AccuIdent returns an accumulator identifier for use with comprehension results. +func (ah *adaptingHelper) AccuIdent() *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewAccuIdent()) +} + +// GlobalCall creates a function call Expr value for a global (free) function. +func (ah *adaptingHelper) GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewCall(function, mustAdaptToExprs(args)...)) +} + +// ReceiverCall creates a function call Expr value for a receiver-style function. +func (ah *adaptingHelper) ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto( + ah.modernHelper.NewMemberCall(function, mustAdaptToExpr(target), mustAdaptToExprs(args)...)) +} + +// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics. +func (ah *adaptingHelper) PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr { + op := mustAdaptToExpr(operand) + return mustAdaptToProto(ah.modernHelper.NewPresenceTest(op, field)) +} + +// Select create a field traversal Expr value. +func (ah *adaptingHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr { + op := mustAdaptToExpr(operand) + return mustAdaptToProto(ah.modernHelper.NewSelect(op, field)) +} + +// OffsetLocation returns the Location of the expression identifier. +func (ah *adaptingHelper) OffsetLocation(exprID int64) common.Location { + return ah.modernHelper.OffsetLocation(exprID) +} + +// NewError associates an error message with a given expression id. +func (ah *adaptingHelper) NewError(exprID int64, message string) *Error { + return ah.modernHelper.NewError(exprID, message) +} + +func mustAdaptToExprs(exprs []*exprpb.Expr) []ast.Expr { + adapted := make([]ast.Expr, len(exprs)) + for i, e := range exprs { + adapted[i] = mustAdaptToExpr(e) + } + return adapted +} + +func mustAdaptToExpr(e *exprpb.Expr) ast.Expr { + out, _ := adaptToExpr(e) + return out +} + +func adaptToExpr(e *exprpb.Expr) (ast.Expr, *Error) { + if e == nil { + return nil, nil + } + out, err := ast.ProtoToExpr(e) + if err != nil { + return nil, wrapErr(e.GetId(), "proto conversion failure", err) + } + return out, nil +} + +func mustAdaptToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) ast.EntryExpr { + out, _ := ast.ProtoToEntryExpr(e) + return out +} + +func mustAdaptToProto(e ast.Expr) *exprpb.Expr { + out, _ := adaptToProto(e) + return out +} + +func adaptToProto(e ast.Expr) (*exprpb.Expr, *Error) { + if e == nil { + return nil, nil + } + out, err := ast.ExprToProto(e) + if err != nil { + return nil, wrapErr(e.ID(), "expr conversion failure", err) + } + return out, nil +} + +func mustAdaptToProtoEntry(e ast.EntryExpr) *exprpb.Expr_CreateStruct_Entry { + out, _ := ast.EntryExprToProto(e) + return out +} + +func toParserHelper(meh MacroExprHelper) (parser.ExprHelper, *Error) { + ah, ok := meh.(*adaptingHelper) + if !ok { + return nil, common.NewError(0, + fmt.Sprintf("unsupported macro helper: %v (%T)", meh, meh), + common.NoLocation) + } + return ah.modernHelper, nil +} diff --git a/vendor/github.com/google/cel-go/cel/optimizer.go b/vendor/github.com/google/cel-go/cel/optimizer.go new file mode 100644 index 00000000000..9422a7eb63b --- /dev/null +++ b/vendor/github.com/google/cel-go/cel/optimizer.go @@ -0,0 +1,390 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order. +// +// The static optimizer normalizes expression ids and type-checking run between optimization +// passes to ensure that the final optimized output is a valid expression with metadata consistent +// with what would have been generated from a parsed and checked expression. +// +// Note: source position information is best-effort and likely wrong, but optimized expressions +// should be suitable for calls to parser.Unparse. +type StaticOptimizer struct { + optimizers []ASTOptimizer +} + +// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied +// to a checked expression. +func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { + return &StaticOptimizer{ + optimizers: optimizers, + } +} + +// Optimize applies a sequence of optimizations to an Ast within a given environment. +// +// If issues are encountered, the Issues.Err() return value will be non-nil. +func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { + // Make a copy of the AST to be optimized. + optimized := ast.Copy(a.impl) + + // Create the optimizer context, could be pooled in the future. + issues := NewIssues(common.NewErrors(a.Source())) + ids := newMonotonicIDGen(ast.MaxID(a.impl)) + fac := &optimizerExprFactory{ + nextID: ids.nextID, + renumberID: ids.renumberID, + fac: ast.NewExprFactory(), + sourceInfo: optimized.SourceInfo(), + } + ctx := &OptimizerContext{ + optimizerExprFactory: fac, + Env: env, + Issues: issues, + } + + // Apply the optimizations sequentially. + for _, o := range opt.optimizers { + optimized = o.Optimize(ctx, optimized) + if issues.Err() != nil { + return nil, issues + } + // Normalize expression id metadata including coordination with macro call metadata. + normalizeIDs(env, optimized) + + // Recheck the updated expression for any possible type-agreement or validation errors. + parsed := &Ast{ + source: a.Source(), + impl: ast.NewAST(optimized.Expr(), optimized.SourceInfo())} + checked, iss := ctx.Check(parsed) + if iss.Err() != nil { + return nil, iss + } + optimized = checked.impl + } + + // Return the optimized result. + return &Ast{ + source: a.Source(), + impl: optimized, + }, nil +} + +// normalizeIDs ensures that the metadata present with an AST is reset in a manner such +// that the ids within the expression correspond to the ids within macros. +func normalizeIDs(e *Env, optimized *ast.AST) { + ids := newStableIDGen() + optimized.Expr().RenumberIDs(ids.renumberID) + allExprMap := make(map[int64]ast.Expr) + ast.PostOrderVisit(optimized.Expr(), ast.NewExprVisitor(func(e ast.Expr) { + allExprMap[e.ID()] = e + })) + info := optimized.SourceInfo() + + // First, update the macro call ids themselves. + for id, call := range info.MacroCalls() { + info.ClearMacroCall(id) + callID := ids.renumberID(id) + if e, found := allExprMap[callID]; found && e.Kind() == ast.LiteralKind { + continue + } + info.SetMacroCall(callID, call) + } + + // Second, update the macro call id references to ensure that macro pointers are' + // updated consistently across macros. + for id, call := range info.MacroCalls() { + call.RenumberIDs(ids.renumberID) + resetMacroCall(optimized, call, allExprMap) + info.SetMacroCall(id, call) + } +} + +func resetMacroCall(optimized *ast.AST, call ast.Expr, allExprMap map[int64]ast.Expr) { + modified := []ast.Expr{} + ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) { + if _, found := allExprMap[e.ID()]; found { + modified = append(modified, e) + } + })) + for _, m := range modified { + updated := allExprMap[m.ID()] + m.SetKindCase(updated) + } +} + +// newMonotonicIDGen increments numbers from an initial seed value. +func newMonotonicIDGen(seed int64) *monotonicIDGenerator { + return &monotonicIDGenerator{seed: seed} +} + +type monotonicIDGenerator struct { + seed int64 +} + +func (gen *monotonicIDGenerator) nextID() int64 { + gen.seed++ + return gen.seed +} + +func (gen *monotonicIDGenerator) renumberID(int64) int64 { + return gen.nextID() +} + +// newStableIDGen ensures that new ids are only created the first time they are encountered. +func newStableIDGen() *stableIDGenerator { + return &stableIDGenerator{ + idMap: make(map[int64]int64), + } +} + +type stableIDGenerator struct { + idMap map[int64]int64 + nextID int64 +} + +func (gen *stableIDGenerator) renumberID(id int64) int64 { + if id == 0 { + return 0 + } + if newID, found := gen.idMap[id]; found { + return newID + } + gen.nextID++ + gen.idMap[id] = gen.nextID + return gen.nextID +} + +// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate +// subexpressions and report any errors encountered along the way. The context also embeds the +// optimizerExprFactory which can be used to generate new sub-expressions with expression ids +// consistent with the expectations of a parsed expression. +type OptimizerContext struct { + *Env + *optimizerExprFactory + *Issues +} + +// ASTOptimizer applies an optimization over an AST and returns the optimized result. +type ASTOptimizer interface { + // Optimize optimizes a type-checked AST within an Environment and accumulates any issues. + Optimize(*OptimizerContext, *ast.AST) *ast.AST +} + +type optimizerExprFactory struct { + nextID func() int64 + renumberID ast.IDGenerator + fac ast.ExprFactory + sourceInfo *ast.SourceInfo +} + +// CopyExpr copies the structure of the input ast.Expr and renumbers the identifiers in a manner +// consistent with the CEL parser / checker. +func (opt *optimizerExprFactory) CopyExpr(e ast.Expr) ast.Expr { + copy := opt.fac.CopyExpr(e) + copy.RenumberIDs(opt.renumberID) + return copy +} + +// NewBindMacro creates a cel.bind() call with a variable name, initialization expression, and remaining expression. +// +// Note: the macroID indicates the insertion point, the call id that matched the macro signature, which will be used +// for coordinating macro metadata with the bind call. This piece of data is what makes it possible to unparse +// optimized expressions which use the bind() call. +// +// Example: +// +// cel.bind(myVar, a && b || c, !myVar || (myVar && d)) +// - varName: myVar +// - varInit: a && b || c +// - remaining: !myVar || (myVar && d) +func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) ast.Expr { + bindID := opt.nextID() + varID := opt.nextID() + + varInit = opt.CopyExpr(varInit) + varInit.RenumberIDs(opt.renumberID) + + remaining = opt.fac.CopyExpr(remaining) + remaining.RenumberIDs(opt.renumberID) + + // Place the expanded macro form in the macro calls list so that the inlined + // call can be unparsed. + opt.sourceInfo.SetMacroCall(macroID, + opt.fac.NewMemberCall(0, "bind", + opt.fac.NewIdent(opt.nextID(), "cel"), + opt.fac.NewIdent(varID, varName), + varInit, + remaining)) + + // Replace the parent node with the intercepted inlining using cel.bind()-like + // generated comprehension AST. + return opt.fac.NewComprehension(bindID, + opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}), + "#unused", + varName, + opt.fac.CopyExpr(varInit), + opt.fac.NewLiteral(opt.nextID(), types.False), + opt.fac.NewIdent(varID, varName), + opt.fac.CopyExpr(remaining)) +} + +// NewCall creates a global function call invocation expression. +// +// Example: +// +// countByField(list, fieldName) +// - function: countByField +// - args: [list, fieldName] +func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr { + return opt.fac.NewCall(opt.nextID(), function, args...) +} + +// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call. +// +// Example: +// +// list.countByField(fieldName) +// - function: countByField +// - target: list +// - args: [fieldName] +func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return opt.fac.NewMemberCall(opt.nextID(), function, target, args...) +} + +// NewIdent creates a new identifier expression. +// +// Examples: +// +// - simple_var_name +// - qualified.subpackage.var_name +func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr { + return opt.fac.NewIdent(opt.nextID(), name) +} + +// NewLiteral creates a new literal expression value. +// +// The range of valid values for a literal generated during optimization is different than for expressions +// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can +// be converted back to a literal-like form. +func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr { + return opt.fac.NewLiteral(opt.nextID(), value) +} + +// NewList creates a list expression with a set of optional indices. +// +// Examples: +// +// [a, b] +// - elems: [a, b] +// - optIndices: [] +// +// [a, ?b, ?c] +// - elems: [a, b, c] +// - optIndices: [1, 2] +func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr { + return opt.fac.NewList(opt.nextID(), elems, optIndices) +} + +// NewMap creates a map from a set of entry expressions which contain a key and value expression. +func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr { + return opt.fac.NewMap(opt.nextID(), entries) +} + +// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the +// entry is optional. +// +// Examples: +// +// {a: b} +// - key: a +// - value: b +// - optional: false +// +// {?a: ?b} +// - key: a +// - value: b +// - optional: true +func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional) +} + +// NewPresenceTest creates a new presence test macro call. +// +// Example: +// +// has(msg.field_name) +// - operand: msg +// - field: field_name +func (opt *optimizerExprFactory) NewPresenceTest(macroID int64, operand ast.Expr, field string) ast.Expr { + // Copy the input operand and renumber it. + operand = opt.CopyExpr(operand) + operand.RenumberIDs(opt.renumberID) + + // Place the expanded macro form in the macro calls list so that the inlined call can be unparsed. + opt.sourceInfo.SetMacroCall(macroID, + opt.fac.NewCall(0, "has", + opt.fac.NewSelect(opt.nextID(), operand, field))) + + // Generate a new presence test macro. + return opt.fac.NewPresenceTest(opt.nextID(), opt.CopyExpr(operand), field) +} + +// NewSelect creates a select expression where a field value is selected from an operand. +// +// Example: +// +// msg.field_name +// - operand: msg +// - field: field_name +func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr { + return opt.fac.NewSelect(opt.nextID(), operand, field) +} + +// NewStruct creates a new typed struct value with an set of field initializations. +// +// Example: +// +// pkg.TypeName{field: value} +// - typeName: pkg.TypeName +// - fields: [{field: value}] +func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr { + return opt.fac.NewStruct(opt.nextID(), typeName, fields) +} + +// NewStructField creates a struct field initialization. +// +// Examples: +// +// {count: 3u} +// - field: count +// - value: 3u +// - optional: false +// +// {?count: x} +// - field: count +// - value: x +// - optional: true +func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewStructField(opt.nextID(), field, value, isOptional) +} diff --git a/vendor/github.com/google/cel-go/cel/options.go b/vendor/github.com/google/cel-go/cel/options.go index d47f55d8eed..a6bff0dc45b 100644 --- a/vendor/github.com/google/cel-go/cel/options.go +++ b/vendor/github.com/google/cel-go/cel/options.go @@ -447,6 +447,8 @@ const ( OptTrackCost EvalOption = 1 << iota // OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality. + // + // Deprecated: use ext.ValidateFormatString() as this option is now a no-op. OptCheckStringFormat EvalOption = 1 << iota ) diff --git a/vendor/github.com/google/cel-go/cel/program.go b/vendor/github.com/google/cel-go/cel/program.go index 11c5c447e37..cec4839df2e 100644 --- a/vendor/github.com/google/cel-go/cel/program.go +++ b/vendor/github.com/google/cel-go/cel/program.go @@ -19,7 +19,6 @@ import ( "fmt" "sync" - celast "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/interpreter" @@ -148,7 +147,7 @@ func (p *prog) clone() *prog { // ProgramOption values. // // If the program cannot be configured the prog will be nil, with a non-nil error response. -func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { +func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) { // Build the dispatcher, interpreter, and default program value. disp := interpreter.NewDispatcher() @@ -208,34 +207,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { if len(p.regexOptimizations) > 0 { decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...)) } - // Enable compile-time checking of syntax/cardinality for string.format calls. - if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat { - var isValidType func(id int64, validTypes ...ref.Type) (bool, error) - if ast.IsChecked() { - isValidType = func(id int64, validTypes ...ref.Type) (bool, error) { - t := ast.typeMap[id] - if t.Kind() == DynKind { - return true, nil - } - for _, vt := range validTypes { - k, err := typeValueToKind(vt) - if err != nil { - return false, err - } - if t.Kind() == k { - return true, nil - } - } - return false, nil - } - } else { - // if the AST isn't type-checked, short-circuit validation - isValidType = func(id int64, validTypes ...ref.Type) (bool, error) { - return true, nil - } - } - decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType)) - } // Enable exhaustive eval, state tracking and cost tracking last since they require a factory. if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 { @@ -263,33 +234,16 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { decs = append(decs, interpreter.Observe(observers...)) } - return p.clone().initInterpretable(ast, decs) + return p.clone().initInterpretable(a, decs) } return newProgGen(factory) } - return p.initInterpretable(ast, decorators) + return p.initInterpretable(a, decorators) } -func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) { - // Unchecked programs do not contain type and reference information and may be slower to execute. - if !ast.IsChecked() { - interpretable, err := - p.interpreter.NewUncheckedInterpretable(ast.Expr(), decs...) - if err != nil { - return nil, err - } - p.interpretable = interpretable - return p, nil - } - - // When the AST has been checked it contains metadata that can be used to speed up program execution. - checked := &celast.CheckedAST{ - Expr: ast.Expr(), - SourceInfo: ast.SourceInfo(), - TypeMap: ast.typeMap, - ReferenceMap: ast.refMap, - } - interpretable, err := p.interpreter.NewInterpretable(checked, decs...) +func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) { + // When the AST has been exprAST it contains metadata that can be used to speed up program execution. + interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...) if err != nil { return nil, err } @@ -559,8 +513,6 @@ func (p *evalActivationPool) Put(value any) { } var ( - emptyEvalState = interpreter.NewEvalState() - // activationPool is an internally managed pool of Activation values that wrap map[string]any inputs activationPool = newEvalActivationPool() diff --git a/vendor/github.com/google/cel-go/cel/validator.go b/vendor/github.com/google/cel-go/cel/validator.go index 78b31138186..b50c674520a 100644 --- a/vendor/github.com/google/cel-go/cel/validator.go +++ b/vendor/github.com/google/cel-go/cel/validator.go @@ -21,8 +21,6 @@ import ( "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/overloads" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) const ( @@ -69,7 +67,7 @@ type ASTValidator interface { // // See individual validators for more information on their configuration keys and configuration // properties. - Validate(*Env, ValidatorConfig, *ast.CheckedAST, *Issues) + Validate(*Env, ValidatorConfig, *ast.AST, *Issues) } // ValidatorConfig provides an accessor method for querying validator configuration state. @@ -180,7 +178,7 @@ func ValidateComprehensionNestingLimit(limit int) ASTValidator { return nestingLimitValidator{limit: limit} } -type argChecker func(env *Env, call, arg ast.NavigableExpr) error +type argChecker func(env *Env, call, arg ast.Expr) error func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator { return formatValidator{ @@ -203,8 +201,8 @@ func (v formatValidator) Name() string { // Validate searches the AST for uses of a given function name with a constant argument and performs a check // on whether the argument is a valid literal value. -func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) { - root := ast.NavigateCheckedAST(a) +func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) { + root := ast.NavigateAST(a) funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName)) for _, call := range funcCalls { callArgs := call.AsCall().Args() @@ -221,8 +219,8 @@ func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, } } -func evalCall(env *Env, call, arg ast.NavigableExpr) error { - ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()}) +func evalCall(env *Env, call, arg ast.Expr) error { + ast := &Ast{impl: ast.NewAST(call, ast.NewSourceInfo(nil))} prg, err := env.Program(ast) if err != nil { return err @@ -231,7 +229,7 @@ func evalCall(env *Env, call, arg ast.NavigableExpr) error { return err } -func compileRegex(_ *Env, _, arg ast.NavigableExpr) error { +func compileRegex(_ *Env, _, arg ast.Expr) error { pattern := arg.AsLiteral().Value().(string) _, err := regexp.Compile(pattern) return err @@ -244,25 +242,14 @@ func (homogeneousAggregateLiteralValidator) Name() string { return homogeneousValidatorName } -// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard -// and exempt functions from homogeneous aggregate literal checks. -// -// TODO: Move this call into the string.format() ASTValidator once ported. -func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error { - emptyList := []string{} - exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string) - exemptFunctions = append(exemptFunctions, "format") - return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions) -} - // Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types. // // This validator makes an exception for list and map literals which occur at any level of nesting within // string format calls. -func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.CheckedAST, iss *Issues) { +func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) { var exemptedFunctions []string exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string) - root := ast.NavigateCheckedAST(a) + root := ast.NavigateAST(a) listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind)) for _, listExpr := range listExprs { if inExemptFunction(listExpr, exemptedFunctions) { @@ -273,7 +260,7 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig optIndices := l.OptionalIndices() var elemType *Type for i, e := range elements { - et := e.Type() + et := a.GetType(e.ID()) if isOptionalIndex(i, optIndices) { et = et.Parameters()[0] } @@ -296,9 +283,10 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig entries := m.Entries() var keyType, valType *Type for _, e := range entries { - key, val := e.Key(), e.Value() - kt, vt := key.Type(), val.Type() - if e.IsOptional() { + mapEntry := e.AsMapEntry() + key, val := mapEntry.Key(), mapEntry.Value() + kt, vt := a.GetType(key.ID()), a.GetType(val.ID()) + if mapEntry.IsOptional() { vt = vt.Parameters()[0] } if keyType == nil && valType == nil { @@ -316,7 +304,8 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig } func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool { - if parent, found := e.Parent(); found { + parent, found := e.Parent() + for found { if parent.Kind() == ast.CallKind { fnName := parent.AsCall().FunctionName() for _, exempt := range exemptFunctions { @@ -325,9 +314,7 @@ func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool { } } } - if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind { - return inExemptFunction(parent, exemptFunctions) - } + parent, found = parent.Parent() } return false } @@ -353,8 +340,8 @@ func (v nestingLimitValidator) Name() string { return "cel.lib.std.validate.comprehension_nesting_limit" } -func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) { - root := ast.NavigateCheckedAST(a) +func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) { + root := ast.NavigateAST(a) comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) if len(comprehensions) <= v.limit { return diff --git a/vendor/github.com/google/cel-go/checker/checker.go b/vendor/github.com/google/cel-go/checker/checker.go index 720e4fa968f..57fb3ce5ea1 100644 --- a/vendor/github.com/google/cel-go/checker/checker.go +++ b/vendor/github.com/google/cel-go/checker/checker.go @@ -18,6 +18,7 @@ package checker import ( "fmt" + "reflect" "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" @@ -25,139 +26,98 @@ import ( "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/types/ref" ) type checker struct { + *ast.AST + ast.ExprFactory env *Env errors *typeErrors mappings *mapping freeTypeVarCounter int - sourceInfo *exprpb.SourceInfo - types map[int64]*types.Type - references map[int64]*ast.ReferenceInfo } // Check performs type checking, giving a typed AST. -// The input is a ParsedExpr proto and an env which encapsulates -// type binding of variables, declarations of built-in functions, -// descriptions of protocol buffers, and a registry for errors. -// Returns a CheckedExpr proto, which might not be usable if -// there are errors in the error registry. -func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast.CheckedAST, *common.Errors) { +// +// The input is a parsed AST and an env which encapsulates type binding of variables, +// declarations of built-in functions, descriptions of protocol buffers, and a registry for +// errors. +// +// Returns a type-checked AST, which might not be usable if there are errors in the error +// registry. +func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.Errors) { errs := common.NewErrors(source) + typeMap := make(map[int64]*types.Type) + refMap := make(map[int64]*ast.ReferenceInfo) c := checker{ + AST: ast.NewCheckedAST(parsed, typeMap, refMap), + ExprFactory: ast.NewExprFactory(), env: env, errors: &typeErrors{errs: errs}, mappings: newMapping(), freeTypeVarCounter: 0, - sourceInfo: parsedExpr.GetSourceInfo(), - types: make(map[int64]*types.Type), - references: make(map[int64]*ast.ReferenceInfo), } - c.check(parsedExpr.GetExpr()) + c.check(c.Expr()) - // Walk over the final type map substituting any type parameters either by their bound value or - // by DYN. - m := make(map[int64]*types.Type) - for id, t := range c.types { - m[id] = substitute(c.mappings, t, true) + // Walk over the final type map substituting any type parameters either by their bound value + // or by DYN. + for id, t := range c.TypeMap() { + c.SetType(id, substitute(c.mappings, t, true)) } - - return &ast.CheckedAST{ - Expr: parsedExpr.GetExpr(), - SourceInfo: parsedExpr.GetSourceInfo(), - TypeMap: m, - ReferenceMap: c.references, - }, errs + return c.AST, errs } -func (c *checker) check(e *exprpb.Expr) { +func (c *checker) check(e ast.Expr) { if e == nil { return } - switch e.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - literal := e.GetConstExpr() - switch literal.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - c.checkBoolLiteral(e) - case *exprpb.Constant_BytesValue: - c.checkBytesLiteral(e) - case *exprpb.Constant_DoubleValue: - c.checkDoubleLiteral(e) - case *exprpb.Constant_Int64Value: - c.checkInt64Literal(e) - case *exprpb.Constant_NullValue: - c.checkNullLiteral(e) - case *exprpb.Constant_StringValue: - c.checkStringLiteral(e) - case *exprpb.Constant_Uint64Value: - c.checkUint64Literal(e) + switch e.Kind() { + case ast.LiteralKind: + literal := ref.Val(e.AsLiteral()) + switch literal.Type() { + case types.BoolType, types.BytesType, types.DoubleType, types.IntType, + types.NullType, types.StringType, types.UintType: + c.setType(e, literal.Type().(*types.Type)) + default: + c.errors.unexpectedASTType(e.ID(), c.location(e), "literal", literal.Type().TypeName()) } - case *exprpb.Expr_IdentExpr: + case ast.IdentKind: c.checkIdent(e) - case *exprpb.Expr_SelectExpr: + case ast.SelectKind: c.checkSelect(e) - case *exprpb.Expr_CallExpr: + case ast.CallKind: c.checkCall(e) - case *exprpb.Expr_ListExpr: + case ast.ListKind: c.checkCreateList(e) - case *exprpb.Expr_StructExpr: + case ast.MapKind: + c.checkCreateMap(e) + case ast.StructKind: c.checkCreateStruct(e) - case *exprpb.Expr_ComprehensionExpr: + case ast.ComprehensionKind: c.checkComprehension(e) default: - c.errors.unexpectedASTType(e.GetId(), c.location(e), e) + c.errors.unexpectedASTType(e.ID(), c.location(e), "unspecified", reflect.TypeOf(e).Name()) } } -func (c *checker) checkInt64Literal(e *exprpb.Expr) { - c.setType(e, types.IntType) -} - -func (c *checker) checkUint64Literal(e *exprpb.Expr) { - c.setType(e, types.UintType) -} - -func (c *checker) checkStringLiteral(e *exprpb.Expr) { - c.setType(e, types.StringType) -} - -func (c *checker) checkBytesLiteral(e *exprpb.Expr) { - c.setType(e, types.BytesType) -} - -func (c *checker) checkDoubleLiteral(e *exprpb.Expr) { - c.setType(e, types.DoubleType) -} - -func (c *checker) checkBoolLiteral(e *exprpb.Expr) { - c.setType(e, types.BoolType) -} - -func (c *checker) checkNullLiteral(e *exprpb.Expr) { - c.setType(e, types.NullType) -} - -func (c *checker) checkIdent(e *exprpb.Expr) { - identExpr := e.GetIdentExpr() +func (c *checker) checkIdent(e ast.Expr) { + identName := e.AsIdent() // Check to see if the identifier is declared. - if ident := c.env.LookupIdent(identExpr.GetName()); ident != nil { + if ident := c.env.LookupIdent(identName); ident != nil { c.setType(e, ident.Type()) c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value())) // Overwrite the identifier with its fully qualified name. - identExpr.Name = ident.Name() + e.SetKindCase(c.NewIdent(e.ID(), ident.Name())) return } c.setType(e, types.ErrorType) - c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), identExpr.GetName()) + c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), identName) } -func (c *checker) checkSelect(e *exprpb.Expr) { - sel := e.GetSelectExpr() +func (c *checker) checkSelect(e ast.Expr) { + sel := e.AsSelect() // Before traversing down the tree, try to interpret as qualified name. qname, found := containers.ToQualifiedName(e) if found { @@ -170,31 +130,26 @@ func (c *checker) checkSelect(e *exprpb.Expr) { // variable name. c.setType(e, ident.Type()) c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value())) - identName := ident.Name() - e.ExprKind = &exprpb.Expr_IdentExpr{ - IdentExpr: &exprpb.Expr_Ident{ - Name: identName, - }, - } + e.SetKindCase(c.NewIdent(e.ID(), ident.Name())) return } } - resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false) - if sel.TestOnly { + resultType := c.checkSelectField(e, sel.Operand(), sel.FieldName(), false) + if sel.IsTestOnly() { resultType = types.BoolType } c.setType(e, substitute(c.mappings, resultType, false)) } -func (c *checker) checkOptSelect(e *exprpb.Expr) { +func (c *checker) checkOptSelect(e ast.Expr) { // Collect metadata related to the opt select call packaged by the parser. - call := e.GetCallExpr() - operand := call.GetArgs()[0] - field := call.GetArgs()[1] + call := e.AsCall() + operand := call.Args()[0] + field := call.Args()[1] fieldName, isString := maybeUnwrapString(field) if !isString { - c.errors.notAnOptionalFieldSelection(field.GetId(), c.location(field), field) + c.errors.notAnOptionalFieldSelection(field.ID(), c.location(field), field) return } @@ -204,7 +159,7 @@ func (c *checker) checkOptSelect(e *exprpb.Expr) { c.setReference(e, ast.NewFunctionReference("select_optional_field")) } -func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *types.Type { +func (c *checker) checkSelectField(e, operand ast.Expr, field string, optional bool) *types.Type { // Interpret as field selection, first traversing down the operand. c.check(operand) operandType := substitute(c.mappings, c.getType(operand), false) @@ -222,7 +177,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option // Objects yield their field type declaration as the selection result type, but only if // the field is defined. messageType := targetType - if fieldType, found := c.lookupFieldType(e.GetId(), messageType.TypeName(), field); found { + if fieldType, found := c.lookupFieldType(e.ID(), messageType.TypeName(), field); found { resultType = fieldType } case types.TypeParamKind: @@ -236,7 +191,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option // Dynamic / error values are treated as DYN type. Errors are handled this way as well // in order to allow forward progress on the check. if !isDynOrError(targetType) { - c.errors.typeDoesNotSupportFieldSelection(e.GetId(), c.location(e), targetType) + c.errors.typeDoesNotSupportFieldSelection(e.ID(), c.location(e), targetType) } resultType = types.DynType } @@ -248,35 +203,34 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option return resultType } -func (c *checker) checkCall(e *exprpb.Expr) { +func (c *checker) checkCall(e ast.Expr) { // Note: similar logic exists within the `interpreter/planner.go`. If making changes here // please consider the impact on planner.go and consolidate implementations or mirror code // as appropriate. - call := e.GetCallExpr() - fnName := call.GetFunction() + call := e.AsCall() + fnName := call.FunctionName() if fnName == operators.OptSelect { c.checkOptSelect(e) return } - args := call.GetArgs() + args := call.Args() // Traverse arguments. for _, arg := range args { c.check(arg) } - target := call.GetTarget() // Regular static call with simple name. - if target == nil { + if !call.IsMemberFunction() { // Check for the existence of the function. fn := c.env.LookupFunction(fnName) if fn == nil { - c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName) + c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName) c.setType(e, types.ErrorType) return } // Overwrite the function name with its fully qualified resolved name. - call.Function = fn.Name() + e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...)) // Check to see whether the overload resolves. c.resolveOverloadOrError(e, fn, nil, args) return @@ -287,6 +241,7 @@ func (c *checker) checkCall(e *exprpb.Expr) { // target a.b. // // Check whether the target is a namespaced function name. + target := call.Target() qualifiedPrefix, maybeQualified := containers.ToQualifiedName(target) if maybeQualified { maybeQualifiedName := qualifiedPrefix + "." + fnName @@ -295,15 +250,14 @@ func (c *checker) checkCall(e *exprpb.Expr) { // The function name is namespaced and so preserving the target operand would // be an inaccurate representation of the desired evaluation behavior. // Overwrite with fully-qualified resolved function name sans receiver target. - call.Target = nil - call.Function = fn.Name() + e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...)) c.resolveOverloadOrError(e, fn, nil, args) return } } // Regular instance call. - c.check(call.Target) + c.check(target) fn := c.env.LookupFunction(fnName) // Function found, attempt overload resolution. if fn != nil { @@ -312,11 +266,11 @@ func (c *checker) checkCall(e *exprpb.Expr) { } // Function name not declared, record error. c.setType(e, types.ErrorType) - c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName) + c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName) } func (c *checker) resolveOverloadOrError( - e *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) { + e ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) { // Attempt to resolve the overload. resolution := c.resolveOverload(e, fn, target, args) // No such overload, error noted in the resolveOverload call, type recorded here. @@ -330,7 +284,7 @@ func (c *checker) resolveOverloadOrError( } func (c *checker) resolveOverload( - call *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution { + call ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) *overloadResolution { var argTypes []*types.Type if target != nil { @@ -362,8 +316,8 @@ func (c *checker) resolveOverload( for i, argType := range argTypes { if !c.isAssignable(argType, types.BoolType) { c.errors.typeMismatch( - args[i].GetId(), - c.locationByID(args[i].GetId()), + args[i].ID(), + c.locationByID(args[i].ID()), types.BoolType, argType) resultType = types.ErrorType @@ -408,29 +362,29 @@ func (c *checker) resolveOverload( for i, argType := range argTypes { argTypes[i] = substitute(c.mappings, argType, true) } - c.errors.noMatchingOverload(call.GetId(), c.location(call), fn.Name(), argTypes, target != nil) + c.errors.noMatchingOverload(call.ID(), c.location(call), fn.Name(), argTypes, target != nil) return nil } return newResolution(checkedRef, resultType) } -func (c *checker) checkCreateList(e *exprpb.Expr) { - create := e.GetListExpr() +func (c *checker) checkCreateList(e ast.Expr) { + create := e.AsList() var elemsType *types.Type - optionalIndices := create.GetOptionalIndices() + optionalIndices := create.OptionalIndices() optionals := make(map[int32]bool, len(optionalIndices)) for _, optInd := range optionalIndices { optionals[optInd] = true } - for i, e := range create.GetElements() { + for i, e := range create.Elements() { c.check(e) elemType := c.getType(e) if optionals[int32(i)] { var isOptional bool elemType, isOptional = maybeUnwrapOptional(elemType) if !isOptional && !isDyn(elemType) { - c.errors.typeMismatch(e.GetId(), c.location(e), types.NewOptionalType(elemType), elemType) + c.errors.typeMismatch(e.ID(), c.location(e), types.NewOptionalType(elemType), elemType) } } elemsType = c.joinTypes(e, elemsType, elemType) @@ -442,32 +396,24 @@ func (c *checker) checkCreateList(e *exprpb.Expr) { c.setType(e, types.NewListType(elemsType)) } -func (c *checker) checkCreateStruct(e *exprpb.Expr) { - str := e.GetStructExpr() - if str.GetMessageName() != "" { - c.checkCreateMessage(e) - } else { - c.checkCreateMap(e) - } -} - -func (c *checker) checkCreateMap(e *exprpb.Expr) { - mapVal := e.GetStructExpr() +func (c *checker) checkCreateMap(e ast.Expr) { + mapVal := e.AsMap() var mapKeyType *types.Type var mapValueType *types.Type - for _, ent := range mapVal.GetEntries() { - key := ent.GetMapKey() + for _, e := range mapVal.Entries() { + entry := e.AsMapEntry() + key := entry.Key() c.check(key) mapKeyType = c.joinTypes(key, mapKeyType, c.getType(key)) - val := ent.GetValue() + val := entry.Value() c.check(val) valType := c.getType(val) - if ent.GetOptionalEntry() { + if entry.IsOptional() { var isOptional bool valType, isOptional = maybeUnwrapOptional(valType) if !isOptional && !isDyn(valType) { - c.errors.typeMismatch(val.GetId(), c.location(val), types.NewOptionalType(valType), valType) + c.errors.typeMismatch(val.ID(), c.location(val), types.NewOptionalType(valType), valType) } } mapValueType = c.joinTypes(val, mapValueType, valType) @@ -480,25 +426,28 @@ func (c *checker) checkCreateMap(e *exprpb.Expr) { c.setType(e, types.NewMapType(mapKeyType, mapValueType)) } -func (c *checker) checkCreateMessage(e *exprpb.Expr) { - msgVal := e.GetStructExpr() +func (c *checker) checkCreateStruct(e ast.Expr) { + msgVal := e.AsStruct() // Determine the type of the message. resultType := types.ErrorType - ident := c.env.LookupIdent(msgVal.GetMessageName()) + ident := c.env.LookupIdent(msgVal.TypeName()) if ident == nil { c.errors.undeclaredReference( - e.GetId(), c.location(e), c.env.container.Name(), msgVal.GetMessageName()) + e.ID(), c.location(e), c.env.container.Name(), msgVal.TypeName()) c.setType(e, types.ErrorType) return } // Ensure the type name is fully qualified in the AST. typeName := ident.Name() - msgVal.MessageName = typeName - c.setReference(e, ast.NewIdentReference(ident.Name(), nil)) + if msgVal.TypeName() != typeName { + e.SetKindCase(c.NewStruct(e.ID(), typeName, msgVal.Fields())) + msgVal = e.AsStruct() + } + c.setReference(e, ast.NewIdentReference(typeName, nil)) identKind := ident.Type().Kind() if identKind != types.ErrorKind { if identKind != types.TypeKind { - c.errors.notAType(e.GetId(), c.location(e), ident.Type().DeclaredTypeName()) + c.errors.notAType(e.ID(), c.location(e), ident.Type().DeclaredTypeName()) } else { resultType = ident.Type().Parameters()[0] // Backwards compatibility test between well-known types and message types @@ -509,7 +458,7 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) { } else if resultType.Kind() == types.StructKind { typeName = resultType.DeclaredTypeName() } else { - c.errors.notAMessageType(e.GetId(), c.location(e), resultType.DeclaredTypeName()) + c.errors.notAMessageType(e.ID(), c.location(e), resultType.DeclaredTypeName()) resultType = types.ErrorType } } @@ -517,37 +466,38 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) { c.setType(e, resultType) // Check the field initializers. - for _, ent := range msgVal.GetEntries() { - field := ent.GetFieldKey() - value := ent.GetValue() + for _, f := range msgVal.Fields() { + field := f.AsStructField() + fieldName := field.Name() + value := field.Value() c.check(value) fieldType := types.ErrorType - ft, found := c.lookupFieldType(ent.GetId(), typeName, field) + ft, found := c.lookupFieldType(f.ID(), typeName, fieldName) if found { fieldType = ft } valType := c.getType(value) - if ent.GetOptionalEntry() { + if field.IsOptional() { var isOptional bool valType, isOptional = maybeUnwrapOptional(valType) if !isOptional && !isDyn(valType) { - c.errors.typeMismatch(value.GetId(), c.location(value), types.NewOptionalType(valType), valType) + c.errors.typeMismatch(value.ID(), c.location(value), types.NewOptionalType(valType), valType) } } if !c.isAssignable(fieldType, valType) { - c.errors.fieldTypeMismatch(ent.GetId(), c.locationByID(ent.GetId()), field, fieldType, valType) + c.errors.fieldTypeMismatch(f.ID(), c.locationByID(f.ID()), fieldName, fieldType, valType) } } } -func (c *checker) checkComprehension(e *exprpb.Expr) { - comp := e.GetComprehensionExpr() - c.check(comp.GetIterRange()) - c.check(comp.GetAccuInit()) - accuType := c.getType(comp.GetAccuInit()) - rangeType := substitute(c.mappings, c.getType(comp.GetIterRange()), false) +func (c *checker) checkComprehension(e ast.Expr) { + comp := e.AsComprehension() + c.check(comp.IterRange()) + c.check(comp.AccuInit()) + accuType := c.getType(comp.AccuInit()) + rangeType := substitute(c.mappings, c.getType(comp.IterRange()), false) var varType *types.Type switch rangeType.Kind() { @@ -564,32 +514,32 @@ func (c *checker) checkComprehension(e *exprpb.Expr) { // Set the range iteration variable to type DYN as well. varType = types.DynType default: - c.errors.notAComprehensionRange(comp.GetIterRange().GetId(), c.location(comp.GetIterRange()), rangeType) + c.errors.notAComprehensionRange(comp.IterRange().ID(), c.location(comp.IterRange()), rangeType) varType = types.ErrorType } // Create a scope for the comprehension since it has a local accumulation variable. // This scope will contain the accumulation variable used to compute the result. c.env = c.env.enterScope() - c.env.AddIdents(decls.NewVariable(comp.GetAccuVar(), accuType)) + c.env.AddIdents(decls.NewVariable(comp.AccuVar(), accuType)) // Create a block scope for the loop. c.env = c.env.enterScope() - c.env.AddIdents(decls.NewVariable(comp.GetIterVar(), varType)) + c.env.AddIdents(decls.NewVariable(comp.IterVar(), varType)) // Check the variable references in the condition and step. - c.check(comp.GetLoopCondition()) - c.assertType(comp.GetLoopCondition(), types.BoolType) - c.check(comp.GetLoopStep()) - c.assertType(comp.GetLoopStep(), accuType) + c.check(comp.LoopCondition()) + c.assertType(comp.LoopCondition(), types.BoolType) + c.check(comp.LoopStep()) + c.assertType(comp.LoopStep(), accuType) // Exit the loop's block scope before checking the result. c.env = c.env.exitScope() - c.check(comp.GetResult()) + c.check(comp.Result()) // Exit the comprehension scope. c.env = c.env.exitScope() - c.setType(e, substitute(c.mappings, c.getType(comp.GetResult()), false)) + c.setType(e, substitute(c.mappings, c.getType(comp.Result()), false)) } // Checks compatibility of joined types, and returns the most general common type. -func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *types.Type { +func (c *checker) joinTypes(e ast.Expr, previous, current *types.Type) *types.Type { if previous == nil { return current } @@ -599,7 +549,7 @@ func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *type if c.dynAggregateLiteralElementTypesEnabled() { return types.DynType } - c.errors.typeMismatch(e.GetId(), c.location(e), previous, current) + c.errors.typeMismatch(e.ID(), c.location(e), previous, current) return types.ErrorType } @@ -633,41 +583,41 @@ func (c *checker) isAssignableList(l1, l2 []*types.Type) bool { return false } -func maybeUnwrapString(e *exprpb.Expr) (string, bool) { - switch e.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - literal := e.GetConstExpr() - switch literal.GetConstantKind().(type) { - case *exprpb.Constant_StringValue: - return literal.GetStringValue(), true +func maybeUnwrapString(e ast.Expr) (string, bool) { + switch e.Kind() { + case ast.LiteralKind: + literal := e.AsLiteral() + switch v := literal.(type) { + case types.String: + return string(v), true } } return "", false } -func (c *checker) setType(e *exprpb.Expr, t *types.Type) { - if old, found := c.types[e.GetId()]; found && !old.IsExactType(t) { - c.errors.incompatibleType(e.GetId(), c.location(e), e, old, t) +func (c *checker) setType(e ast.Expr, t *types.Type) { + if old, found := c.TypeMap()[e.ID()]; found && !old.IsExactType(t) { + c.errors.incompatibleType(e.ID(), c.location(e), e, old, t) return } - c.types[e.GetId()] = t + c.SetType(e.ID(), t) } -func (c *checker) getType(e *exprpb.Expr) *types.Type { - return c.types[e.GetId()] +func (c *checker) getType(e ast.Expr) *types.Type { + return c.TypeMap()[e.ID()] } -func (c *checker) setReference(e *exprpb.Expr, r *ast.ReferenceInfo) { - if old, found := c.references[e.GetId()]; found && !old.Equals(r) { - c.errors.referenceRedefinition(e.GetId(), c.location(e), e, old, r) +func (c *checker) setReference(e ast.Expr, r *ast.ReferenceInfo) { + if old, found := c.ReferenceMap()[e.ID()]; found && !old.Equals(r) { + c.errors.referenceRedefinition(e.ID(), c.location(e), e, old, r) return } - c.references[e.GetId()] = r + c.SetReference(e.ID(), r) } -func (c *checker) assertType(e *exprpb.Expr, t *types.Type) { +func (c *checker) assertType(e ast.Expr, t *types.Type) { if !c.isAssignable(t, c.getType(e)) { - c.errors.typeMismatch(e.GetId(), c.location(e), t, c.getType(e)) + c.errors.typeMismatch(e.ID(), c.location(e), t, c.getType(e)) } } @@ -683,26 +633,12 @@ func newResolution(r *ast.ReferenceInfo, t *types.Type) *overloadResolution { } } -func (c *checker) location(e *exprpb.Expr) common.Location { - return c.locationByID(e.GetId()) +func (c *checker) location(e ast.Expr) common.Location { + return c.locationByID(e.ID()) } func (c *checker) locationByID(id int64) common.Location { - positions := c.sourceInfo.GetPositions() - var line = 1 - if offset, found := positions[id]; found { - col := int(offset) - for _, lineOffset := range c.sourceInfo.GetLineOffsets() { - if lineOffset < offset { - line++ - col = int(offset - lineOffset) - } else { - break - } - } - return common.NewLocation(line, col) - } - return common.NoLocation + return c.SourceInfo().GetStartLocation(id) } func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*types.Type, bool) { diff --git a/vendor/github.com/google/cel-go/checker/cost.go b/vendor/github.com/google/cel-go/checker/cost.go index d02f2628ab7..b6109d9182e 100644 --- a/vendor/github.com/google/cel-go/checker/cost.go +++ b/vendor/github.com/google/cel-go/checker/cost.go @@ -22,8 +22,6 @@ import ( "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/parser" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // WARNING: Any changes to cost calculations in this file require a corresponding change in interpreter/runtimecost.go @@ -58,7 +56,7 @@ type AstNode interface { // Type returns the deduced type of the AstNode. Type() *types.Type // Expr returns the expression of the AstNode. - Expr() *exprpb.Expr + Expr() ast.Expr // ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression. // For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings // and bytes, the size is derived from the size estimates of the operands. nil is returned if there is no @@ -69,7 +67,7 @@ type AstNode interface { type astNode struct { path []string t *types.Type - expr *exprpb.Expr + expr ast.Expr derivedSize *SizeEstimate } @@ -81,7 +79,7 @@ func (e astNode) Type() *types.Type { return e.t } -func (e astNode) Expr() *exprpb.Expr { +func (e astNode) Expr() ast.Expr { return e.expr } @@ -90,29 +88,27 @@ func (e astNode) ComputedSize() *SizeEstimate { return e.derivedSize } var v uint64 - switch ek := e.expr.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - switch ck := ek.ConstExpr.GetConstantKind().(type) { - case *exprpb.Constant_StringValue: + switch e.expr.Kind() { + case ast.LiteralKind: + switch ck := e.expr.AsLiteral().(type) { + case types.String: // converting to runes here is an O(n) operation, but // this is consistent with how size is computed at runtime, // and how the language definition defines string size - v = uint64(len([]rune(ck.StringValue))) - case *exprpb.Constant_BytesValue: - v = uint64(len(ck.BytesValue)) - case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue, - *exprpb.Constant_Int64Value, *exprpb.Constant_TimestampValue, *exprpb.Constant_Uint64Value, - *exprpb.Constant_NullValue: + v = uint64(len([]rune(ck))) + case types.Bytes: + v = uint64(len(ck)) + case types.Bool, types.Double, types.Duration, + types.Int, types.Timestamp, types.Uint, + types.Null: v = uint64(1) default: return nil } - case *exprpb.Expr_ListExpr: - v = uint64(len(ek.ListExpr.GetElements())) - case *exprpb.Expr_StructExpr: - if ek.StructExpr.GetMessageName() == "" { - v = uint64(len(ek.StructExpr.GetEntries())) - } + case ast.ListKind: + v = uint64(e.expr.AsList().Size()) + case ast.MapKind: + v = uint64(e.expr.AsMap().Size()) default: return nil } @@ -261,7 +257,7 @@ type coster struct { iterRanges iterRangeScopes // computedSizes tracks the computed sizes of call results. computedSizes map[int64]SizeEstimate - checkedAST *ast.CheckedAST + checkedAST *ast.AST estimator CostEstimator // presenceTestCost will either be a zero or one based on whether has() macros count against cost computations. presenceTestCost CostEstimate @@ -270,8 +266,8 @@ type coster struct { // Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names. type iterRangeScopes map[string][]int64 -func (vs iterRangeScopes) push(varName string, expr *exprpb.Expr) { - vs[varName] = append(vs[varName], expr.GetId()) +func (vs iterRangeScopes) push(varName string, expr ast.Expr) { + vs[varName] = append(vs[varName], expr.ID()) } func (vs iterRangeScopes) pop(varName string) { @@ -304,9 +300,9 @@ func PresenceTestHasCost(hasCost bool) CostOption { } // Cost estimates the cost of the parsed and type checked CEL expression. -func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) { +func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) { c := &coster{ - checkedAST: checker, + checkedAST: checked, estimator: estimator, exprPath: map[int64][]string{}, iterRanges: map[string][]int64{}, @@ -319,28 +315,30 @@ func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) return CostEstimate{}, err } } - return c.cost(checker.Expr), nil + return c.cost(checked.Expr()), nil } -func (c *coster) cost(e *exprpb.Expr) CostEstimate { +func (c *coster) cost(e ast.Expr) CostEstimate { if e == nil { return CostEstimate{} } var cost CostEstimate - switch e.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: + switch e.Kind() { + case ast.LiteralKind: cost = constCost - case *exprpb.Expr_IdentExpr: + case ast.IdentKind: cost = c.costIdent(e) - case *exprpb.Expr_SelectExpr: + case ast.SelectKind: cost = c.costSelect(e) - case *exprpb.Expr_CallExpr: + case ast.CallKind: cost = c.costCall(e) - case *exprpb.Expr_ListExpr: + case ast.ListKind: cost = c.costCreateList(e) - case *exprpb.Expr_StructExpr: + case ast.MapKind: + cost = c.costCreateMap(e) + case ast.StructKind: cost = c.costCreateStruct(e) - case *exprpb.Expr_ComprehensionExpr: + case ast.ComprehensionKind: cost = c.costComprehension(e) default: return CostEstimate{} @@ -348,53 +346,51 @@ func (c *coster) cost(e *exprpb.Expr) CostEstimate { return cost } -func (c *coster) costIdent(e *exprpb.Expr) CostEstimate { - identExpr := e.GetIdentExpr() - +func (c *coster) costIdent(e ast.Expr) CostEstimate { + identName := e.AsIdent() // build and track the field path - if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok { - switch c.checkedAST.TypeMap[iterRange].Kind() { + if iterRange, ok := c.iterRanges.peek(identName); ok { + switch c.checkedAST.GetType(iterRange).Kind() { case types.ListKind: c.addPath(e, append(c.exprPath[iterRange], "@items")) case types.MapKind: c.addPath(e, append(c.exprPath[iterRange], "@keys")) } } else { - c.addPath(e, []string{identExpr.GetName()}) + c.addPath(e, []string{identName}) } return selectAndIdentCost } -func (c *coster) costSelect(e *exprpb.Expr) CostEstimate { - sel := e.GetSelectExpr() +func (c *coster) costSelect(e ast.Expr) CostEstimate { + sel := e.AsSelect() var sum CostEstimate - if sel.GetTestOnly() { + if sel.IsTestOnly() { // recurse, but do not add any cost // this is equivalent to how evalTestOnly increments the runtime cost counter // but does not add any additional cost for the qualifier, except here we do // the reverse (ident adds cost) sum = sum.Add(c.presenceTestCost) - sum = sum.Add(c.cost(sel.GetOperand())) + sum = sum.Add(c.cost(sel.Operand())) return sum } - sum = sum.Add(c.cost(sel.GetOperand())) - targetType := c.getType(sel.GetOperand()) + sum = sum.Add(c.cost(sel.Operand())) + targetType := c.getType(sel.Operand()) switch targetType.Kind() { case types.MapKind, types.StructKind, types.TypeParamKind: sum = sum.Add(selectAndIdentCost) } // build and track the field path - c.addPath(e, append(c.getPath(sel.GetOperand()), sel.GetField())) + c.addPath(e, append(c.getPath(sel.Operand()), sel.FieldName())) return sum } -func (c *coster) costCall(e *exprpb.Expr) CostEstimate { - call := e.GetCallExpr() - target := call.GetTarget() - args := call.GetArgs() +func (c *coster) costCall(e ast.Expr) CostEstimate { + call := e.AsCall() + args := call.Args() var sum CostEstimate @@ -405,22 +401,20 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate { argTypes[i] = c.newAstNode(arg) } - ref := c.checkedAST.ReferenceMap[e.GetId()] - if ref == nil || len(ref.OverloadIDs) == 0 { + overloadIDs := c.checkedAST.GetOverloadIDs(e.ID()) + if len(overloadIDs) == 0 { return CostEstimate{} } var targetType AstNode - if target != nil { - if call.Target != nil { - sum = sum.Add(c.cost(call.GetTarget())) - targetType = c.newAstNode(call.GetTarget()) - } + if call.IsMemberFunction() { + sum = sum.Add(c.cost(call.Target())) + targetType = c.newAstNode(call.Target()) } // Pick a cost estimate range that covers all the overload cost estimation ranges fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0} var resultSize *SizeEstimate - for _, overload := range ref.OverloadIDs { - overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts) + for _, overload := range overloadIDs { + overloadCost := c.functionCost(call.FunctionName(), overload, &targetType, argTypes, argCosts) fnCost = fnCost.Union(overloadCost.CostEstimate) if overloadCost.ResultSize != nil { if resultSize == nil { @@ -443,62 +437,54 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate { } } if resultSize != nil { - c.computedSizes[e.GetId()] = *resultSize + c.computedSizes[e.ID()] = *resultSize } return sum.Add(fnCost) } -func (c *coster) costCreateList(e *exprpb.Expr) CostEstimate { - create := e.GetListExpr() +func (c *coster) costCreateList(e ast.Expr) CostEstimate { + create := e.AsList() var sum CostEstimate - for _, e := range create.GetElements() { + for _, e := range create.Elements() { sum = sum.Add(c.cost(e)) } return sum.Add(createListBaseCost) } -func (c *coster) costCreateStruct(e *exprpb.Expr) CostEstimate { - str := e.GetStructExpr() - if str.MessageName != "" { - return c.costCreateMessage(e) - } - return c.costCreateMap(e) -} - -func (c *coster) costCreateMap(e *exprpb.Expr) CostEstimate { - mapVal := e.GetStructExpr() +func (c *coster) costCreateMap(e ast.Expr) CostEstimate { + mapVal := e.AsMap() var sum CostEstimate - for _, ent := range mapVal.GetEntries() { - key := ent.GetMapKey() - sum = sum.Add(c.cost(key)) - - sum = sum.Add(c.cost(ent.GetValue())) + for _, ent := range mapVal.Entries() { + entry := ent.AsMapEntry() + sum = sum.Add(c.cost(entry.Key())) + sum = sum.Add(c.cost(entry.Value())) } return sum.Add(createMapBaseCost) } -func (c *coster) costCreateMessage(e *exprpb.Expr) CostEstimate { - msgVal := e.GetStructExpr() +func (c *coster) costCreateStruct(e ast.Expr) CostEstimate { + msgVal := e.AsStruct() var sum CostEstimate - for _, ent := range msgVal.GetEntries() { - sum = sum.Add(c.cost(ent.GetValue())) + for _, ent := range msgVal.Fields() { + field := ent.AsStructField() + sum = sum.Add(c.cost(field.Value())) } return sum.Add(createMessageBaseCost) } -func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate { - comp := e.GetComprehensionExpr() +func (c *coster) costComprehension(e ast.Expr) CostEstimate { + comp := e.AsComprehension() var sum CostEstimate - sum = sum.Add(c.cost(comp.GetIterRange())) - sum = sum.Add(c.cost(comp.GetAccuInit())) + sum = sum.Add(c.cost(comp.IterRange())) + sum = sum.Add(c.cost(comp.AccuInit())) // Track the iterRange of each IterVar for field path construction - c.iterRanges.push(comp.GetIterVar(), comp.GetIterRange()) - loopCost := c.cost(comp.GetLoopCondition()) - stepCost := c.cost(comp.GetLoopStep()) - c.iterRanges.pop(comp.GetIterVar()) - sum = sum.Add(c.cost(comp.Result)) - rangeCnt := c.sizeEstimate(c.newAstNode(comp.GetIterRange())) + c.iterRanges.push(comp.IterVar(), comp.IterRange()) + loopCost := c.cost(comp.LoopCondition()) + stepCost := c.cost(comp.LoopStep()) + c.iterRanges.pop(comp.IterVar()) + sum = sum.Add(c.cost(comp.Result())) + rangeCnt := c.sizeEstimate(c.newAstNode(comp.IterRange())) rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost)) sum = sum.Add(rangeCost) @@ -643,26 +629,26 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())} } -func (c *coster) getType(e *exprpb.Expr) *types.Type { - return c.checkedAST.TypeMap[e.GetId()] +func (c *coster) getType(e ast.Expr) *types.Type { + return c.checkedAST.GetType(e.ID()) } -func (c *coster) getPath(e *exprpb.Expr) []string { - return c.exprPath[e.GetId()] +func (c *coster) getPath(e ast.Expr) []string { + return c.exprPath[e.ID()] } -func (c *coster) addPath(e *exprpb.Expr, path []string) { - c.exprPath[e.GetId()] = path +func (c *coster) addPath(e ast.Expr, path []string) { + c.exprPath[e.ID()] = path } -func (c *coster) newAstNode(e *exprpb.Expr) *astNode { +func (c *coster) newAstNode(e ast.Expr) *astNode { path := c.getPath(e) if len(path) > 0 && path[0] == parser.AccumulatorName { // only provide paths to root vars; omit accumulator vars path = nil } var derivedSize *SizeEstimate - if size, ok := c.computedSizes[e.GetId()]; ok { + if size, ok := c.computedSizes[e.ID()]; ok { derivedSize = &size } return &astNode{ diff --git a/vendor/github.com/google/cel-go/checker/errors.go b/vendor/github.com/google/cel-go/checker/errors.go index c2b96498d1d..8b3bf0b8b64 100644 --- a/vendor/github.com/google/cel-go/checker/errors.go +++ b/vendor/github.com/google/cel-go/checker/errors.go @@ -15,13 +15,9 @@ package checker import ( - "reflect" - "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // typeErrors is a specialization of Errors. @@ -34,9 +30,9 @@ func (e *typeErrors) fieldTypeMismatch(id int64, l common.Location, name string, name, FormatCELType(field), FormatCELType(value)) } -func (e *typeErrors) incompatibleType(id int64, l common.Location, ex *exprpb.Expr, prev, next *types.Type) { +func (e *typeErrors) incompatibleType(id int64, l common.Location, ex ast.Expr, prev, next *types.Type) { e.errs.ReportErrorAtID(id, l, - "incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next) + "incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next) } func (e *typeErrors) noMatchingOverload(id int64, l common.Location, name string, args []*types.Type, isInstance bool) { @@ -49,7 +45,7 @@ func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *type FormatCELType(t)) } -func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field *exprpb.Expr) { +func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field ast.Expr) { e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field) } @@ -61,9 +57,9 @@ func (e *typeErrors) notAMessageType(id int64, l common.Location, typeName strin e.errs.ReportErrorAtID(id, l, "'%s' is not a message type", typeName) } -func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex *exprpb.Expr, prev, next *ast.ReferenceInfo) { +func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex ast.Expr, prev, next *ast.ReferenceInfo) { e.errs.ReportErrorAtID(id, l, - "reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next) + "reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next) } func (e *typeErrors) typeDoesNotSupportFieldSelection(id int64, l common.Location, t *types.Type) { @@ -87,6 +83,6 @@ func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typ e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName) } -func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex *exprpb.Expr) { - e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex)) +func (e *typeErrors) unexpectedASTType(id int64, l common.Location, kind, typeName string) { + e.errs.ReportErrorAtID(id, l, "unexpected %s type: %v", kind, typeName) } diff --git a/vendor/github.com/google/cel-go/checker/printer.go b/vendor/github.com/google/cel-go/checker/printer.go index 15cba06ee97..7a3984f02ce 100644 --- a/vendor/github.com/google/cel-go/checker/printer.go +++ b/vendor/github.com/google/cel-go/checker/printer.go @@ -17,40 +17,40 @@ package checker import ( "sort" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/debug" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) type semanticAdorner struct { - checks *exprpb.CheckedExpr + checked *ast.AST } var _ debug.Adorner = &semanticAdorner{} func (a *semanticAdorner) GetMetadata(elem any) string { result := "" - e, isExpr := elem.(*exprpb.Expr) + e, isExpr := elem.(ast.Expr) if !isExpr { return result } - t := a.checks.TypeMap[e.GetId()] + t := a.checked.TypeMap()[e.ID()] if t != nil { result += "~" - result += FormatCheckedType(t) + result += FormatCELType(t) } - switch e.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr, - *exprpb.Expr_CallExpr, - *exprpb.Expr_StructExpr, - *exprpb.Expr_SelectExpr: - if ref, found := a.checks.ReferenceMap[e.GetId()]; found { - if len(ref.GetOverloadId()) == 0 { + switch e.Kind() { + case ast.IdentKind, + ast.CallKind, + ast.ListKind, + ast.StructKind, + ast.SelectKind: + if ref, found := a.checked.ReferenceMap()[e.ID()]; found { + if len(ref.OverloadIDs) == 0 { result += "^" + ref.Name } else { - sort.Strings(ref.GetOverloadId()) - for i, overload := range ref.GetOverloadId() { + sort.Strings(ref.OverloadIDs) + for i, overload := range ref.OverloadIDs { if i == 0 { result += "^" } else { @@ -68,7 +68,7 @@ func (a *semanticAdorner) GetMetadata(elem any) string { // Print returns a string representation of the Expr message, // annotated with types from the CheckedExpr. The Expr must // be a sub-expression embedded in the CheckedExpr. -func Print(e *exprpb.Expr, checks *exprpb.CheckedExpr) string { - a := &semanticAdorner{checks: checks} +func Print(e ast.Expr, checked *ast.AST) string { + a := &semanticAdorner{checked: checked} return debug.ToAdornedDebugString(e, a) } diff --git a/vendor/github.com/google/cel-go/common/ast/BUILD.bazel b/vendor/github.com/google/cel-go/common/ast/BUILD.bazel index 7269cdff5f7..c92a0f17972 100644 --- a/vendor/github.com/google/cel-go/common/ast/BUILD.bazel +++ b/vendor/github.com/google/cel-go/common/ast/BUILD.bazel @@ -5,7 +5,9 @@ package( "//cel:__subpackages__", "//checker:__subpackages__", "//common:__subpackages__", + "//ext:__subpackages__", "//interpreter:__subpackages__", + "//parser:__subpackages__", ], licenses = ["notice"], # Apache 2.0 ) @@ -14,10 +16,14 @@ go_library( name = "go_default_library", srcs = [ "ast.go", + "conversion.go", "expr.go", + "factory.go", + "navigable.go", ], importpath = "github.com/google/cel-go/common/ast", deps = [ + "//common:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", @@ -29,7 +35,9 @@ go_test( name = "go_default_test", srcs = [ "ast_test.go", + "conversion_test.go", "expr_test.go", + "navigable_test.go", ], embed = [ ":go_default_library", @@ -48,5 +56,6 @@ go_test( "//test/proto3pb:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", ], ) \ No newline at end of file diff --git a/vendor/github.com/google/cel-go/common/ast/ast.go b/vendor/github.com/google/cel-go/common/ast/ast.go index b3c150793a9..c3620eb956f 100644 --- a/vendor/github.com/google/cel-go/common/ast/ast.go +++ b/vendor/github.com/google/cel-go/common/ast/ast.go @@ -16,74 +16,336 @@ package ast import ( - "fmt" - + "github.com/google/cel-go/common" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" +) - structpb "google.golang.org/protobuf/types/known/structpb" +// AST contains a protobuf expression and source info along with CEL-native type and reference information. +type AST struct { + expr Expr + sourceInfo *SourceInfo + typeMap map[int64]*types.Type + refMap map[int64]*ReferenceInfo +} - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" -) +// Expr returns the root ast.Expr value in the AST. +func (a *AST) Expr() Expr { + if a == nil { + return nilExpr + } + return a.expr +} -// CheckedAST contains a protobuf expression and source info along with CEL-native type and reference information. -type CheckedAST struct { - Expr *exprpb.Expr - SourceInfo *exprpb.SourceInfo - TypeMap map[int64]*types.Type - ReferenceMap map[int64]*ReferenceInfo -} - -// CheckedASTToCheckedExpr converts a CheckedAST to a CheckedExpr protobouf. -func CheckedASTToCheckedExpr(ast *CheckedAST) (*exprpb.CheckedExpr, error) { - refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap)) - for id, ref := range ast.ReferenceMap { - r, err := ReferenceInfoToReferenceExpr(ref) - if err != nil { - return nil, err - } - refMap[id] = r +// SourceInfo returns the source metadata associated with the parse / type-check passes. +func (a *AST) SourceInfo() *SourceInfo { + if a == nil { + return nil } - typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap)) - for id, typ := range ast.TypeMap { - t, err := types.TypeToExprType(typ) - if err != nil { - return nil, err - } - typeMap[id] = t - } - return &exprpb.CheckedExpr{ - Expr: ast.Expr, - SourceInfo: ast.SourceInfo, - ReferenceMap: refMap, - TypeMap: typeMap, - }, nil -} - -// CheckedExprToCheckedAST converts a CheckedExpr protobuf to a CheckedAST instance. -func CheckedExprToCheckedAST(checked *exprpb.CheckedExpr) (*CheckedAST, error) { - refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap())) - for id, ref := range checked.GetReferenceMap() { - r, err := ReferenceExprToReferenceInfo(ref) - if err != nil { - return nil, err + return a.sourceInfo +} + +// GetType returns the type for the expression at the given id, if one exists, else types.DynType. +func (a *AST) GetType(id int64) *types.Type { + if t, found := a.TypeMap()[id]; found { + return t + } + return types.DynType +} + +// SetType sets the type of the expression node at the given id. +func (a *AST) SetType(id int64, t *types.Type) { + if a == nil { + return + } + a.typeMap[id] = t +} + +// TypeMap returns the map of expression ids to type-checked types. +// +// If the AST is not type-checked, the map will be empty. +func (a *AST) TypeMap() map[int64]*types.Type { + if a == nil { + return map[int64]*types.Type{} + } + return a.typeMap +} + +// GetOverloadIDs returns the set of overload function names for a given expression id. +// +// If the expression id is not a function call, or the AST is not type-checked, the result will be empty. +func (a *AST) GetOverloadIDs(id int64) []string { + if ref, found := a.ReferenceMap()[id]; found { + return ref.OverloadIDs + } + return []string{} +} + +// ReferenceMap returns the map of expression id to identifier, constant, and function references. +func (a *AST) ReferenceMap() map[int64]*ReferenceInfo { + if a == nil { + return map[int64]*ReferenceInfo{} + } + return a.refMap +} + +// SetReference adds a reference to the checked AST type map. +func (a *AST) SetReference(id int64, r *ReferenceInfo) { + if a == nil { + return + } + a.refMap[id] = r +} + +// IsChecked returns whether the AST is type-checked. +func (a *AST) IsChecked() bool { + return a != nil && len(a.TypeMap()) > 0 +} + +// NewAST creates a base AST instance with an ast.Expr and ast.SourceInfo value. +func NewAST(e Expr, sourceInfo *SourceInfo) *AST { + if e == nil { + e = nilExpr + } + return &AST{ + expr: e, + sourceInfo: sourceInfo, + typeMap: make(map[int64]*types.Type), + refMap: make(map[int64]*ReferenceInfo), + } +} + +// NewCheckedAST wraps an parsed AST and augments it with type and reference metadata. +func NewCheckedAST(parsed *AST, typeMap map[int64]*types.Type, refMap map[int64]*ReferenceInfo) *AST { + return &AST{ + expr: parsed.Expr(), + sourceInfo: parsed.SourceInfo(), + typeMap: typeMap, + refMap: refMap, + } +} + +// Copy creates a deep copy of the Expr and SourceInfo values in the input AST. +// +// Copies of the Expr value are generated using an internal default ExprFactory. +func Copy(a *AST) *AST { + if a == nil { + return nil + } + e := defaultFactory.CopyExpr(a.expr) + if !a.IsChecked() { + return NewAST(e, CopySourceInfo(a.SourceInfo())) + } + typesCopy := make(map[int64]*types.Type, len(a.typeMap)) + for id, t := range a.typeMap { + typesCopy[id] = t + } + refsCopy := make(map[int64]*ReferenceInfo, len(a.refMap)) + for id, r := range a.refMap { + refsCopy[id] = r + } + return NewCheckedAST(NewAST(e, CopySourceInfo(a.SourceInfo())), typesCopy, refsCopy) +} + +// MaxID returns the upper-bound, non-inclusive, of ids present within the AST's Expr value. +func MaxID(a *AST) int64 { + visitor := &maxIDVisitor{maxID: 1} + PostOrderVisit(a.Expr(), visitor) + return visitor.maxID + 1 +} + +// NewSourceInfo creates a simple SourceInfo object from an input common.Source value. +func NewSourceInfo(src common.Source) *SourceInfo { + var lineOffsets []int32 + var desc string + if src != nil { + desc = src.Description() + lineOffsets = src.LineOffsets() + } + return &SourceInfo{ + desc: desc, + lines: lineOffsets, + offsetRanges: make(map[int64]OffsetRange), + macroCalls: make(map[int64]Expr), + } +} + +// CopySourceInfo creates a deep copy of the MacroCalls within the input SourceInfo. +// +// Copies of macro Expr values are generated using an internal default ExprFactory. +func CopySourceInfo(info *SourceInfo) *SourceInfo { + if info == nil { + return nil + } + rangesCopy := make(map[int64]OffsetRange, len(info.offsetRanges)) + for id, off := range info.offsetRanges { + rangesCopy[id] = off + } + callsCopy := make(map[int64]Expr, len(info.macroCalls)) + for id, call := range info.macroCalls { + callsCopy[id] = defaultFactory.CopyExpr(call) + } + return &SourceInfo{ + syntax: info.syntax, + desc: info.desc, + lines: info.lines, + offsetRanges: rangesCopy, + macroCalls: callsCopy, + } +} + +// SourceInfo records basic information about the expression as a textual input and +// as a parsed expression value. +type SourceInfo struct { + syntax string + desc string + lines []int32 + offsetRanges map[int64]OffsetRange + macroCalls map[int64]Expr +} + +// SyntaxVersion returns the syntax version associated with the text expression. +func (s *SourceInfo) SyntaxVersion() string { + if s == nil { + return "" + } + return s.syntax +} + +// Description provides information about where the expression came from. +func (s *SourceInfo) Description() string { + if s == nil { + return "" + } + return s.desc +} + +// LineOffsets returns a list of the 0-based character offsets in the input text where newlines appear. +func (s *SourceInfo) LineOffsets() []int32 { + if s == nil { + return []int32{} + } + return s.lines +} + +// MacroCalls returns a map of expression id to ast.Expr value where the id represents the expression +// node where the macro was inserted into the AST, and the ast.Expr value represents the original call +// signature which was replaced. +func (s *SourceInfo) MacroCalls() map[int64]Expr { + if s == nil { + return map[int64]Expr{} + } + return s.macroCalls +} + +// GetMacroCall returns the original ast.Expr value for the given expression if it was generated via +// a macro replacement. +// +// Note, parsing options must be enabled to track macro calls before this method will return a value. +func (s *SourceInfo) GetMacroCall(id int64) (Expr, bool) { + e, found := s.MacroCalls()[id] + return e, found +} + +// SetMacroCall records a macro call at a specific location. +func (s *SourceInfo) SetMacroCall(id int64, e Expr) { + if s != nil { + s.macroCalls[id] = e + } +} + +// ClearMacroCall removes the macro call at the given expression id. +func (s *SourceInfo) ClearMacroCall(id int64) { + if s != nil { + delete(s.macroCalls, id) + } +} + +// OffsetRanges returns a map of expression id to OffsetRange values where the range indicates either: +// the start and end position in the input stream where the expression occurs, or the start position +// only. If the range only captures start position, the stop position of the range will be equal to +// the start. +func (s *SourceInfo) OffsetRanges() map[int64]OffsetRange { + if s == nil { + return map[int64]OffsetRange{} + } + return s.offsetRanges +} + +// GetOffsetRange retrieves an OffsetRange for the given expression id if one exists. +func (s *SourceInfo) GetOffsetRange(id int64) (OffsetRange, bool) { + if s == nil { + return OffsetRange{}, false + } + o, found := s.offsetRanges[id] + return o, found +} + +// SetOffsetRange sets the OffsetRange for the given expression id. +func (s *SourceInfo) SetOffsetRange(id int64, o OffsetRange) { + if s == nil { + return + } + s.offsetRanges[id] = o +} + +// GetStartLocation calculates the human-readable 1-based line and 0-based column of the first character +// of the expression node at the id. +func (s *SourceInfo) GetStartLocation(id int64) common.Location { + if o, found := s.GetOffsetRange(id); found { + line := 1 + col := int(o.Start) + for _, lineOffset := range s.LineOffsets() { + if lineOffset < o.Start { + line++ + col = int(o.Start - lineOffset) + } else { + break + } } - refMap[id] = r + return common.NewLocation(line, col) } - typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap())) - for id, typ := range checked.GetTypeMap() { - t, err := types.ExprTypeToType(typ) - if err != nil { - return nil, err + return common.NoLocation +} + +// GetStopLocation calculates the human-readable 1-based line and 0-based column of the last character for +// the expression node at the given id. +// +// If the SourceInfo was generated from a serialized protobuf representation, the stop location will +// be identical to the start location for the expression. +func (s *SourceInfo) GetStopLocation(id int64) common.Location { + if o, found := s.GetOffsetRange(id); found { + line := 1 + col := int(o.Stop) + for _, lineOffset := range s.LineOffsets() { + if lineOffset < o.Stop { + line++ + col = int(o.Stop - lineOffset) + } else { + break + } } - typeMap[id] = t + return common.NewLocation(line, col) } - return &CheckedAST{ - Expr: checked.GetExpr(), - SourceInfo: checked.GetSourceInfo(), - ReferenceMap: refMap, - TypeMap: typeMap, - }, nil + return common.NoLocation +} + +// ComputeOffset calculates the 0-based character offset from a 1-based line and 0-based column. +func (s *SourceInfo) ComputeOffset(line, col int32) int32 { + if line == 1 { + return col + } + if line < 1 || line > int32(len(s.LineOffsets())) { + return -1 + } + offset := s.LineOffsets()[line-2] + return offset + col +} + +// OffsetRange captures the start and stop positions of a section of text in the input expression. +type OffsetRange struct { + Start int32 + Stop int32 } // ReferenceInfo contains a CEL native representation of an identifier reference which may refer to @@ -149,78 +411,21 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool { return true } -// ReferenceInfoToReferenceExpr converts a ReferenceInfo instance to a protobuf Reference suitable for serialization. -func ReferenceInfoToReferenceExpr(info *ReferenceInfo) (*exprpb.Reference, error) { - c, err := ValToConstant(info.Value) - if err != nil { - return nil, err - } - return &exprpb.Reference{ - Name: info.Name, - OverloadId: info.OverloadIDs, - Value: c, - }, nil +type maxIDVisitor struct { + maxID int64 + *baseVisitor } -// ReferenceExprToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance. -func ReferenceExprToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) { - v, err := ConstantToVal(ref.GetValue()) - if err != nil { - return nil, err +// VisitExpr updates the max identifier if the incoming expression id is greater than previously observed. +func (v *maxIDVisitor) VisitExpr(e Expr) { + if v.maxID < e.ID() { + v.maxID = e.ID() } - return &ReferenceInfo{ - Name: ref.GetName(), - OverloadIDs: ref.GetOverloadId(), - Value: v, - }, nil } -// ValToConstant converts a CEL-native ref.Val to a protobuf Constant. -// -// Only simple scalar types are supported by this method. -func ValToConstant(v ref.Val) (*exprpb.Constant, error) { - if v == nil { - return nil, nil - } - switch v.Type() { - case types.BoolType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil - case types.BytesType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil - case types.DoubleType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil - case types.IntType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil - case types.NullType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil - case types.StringType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil - case types.UintType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil - } - return nil, fmt.Errorf("unsupported constant kind: %v", v.Type()) -} - -// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val. -func ConstantToVal(c *exprpb.Constant) (ref.Val, error) { - if c == nil { - return nil, nil - } - switch c.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - return types.Bool(c.GetBoolValue()), nil - case *exprpb.Constant_BytesValue: - return types.Bytes(c.GetBytesValue()), nil - case *exprpb.Constant_DoubleValue: - return types.Double(c.GetDoubleValue()), nil - case *exprpb.Constant_Int64Value: - return types.Int(c.GetInt64Value()), nil - case *exprpb.Constant_NullValue: - return types.NullValue, nil - case *exprpb.Constant_StringValue: - return types.String(c.GetStringValue()), nil - case *exprpb.Constant_Uint64Value: - return types.Uint(c.GetUint64Value()), nil - } - return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind()) +// VisitEntryExpr updates the max identifier if the incoming entry id is greater than previously observed. +func (v *maxIDVisitor) VisitEntryExpr(e EntryExpr) { + if v.maxID < e.ID() { + v.maxID = e.ID() + } } diff --git a/vendor/github.com/google/cel-go/common/ast/conversion.go b/vendor/github.com/google/cel-go/common/ast/conversion.go new file mode 100644 index 00000000000..8f2c4bd1e64 --- /dev/null +++ b/vendor/github.com/google/cel-go/common/ast/conversion.go @@ -0,0 +1,632 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import ( + "fmt" + + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + + structpb "google.golang.org/protobuf/types/known/structpb" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// ToProto converts an AST to a CheckedExpr protobouf. +func ToProto(ast *AST) (*exprpb.CheckedExpr, error) { + refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap())) + for id, ref := range ast.ReferenceMap() { + r, err := ReferenceInfoToProto(ref) + if err != nil { + return nil, err + } + refMap[id] = r + } + typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap())) + for id, typ := range ast.TypeMap() { + t, err := types.TypeToExprType(typ) + if err != nil { + return nil, err + } + typeMap[id] = t + } + e, err := ExprToProto(ast.Expr()) + if err != nil { + return nil, err + } + info, err := SourceInfoToProto(ast.SourceInfo()) + if err != nil { + return nil, err + } + return &exprpb.CheckedExpr{ + Expr: e, + SourceInfo: info, + ReferenceMap: refMap, + TypeMap: typeMap, + }, nil +} + +// ToAST converts a CheckedExpr protobuf to an AST instance. +func ToAST(checked *exprpb.CheckedExpr) (*AST, error) { + refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap())) + for id, ref := range checked.GetReferenceMap() { + r, err := ProtoToReferenceInfo(ref) + if err != nil { + return nil, err + } + refMap[id] = r + } + typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap())) + for id, typ := range checked.GetTypeMap() { + t, err := types.ExprTypeToType(typ) + if err != nil { + return nil, err + } + typeMap[id] = t + } + info, err := ProtoToSourceInfo(checked.GetSourceInfo()) + if err != nil { + return nil, err + } + root, err := ProtoToExpr(checked.GetExpr()) + if err != nil { + return nil, err + } + ast := NewCheckedAST(NewAST(root, info), typeMap, refMap) + return ast, nil +} + +// ProtoToExpr converts a protobuf Expr value to an ast.Expr value. +func ProtoToExpr(e *exprpb.Expr) (Expr, error) { + factory := NewExprFactory() + return exprInternal(factory, e) +} + +// ProtoToEntryExpr converts a protobuf struct/map entry to an ast.EntryExpr +func ProtoToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) { + factory := NewExprFactory() + switch e.GetKeyKind().(type) { + case *exprpb.Expr_CreateStruct_Entry_FieldKey: + return exprStructField(factory, e.GetId(), e) + case *exprpb.Expr_CreateStruct_Entry_MapKey: + return exprMapEntry(factory, e.GetId(), e) + } + return nil, fmt.Errorf("unsupported expr entry kind: %v", e) +} + +func exprInternal(factory ExprFactory, e *exprpb.Expr) (Expr, error) { + id := e.GetId() + switch e.GetExprKind().(type) { + case *exprpb.Expr_CallExpr: + return exprCall(factory, id, e.GetCallExpr()) + case *exprpb.Expr_ComprehensionExpr: + return exprComprehension(factory, id, e.GetComprehensionExpr()) + case *exprpb.Expr_ConstExpr: + return exprLiteral(factory, id, e.GetConstExpr()) + case *exprpb.Expr_IdentExpr: + return exprIdent(factory, id, e.GetIdentExpr()) + case *exprpb.Expr_ListExpr: + return exprList(factory, id, e.GetListExpr()) + case *exprpb.Expr_SelectExpr: + return exprSelect(factory, id, e.GetSelectExpr()) + case *exprpb.Expr_StructExpr: + s := e.GetStructExpr() + if s.GetMessageName() != "" { + return exprStruct(factory, id, s) + } + return exprMap(factory, id, s) + } + return factory.NewUnspecifiedExpr(id), nil +} + +func exprCall(factory ExprFactory, id int64, call *exprpb.Expr_Call) (Expr, error) { + var err error + args := make([]Expr, len(call.GetArgs())) + for i, a := range call.GetArgs() { + args[i], err = exprInternal(factory, a) + if err != nil { + return nil, err + } + } + if call.GetTarget() == nil { + return factory.NewCall(id, call.GetFunction(), args...), nil + } + + target, err := exprInternal(factory, call.GetTarget()) + if err != nil { + return nil, err + } + return factory.NewMemberCall(id, call.GetFunction(), target, args...), nil +} + +func exprComprehension(factory ExprFactory, id int64, comp *exprpb.Expr_Comprehension) (Expr, error) { + iterRange, err := exprInternal(factory, comp.GetIterRange()) + if err != nil { + return nil, err + } + accuInit, err := exprInternal(factory, comp.GetAccuInit()) + if err != nil { + return nil, err + } + loopCond, err := exprInternal(factory, comp.GetLoopCondition()) + if err != nil { + return nil, err + } + loopStep, err := exprInternal(factory, comp.GetLoopStep()) + if err != nil { + return nil, err + } + result, err := exprInternal(factory, comp.GetResult()) + if err != nil { + return nil, err + } + return factory.NewComprehension(id, + iterRange, + comp.GetIterVar(), + comp.GetAccuVar(), + accuInit, + loopCond, + loopStep, + result), nil +} + +func exprLiteral(factory ExprFactory, id int64, c *exprpb.Constant) (Expr, error) { + val, err := ConstantToVal(c) + if err != nil { + return nil, err + } + return factory.NewLiteral(id, val), nil +} + +func exprIdent(factory ExprFactory, id int64, i *exprpb.Expr_Ident) (Expr, error) { + return factory.NewIdent(id, i.GetName()), nil +} + +func exprList(factory ExprFactory, id int64, l *exprpb.Expr_CreateList) (Expr, error) { + elems := make([]Expr, len(l.GetElements())) + for i, e := range l.GetElements() { + elem, err := exprInternal(factory, e) + if err != nil { + return nil, err + } + elems[i] = elem + } + return factory.NewList(id, elems, l.GetOptionalIndices()), nil +} + +func exprMap(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) { + entries := make([]EntryExpr, len(s.GetEntries())) + var err error + for i, entry := range s.GetEntries() { + entries[i], err = exprMapEntry(factory, entry.GetId(), entry) + if err != nil { + return nil, err + } + } + return factory.NewMap(id, entries), nil +} + +func exprMapEntry(factory ExprFactory, id int64, e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) { + k, err := exprInternal(factory, e.GetMapKey()) + if err != nil { + return nil, err + } + v, err := exprInternal(factory, e.GetValue()) + if err != nil { + return nil, err + } + return factory.NewMapEntry(id, k, v, e.GetOptionalEntry()), nil +} + +func exprSelect(factory ExprFactory, id int64, s *exprpb.Expr_Select) (Expr, error) { + op, err := exprInternal(factory, s.GetOperand()) + if err != nil { + return nil, err + } + if s.GetTestOnly() { + return factory.NewPresenceTest(id, op, s.GetField()), nil + } + return factory.NewSelect(id, op, s.GetField()), nil +} + +func exprStruct(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) { + fields := make([]EntryExpr, len(s.GetEntries())) + var err error + for i, field := range s.GetEntries() { + fields[i], err = exprStructField(factory, field.GetId(), field) + if err != nil { + return nil, err + } + } + return factory.NewStruct(id, s.GetMessageName(), fields), nil +} + +func exprStructField(factory ExprFactory, id int64, f *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) { + v, err := exprInternal(factory, f.GetValue()) + if err != nil { + return nil, err + } + return factory.NewStructField(id, f.GetFieldKey(), v, f.GetOptionalEntry()), nil +} + +// ExprToProto serializes an ast.Expr value to a protobuf Expr representation. +func ExprToProto(e Expr) (*exprpb.Expr, error) { + if e == nil { + return &exprpb.Expr{}, nil + } + switch e.Kind() { + case CallKind: + return protoCall(e.ID(), e.AsCall()) + case ComprehensionKind: + return protoComprehension(e.ID(), e.AsComprehension()) + case IdentKind: + return protoIdent(e.ID(), e.AsIdent()) + case ListKind: + return protoList(e.ID(), e.AsList()) + case LiteralKind: + return protoLiteral(e.ID(), e.AsLiteral()) + case MapKind: + return protoMap(e.ID(), e.AsMap()) + case SelectKind: + return protoSelect(e.ID(), e.AsSelect()) + case StructKind: + return protoStruct(e.ID(), e.AsStruct()) + case UnspecifiedExprKind: + // Handle the case where a macro reference may be getting translated. + // A nested macro 'pointer' is a non-zero expression id with no kind set. + if e.ID() != 0 { + return &exprpb.Expr{Id: e.ID()}, nil + } + return &exprpb.Expr{}, nil + } + return nil, fmt.Errorf("unsupported expr kind: %v", e) +} + +// EntryExprToProto converts an ast.EntryExpr to a protobuf CreateStruct entry +func EntryExprToProto(e EntryExpr) (*exprpb.Expr_CreateStruct_Entry, error) { + switch e.Kind() { + case MapEntryKind: + return protoMapEntry(e.ID(), e.AsMapEntry()) + case StructFieldKind: + return protoStructField(e.ID(), e.AsStructField()) + case UnspecifiedEntryExprKind: + return &exprpb.Expr_CreateStruct_Entry{}, nil + } + return nil, fmt.Errorf("unsupported expr entry kind: %v", e) +} + +func protoCall(id int64, call CallExpr) (*exprpb.Expr, error) { + var err error + var target *exprpb.Expr + if call.IsMemberFunction() { + target, err = ExprToProto(call.Target()) + if err != nil { + return nil, err + } + } + callArgs := call.Args() + args := make([]*exprpb.Expr, len(callArgs)) + for i, a := range callArgs { + args[i], err = ExprToProto(a) + if err != nil { + return nil, err + } + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Function: call.FunctionName(), + Target: target, + Args: args, + }, + }, + }, nil +} + +func protoComprehension(id int64, comp ComprehensionExpr) (*exprpb.Expr, error) { + iterRange, err := ExprToProto(comp.IterRange()) + if err != nil { + return nil, err + } + accuInit, err := ExprToProto(comp.AccuInit()) + if err != nil { + return nil, err + } + loopCond, err := ExprToProto(comp.LoopCondition()) + if err != nil { + return nil, err + } + loopStep, err := ExprToProto(comp.LoopStep()) + if err != nil { + return nil, err + } + result, err := ExprToProto(comp.Result()) + if err != nil { + return nil, err + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_ComprehensionExpr{ + ComprehensionExpr: &exprpb.Expr_Comprehension{ + IterVar: comp.IterVar(), + IterRange: iterRange, + AccuVar: comp.AccuVar(), + AccuInit: accuInit, + LoopCondition: loopCond, + LoopStep: loopStep, + Result: result, + }, + }, + }, nil +} + +func protoIdent(id int64, name string) (*exprpb.Expr, error) { + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_IdentExpr{ + IdentExpr: &exprpb.Expr_Ident{ + Name: name, + }, + }, + }, nil +} + +func protoList(id int64, list ListExpr) (*exprpb.Expr, error) { + var err error + elems := make([]*exprpb.Expr, list.Size()) + for i, e := range list.Elements() { + elems[i], err = ExprToProto(e) + if err != nil { + return nil, err + } + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_ListExpr{ + ListExpr: &exprpb.Expr_CreateList{ + Elements: elems, + OptionalIndices: list.OptionalIndices(), + }, + }, + }, nil +} + +func protoLiteral(id int64, val ref.Val) (*exprpb.Expr, error) { + c, err := ValToConstant(val) + if err != nil { + return nil, err + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_ConstExpr{ + ConstExpr: c, + }, + }, nil +} + +func protoMap(id int64, m MapExpr) (*exprpb.Expr, error) { + entries := make([]*exprpb.Expr_CreateStruct_Entry, len(m.Entries())) + var err error + for i, e := range m.Entries() { + entries[i], err = EntryExprToProto(e) + if err != nil { + return nil, err + } + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_StructExpr{ + StructExpr: &exprpb.Expr_CreateStruct{ + Entries: entries, + }, + }, + }, nil +} + +func protoMapEntry(id int64, e MapEntry) (*exprpb.Expr_CreateStruct_Entry, error) { + k, err := ExprToProto(e.Key()) + if err != nil { + return nil, err + } + v, err := ExprToProto(e.Value()) + if err != nil { + return nil, err + } + return &exprpb.Expr_CreateStruct_Entry{ + Id: id, + KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{ + MapKey: k, + }, + Value: v, + OptionalEntry: e.IsOptional(), + }, nil +} + +func protoSelect(id int64, s SelectExpr) (*exprpb.Expr, error) { + op, err := ExprToProto(s.Operand()) + if err != nil { + return nil, err + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_SelectExpr{ + SelectExpr: &exprpb.Expr_Select{ + Operand: op, + Field: s.FieldName(), + TestOnly: s.IsTestOnly(), + }, + }, + }, nil +} + +func protoStruct(id int64, s StructExpr) (*exprpb.Expr, error) { + entries := make([]*exprpb.Expr_CreateStruct_Entry, len(s.Fields())) + var err error + for i, e := range s.Fields() { + entries[i], err = EntryExprToProto(e) + if err != nil { + return nil, err + } + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_StructExpr{ + StructExpr: &exprpb.Expr_CreateStruct{ + MessageName: s.TypeName(), + Entries: entries, + }, + }, + }, nil +} + +func protoStructField(id int64, f StructField) (*exprpb.Expr_CreateStruct_Entry, error) { + v, err := ExprToProto(f.Value()) + if err != nil { + return nil, err + } + return &exprpb.Expr_CreateStruct_Entry{ + Id: id, + KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{ + FieldKey: f.Name(), + }, + Value: v, + OptionalEntry: f.IsOptional(), + }, nil +} + +// SourceInfoToProto serializes an ast.SourceInfo value to a protobuf SourceInfo object. +func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) { + if info == nil { + return &exprpb.SourceInfo{}, nil + } + sourceInfo := &exprpb.SourceInfo{ + SyntaxVersion: info.SyntaxVersion(), + Location: info.Description(), + LineOffsets: info.LineOffsets(), + Positions: make(map[int64]int32, len(info.OffsetRanges())), + MacroCalls: make(map[int64]*exprpb.Expr, len(info.MacroCalls())), + } + for id, offset := range info.OffsetRanges() { + sourceInfo.Positions[id] = offset.Start + } + for id, e := range info.MacroCalls() { + call, err := ExprToProto(e) + if err != nil { + return nil, err + } + sourceInfo.MacroCalls[id] = call + } + return sourceInfo, nil +} + +// ProtoToSourceInfo deserializes the protobuf into a native SourceInfo value. +func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) { + sourceInfo := &SourceInfo{ + syntax: info.GetSyntaxVersion(), + desc: info.GetLocation(), + lines: info.GetLineOffsets(), + offsetRanges: make(map[int64]OffsetRange, len(info.GetPositions())), + macroCalls: make(map[int64]Expr, len(info.GetMacroCalls())), + } + for id, offset := range info.GetPositions() { + sourceInfo.SetOffsetRange(id, OffsetRange{Start: offset, Stop: offset}) + } + for id, e := range info.GetMacroCalls() { + call, err := ProtoToExpr(e) + if err != nil { + return nil, err + } + sourceInfo.SetMacroCall(id, call) + } + return sourceInfo, nil +} + +// ReferenceInfoToProto converts a ReferenceInfo instance to a protobuf Reference suitable for serialization. +func ReferenceInfoToProto(info *ReferenceInfo) (*exprpb.Reference, error) { + c, err := ValToConstant(info.Value) + if err != nil { + return nil, err + } + return &exprpb.Reference{ + Name: info.Name, + OverloadId: info.OverloadIDs, + Value: c, + }, nil +} + +// ProtoToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance. +func ProtoToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) { + v, err := ConstantToVal(ref.GetValue()) + if err != nil { + return nil, err + } + return &ReferenceInfo{ + Name: ref.GetName(), + OverloadIDs: ref.GetOverloadId(), + Value: v, + }, nil +} + +// ValToConstant converts a CEL-native ref.Val to a protobuf Constant. +// +// Only simple scalar types are supported by this method. +func ValToConstant(v ref.Val) (*exprpb.Constant, error) { + if v == nil { + return nil, nil + } + switch v.Type() { + case types.BoolType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil + case types.BytesType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil + case types.DoubleType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil + case types.IntType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil + case types.NullType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil + case types.StringType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil + case types.UintType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil + } + return nil, fmt.Errorf("unsupported constant kind: %v", v.Type()) +} + +// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val. +func ConstantToVal(c *exprpb.Constant) (ref.Val, error) { + if c == nil { + return nil, nil + } + switch c.GetConstantKind().(type) { + case *exprpb.Constant_BoolValue: + return types.Bool(c.GetBoolValue()), nil + case *exprpb.Constant_BytesValue: + return types.Bytes(c.GetBytesValue()), nil + case *exprpb.Constant_DoubleValue: + return types.Double(c.GetDoubleValue()), nil + case *exprpb.Constant_Int64Value: + return types.Int(c.GetInt64Value()), nil + case *exprpb.Constant_NullValue: + return types.NullValue, nil + case *exprpb.Constant_StringValue: + return types.String(c.GetStringValue()), nil + case *exprpb.Constant_Uint64Value: + return types.Uint(c.GetUint64Value()), nil + } + return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind()) +} diff --git a/vendor/github.com/google/cel-go/common/ast/expr.go b/vendor/github.com/google/cel-go/common/ast/expr.go index b63884a6028..c9d88bbaab7 100644 --- a/vendor/github.com/google/cel-go/common/ast/expr.go +++ b/vendor/github.com/google/cel-go/common/ast/expr.go @@ -15,168 +15,61 @@ package ast import ( - "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // ExprKind represents the expression node kind. type ExprKind int const ( - // UnspecifiedKind represents an unset expression with no specified properties. - UnspecifiedKind ExprKind = iota + // UnspecifiedExprKind represents an unset expression with no specified properties. + UnspecifiedExprKind ExprKind = iota - // LiteralKind represents a primitive scalar literal. - LiteralKind + // CallKind represents a function call. + CallKind + + // ComprehensionKind represents a comprehension expression generated by a macro. + ComprehensionKind // IdentKind represents a simple variable, constant, or type identifier. IdentKind - // SelectKind represents a field selection expression. - SelectKind - - // CallKind represents a function call. - CallKind - // ListKind represents a list literal expression. ListKind + // LiteralKind represents a primitive scalar literal. + LiteralKind + // MapKind represents a map literal expression. MapKind + // SelectKind represents a field selection expression. + SelectKind + // StructKind represents a struct literal expression. StructKind - - // ComprehensionKind represents a comprehension expression generated by a macro. - ComprehensionKind ) -// NavigateCheckedAST converts a CheckedAST to a NavigableExpr -func NavigateCheckedAST(ast *CheckedAST) NavigableExpr { - return newNavigableExpr(nil, ast.Expr, ast.TypeMap) -} - -// ExprMatcher takes a NavigableExpr in and indicates whether the value is a match. +// Expr represents the base expression node in a CEL abstract syntax tree. // -// This function type should be use with the `Match` and `MatchList` calls. -type ExprMatcher func(NavigableExpr) bool - -// ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr -// is comprised of all constant values, such as a simple literal or even list and map literal. -func ConstantValueMatcher() ExprMatcher { - return matchIsConstantValue -} - -// KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches -// the specified `kind`. -func KindMatcher(kind ExprKind) ExprMatcher { - return func(e NavigableExpr) bool { - return e.Kind() == kind - } -} - -// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose -// function name is equal to `funcName`. -func FunctionMatcher(funcName string) ExprMatcher { - return func(e NavigableExpr) bool { - if e.Kind() != CallKind { - return false - } - return e.AsCall().FunctionName() == funcName - } -} - -// AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list. -// -// Such a result would work well with subsequent MatchList calls. -func AllMatcher() ExprMatcher { - return func(NavigableExpr) bool { - return true - } -} - -// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values of the -// descendants which match. -func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr { - return matchListInternal([]NavigableExpr{expr}, matcher, true) -} - -// MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a -// subset of NavigableExpr values which match. -func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr { - visit := make([]NavigableExpr, len(exprs)) - copy(visit, exprs) - return matchListInternal(visit, matcher, false) -} - -func matchListInternal(visit []NavigableExpr, matcher ExprMatcher, visitDescendants bool) []NavigableExpr { - var matched []NavigableExpr - for len(visit) != 0 { - e := visit[0] - if matcher(e) { - matched = append(matched, e) - } - if visitDescendants { - visit = append(visit[1:], e.Children()...) - } else { - visit = visit[1:] - } - } - return matched -} - -func matchIsConstantValue(e NavigableExpr) bool { - if e.Kind() == LiteralKind { - return true - } - if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind { - for _, child := range e.Children() { - if !matchIsConstantValue(child) { - return false - } - } - return true - } - return false -} - -// NavigableExpr represents the base navigable expression value. -// -// Depending on the `Kind()` value, the NavigableExpr may be converted to a concrete expression types +// Depending on the `Kind()` value, the Expr may be converted to a concrete expression types // as indicated by the `As` methods. -// -// NavigableExpr values and their concrete expression types should be nil-safe. Conversion of an expr -// to the wrong kind should produce a nil value. -type NavigableExpr interface { +type Expr interface { // ID of the expression as it appears in the AST ID() int64 // Kind of the expression node. See ExprKind for the valid enum values. Kind() ExprKind - // Type of the expression node. - Type() *types.Type - - // Parent returns the parent expression node, if one exists. - Parent() (NavigableExpr, bool) - - // Children returns a list of child expression nodes. - Children() []NavigableExpr - - // ToExpr adapts this NavigableExpr to a protobuf representation. - ToExpr() *exprpb.Expr - - // AsCall adapts the expr into a NavigableCallExpr + // AsCall adapts the expr into a CallExpr // // The Kind() must be equal to a CallKind for the conversion to be well-defined. - AsCall() NavigableCallExpr + AsCall() CallExpr - // AsComprehension adapts the expr into a NavigableComprehensionExpr. + // AsComprehension adapts the expr into a ComprehensionExpr. // // The Kind() must be equal to a ComprehensionKind for the conversion to be well-defined. - AsComprehension() NavigableComprehensionExpr + AsComprehension() ComprehensionExpr // AsIdent adapts the expr into an identifier string. // @@ -188,67 +81,123 @@ type NavigableExpr interface { // The Kind() must be equal to a LiteralKind for the conversion to be well-defined. AsLiteral() ref.Val - // AsList adapts the expr into a NavigableListExpr. + // AsList adapts the expr into a ListExpr. // // The Kind() must be equal to a ListKind for the conversion to be well-defined. - AsList() NavigableListExpr + AsList() ListExpr - // AsMap adapts the expr into a NavigableMapExpr. + // AsMap adapts the expr into a MapExpr. // // The Kind() must be equal to a MapKind for the conversion to be well-defined. - AsMap() NavigableMapExpr + AsMap() MapExpr - // AsSelect adapts the expr into a NavigableSelectExpr. + // AsSelect adapts the expr into a SelectExpr. // // The Kind() must be equal to a SelectKind for the conversion to be well-defined. - AsSelect() NavigableSelectExpr + AsSelect() SelectExpr - // AsStruct adapts the expr into a NavigableStructExpr. + // AsStruct adapts the expr into a StructExpr. // // The Kind() must be equal to a StructKind for the conversion to be well-defined. - AsStruct() NavigableStructExpr + AsStruct() StructExpr - // marker interface method - isNavigable() + // RenumberIDs performs an in-place update of the expression and all of its descendents numeric ids. + RenumberIDs(IDGenerator) + + // SetKindCase replaces the contents of the current expression with the contents of the other. + // + // The SetKindCase takes ownership of any expression instances references within the input Expr. + // A shallow copy is made of the Expr value itself, but not a deep one. + // + // This method should only be used during AST rewrites using temporary Expr values. + SetKindCase(Expr) + + // isExpr is a marker interface. + isExpr() } -// NavigableCallExpr defines an interface for inspecting a function call and its arugments. -type NavigableCallExpr interface { +// EntryExprKind represents the possible EntryExpr kinds. +type EntryExprKind int + +const ( + // UnspecifiedEntryExprKind indicates that the entry expr is not set. + UnspecifiedEntryExprKind EntryExprKind = iota + + // MapEntryKind indicates that the entry is a MapEntry type with key and value expressions. + MapEntryKind + + // StructFieldKind indicates that the entry is a StructField with a field name and initializer + // expression. + StructFieldKind +) + +// EntryExpr represents the base entry expression in a CEL map or struct literal. +type EntryExpr interface { + // ID of the entry as it appears in the AST. + ID() int64 + + // Kind of the entry expression node. See EntryExprKind for valid enum values. + Kind() EntryExprKind + + // AsMapEntry casts the EntryExpr to a MapEntry. + // + // The Kind() must be equal to MapEntryKind for the conversion to be well-defined. + AsMapEntry() MapEntry + + // AsStructField casts the EntryExpr to a StructField + // + // The Kind() must be equal to StructFieldKind for the conversion to be well-defined. + AsStructField() StructField + + // RenumberIDs performs an in-place update of the expression and all of its descendents numeric ids. + RenumberIDs(IDGenerator) + + isEntryExpr() +} + +// IDGenerator produces unique ids suitable for tagging expression nodes +type IDGenerator func(originalID int64) int64 + +// CallExpr defines an interface for inspecting a function call and its arugments. +type CallExpr interface { // FunctionName returns the name of the function. FunctionName() string + // IsMemberFunction returns whether the call has a non-nil target indicating it is a member function + IsMemberFunction() bool + // Target returns the target of the expression if one is present. - Target() NavigableExpr + Target() Expr // Args returns the list of call arguments, excluding the target. - Args() []NavigableExpr - - // ReturnType returns the result type of the call. - ReturnType() *types.Type + Args() []Expr // marker interface method - isNavigable() + isExpr() } -// NavigableListExpr defines an interface for inspecting a list literal expression. -type NavigableListExpr interface { +// ListExpr defines an interface for inspecting a list literal expression. +type ListExpr interface { // Elements returns the list elements as navigable expressions. - Elements() []NavigableExpr + Elements() []Expr // OptionalIndicies returns the list of optional indices in the list literal. OptionalIndices() []int32 + // IsOptional indicates whether the given element index is optional. + IsOptional(int32) bool + // Size returns the number of elements in the list. Size() int // marker interface method - isNavigable() + isExpr() } -// NavigableSelectExpr defines an interface for inspecting a select expression. -type NavigableSelectExpr interface { +// SelectExpr defines an interface for inspecting a select expression. +type SelectExpr interface { // Operand returns the selection operand expression. - Operand() NavigableExpr + Operand() Expr // FieldName returns the field name being selected from the operand. FieldName() string @@ -257,67 +206,67 @@ type NavigableSelectExpr interface { IsTestOnly() bool // marker interface method - isNavigable() + isExpr() } -// NavigableMapExpr defines an interface for inspecting a map expression. -type NavigableMapExpr interface { - // Entries returns the map key value pairs as NavigableEntry values. - Entries() []NavigableEntry +// MapExpr defines an interface for inspecting a map expression. +type MapExpr interface { + // Entries returns the map key value pairs as EntryExpr values. + Entries() []EntryExpr // Size returns the number of entries in the map. Size() int // marker interface method - isNavigable() + isExpr() } -// NavigableEntry defines an interface for inspecting a map entry. -type NavigableEntry interface { +// MapEntry defines an interface for inspecting a map entry. +type MapEntry interface { // Key returns the map entry key expression. - Key() NavigableExpr + Key() Expr // Value returns the map entry value expression. - Value() NavigableExpr + Value() Expr // IsOptional returns whether the entry is optional. IsOptional() bool // marker interface method - isNavigable() + isEntryExpr() } -// NavigableStructExpr defines an interfaces for inspecting a struct and its field initializers. -type NavigableStructExpr interface { +// StructExpr defines an interfaces for inspecting a struct and its field initializers. +type StructExpr interface { // TypeName returns the struct type name. TypeName() string - // Fields returns the set of field initializers in the struct expression as NavigableField values. - Fields() []NavigableField + // Fields returns the set of field initializers in the struct expression as EntryExpr values. + Fields() []EntryExpr // marker interface method - isNavigable() + isExpr() } -// NavigableField defines an interface for inspecting a struct field initialization. -type NavigableField interface { - // FieldName returns the name of the field. - FieldName() string +// StructField defines an interface for inspecting a struct field initialization. +type StructField interface { + // Name returns the name of the field. + Name() string // Value returns the field initialization expression. - Value() NavigableExpr + Value() Expr // IsOptional returns whether the field is optional. IsOptional() bool // marker interface method - isNavigable() + isEntryExpr() } -// NavigableComprehensionExpr defines an interface for inspecting a comprehension expression. -type NavigableComprehensionExpr interface { +// ComprehensionExpr defines an interface for inspecting a comprehension expression. +type ComprehensionExpr interface { // IterRange returns the iteration range expression. - IterRange() NavigableExpr + IterRange() Expr // IterVar returns the iteration variable name. IterVar() string @@ -326,384 +275,586 @@ type NavigableComprehensionExpr interface { AccuVar() string // AccuInit returns the accumulation variable initialization expression. - AccuInit() NavigableExpr + AccuInit() Expr // LoopCondition returns the loop condition expression. - LoopCondition() NavigableExpr + LoopCondition() Expr // LoopStep returns the loop step expression. - LoopStep() NavigableExpr + LoopStep() Expr // Result returns the comprehension result expression. - Result() NavigableExpr + Result() Expr // marker interface method - isNavigable() + isExpr() } -func newNavigableExpr(parent NavigableExpr, expr *exprpb.Expr, typeMap map[int64]*types.Type) NavigableExpr { - kind, factory := kindOf(expr) - nav := &navigableExprImpl{ - parent: parent, - kind: kind, - expr: expr, - typeMap: typeMap, - createChildren: factory, +var _ Expr = &expr{} + +type expr struct { + id int64 + exprKindCase +} + +type exprKindCase interface { + Kind() ExprKind + + renumberIDs(IDGenerator) + + isExpr() +} + +func (e *expr) ID() int64 { + if e == nil { + return 0 } - return nav + return e.id } -type navigableExprImpl struct { - parent NavigableExpr - kind ExprKind - expr *exprpb.Expr - typeMap map[int64]*types.Type - createChildren childFactory +func (e *expr) Kind() ExprKind { + if e == nil || e.exprKindCase == nil { + return UnspecifiedExprKind + } + return e.exprKindCase.Kind() } -func (nav *navigableExprImpl) ID() int64 { - return nav.ToExpr().GetId() +func (e *expr) AsCall() CallExpr { + if e.Kind() != CallKind { + return nilCall + } + return e.exprKindCase.(CallExpr) } -func (nav *navigableExprImpl) Kind() ExprKind { - return nav.kind +func (e *expr) AsComprehension() ComprehensionExpr { + if e.Kind() != ComprehensionKind { + return nilCompre + } + return e.exprKindCase.(ComprehensionExpr) } -func (nav *navigableExprImpl) Type() *types.Type { - if t, found := nav.typeMap[nav.ID()]; found { - return t +func (e *expr) AsIdent() string { + if e.Kind() != IdentKind { + return "" } - return types.DynType + return string(e.exprKindCase.(baseIdentExpr)) } -func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) { - if nav.parent != nil { - return nav.parent, true +func (e *expr) AsLiteral() ref.Val { + if e.Kind() != LiteralKind { + return nil } - return nil, false + return e.exprKindCase.(*baseLiteral).Val } -func (nav *navigableExprImpl) Children() []NavigableExpr { - return nav.createChildren(nav) +func (e *expr) AsList() ListExpr { + if e.Kind() != ListKind { + return nilList + } + return e.exprKindCase.(ListExpr) } -func (nav *navigableExprImpl) ToExpr() *exprpb.Expr { - return nav.expr +func (e *expr) AsMap() MapExpr { + if e.Kind() != MapKind { + return nilMap + } + return e.exprKindCase.(MapExpr) } -func (nav *navigableExprImpl) AsCall() NavigableCallExpr { - return navigableCallImpl{navigableExprImpl: nav} +func (e *expr) AsSelect() SelectExpr { + if e.Kind() != SelectKind { + return nilSel + } + return e.exprKindCase.(SelectExpr) } -func (nav *navigableExprImpl) AsComprehension() NavigableComprehensionExpr { - return navigableComprehensionImpl{navigableExprImpl: nav} +func (e *expr) AsStruct() StructExpr { + if e.Kind() != StructKind { + return nilStruct + } + return e.exprKindCase.(StructExpr) } -func (nav *navigableExprImpl) AsIdent() string { - return nav.ToExpr().GetIdentExpr().GetName() +func (e *expr) SetKindCase(other Expr) { + if e == nil { + return + } + if other == nil { + e.exprKindCase = nil + return + } + switch other.Kind() { + case CallKind: + c := other.AsCall() + e.exprKindCase = &baseCallExpr{ + function: c.FunctionName(), + target: c.Target(), + args: c.Args(), + isMember: c.IsMemberFunction(), + } + case ComprehensionKind: + c := other.AsComprehension() + e.exprKindCase = &baseComprehensionExpr{ + iterRange: c.IterRange(), + iterVar: c.IterVar(), + accuVar: c.AccuVar(), + accuInit: c.AccuInit(), + loopCond: c.LoopCondition(), + loopStep: c.LoopStep(), + result: c.Result(), + } + case IdentKind: + e.exprKindCase = baseIdentExpr(other.AsIdent()) + case ListKind: + l := other.AsList() + optIndexMap := make(map[int32]struct{}, len(l.OptionalIndices())) + for _, idx := range l.OptionalIndices() { + optIndexMap[idx] = struct{}{} + } + e.exprKindCase = &baseListExpr{ + elements: l.Elements(), + optIndices: l.OptionalIndices(), + optIndexMap: optIndexMap, + } + case LiteralKind: + e.exprKindCase = &baseLiteral{Val: other.AsLiteral()} + case MapKind: + e.exprKindCase = &baseMapExpr{ + entries: other.AsMap().Entries(), + } + case SelectKind: + s := other.AsSelect() + e.exprKindCase = &baseSelectExpr{ + operand: s.Operand(), + field: s.FieldName(), + testOnly: s.IsTestOnly(), + } + case StructKind: + s := other.AsStruct() + e.exprKindCase = &baseStructExpr{ + typeName: s.TypeName(), + fields: s.Fields(), + } + case UnspecifiedExprKind: + e.exprKindCase = nil + } } -func (nav *navigableExprImpl) AsLiteral() ref.Val { - if nav.Kind() != LiteralKind { - return nil +func (e *expr) RenumberIDs(idGen IDGenerator) { + if e == nil { + return } - val, err := ConstantToVal(nav.ToExpr().GetConstExpr()) - if err != nil { - panic(err) + e.id = idGen(e.id) + if e.exprKindCase != nil { + e.exprKindCase.renumberIDs(idGen) } - return val } -func (nav *navigableExprImpl) AsList() NavigableListExpr { - return navigableListImpl{navigableExprImpl: nav} +type baseCallExpr struct { + function string + target Expr + args []Expr + isMember bool } -func (nav *navigableExprImpl) AsMap() NavigableMapExpr { - return navigableMapImpl{navigableExprImpl: nav} +func (*baseCallExpr) Kind() ExprKind { + return CallKind } -func (nav *navigableExprImpl) AsSelect() NavigableSelectExpr { - return navigableSelectImpl{navigableExprImpl: nav} +func (e *baseCallExpr) FunctionName() string { + if e == nil { + return "" + } + return e.function } -func (nav *navigableExprImpl) AsStruct() NavigableStructExpr { - return navigableStructImpl{navigableExprImpl: nav} +func (e *baseCallExpr) IsMemberFunction() bool { + if e == nil { + return false + } + return e.isMember } -func (nav *navigableExprImpl) createChild(e *exprpb.Expr) NavigableExpr { - return newNavigableExpr(nav, e, nav.typeMap) +func (e *baseCallExpr) Target() Expr { + if e == nil || !e.IsMemberFunction() { + return nilExpr + } + return e.target } -func (nav *navigableExprImpl) isNavigable() {} +func (e *baseCallExpr) Args() []Expr { + if e == nil { + return []Expr{} + } + return e.args +} -type navigableCallImpl struct { - *navigableExprImpl +func (e *baseCallExpr) renumberIDs(idGen IDGenerator) { + if e.IsMemberFunction() { + e.Target().RenumberIDs(idGen) + } + for _, arg := range e.Args() { + arg.RenumberIDs(idGen) + } } -func (call navigableCallImpl) FunctionName() string { - return call.ToExpr().GetCallExpr().GetFunction() +func (*baseCallExpr) isExpr() {} + +var _ ComprehensionExpr = &baseComprehensionExpr{} + +type baseComprehensionExpr struct { + iterRange Expr + iterVar string + accuVar string + accuInit Expr + loopCond Expr + loopStep Expr + result Expr } -func (call navigableCallImpl) Target() NavigableExpr { - t := call.ToExpr().GetCallExpr().GetTarget() - if t != nil { - return call.createChild(t) - } - return nil +func (*baseComprehensionExpr) Kind() ExprKind { + return ComprehensionKind } -func (call navigableCallImpl) Args() []NavigableExpr { - args := call.ToExpr().GetCallExpr().GetArgs() - navArgs := make([]NavigableExpr, len(args)) - for i, a := range args { - navArgs[i] = call.createChild(a) +func (e *baseComprehensionExpr) IterRange() Expr { + if e == nil { + return nilExpr } - return navArgs + return e.iterRange } -func (call navigableCallImpl) ReturnType() *types.Type { - return call.Type() +func (e *baseComprehensionExpr) IterVar() string { + return e.iterVar } -type navigableComprehensionImpl struct { - *navigableExprImpl +func (e *baseComprehensionExpr) AccuVar() string { + return e.accuVar } -func (comp navigableComprehensionImpl) IterRange() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetIterRange()) +func (e *baseComprehensionExpr) AccuInit() Expr { + if e == nil { + return nilExpr + } + return e.accuInit } -func (comp navigableComprehensionImpl) IterVar() string { - return comp.ToExpr().GetComprehensionExpr().GetIterVar() +func (e *baseComprehensionExpr) LoopCondition() Expr { + if e == nil { + return nilExpr + } + return e.loopCond } -func (comp navigableComprehensionImpl) AccuVar() string { - return comp.ToExpr().GetComprehensionExpr().GetAccuVar() +func (e *baseComprehensionExpr) LoopStep() Expr { + if e == nil { + return nilExpr + } + return e.loopStep } -func (comp navigableComprehensionImpl) AccuInit() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetAccuInit()) +func (e *baseComprehensionExpr) Result() Expr { + if e == nil { + return nilExpr + } + return e.result } -func (comp navigableComprehensionImpl) LoopCondition() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopCondition()) +func (e *baseComprehensionExpr) renumberIDs(idGen IDGenerator) { + e.IterRange().RenumberIDs(idGen) + e.AccuInit().RenumberIDs(idGen) + e.LoopCondition().RenumberIDs(idGen) + e.LoopStep().RenumberIDs(idGen) + e.Result().RenumberIDs(idGen) } -func (comp navigableComprehensionImpl) LoopStep() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopStep()) +func (*baseComprehensionExpr) isExpr() {} + +var _ exprKindCase = baseIdentExpr("") + +type baseIdentExpr string + +func (baseIdentExpr) Kind() ExprKind { + return IdentKind } -func (comp navigableComprehensionImpl) Result() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetResult()) +func (e baseIdentExpr) renumberIDs(IDGenerator) {} + +func (baseIdentExpr) isExpr() {} + +var _ exprKindCase = &baseLiteral{} +var _ ref.Val = &baseLiteral{} + +type baseLiteral struct { + ref.Val } -type navigableListImpl struct { - *navigableExprImpl +func (*baseLiteral) Kind() ExprKind { + return LiteralKind } -func (l navigableListImpl) Elements() []NavigableExpr { - return l.Children() +func (l *baseLiteral) renumberIDs(IDGenerator) {} + +func (*baseLiteral) isExpr() {} + +var _ ListExpr = &baseListExpr{} + +type baseListExpr struct { + elements []Expr + optIndices []int32 + optIndexMap map[int32]struct{} } -func (l navigableListImpl) OptionalIndices() []int32 { - return l.ToExpr().GetListExpr().GetOptionalIndices() +func (*baseListExpr) Kind() ExprKind { + return ListKind } -func (l navigableListImpl) Size() int { - return len(l.ToExpr().GetListExpr().GetElements()) +func (e *baseListExpr) Elements() []Expr { + if e == nil { + return []Expr{} + } + return e.elements } -type navigableMapImpl struct { - *navigableExprImpl +func (e *baseListExpr) IsOptional(index int32) bool { + _, found := e.optIndexMap[index] + return found } -func (m navigableMapImpl) Entries() []NavigableEntry { - mapExpr := m.ToExpr().GetStructExpr() - entries := make([]NavigableEntry, len(mapExpr.GetEntries())) - for i, e := range mapExpr.GetEntries() { - entries[i] = navigableEntryImpl{ - key: m.createChild(e.GetMapKey()), - val: m.createChild(e.GetValue()), - isOpt: e.GetOptionalEntry(), - } +func (e *baseListExpr) OptionalIndices() []int32 { + if e == nil { + return []int32{} } - return entries + return e.optIndices } -func (m navigableMapImpl) Size() int { - return len(m.ToExpr().GetStructExpr().GetEntries()) +func (e *baseListExpr) Size() int { + return len(e.Elements()) } -type navigableEntryImpl struct { - key NavigableExpr - val NavigableExpr - isOpt bool +func (e *baseListExpr) renumberIDs(idGen IDGenerator) { + for _, elem := range e.Elements() { + elem.RenumberIDs(idGen) + } } -func (e navigableEntryImpl) Key() NavigableExpr { - return e.key -} +func (*baseListExpr) isExpr() {} -func (e navigableEntryImpl) Value() NavigableExpr { - return e.val +type baseMapExpr struct { + entries []EntryExpr } -func (e navigableEntryImpl) IsOptional() bool { - return e.isOpt +func (*baseMapExpr) Kind() ExprKind { + return MapKind } -func (e navigableEntryImpl) isNavigable() {} +func (e *baseMapExpr) Entries() []EntryExpr { + if e == nil { + return []EntryExpr{} + } + return e.entries +} -type navigableSelectImpl struct { - *navigableExprImpl +func (e *baseMapExpr) Size() int { + return len(e.Entries()) } -func (sel navigableSelectImpl) FieldName() string { - return sel.ToExpr().GetSelectExpr().GetField() +func (e *baseMapExpr) renumberIDs(idGen IDGenerator) { + for _, entry := range e.Entries() { + entry.RenumberIDs(idGen) + } } -func (sel navigableSelectImpl) IsTestOnly() bool { - return sel.ToExpr().GetSelectExpr().GetTestOnly() +func (*baseMapExpr) isExpr() {} + +type baseSelectExpr struct { + operand Expr + field string + testOnly bool } -func (sel navigableSelectImpl) Operand() NavigableExpr { - return sel.createChild(sel.ToExpr().GetSelectExpr().GetOperand()) +func (*baseSelectExpr) Kind() ExprKind { + return SelectKind } -type navigableStructImpl struct { - *navigableExprImpl +func (e *baseSelectExpr) Operand() Expr { + if e == nil || e.operand == nil { + return nilExpr + } + return e.operand } -func (s navigableStructImpl) TypeName() string { - return s.ToExpr().GetStructExpr().GetMessageName() +func (e *baseSelectExpr) FieldName() string { + if e == nil { + return "" + } + return e.field } -func (s navigableStructImpl) Fields() []NavigableField { - fieldInits := s.ToExpr().GetStructExpr().GetEntries() - fields := make([]NavigableField, len(fieldInits)) - for i, f := range fieldInits { - fields[i] = navigableFieldImpl{ - name: f.GetFieldKey(), - val: s.createChild(f.GetValue()), - isOpt: f.GetOptionalEntry(), - } +func (e *baseSelectExpr) IsTestOnly() bool { + if e == nil { + return false } - return fields + return e.testOnly } -type navigableFieldImpl struct { - name string - val NavigableExpr - isOpt bool +func (e *baseSelectExpr) renumberIDs(idGen IDGenerator) { + e.Operand().RenumberIDs(idGen) } -func (f navigableFieldImpl) FieldName() string { - return f.name +func (*baseSelectExpr) isExpr() {} + +type baseStructExpr struct { + typeName string + fields []EntryExpr } -func (f navigableFieldImpl) Value() NavigableExpr { - return f.val +func (*baseStructExpr) Kind() ExprKind { + return StructKind } -func (f navigableFieldImpl) IsOptional() bool { - return f.isOpt +func (e *baseStructExpr) TypeName() string { + if e == nil { + return "" + } + return e.typeName } -func (f navigableFieldImpl) isNavigable() {} +func (e *baseStructExpr) Fields() []EntryExpr { + if e == nil { + return []EntryExpr{} + } + return e.fields +} -func kindOf(expr *exprpb.Expr) (ExprKind, childFactory) { - switch expr.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - return LiteralKind, noopFactory - case *exprpb.Expr_IdentExpr: - return IdentKind, noopFactory - case *exprpb.Expr_SelectExpr: - return SelectKind, selectFactory - case *exprpb.Expr_CallExpr: - return CallKind, callArgFactory - case *exprpb.Expr_ListExpr: - return ListKind, listElemFactory - case *exprpb.Expr_StructExpr: - if expr.GetStructExpr().GetMessageName() != "" { - return StructKind, structEntryFactory - } - return MapKind, mapEntryFactory - case *exprpb.Expr_ComprehensionExpr: - return ComprehensionKind, comprehensionFactory - default: - return UnspecifiedKind, noopFactory +func (e *baseStructExpr) renumberIDs(idGen IDGenerator) { + for _, f := range e.Fields() { + f.RenumberIDs(idGen) } } -type childFactory func(*navigableExprImpl) []NavigableExpr +func (*baseStructExpr) isExpr() {} + +type entryExprKindCase interface { + Kind() EntryExprKind + + renumberIDs(IDGenerator) + + isEntryExpr() +} + +var _ EntryExpr = &entryExpr{} + +type entryExpr struct { + id int64 + entryExprKindCase +} -func noopFactory(*navigableExprImpl) []NavigableExpr { - return nil +func (e *entryExpr) ID() int64 { + return e.id } -func selectFactory(nav *navigableExprImpl) []NavigableExpr { - return []NavigableExpr{ - nav.createChild(nav.ToExpr().GetSelectExpr().GetOperand()), +func (e *entryExpr) AsMapEntry() MapEntry { + if e.Kind() != MapEntryKind { + return nilMapEntry } + return e.entryExprKindCase.(MapEntry) } -func callArgFactory(nav *navigableExprImpl) []NavigableExpr { - call := nav.ToExpr().GetCallExpr() - argCount := len(call.GetArgs()) - if call.GetTarget() != nil { - argCount++ +func (e *entryExpr) AsStructField() StructField { + if e.Kind() != StructFieldKind { + return nilStructField } - navExprs := make([]NavigableExpr, argCount) - i := 0 - if call.GetTarget() != nil { - navExprs[i] = nav.createChild(call.GetTarget()) - i++ + return e.entryExprKindCase.(StructField) +} + +func (e *entryExpr) RenumberIDs(idGen IDGenerator) { + e.id = idGen(e.id) + e.entryExprKindCase.renumberIDs(idGen) +} + +type baseMapEntry struct { + key Expr + value Expr + isOptional bool +} + +func (e *baseMapEntry) Kind() EntryExprKind { + return MapEntryKind +} + +func (e *baseMapEntry) Key() Expr { + if e == nil { + return nilExpr } - for _, arg := range call.GetArgs() { - navExprs[i] = nav.createChild(arg) - i++ + return e.key +} + +func (e *baseMapEntry) Value() Expr { + if e == nil { + return nilExpr } - return navExprs + return e.value } -func listElemFactory(nav *navigableExprImpl) []NavigableExpr { - l := nav.ToExpr().GetListExpr() - navExprs := make([]NavigableExpr, len(l.GetElements())) - for i, e := range l.GetElements() { - navExprs[i] = nav.createChild(e) +func (e *baseMapEntry) IsOptional() bool { + if e == nil { + return false } - return navExprs + return e.isOptional } -func structEntryFactory(nav *navigableExprImpl) []NavigableExpr { - s := nav.ToExpr().GetStructExpr() - entries := make([]NavigableExpr, len(s.GetEntries())) - for i, e := range s.GetEntries() { +func (e *baseMapEntry) renumberIDs(idGen IDGenerator) { + e.Key().RenumberIDs(idGen) + e.Value().RenumberIDs(idGen) +} + +func (*baseMapEntry) isEntryExpr() {} + +type baseStructField struct { + field string + value Expr + isOptional bool +} + +func (f *baseStructField) Kind() EntryExprKind { + return StructFieldKind +} - entries[i] = nav.createChild(e.GetValue()) +func (f *baseStructField) Name() string { + if f == nil { + return "" } - return entries + return f.field } -func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr { - s := nav.ToExpr().GetStructExpr() - entries := make([]NavigableExpr, len(s.GetEntries())*2) - j := 0 - for _, e := range s.GetEntries() { - entries[j] = nav.createChild(e.GetMapKey()) - entries[j+1] = nav.createChild(e.GetValue()) - j += 2 +func (f *baseStructField) Value() Expr { + if f == nil { + return nilExpr } - return entries + return f.value } -func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr { - compre := nav.ToExpr().GetComprehensionExpr() - return []NavigableExpr{ - nav.createChild(compre.GetIterRange()), - nav.createChild(compre.GetAccuInit()), - nav.createChild(compre.GetLoopCondition()), - nav.createChild(compre.GetLoopStep()), - nav.createChild(compre.GetResult()), +func (f *baseStructField) IsOptional() bool { + if f == nil { + return false } + return f.isOptional } + +func (f *baseStructField) renumberIDs(idGen IDGenerator) { + f.Value().RenumberIDs(idGen) +} + +func (*baseStructField) isEntryExpr() {} + +var ( + nilExpr *expr = nil + nilCall *baseCallExpr = nil + nilCompre *baseComprehensionExpr = nil + nilList *baseListExpr = nil + nilMap *baseMapExpr = nil + nilMapEntry *baseMapEntry = nil + nilSel *baseSelectExpr = nil + nilStruct *baseStructExpr = nil + nilStructField *baseStructField = nil +) diff --git a/vendor/github.com/google/cel-go/common/ast/factory.go b/vendor/github.com/google/cel-go/common/ast/factory.go new file mode 100644 index 00000000000..b7f36e72a49 --- /dev/null +++ b/vendor/github.com/google/cel-go/common/ast/factory.go @@ -0,0 +1,303 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import "github.com/google/cel-go/common/types/ref" + +// ExprFactory interfaces defines a set of methods necessary for building native expression values. +type ExprFactory interface { + // CopyExpr creates a deep copy of the input Expr value. + CopyExpr(Expr) Expr + + // CopyEntryExpr creates a deep copy of the input EntryExpr value. + CopyEntryExpr(EntryExpr) EntryExpr + + // NewCall creates an Expr value representing a global function call. + NewCall(id int64, function string, args ...Expr) Expr + + // NewComprehension creates an Expr value representing a comprehension over a value range. + NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr + + // NewMemberCall creates an Expr value representing a member function call. + NewMemberCall(id int64, function string, receiver Expr, args ...Expr) Expr + + // NewIdent creates an Expr value representing an identifier. + NewIdent(id int64, name string) Expr + + // NewAccuIdent creates an Expr value representing an accumulator identifier within a + //comprehension. + NewAccuIdent(id int64) Expr + + // NewLiteral creates an Expr value representing a literal value, such as a string or integer. + NewLiteral(id int64, value ref.Val) Expr + + // NewList creates an Expr value representing a list literal expression with optional indices. + // + // Optional indicies will typically be empty unless the CEL optional types are enabled. + NewList(id int64, elems []Expr, optIndices []int32) Expr + + // NewMap creates an Expr value representing a map literal expression + NewMap(id int64, entries []EntryExpr) Expr + + // NewMapEntry creates a MapEntry with a given key, value, and a flag indicating whether + // the key is optionally set. + NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr + + // NewPresenceTest creates an Expr representing a field presence test on an operand expression. + NewPresenceTest(id int64, operand Expr, field string) Expr + + // NewSelect creates an Expr representing a field selection on an operand expression. + NewSelect(id int64, operand Expr, field string) Expr + + // NewStruct creates an Expr value representing a struct literal with a given type name and a + // set of field initializers. + NewStruct(id int64, typeName string, fields []EntryExpr) Expr + + // NewStructField creates a StructField with a given field name, value, and a flag indicating + // whether the field is optionally set. + NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr + + // NewUnspecifiedExpr creates an empty expression node. + NewUnspecifiedExpr(id int64) Expr + + isExprFactory() +} + +type baseExprFactory struct{} + +// NewExprFactory creates an ExprFactory instance. +func NewExprFactory() ExprFactory { + return &baseExprFactory{} +} + +func (fac *baseExprFactory) NewCall(id int64, function string, args ...Expr) Expr { + if len(args) == 0 { + args = []Expr{} + } + return fac.newExpr( + id, + &baseCallExpr{ + function: function, + target: nilExpr, + args: args, + isMember: false, + }) +} + +func (fac *baseExprFactory) NewMemberCall(id int64, function string, target Expr, args ...Expr) Expr { + if len(args) == 0 { + args = []Expr{} + } + return fac.newExpr( + id, + &baseCallExpr{ + function: function, + target: target, + args: args, + isMember: true, + }) +} + +func (fac *baseExprFactory) NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr { + return fac.newExpr( + id, + &baseComprehensionExpr{ + iterRange: iterRange, + iterVar: iterVar, + accuVar: accuVar, + accuInit: accuInit, + loopCond: loopCond, + loopStep: loopStep, + result: result, + }) +} + +func (fac *baseExprFactory) NewIdent(id int64, name string) Expr { + return fac.newExpr(id, baseIdentExpr(name)) +} + +func (fac *baseExprFactory) NewAccuIdent(id int64) Expr { + return fac.NewIdent(id, "__result__") +} + +func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr { + return fac.newExpr(id, &baseLiteral{Val: value}) +} + +func (fac *baseExprFactory) NewList(id int64, elems []Expr, optIndices []int32) Expr { + optIndexMap := make(map[int32]struct{}, len(optIndices)) + for _, idx := range optIndices { + optIndexMap[idx] = struct{}{} + } + return fac.newExpr(id, + &baseListExpr{ + elements: elems, + optIndices: optIndices, + optIndexMap: optIndexMap, + }) +} + +func (fac *baseExprFactory) NewMap(id int64, entries []EntryExpr) Expr { + return fac.newExpr(id, &baseMapExpr{entries: entries}) +} + +func (fac *baseExprFactory) NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr { + return fac.newEntryExpr( + id, + &baseMapEntry{ + key: key, + value: value, + isOptional: isOptional, + }) +} + +func (fac *baseExprFactory) NewPresenceTest(id int64, operand Expr, field string) Expr { + return fac.newExpr( + id, + &baseSelectExpr{ + operand: operand, + field: field, + testOnly: true, + }) +} + +func (fac *baseExprFactory) NewSelect(id int64, operand Expr, field string) Expr { + return fac.newExpr( + id, + &baseSelectExpr{ + operand: operand, + field: field, + }) +} + +func (fac *baseExprFactory) NewStruct(id int64, typeName string, fields []EntryExpr) Expr { + return fac.newExpr( + id, + &baseStructExpr{ + typeName: typeName, + fields: fields, + }) +} + +func (fac *baseExprFactory) NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr { + return fac.newEntryExpr( + id, + &baseStructField{ + field: field, + value: value, + isOptional: isOptional, + }) +} + +func (fac *baseExprFactory) NewUnspecifiedExpr(id int64) Expr { + return fac.newExpr(id, nil) +} + +func (fac *baseExprFactory) CopyExpr(e Expr) Expr { + // unwrap navigable expressions to avoid unnecessary allocations during copying. + if nav, ok := e.(*navigableExprImpl); ok { + e = nav.Expr + } + switch e.Kind() { + case CallKind: + c := e.AsCall() + argsCopy := make([]Expr, len(c.Args())) + for i, arg := range c.Args() { + argsCopy[i] = fac.CopyExpr(arg) + } + if !c.IsMemberFunction() { + return fac.NewCall(e.ID(), c.FunctionName(), argsCopy...) + } + return fac.NewMemberCall(e.ID(), c.FunctionName(), fac.CopyExpr(c.Target()), argsCopy...) + case ComprehensionKind: + compre := e.AsComprehension() + return fac.NewComprehension(e.ID(), + fac.CopyExpr(compre.IterRange()), + compre.IterVar(), + compre.AccuVar(), + fac.CopyExpr(compre.AccuInit()), + fac.CopyExpr(compre.LoopCondition()), + fac.CopyExpr(compre.LoopStep()), + fac.CopyExpr(compre.Result())) + case IdentKind: + return fac.NewIdent(e.ID(), e.AsIdent()) + case ListKind: + l := e.AsList() + elemsCopy := make([]Expr, l.Size()) + for i, elem := range l.Elements() { + elemsCopy[i] = fac.CopyExpr(elem) + } + return fac.NewList(e.ID(), elemsCopy, l.OptionalIndices()) + case LiteralKind: + return fac.NewLiteral(e.ID(), e.AsLiteral()) + case MapKind: + m := e.AsMap() + entriesCopy := make([]EntryExpr, m.Size()) + for i, entry := range m.Entries() { + entriesCopy[i] = fac.CopyEntryExpr(entry) + } + return fac.NewMap(e.ID(), entriesCopy) + case SelectKind: + s := e.AsSelect() + if s.IsTestOnly() { + return fac.NewPresenceTest(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName()) + } + return fac.NewSelect(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName()) + case StructKind: + s := e.AsStruct() + fieldsCopy := make([]EntryExpr, len(s.Fields())) + for i, field := range s.Fields() { + fieldsCopy[i] = fac.CopyEntryExpr(field) + } + return fac.NewStruct(e.ID(), s.TypeName(), fieldsCopy) + default: + return fac.NewUnspecifiedExpr(e.ID()) + } +} + +func (fac *baseExprFactory) CopyEntryExpr(e EntryExpr) EntryExpr { + switch e.Kind() { + case MapEntryKind: + entry := e.AsMapEntry() + return fac.NewMapEntry(e.ID(), + fac.CopyExpr(entry.Key()), fac.CopyExpr(entry.Value()), entry.IsOptional()) + case StructFieldKind: + field := e.AsStructField() + return fac.NewStructField(e.ID(), + field.Name(), fac.CopyExpr(field.Value()), field.IsOptional()) + default: + return fac.newEntryExpr(e.ID(), nil) + } +} + +func (*baseExprFactory) isExprFactory() {} + +func (fac *baseExprFactory) newExpr(id int64, e exprKindCase) Expr { + return &expr{ + id: id, + exprKindCase: e, + } +} + +func (fac *baseExprFactory) newEntryExpr(id int64, e entryExprKindCase) EntryExpr { + return &entryExpr{ + id: id, + entryExprKindCase: e, + } +} + +var ( + defaultFactory = &baseExprFactory{} +) diff --git a/vendor/github.com/google/cel-go/common/ast/navigable.go b/vendor/github.com/google/cel-go/common/ast/navigable.go new file mode 100644 index 00000000000..f5ddf6aac66 --- /dev/null +++ b/vendor/github.com/google/cel-go/common/ast/navigable.go @@ -0,0 +1,652 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import ( + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// NavigableExpr represents the base navigable expression value with methods to inspect the +// parent and child expressions. +type NavigableExpr interface { + Expr + + // Type of the expression. + // + // If the expression is type-checked, the type check metadata is returned. If the expression + // has not been type-checked, the types.DynType value is returned. + Type() *types.Type + + // Parent returns the parent expression node, if one exists. + Parent() (NavigableExpr, bool) + + // Children returns a list of child expression nodes. + Children() []NavigableExpr + + // Depth indicates the depth in the expression tree. + // + // The root expression has depth 0. + Depth() int +} + +// NavigateAST converts an AST to a NavigableExpr +func NavigateAST(ast *AST) NavigableExpr { + return NavigateExpr(ast, ast.Expr()) +} + +// NavigateExpr creates a NavigableExpr whose type information is backed by the input AST. +// +// If the expression is already a NavigableExpr, the parent and depth information will be +// propagated on the new NavigableExpr value; otherwise, the expr value will be treated +// as though it is the root of the expression graph with a depth of 0. +func NavigateExpr(ast *AST, expr Expr) NavigableExpr { + depth := 0 + var parent NavigableExpr = nil + if nav, ok := expr.(NavigableExpr); ok { + depth = nav.Depth() + parent, _ = nav.Parent() + } + return newNavigableExpr(ast, parent, expr, depth) +} + +// ExprMatcher takes a NavigableExpr in and indicates whether the value is a match. +// +// This function type should be use with the `Match` and `MatchList` calls. +type ExprMatcher func(NavigableExpr) bool + +// ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr +// is comprised of all constant values, such as a simple literal or even list and map literal. +func ConstantValueMatcher() ExprMatcher { + return matchIsConstantValue +} + +// KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches +// the specified `kind`. +func KindMatcher(kind ExprKind) ExprMatcher { + return func(e NavigableExpr) bool { + return e.Kind() == kind + } +} + +// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose +// function name is equal to `funcName`. +func FunctionMatcher(funcName string) ExprMatcher { + return func(e NavigableExpr) bool { + if e.Kind() != CallKind { + return false + } + return e.AsCall().FunctionName() == funcName + } +} + +// AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list. +// +// Such a result would work well with subsequent MatchList calls. +func AllMatcher() ExprMatcher { + return func(NavigableExpr) bool { + return true + } +} + +// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values +// matching the input criteria in post-order (bottom up). +func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr { + matches := []NavigableExpr{} + navVisitor := &baseVisitor{ + visitExpr: func(e Expr) { + nav := e.(NavigableExpr) + if matcher(nav) { + matches = append(matches, nav) + } + }, + } + visit(expr, navVisitor, postOrder, 0, 0) + return matches +} + +// MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a +// subset of NavigableExpr values which match. +func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr { + matches := []NavigableExpr{} + navVisitor := &baseVisitor{ + visitExpr: func(e Expr) { + nav := e.(NavigableExpr) + if matcher(nav) { + matches = append(matches, nav) + } + }, + } + for _, expr := range exprs { + visit(expr, navVisitor, postOrder, 0, 1) + } + return matches +} + +// Visitor defines an object for visiting Expr and EntryExpr nodes within an expression graph. +type Visitor interface { + // VisitExpr visits the input expression. + VisitExpr(Expr) + + // VisitEntryExpr visits the input entry expression, i.e. a struct field or map entry. + VisitEntryExpr(EntryExpr) +} + +type baseVisitor struct { + visitExpr func(Expr) + visitEntryExpr func(EntryExpr) +} + +// VisitExpr visits the Expr if the internal expr visitor has been configured. +func (v *baseVisitor) VisitExpr(e Expr) { + if v.visitExpr != nil { + v.visitExpr(e) + } +} + +// VisitEntryExpr visits the entry if the internal expr entry visitor has been configured. +func (v *baseVisitor) VisitEntryExpr(e EntryExpr) { + if v.visitEntryExpr != nil { + v.visitEntryExpr(e) + } +} + +// NewExprVisitor creates a visitor which only visits expression nodes. +func NewExprVisitor(v func(Expr)) Visitor { + return &baseVisitor{ + visitExpr: v, + visitEntryExpr: nil, + } +} + +// PostOrderVisit walks the expression graph and calls the visitor in post-order (bottom-up). +func PostOrderVisit(expr Expr, visitor Visitor) { + visit(expr, visitor, postOrder, 0, 0) +} + +// PreOrderVisit walks the expression graph and calls the visitor in pre-order (top-down). +func PreOrderVisit(expr Expr, visitor Visitor) { + visit(expr, visitor, preOrder, 0, 0) +} + +type visitOrder int + +const ( + preOrder = iota + 1 + postOrder +) + +// TODO: consider exposing a way to configure a limit for the max visit depth. +// It's possible that we could want to configure this on the NewExprVisitor() +// and through MatchDescendents() / MaxID(). +func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) { + if maxDepth > 0 && depth == maxDepth { + return + } + if order == preOrder { + visitor.VisitExpr(expr) + } + switch expr.Kind() { + case CallKind: + c := expr.AsCall() + if c.IsMemberFunction() { + visit(c.Target(), visitor, order, depth+1, maxDepth) + } + for _, arg := range c.Args() { + visit(arg, visitor, order, depth+1, maxDepth) + } + case ComprehensionKind: + c := expr.AsComprehension() + visit(c.IterRange(), visitor, order, depth+1, maxDepth) + visit(c.AccuInit(), visitor, order, depth+1, maxDepth) + visit(c.LoopCondition(), visitor, order, depth+1, maxDepth) + visit(c.LoopStep(), visitor, order, depth+1, maxDepth) + visit(c.Result(), visitor, order, depth+1, maxDepth) + case ListKind: + l := expr.AsList() + for _, elem := range l.Elements() { + visit(elem, visitor, order, depth+1, maxDepth) + } + case MapKind: + m := expr.AsMap() + for _, e := range m.Entries() { + if order == preOrder { + visitor.VisitEntryExpr(e) + } + entry := e.AsMapEntry() + visit(entry.Key(), visitor, order, depth+1, maxDepth) + visit(entry.Value(), visitor, order, depth+1, maxDepth) + if order == postOrder { + visitor.VisitEntryExpr(e) + } + } + case SelectKind: + visit(expr.AsSelect().Operand(), visitor, order, depth+1, maxDepth) + case StructKind: + s := expr.AsStruct() + for _, f := range s.Fields() { + visitor.VisitEntryExpr(f) + visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth) + } + } + if order == postOrder { + visitor.VisitExpr(expr) + } +} + +func matchIsConstantValue(e NavigableExpr) bool { + if e.Kind() == LiteralKind { + return true + } + if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind { + for _, child := range e.Children() { + if !matchIsConstantValue(child) { + return false + } + } + return true + } + return false +} + +func newNavigableExpr(ast *AST, parent NavigableExpr, expr Expr, depth int) NavigableExpr { + // Reduce navigable expression nesting by unwrapping the embedded Expr value. + if nav, ok := expr.(*navigableExprImpl); ok { + expr = nav.Expr + } + nav := &navigableExprImpl{ + Expr: expr, + depth: depth, + ast: ast, + parent: parent, + createChildren: getChildFactory(expr), + } + return nav +} + +type navigableExprImpl struct { + Expr + depth int + ast *AST + parent NavigableExpr + createChildren childFactory +} + +func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) { + if nav.parent != nil { + return nav.parent, true + } + return nil, false +} + +func (nav *navigableExprImpl) ID() int64 { + return nav.Expr.ID() +} + +func (nav *navigableExprImpl) Kind() ExprKind { + return nav.Expr.Kind() +} + +func (nav *navigableExprImpl) Type() *types.Type { + return nav.ast.GetType(nav.ID()) +} + +func (nav *navigableExprImpl) Children() []NavigableExpr { + return nav.createChildren(nav) +} + +func (nav *navigableExprImpl) Depth() int { + return nav.depth +} + +func (nav *navigableExprImpl) AsCall() CallExpr { + return navigableCallImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsComprehension() ComprehensionExpr { + return navigableComprehensionImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsIdent() string { + return nav.Expr.AsIdent() +} + +func (nav *navigableExprImpl) AsList() ListExpr { + return navigableListImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsLiteral() ref.Val { + return nav.Expr.AsLiteral() +} + +func (nav *navigableExprImpl) AsMap() MapExpr { + return navigableMapImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsSelect() SelectExpr { + return navigableSelectImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsStruct() StructExpr { + return navigableStructImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) createChild(e Expr) NavigableExpr { + return newNavigableExpr(nav.ast, nav, e, nav.depth+1) +} + +func (nav *navigableExprImpl) isExpr() {} + +type navigableCallImpl struct { + *navigableExprImpl +} + +func (call navigableCallImpl) FunctionName() string { + return call.Expr.AsCall().FunctionName() +} + +func (call navigableCallImpl) IsMemberFunction() bool { + return call.Expr.AsCall().IsMemberFunction() +} + +func (call navigableCallImpl) Target() Expr { + t := call.Expr.AsCall().Target() + if t != nil { + return call.createChild(t) + } + return nil +} + +func (call navigableCallImpl) Args() []Expr { + args := call.Expr.AsCall().Args() + navArgs := make([]Expr, len(args)) + for i, a := range args { + navArgs[i] = call.createChild(a) + } + return navArgs +} + +type navigableComprehensionImpl struct { + *navigableExprImpl +} + +func (comp navigableComprehensionImpl) IterRange() Expr { + return comp.createChild(comp.Expr.AsComprehension().IterRange()) +} + +func (comp navigableComprehensionImpl) IterVar() string { + return comp.Expr.AsComprehension().IterVar() +} + +func (comp navigableComprehensionImpl) AccuVar() string { + return comp.Expr.AsComprehension().AccuVar() +} + +func (comp navigableComprehensionImpl) AccuInit() Expr { + return comp.createChild(comp.Expr.AsComprehension().AccuInit()) +} + +func (comp navigableComprehensionImpl) LoopCondition() Expr { + return comp.createChild(comp.Expr.AsComprehension().LoopCondition()) +} + +func (comp navigableComprehensionImpl) LoopStep() Expr { + return comp.createChild(comp.Expr.AsComprehension().LoopStep()) +} + +func (comp navigableComprehensionImpl) Result() Expr { + return comp.createChild(comp.Expr.AsComprehension().Result()) +} + +type navigableListImpl struct { + *navigableExprImpl +} + +func (l navigableListImpl) Elements() []Expr { + pbElems := l.Expr.AsList().Elements() + elems := make([]Expr, len(pbElems)) + for i := 0; i < len(pbElems); i++ { + elems[i] = l.createChild(pbElems[i]) + } + return elems +} + +func (l navigableListImpl) IsOptional(index int32) bool { + return l.Expr.AsList().IsOptional(index) +} + +func (l navigableListImpl) OptionalIndices() []int32 { + return l.Expr.AsList().OptionalIndices() +} + +func (l navigableListImpl) Size() int { + return l.Expr.AsList().Size() +} + +type navigableMapImpl struct { + *navigableExprImpl +} + +func (m navigableMapImpl) Entries() []EntryExpr { + mapExpr := m.Expr.AsMap() + entries := make([]EntryExpr, len(mapExpr.Entries())) + for i, e := range mapExpr.Entries() { + entry := e.AsMapEntry() + entries[i] = &entryExpr{ + id: e.ID(), + entryExprKindCase: navigableEntryImpl{ + key: m.createChild(entry.Key()), + val: m.createChild(entry.Value()), + isOpt: entry.IsOptional(), + }, + } + } + return entries +} + +func (m navigableMapImpl) Size() int { + return m.Expr.AsMap().Size() +} + +type navigableEntryImpl struct { + key NavigableExpr + val NavigableExpr + isOpt bool +} + +func (e navigableEntryImpl) Kind() EntryExprKind { + return MapEntryKind +} + +func (e navigableEntryImpl) Key() Expr { + return e.key +} + +func (e navigableEntryImpl) Value() Expr { + return e.val +} + +func (e navigableEntryImpl) IsOptional() bool { + return e.isOpt +} + +func (e navigableEntryImpl) renumberIDs(IDGenerator) {} + +func (e navigableEntryImpl) isEntryExpr() {} + +type navigableSelectImpl struct { + *navigableExprImpl +} + +func (sel navigableSelectImpl) FieldName() string { + return sel.Expr.AsSelect().FieldName() +} + +func (sel navigableSelectImpl) IsTestOnly() bool { + return sel.Expr.AsSelect().IsTestOnly() +} + +func (sel navigableSelectImpl) Operand() Expr { + return sel.createChild(sel.Expr.AsSelect().Operand()) +} + +type navigableStructImpl struct { + *navigableExprImpl +} + +func (s navigableStructImpl) TypeName() string { + return s.Expr.AsStruct().TypeName() +} + +func (s navigableStructImpl) Fields() []EntryExpr { + fieldInits := s.Expr.AsStruct().Fields() + fields := make([]EntryExpr, len(fieldInits)) + for i, f := range fieldInits { + field := f.AsStructField() + fields[i] = &entryExpr{ + id: f.ID(), + entryExprKindCase: navigableFieldImpl{ + name: field.Name(), + val: s.createChild(field.Value()), + isOpt: field.IsOptional(), + }, + } + } + return fields +} + +type navigableFieldImpl struct { + name string + val NavigableExpr + isOpt bool +} + +func (f navigableFieldImpl) Kind() EntryExprKind { + return StructFieldKind +} + +func (f navigableFieldImpl) Name() string { + return f.name +} + +func (f navigableFieldImpl) Value() Expr { + return f.val +} + +func (f navigableFieldImpl) IsOptional() bool { + return f.isOpt +} + +func (f navigableFieldImpl) renumberIDs(IDGenerator) {} + +func (f navigableFieldImpl) isEntryExpr() {} + +func getChildFactory(expr Expr) childFactory { + if expr == nil { + return noopFactory + } + switch expr.Kind() { + case LiteralKind: + return noopFactory + case IdentKind: + return noopFactory + case SelectKind: + return selectFactory + case CallKind: + return callArgFactory + case ListKind: + return listElemFactory + case MapKind: + return mapEntryFactory + case StructKind: + return structEntryFactory + case ComprehensionKind: + return comprehensionFactory + default: + return noopFactory + } +} + +type childFactory func(*navigableExprImpl) []NavigableExpr + +func noopFactory(*navigableExprImpl) []NavigableExpr { + return nil +} + +func selectFactory(nav *navigableExprImpl) []NavigableExpr { + return []NavigableExpr{nav.createChild(nav.AsSelect().Operand())} +} + +func callArgFactory(nav *navigableExprImpl) []NavigableExpr { + call := nav.Expr.AsCall() + argCount := len(call.Args()) + if call.IsMemberFunction() { + argCount++ + } + navExprs := make([]NavigableExpr, argCount) + i := 0 + if call.IsMemberFunction() { + navExprs[i] = nav.createChild(call.Target()) + i++ + } + for _, arg := range call.Args() { + navExprs[i] = nav.createChild(arg) + i++ + } + return navExprs +} + +func listElemFactory(nav *navigableExprImpl) []NavigableExpr { + l := nav.Expr.AsList() + navExprs := make([]NavigableExpr, len(l.Elements())) + for i, e := range l.Elements() { + navExprs[i] = nav.createChild(e) + } + return navExprs +} + +func structEntryFactory(nav *navigableExprImpl) []NavigableExpr { + s := nav.Expr.AsStruct() + entries := make([]NavigableExpr, len(s.Fields())) + for i, e := range s.Fields() { + f := e.AsStructField() + entries[i] = nav.createChild(f.Value()) + } + return entries +} + +func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr { + m := nav.Expr.AsMap() + entries := make([]NavigableExpr, len(m.Entries())*2) + j := 0 + for _, e := range m.Entries() { + mapEntry := e.AsMapEntry() + entries[j] = nav.createChild(mapEntry.Key()) + entries[j+1] = nav.createChild(mapEntry.Value()) + j += 2 + } + return entries +} + +func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr { + compre := nav.Expr.AsComprehension() + return []NavigableExpr{ + nav.createChild(compre.IterRange()), + nav.createChild(compre.AccuInit()), + nav.createChild(compre.LoopCondition()), + nav.createChild(compre.LoopStep()), + nav.createChild(compre.Result()), + } +} diff --git a/vendor/github.com/google/cel-go/common/containers/BUILD.bazel b/vendor/github.com/google/cel-go/common/containers/BUILD.bazel index 3f3f0788712..81197f0641c 100644 --- a/vendor/github.com/google/cel-go/common/containers/BUILD.bazel +++ b/vendor/github.com/google/cel-go/common/containers/BUILD.bazel @@ -12,7 +12,7 @@ go_library( ], importpath = "github.com/google/cel-go/common/containers", deps = [ - "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "//common/ast:go_default_library", ], ) @@ -26,6 +26,6 @@ go_test( ":go_default_library", ], deps = [ - "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "//common/ast:go_default_library", ], ) diff --git a/vendor/github.com/google/cel-go/common/containers/container.go b/vendor/github.com/google/cel-go/common/containers/container.go index d46698d3cdf..52153d4cd18 100644 --- a/vendor/github.com/google/cel-go/common/containers/container.go +++ b/vendor/github.com/google/cel-go/common/containers/container.go @@ -20,7 +20,7 @@ import ( "fmt" "strings" - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/ast" ) var ( @@ -297,19 +297,19 @@ func Name(name string) ContainerOption { // ToQualifiedName converts an expression AST into a qualified name if possible, with a boolean // 'found' value that indicates if the conversion is successful. -func ToQualifiedName(e *exprpb.Expr) (string, bool) { - switch e.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr: - id := e.GetIdentExpr() - return id.GetName(), true - case *exprpb.Expr_SelectExpr: - sel := e.GetSelectExpr() +func ToQualifiedName(e ast.Expr) (string, bool) { + switch e.Kind() { + case ast.IdentKind: + id := e.AsIdent() + return id, true + case ast.SelectKind: + sel := e.AsSelect() // Test only expressions are not valid as qualified names. - if sel.GetTestOnly() { + if sel.IsTestOnly() { return "", false } - if qual, found := ToQualifiedName(sel.GetOperand()); found { - return qual + "." + sel.GetField(), true + if qual, found := ToQualifiedName(sel.Operand()); found { + return qual + "." + sel.FieldName(), true } } return "", false diff --git a/vendor/github.com/google/cel-go/common/debug/BUILD.bazel b/vendor/github.com/google/cel-go/common/debug/BUILD.bazel index 1f029839c71..724ed340458 100644 --- a/vendor/github.com/google/cel-go/common/debug/BUILD.bazel +++ b/vendor/github.com/google/cel-go/common/debug/BUILD.bazel @@ -13,6 +13,8 @@ go_library( importpath = "github.com/google/cel-go/common/debug", deps = [ "//common:go_default_library", - "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "//common/ast:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", ], ) diff --git a/vendor/github.com/google/cel-go/common/debug/debug.go b/vendor/github.com/google/cel-go/common/debug/debug.go index 5dab156ef36..e4c01ac6ed2 100644 --- a/vendor/github.com/google/cel-go/common/debug/debug.go +++ b/vendor/github.com/google/cel-go/common/debug/debug.go @@ -22,7 +22,9 @@ import ( "strconv" "strings" - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" ) // Adorner returns debug metadata that will be tacked on to the string @@ -38,7 +40,7 @@ type Writer interface { // Buffer pushes an expression into an internal queue of expressions to // write to a string. - Buffer(e *exprpb.Expr) + Buffer(e ast.Expr) } type emptyDebugAdorner struct { @@ -51,12 +53,12 @@ func (a *emptyDebugAdorner) GetMetadata(e any) string { } // ToDebugString gives the unadorned string representation of the Expr. -func ToDebugString(e *exprpb.Expr) string { +func ToDebugString(e ast.Expr) string { return ToAdornedDebugString(e, emptyAdorner) } // ToAdornedDebugString gives the adorned string representation of the Expr. -func ToAdornedDebugString(e *exprpb.Expr, adorner Adorner) string { +func ToAdornedDebugString(e ast.Expr, adorner Adorner) string { w := newDebugWriter(adorner) w.Buffer(e) return w.String() @@ -78,49 +80,51 @@ func newDebugWriter(a Adorner) *debugWriter { } } -func (w *debugWriter) Buffer(e *exprpb.Expr) { +func (w *debugWriter) Buffer(e ast.Expr) { if e == nil { return } - switch e.ExprKind.(type) { - case *exprpb.Expr_ConstExpr: - w.append(formatLiteral(e.GetConstExpr())) - case *exprpb.Expr_IdentExpr: - w.append(e.GetIdentExpr().Name) - case *exprpb.Expr_SelectExpr: - w.appendSelect(e.GetSelectExpr()) - case *exprpb.Expr_CallExpr: - w.appendCall(e.GetCallExpr()) - case *exprpb.Expr_ListExpr: - w.appendList(e.GetListExpr()) - case *exprpb.Expr_StructExpr: - w.appendStruct(e.GetStructExpr()) - case *exprpb.Expr_ComprehensionExpr: - w.appendComprehension(e.GetComprehensionExpr()) + switch e.Kind() { + case ast.LiteralKind: + w.append(formatLiteral(e.AsLiteral())) + case ast.IdentKind: + w.append(e.AsIdent()) + case ast.SelectKind: + w.appendSelect(e.AsSelect()) + case ast.CallKind: + w.appendCall(e.AsCall()) + case ast.ListKind: + w.appendList(e.AsList()) + case ast.MapKind: + w.appendMap(e.AsMap()) + case ast.StructKind: + w.appendStruct(e.AsStruct()) + case ast.ComprehensionKind: + w.appendComprehension(e.AsComprehension()) } w.adorn(e) } -func (w *debugWriter) appendSelect(sel *exprpb.Expr_Select) { - w.Buffer(sel.GetOperand()) +func (w *debugWriter) appendSelect(sel ast.SelectExpr) { + w.Buffer(sel.Operand()) w.append(".") - w.append(sel.GetField()) - if sel.TestOnly { + w.append(sel.FieldName()) + if sel.IsTestOnly() { w.append("~test-only~") } } -func (w *debugWriter) appendCall(call *exprpb.Expr_Call) { - if call.Target != nil { - w.Buffer(call.GetTarget()) +func (w *debugWriter) appendCall(call ast.CallExpr) { + if call.IsMemberFunction() { + w.Buffer(call.Target()) w.append(".") } - w.append(call.GetFunction()) + w.append(call.FunctionName()) w.append("(") - if len(call.GetArgs()) > 0 { + if len(call.Args()) > 0 { w.addIndent() w.appendLine() - for i, arg := range call.GetArgs() { + for i, arg := range call.Args() { if i > 0 { w.append(",") w.appendLine() @@ -133,12 +137,12 @@ func (w *debugWriter) appendCall(call *exprpb.Expr_Call) { w.append(")") } -func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) { +func (w *debugWriter) appendList(list ast.ListExpr) { w.append("[") - if len(list.GetElements()) > 0 { + if len(list.Elements()) > 0 { w.appendLine() w.addIndent() - for i, elem := range list.GetElements() { + for i, elem := range list.Elements() { if i > 0 { w.append(",") w.appendLine() @@ -151,32 +155,25 @@ func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) { w.append("]") } -func (w *debugWriter) appendStruct(obj *exprpb.Expr_CreateStruct) { - if obj.MessageName != "" { - w.appendObject(obj) - } else { - w.appendMap(obj) - } -} - -func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) { - w.append(obj.GetMessageName()) +func (w *debugWriter) appendStruct(obj ast.StructExpr) { + w.append(obj.TypeName()) w.append("{") - if len(obj.GetEntries()) > 0 { + if len(obj.Fields()) > 0 { w.appendLine() w.addIndent() - for i, entry := range obj.GetEntries() { + for i, f := range obj.Fields() { + field := f.AsStructField() if i > 0 { w.append(",") w.appendLine() } - if entry.GetOptionalEntry() { + if field.IsOptional() { w.append("?") } - w.append(entry.GetFieldKey()) + w.append(field.Name()) w.append(":") - w.Buffer(entry.GetValue()) - w.adorn(entry) + w.Buffer(field.Value()) + w.adorn(f) } w.removeIndent() w.appendLine() @@ -184,23 +181,24 @@ func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) { w.append("}") } -func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) { +func (w *debugWriter) appendMap(m ast.MapExpr) { w.append("{") - if len(obj.GetEntries()) > 0 { + if m.Size() > 0 { w.appendLine() w.addIndent() - for i, entry := range obj.GetEntries() { + for i, e := range m.Entries() { + entry := e.AsMapEntry() if i > 0 { w.append(",") w.appendLine() } - if entry.GetOptionalEntry() { + if entry.IsOptional() { w.append("?") } - w.Buffer(entry.GetMapKey()) + w.Buffer(entry.Key()) w.append(":") - w.Buffer(entry.GetValue()) - w.adorn(entry) + w.Buffer(entry.Value()) + w.adorn(e) } w.removeIndent() w.appendLine() @@ -208,62 +206,62 @@ func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) { w.append("}") } -func (w *debugWriter) appendComprehension(comprehension *exprpb.Expr_Comprehension) { +func (w *debugWriter) appendComprehension(comprehension ast.ComprehensionExpr) { w.append("__comprehension__(") w.addIndent() w.appendLine() w.append("// Variable") w.appendLine() - w.append(comprehension.GetIterVar()) + w.append(comprehension.IterVar()) w.append(",") w.appendLine() w.append("// Target") w.appendLine() - w.Buffer(comprehension.GetIterRange()) + w.Buffer(comprehension.IterRange()) w.append(",") w.appendLine() w.append("// Accumulator") w.appendLine() - w.append(comprehension.GetAccuVar()) + w.append(comprehension.AccuVar()) w.append(",") w.appendLine() w.append("// Init") w.appendLine() - w.Buffer(comprehension.GetAccuInit()) + w.Buffer(comprehension.AccuInit()) w.append(",") w.appendLine() w.append("// LoopCondition") w.appendLine() - w.Buffer(comprehension.GetLoopCondition()) + w.Buffer(comprehension.LoopCondition()) w.append(",") w.appendLine() w.append("// LoopStep") w.appendLine() - w.Buffer(comprehension.GetLoopStep()) + w.Buffer(comprehension.LoopStep()) w.append(",") w.appendLine() w.append("// Result") w.appendLine() - w.Buffer(comprehension.GetResult()) + w.Buffer(comprehension.Result()) w.append(")") w.removeIndent() } -func formatLiteral(c *exprpb.Constant) string { - switch c.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - return fmt.Sprintf("%t", c.GetBoolValue()) - case *exprpb.Constant_BytesValue: - return fmt.Sprintf("b\"%s\"", string(c.GetBytesValue())) - case *exprpb.Constant_DoubleValue: - return fmt.Sprintf("%v", c.GetDoubleValue()) - case *exprpb.Constant_Int64Value: - return fmt.Sprintf("%d", c.GetInt64Value()) - case *exprpb.Constant_StringValue: - return strconv.Quote(c.GetStringValue()) - case *exprpb.Constant_Uint64Value: - return fmt.Sprintf("%du", c.GetUint64Value()) - case *exprpb.Constant_NullValue: +func formatLiteral(c ref.Val) string { + switch v := c.(type) { + case types.Bool: + return fmt.Sprintf("%t", v) + case types.Bytes: + return fmt.Sprintf("b\"%s\"", string(v)) + case types.Double: + return fmt.Sprintf("%v", float64(v)) + case types.Int: + return fmt.Sprintf("%d", int64(v)) + case types.String: + return strconv.Quote(string(v)) + case types.Uint: + return fmt.Sprintf("%du", uint64(v)) + case types.Null: return "null" default: panic("Unknown constant type") diff --git a/vendor/github.com/google/cel-go/common/errors.go b/vendor/github.com/google/cel-go/common/errors.go index 63919714ea7..25adc73d8e0 100644 --- a/vendor/github.com/google/cel-go/common/errors.go +++ b/vendor/github.com/google/cel-go/common/errors.go @@ -64,7 +64,7 @@ func (e *Errors) GetErrors() []*Error { // Append creates a new Errors object with the current and input errors. func (e *Errors) Append(errs []*Error) *Errors { return &Errors{ - errors: append(e.errors, errs...), + errors: append(e.errors[:], errs...), source: e.source, numErrors: e.numErrors + len(errs), maxErrorsToReport: e.maxErrorsToReport, diff --git a/vendor/github.com/google/cel-go/common/types/provider.go b/vendor/github.com/google/cel-go/common/types/provider.go index e80b4622e24..d301aa38a15 100644 --- a/vendor/github.com/google/cel-go/common/types/provider.go +++ b/vendor/github.com/google/cel-go/common/types/provider.go @@ -54,6 +54,10 @@ type Provider interface { // Returns false if not found. FindStructType(structType string) (*Type, bool) + // FindStructFieldNames returns thet field names associated with the type, if the type + // is found. + FindStructFieldNames(structType string) ([]string, bool) + // FieldStructFieldType returns the field type for a checked type value. Returns // false if the field could not be found. FindStructFieldType(structType, fieldName string) (*FieldType, bool) @@ -154,7 +158,7 @@ func (p *Registry) EnumValue(enumName string) ref.Val { return Int(enumVal.Value()) } -// FieldFieldType returns the field type for a checked type value. Returns false if +// FindFieldType returns the field type for a checked type value. Returns false if // the field could not be found. // // Deprecated: use FindStructFieldType @@ -173,7 +177,24 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType, GetFrom: field.GetFrom}, true } -// FieldStructFieldType returns the field type for a checked type value. Returns +// FindStructFieldNames returns the set of field names for the given struct type, +// if the type exists in the registry. +func (p *Registry) FindStructFieldNames(structType string) ([]string, bool) { + msgType, found := p.pbdb.DescribeType(structType) + if !found { + return []string{}, false + } + fieldMap := msgType.FieldMap() + fields := make([]string, len(fieldMap)) + idx := 0 + for f := range fieldMap { + fields[idx] = f + idx++ + } + return fields, true +} + +// FindStructFieldType returns the field type for a checked type value. Returns // false if the field could not be found. func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) { msgType, found := p.pbdb.DescribeType(structType) diff --git a/vendor/github.com/google/cel-go/interpreter/BUILD.bazel b/vendor/github.com/google/cel-go/interpreter/BUILD.bazel index 3a5219eb5f6..220e23d4752 100644 --- a/vendor/github.com/google/cel-go/interpreter/BUILD.bazel +++ b/vendor/github.com/google/cel-go/interpreter/BUILD.bazel @@ -14,7 +14,6 @@ go_library( "decorators.go", "dispatcher.go", "evalstate.go", - "formatting.go", "interpretable.go", "interpreter.go", "optimizations.go", diff --git a/vendor/github.com/google/cel-go/interpreter/formatting.go b/vendor/github.com/google/cel-go/interpreter/formatting.go deleted file mode 100644 index e3f7533745e..00000000000 --- a/vendor/github.com/google/cel-go/interpreter/formatting.go +++ /dev/null @@ -1,383 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package interpreter - -import ( - "errors" - "fmt" - "strconv" - "strings" - "unicode" - - "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" -) - -type typeVerifier func(int64, ...ref.Type) (bool, error) - -// InterpolateFormattedString checks the syntax and cardinality of any string.format calls present in the expression and reports -// any errors at compile time. -func InterpolateFormattedString(verifier typeVerifier) InterpretableDecorator { - return func(inter Interpretable) (Interpretable, error) { - call, ok := inter.(InterpretableCall) - if !ok { - return inter, nil - } - if call.OverloadID() != "string_format" { - return inter, nil - } - args := call.Args() - if len(args) != 2 { - return nil, fmt.Errorf("wrong number of arguments to string.format (expected 2, got %d)", len(args)) - } - fmtStrInter, ok := args[0].(InterpretableConst) - if !ok { - return inter, nil - } - var fmtArgsInter InterpretableConstructor - fmtArgsInter, ok = args[1].(InterpretableConstructor) - if !ok { - return inter, nil - } - if fmtArgsInter.Type() != types.ListType { - // don't necessarily return an error since the list may be DynType - return inter, nil - } - formatStr := fmtStrInter.Value().Value().(string) - initVals := fmtArgsInter.InitVals() - - formatCheck := &formatCheck{ - args: initVals, - verifier: verifier, - } - // use a placeholder locale, since locale doesn't affect syntax - _, err := ParseFormatString(formatStr, formatCheck, formatCheck, "en_US") - if err != nil { - return nil, err - } - seenArgs := formatCheck.argsRequested - if len(initVals) > seenArgs { - return nil, fmt.Errorf("too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(initVals)) - } - return inter, nil - } -} - -type formatCheck struct { - args []Interpretable - argsRequested int - curArgIndex int64 - enableCheckArgTypes bool - verifier typeVerifier -} - -func (c *formatCheck) String(arg ref.Val, locale string) (string, error) { - valid, err := verifyString(c.args[c.curArgIndex], c.verifier) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps") - } - return "", nil -} - -func (c *formatCheck) Decimal(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.IntType, types.UintType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("integer clause can only be used on integers") - } - return "", nil -} - -func (c *formatCheck) Fixed(precision *int) func(ref.Val, string) (string, error) { - return func(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - // we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values - valid, err := c.verifier(id, types.DoubleType, types.StringType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("fixed-point clause can only be used on doubles") - } - return "", nil - } -} - -func (c *formatCheck) Scientific(precision *int) func(ref.Val, string) (string, error) { - return func(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.DoubleType, types.StringType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("scientific clause can only be used on doubles") - } - return "", nil - } -} - -func (c *formatCheck) Binary(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.IntType, types.UintType, types.BoolType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("only integers and bools can be formatted as binary") - } - return "", nil -} - -func (c *formatCheck) Hex(useUpper bool) func(ref.Val, string) (string, error) { - return func(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.IntType, types.UintType, types.StringType, types.BytesType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("only integers, byte buffers, and strings can be formatted as hex") - } - return "", nil - } -} - -func (c *formatCheck) Octal(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.IntType, types.UintType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("octal clause can only be used on integers") - } - return "", nil -} - -func (c *formatCheck) Arg(index int64) (ref.Val, error) { - c.argsRequested++ - c.curArgIndex = index - // return a dummy value - this is immediately passed to back to us - // through one of the FormatCallback functions, so anything will do - return types.Int(0), nil -} - -func (c *formatCheck) ArgSize() int64 { - return int64(len(c.args)) -} - -func verifyString(sub Interpretable, verifier typeVerifier) (bool, error) { - subVerified, err := verifier(sub.ID(), - types.ListType, types.MapType, types.IntType, types.UintType, types.DoubleType, - types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType, types.NullType) - if err != nil { - return false, err - } - if !subVerified { - return false, nil - } - con, ok := sub.(InterpretableConstructor) - if ok { - members := con.InitVals() - for _, m := range members { - // recursively verify if we're dealing with a list/map - verified, err := verifyString(m, verifier) - if err != nil { - return false, err - } - if !verified { - return false, nil - } - } - } - return true, nil - -} - -// FormatStringInterpolator is an interface that allows user-defined behavior -// for formatting clause implementations, as well as argument retrieval. -// Each function is expected to support the appropriate types as laid out in -// the string.format documentation, and to return an error if given an inappropriate type. -type FormatStringInterpolator interface { - // String takes a ref.Val and a string representing the current locale identifier - // and returns the Val formatted as a string, or an error if one occurred. - String(ref.Val, string) (string, error) - - // Decimal takes a ref.Val and a string representing the current locale identifier - // and returns the Val formatted as a decimal integer, or an error if one occurred. - Decimal(ref.Val, string) (string, error) - - // Fixed takes an int pointer representing precision (or nil if none was given) and - // returns a function operating in a similar manner to String and Decimal, taking a - // ref.Val and locale and returning the appropriate string. A closure is returned - // so precision can be set without needing an additional function call/configuration. - Fixed(*int) func(ref.Val, string) (string, error) - - // Scientific functions identically to Fixed, except the string returned from the closure - // is expected to be in scientific notation. - Scientific(*int) func(ref.Val, string) (string, error) - - // Binary takes a ref.Val and a string representing the current locale identifier - // and returns the Val formatted as a binary integer, or an error if one occurred. - Binary(ref.Val, string) (string, error) - - // Hex takes a boolean that, if true, indicates the hex string output by the returned - // closure should use uppercase letters for A-F. - Hex(bool) func(ref.Val, string) (string, error) - - // Octal takes a ref.Val and a string representing the current locale identifier and - // returns the Val formatted in octal, or an error if one occurred. - Octal(ref.Val, string) (string, error) -} - -// FormatList is an interface that allows user-defined list-like datatypes to be used -// for formatting clause implementations. -type FormatList interface { - // Arg returns the ref.Val at the given index, or an error if one occurred. - Arg(int64) (ref.Val, error) - // ArgSize returns the length of the argument list. - ArgSize() int64 -} - -type clauseImpl func(ref.Val, string) (string, error) - -// ParseFormatString formats a string according to the string.format syntax, taking the clause implementations -// from the provided FormatCallback and the args from the given FormatList. -func ParseFormatString(formatStr string, callback FormatStringInterpolator, list FormatList, locale string) (string, error) { - i := 0 - argIndex := 0 - var builtStr strings.Builder - for i < len(formatStr) { - if formatStr[i] == '%' { - if i+1 < len(formatStr) && formatStr[i+1] == '%' { - err := builtStr.WriteByte('%') - if err != nil { - return "", fmt.Errorf("error writing format string: %w", err) - } - i += 2 - continue - } else { - argAny, err := list.Arg(int64(argIndex)) - if err != nil { - return "", err - } - if i+1 >= len(formatStr) { - return "", errors.New("unexpected end of string") - } - if int64(argIndex) >= list.ArgSize() { - return "", fmt.Errorf("index %d out of range", argIndex) - } - numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale) - if refErr != nil { - return "", refErr - } - _, err = builtStr.WriteString(val) - if err != nil { - return "", fmt.Errorf("error writing format string: %w", err) - } - i += numRead - argIndex++ - } - } else { - err := builtStr.WriteByte(formatStr[i]) - if err != nil { - return "", fmt.Errorf("error writing format string: %w", err) - } - i++ - } - } - return builtStr.String(), nil -} - -// parseAndFormatClause parses the format clause at the start of the given string with val, and returns -// how many characters were consumed and the substituted string form of val, or an error if one occurred. -func parseAndFormatClause(formatStr string, val ref.Val, callback FormatStringInterpolator, list FormatList, locale string) (int, string, error) { - i := 1 - read, formatter, err := parseFormattingClause(formatStr[i:], callback) - i += read - if err != nil { - return -1, "", fmt.Errorf("could not parse formatting clause: %s", err) - } - - valStr, err := formatter(val, locale) - if err != nil { - return -1, "", fmt.Errorf("error during formatting: %s", err) - } - return i, valStr, nil -} - -func parseFormattingClause(formatStr string, callback FormatStringInterpolator) (int, clauseImpl, error) { - i := 0 - read, precision, err := parsePrecision(formatStr[i:]) - i += read - if err != nil { - return -1, nil, fmt.Errorf("error while parsing precision: %w", err) - } - r := rune(formatStr[i]) - i++ - switch r { - case 's': - return i, callback.String, nil - case 'd': - return i, callback.Decimal, nil - case 'f': - return i, callback.Fixed(precision), nil - case 'e': - return i, callback.Scientific(precision), nil - case 'b': - return i, callback.Binary, nil - case 'x', 'X': - return i, callback.Hex(unicode.IsUpper(r)), nil - case 'o': - return i, callback.Octal, nil - default: - return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r) - } -} - -func parsePrecision(formatStr string) (int, *int, error) { - i := 0 - if formatStr[i] != '.' { - return i, nil, nil - } - i++ - var buffer strings.Builder - for { - if i >= len(formatStr) { - return -1, nil, errors.New("could not find end of precision specifier") - } - if !isASCIIDigit(rune(formatStr[i])) { - break - } - buffer.WriteByte(formatStr[i]) - i++ - } - precision, err := strconv.Atoi(buffer.String()) - if err != nil { - return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err) - } - return i, &precision, nil -} - -func isASCIIDigit(r rune) bool { - return r <= unicode.MaxASCII && unicode.IsDigit(r) -} diff --git a/vendor/github.com/google/cel-go/interpreter/interpreter.go b/vendor/github.com/google/cel-go/interpreter/interpreter.go index 00fc74732c6..0aca74d88b9 100644 --- a/vendor/github.com/google/cel-go/interpreter/interpreter.go +++ b/vendor/github.com/google/cel-go/interpreter/interpreter.go @@ -22,19 +22,13 @@ import ( "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // Interpreter generates a new Interpretable from a checked or unchecked expression. type Interpreter interface { // NewInterpretable creates an Interpretable from a checked expression and an // optional list of InterpretableDecorator values. - NewInterpretable(checked *ast.CheckedAST, decorators ...InterpretableDecorator) (Interpretable, error) - - // NewUncheckedInterpretable returns an Interpretable from a parsed expression - // and an optional list of InterpretableDecorator values. - NewUncheckedInterpretable(expr *exprpb.Expr, decorators ...InterpretableDecorator) (Interpretable, error) + NewInterpretable(exprAST *ast.AST, decorators ...InterpretableDecorator) (Interpretable, error) } // EvalObserver is a functional interface that accepts an expression id and an observed value. @@ -177,7 +171,7 @@ func NewInterpreter(dispatcher Dispatcher, // NewIntepretable implements the Interpreter interface method. func (i *exprInterpreter) NewInterpretable( - checked *ast.CheckedAST, + checked *ast.AST, decorators ...InterpretableDecorator) (Interpretable, error) { p := newPlanner( i.dispatcher, @@ -187,19 +181,5 @@ func (i *exprInterpreter) NewInterpretable( i.container, checked, decorators...) - return p.Plan(checked.Expr) -} - -// NewUncheckedIntepretable implements the Interpreter interface method. -func (i *exprInterpreter) NewUncheckedInterpretable( - expr *exprpb.Expr, - decorators ...InterpretableDecorator) (Interpretable, error) { - p := newUncheckedPlanner( - i.dispatcher, - i.provider, - i.adapter, - i.attrFactory, - i.container, - decorators...) - return p.Plan(expr) + return p.Plan(checked.Expr()) } diff --git a/vendor/github.com/google/cel-go/interpreter/planner.go b/vendor/github.com/google/cel-go/interpreter/planner.go index 757cd080e5c..cf371f95d51 100644 --- a/vendor/github.com/google/cel-go/interpreter/planner.go +++ b/vendor/github.com/google/cel-go/interpreter/planner.go @@ -23,15 +23,12 @@ import ( "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // interpretablePlanner creates an Interpretable evaluation plan from a proto Expr value. type interpretablePlanner interface { // Plan generates an Interpretable value (or error) from the input proto Expr. - Plan(expr *exprpb.Expr) (Interpretable, error) + Plan(expr ast.Expr) (Interpretable, error) } // newPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider, @@ -43,28 +40,7 @@ func newPlanner(disp Dispatcher, adapter types.Adapter, attrFactory AttributeFactory, cont *containers.Container, - checked *ast.CheckedAST, - decorators ...InterpretableDecorator) interpretablePlanner { - return &planner{ - disp: disp, - provider: provider, - adapter: adapter, - attrFactory: attrFactory, - container: cont, - refMap: checked.ReferenceMap, - typeMap: checked.TypeMap, - decorators: decorators, - } -} - -// newUncheckedPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider, -// TypeAdapter, and Container to resolve functions and types at plan time. Namespaces present in -// Select expressions are resolved lazily at evaluation time. -func newUncheckedPlanner(disp Dispatcher, - provider types.Provider, - adapter types.Adapter, - attrFactory AttributeFactory, - cont *containers.Container, + exprAST *ast.AST, decorators ...InterpretableDecorator) interpretablePlanner { return &planner{ disp: disp, @@ -72,8 +48,8 @@ func newUncheckedPlanner(disp Dispatcher, adapter: adapter, attrFactory: attrFactory, container: cont, - refMap: make(map[int64]*ast.ReferenceInfo), - typeMap: make(map[int64]*types.Type), + refMap: exprAST.ReferenceMap(), + typeMap: exprAST.TypeMap(), decorators: decorators, } } @@ -95,22 +71,24 @@ type planner struct { // useful for layering functionality into the evaluation that is not natively understood by CEL, // such as state-tracking, expression re-write, and possibly efficient thread-safe memoization of // repeated expressions. -func (p *planner) Plan(expr *exprpb.Expr) (Interpretable, error) { - switch expr.GetExprKind().(type) { - case *exprpb.Expr_CallExpr: +func (p *planner) Plan(expr ast.Expr) (Interpretable, error) { + switch expr.Kind() { + case ast.CallKind: return p.decorate(p.planCall(expr)) - case *exprpb.Expr_IdentExpr: + case ast.IdentKind: return p.decorate(p.planIdent(expr)) - case *exprpb.Expr_SelectExpr: + case ast.LiteralKind: + return p.decorate(p.planConst(expr)) + case ast.SelectKind: return p.decorate(p.planSelect(expr)) - case *exprpb.Expr_ListExpr: + case ast.ListKind: return p.decorate(p.planCreateList(expr)) - case *exprpb.Expr_StructExpr: + case ast.MapKind: + return p.decorate(p.planCreateMap(expr)) + case ast.StructKind: return p.decorate(p.planCreateStruct(expr)) - case *exprpb.Expr_ComprehensionExpr: + case ast.ComprehensionKind: return p.decorate(p.planComprehension(expr)) - case *exprpb.Expr_ConstExpr: - return p.decorate(p.planConst(expr)) } return nil, fmt.Errorf("unsupported expr: %v", expr) } @@ -132,16 +110,16 @@ func (p *planner) decorate(i Interpretable, err error) (Interpretable, error) { } // planIdent creates an Interpretable that resolves an identifier from an Activation. -func (p *planner) planIdent(expr *exprpb.Expr) (Interpretable, error) { +func (p *planner) planIdent(expr ast.Expr) (Interpretable, error) { // Establish whether the identifier is in the reference map. - if identRef, found := p.refMap[expr.GetId()]; found { - return p.planCheckedIdent(expr.GetId(), identRef) + if identRef, found := p.refMap[expr.ID()]; found { + return p.planCheckedIdent(expr.ID(), identRef) } // Create the possible attribute list for the unresolved reference. - ident := expr.GetIdentExpr() + ident := expr.AsIdent() return &evalAttr{ adapter: p.adapter, - attr: p.attrFactory.MaybeAttribute(expr.GetId(), ident.Name), + attr: p.attrFactory.MaybeAttribute(expr.ID(), ident), }, nil } @@ -174,20 +152,20 @@ func (p *planner) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Inter // a) selects a field from a map or proto. // b) creates a field presence test for a select within a has() macro. // c) resolves the select expression to a namespaced identifier. -func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) { +func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) { // If the Select id appears in the reference map from the CheckedExpr proto then it is either // a namespaced identifier or enum value. - if identRef, found := p.refMap[expr.GetId()]; found { - return p.planCheckedIdent(expr.GetId(), identRef) + if identRef, found := p.refMap[expr.ID()]; found { + return p.planCheckedIdent(expr.ID(), identRef) } - sel := expr.GetSelectExpr() + sel := expr.AsSelect() // Plan the operand evaluation. - op, err := p.Plan(sel.GetOperand()) + op, err := p.Plan(sel.Operand()) if err != nil { return nil, err } - opType := p.typeMap[sel.GetOperand().GetId()] + opType := p.typeMap[sel.Operand().ID()] // If the Select was marked TestOnly, this is a presence test. // @@ -211,14 +189,14 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) { } // Build a qualifier for the attribute. - qual, err := p.attrFactory.NewQualifier(opType, expr.GetId(), sel.GetField(), false) + qual, err := p.attrFactory.NewQualifier(opType, expr.ID(), sel.FieldName(), false) if err != nil { return nil, err } // Modify the attribute to be test-only. - if sel.GetTestOnly() { + if sel.IsTestOnly() { attr = &evalTestOnly{ - id: expr.GetId(), + id: expr.ID(), InterpretableAttribute: attr, } } @@ -230,10 +208,10 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) { // planCall creates a callable Interpretable while specializing for common functions and invocation // patterns. Specifically, conditional operators &&, ||, ?:, and (in)equality functions result in // optimized Interpretable values. -func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) { - call := expr.GetCallExpr() +func (p *planner) planCall(expr ast.Expr) (Interpretable, error) { + call := expr.AsCall() target, fnName, oName := p.resolveFunction(expr) - argCount := len(call.GetArgs()) + argCount := len(call.Args()) var offset int if target != nil { argCount++ @@ -248,7 +226,7 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) { } args[0] = arg } - for i, argExpr := range call.GetArgs() { + for i, argExpr := range call.Args() { arg, err := p.Plan(argExpr) if err != nil { return nil, err @@ -307,7 +285,7 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) { } // planCallZero generates a zero-arity callable Interpretable. -func (p *planner) planCallZero(expr *exprpb.Expr, +func (p *planner) planCallZero(expr ast.Expr, function string, overload string, impl *functions.Overload) (Interpretable, error) { @@ -315,7 +293,7 @@ func (p *planner) planCallZero(expr *exprpb.Expr, return nil, fmt.Errorf("no such overload: %s()", function) } return &evalZeroArity{ - id: expr.GetId(), + id: expr.ID(), function: function, overload: overload, impl: impl.Function, @@ -323,7 +301,7 @@ func (p *planner) planCallZero(expr *exprpb.Expr, } // planCallUnary generates a unary callable Interpretable. -func (p *planner) planCallUnary(expr *exprpb.Expr, +func (p *planner) planCallUnary(expr ast.Expr, function string, overload string, impl *functions.Overload, @@ -340,7 +318,7 @@ func (p *planner) planCallUnary(expr *exprpb.Expr, nonStrict = impl.NonStrict } return &evalUnary{ - id: expr.GetId(), + id: expr.ID(), function: function, overload: overload, arg: args[0], @@ -351,7 +329,7 @@ func (p *planner) planCallUnary(expr *exprpb.Expr, } // planCallBinary generates a binary callable Interpretable. -func (p *planner) planCallBinary(expr *exprpb.Expr, +func (p *planner) planCallBinary(expr ast.Expr, function string, overload string, impl *functions.Overload, @@ -368,7 +346,7 @@ func (p *planner) planCallBinary(expr *exprpb.Expr, nonStrict = impl.NonStrict } return &evalBinary{ - id: expr.GetId(), + id: expr.ID(), function: function, overload: overload, lhs: args[0], @@ -380,7 +358,7 @@ func (p *planner) planCallBinary(expr *exprpb.Expr, } // planCallVarArgs generates a variable argument callable Interpretable. -func (p *planner) planCallVarArgs(expr *exprpb.Expr, +func (p *planner) planCallVarArgs(expr ast.Expr, function string, overload string, impl *functions.Overload, @@ -397,7 +375,7 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr, nonStrict = impl.NonStrict } return &evalVarArgs{ - id: expr.GetId(), + id: expr.ID(), function: function, overload: overload, args: args, @@ -408,41 +386,41 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr, } // planCallEqual generates an equals (==) Interpretable. -func (p *planner) planCallEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalEq{ - id: expr.GetId(), + id: expr.ID(), lhs: args[0], rhs: args[1], }, nil } // planCallNotEqual generates a not equals (!=) Interpretable. -func (p *planner) planCallNotEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallNotEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalNe{ - id: expr.GetId(), + id: expr.ID(), lhs: args[0], rhs: args[1], }, nil } // planCallLogicalAnd generates a logical and (&&) Interpretable. -func (p *planner) planCallLogicalAnd(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalAnd{ - id: expr.GetId(), + id: expr.ID(), terms: args, }, nil } // planCallLogicalOr generates a logical or (||) Interpretable. -func (p *planner) planCallLogicalOr(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallLogicalOr(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalOr{ - id: expr.GetId(), + id: expr.ID(), terms: args, }, nil } // planCallConditional generates a conditional / ternary (c ? t : f) Interpretable. -func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallConditional(expr ast.Expr, args []Interpretable) (Interpretable, error) { cond := args[0] t := args[1] var tAttr Attribute @@ -464,13 +442,13 @@ func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) ( return &evalAttr{ adapter: p.adapter, - attr: p.attrFactory.ConditionalAttribute(expr.GetId(), cond, tAttr, fAttr), + attr: p.attrFactory.ConditionalAttribute(expr.ID(), cond, tAttr, fAttr), }, nil } // planCallIndex either extends an attribute with the argument to the index operation, or creates // a relative attribute based on the return of a function call or operation. -func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optional bool) (Interpretable, error) { +func (p *planner) planCallIndex(expr ast.Expr, args []Interpretable, optional bool) (Interpretable, error) { op := args[0] ind := args[1] opType := p.typeMap[op.ID()] @@ -489,11 +467,11 @@ func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optiona var qual Qualifier switch ind := ind.(type) { case InterpretableConst: - qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind.Value(), optional) + qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind.Value(), optional) case InterpretableAttribute: - qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind, optional) + qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind, optional) default: - qual, err = p.relativeAttr(expr.GetId(), ind, optional) + qual, err = p.relativeAttr(expr.ID(), ind, optional) } if err != nil { return nil, err @@ -505,10 +483,10 @@ func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optiona } // planCreateList generates a list construction Interpretable. -func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) { - list := expr.GetListExpr() - optionalIndices := list.GetOptionalIndices() - elements := list.GetElements() +func (p *planner) planCreateList(expr ast.Expr) (Interpretable, error) { + list := expr.AsList() + optionalIndices := list.OptionalIndices() + elements := list.Elements() optionals := make([]bool, len(elements)) for _, index := range optionalIndices { if index < 0 || index >= int32(len(elements)) { @@ -525,7 +503,7 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) { elems[i] = elemVal } return &evalList{ - id: expr.GetId(), + id: expr.ID(), elems: elems, optionals: optionals, hasOptionals: len(optionals) != 0, @@ -534,31 +512,29 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) { } // planCreateStruct generates a map or object construction Interpretable. -func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) { - str := expr.GetStructExpr() - if len(str.MessageName) != 0 { - return p.planCreateObj(expr) - } - entries := str.GetEntries() +func (p *planner) planCreateMap(expr ast.Expr) (Interpretable, error) { + m := expr.AsMap() + entries := m.Entries() optionals := make([]bool, len(entries)) keys := make([]Interpretable, len(entries)) vals := make([]Interpretable, len(entries)) - for i, entry := range entries { - keyVal, err := p.Plan(entry.GetMapKey()) + for i, e := range entries { + entry := e.AsMapEntry() + keyVal, err := p.Plan(entry.Key()) if err != nil { return nil, err } keys[i] = keyVal - valVal, err := p.Plan(entry.GetValue()) + valVal, err := p.Plan(entry.Value()) if err != nil { return nil, err } vals[i] = valVal - optionals[i] = entry.GetOptionalEntry() + optionals[i] = entry.IsOptional() } return &evalMap{ - id: expr.GetId(), + id: expr.ID(), keys: keys, vals: vals, optionals: optionals, @@ -568,27 +544,28 @@ func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) { } // planCreateObj generates an object construction Interpretable. -func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) { - obj := expr.GetStructExpr() - typeName, defined := p.resolveTypeName(obj.GetMessageName()) +func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) { + obj := expr.AsStruct() + typeName, defined := p.resolveTypeName(obj.TypeName()) if !defined { - return nil, fmt.Errorf("unknown type: %s", obj.GetMessageName()) - } - entries := obj.GetEntries() - optionals := make([]bool, len(entries)) - fields := make([]string, len(entries)) - vals := make([]Interpretable, len(entries)) - for i, entry := range entries { - fields[i] = entry.GetFieldKey() - val, err := p.Plan(entry.GetValue()) + return nil, fmt.Errorf("unknown type: %s", obj.TypeName()) + } + objFields := obj.Fields() + optionals := make([]bool, len(objFields)) + fields := make([]string, len(objFields)) + vals := make([]Interpretable, len(objFields)) + for i, f := range objFields { + field := f.AsStructField() + fields[i] = field.Name() + val, err := p.Plan(field.Value()) if err != nil { return nil, err } vals[i] = val - optionals[i] = entry.GetOptionalEntry() + optionals[i] = field.IsOptional() } return &evalObj{ - id: expr.GetId(), + id: expr.ID(), typeName: typeName, fields: fields, vals: vals, @@ -599,33 +576,33 @@ func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) { } // planComprehension generates an Interpretable fold operation. -func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) { - fold := expr.GetComprehensionExpr() - accu, err := p.Plan(fold.GetAccuInit()) +func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) { + fold := expr.AsComprehension() + accu, err := p.Plan(fold.AccuInit()) if err != nil { return nil, err } - iterRange, err := p.Plan(fold.GetIterRange()) + iterRange, err := p.Plan(fold.IterRange()) if err != nil { return nil, err } - cond, err := p.Plan(fold.GetLoopCondition()) + cond, err := p.Plan(fold.LoopCondition()) if err != nil { return nil, err } - step, err := p.Plan(fold.GetLoopStep()) + step, err := p.Plan(fold.LoopStep()) if err != nil { return nil, err } - result, err := p.Plan(fold.GetResult()) + result, err := p.Plan(fold.Result()) if err != nil { return nil, err } return &evalFold{ - id: expr.GetId(), - accuVar: fold.AccuVar, + id: expr.ID(), + accuVar: fold.AccuVar(), accu: accu, - iterVar: fold.IterVar, + iterVar: fold.IterVar(), iterRange: iterRange, cond: cond, step: step, @@ -635,37 +612,8 @@ func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) { } // planConst generates a constant valued Interpretable. -func (p *planner) planConst(expr *exprpb.Expr) (Interpretable, error) { - val, err := p.constValue(expr.GetConstExpr()) - if err != nil { - return nil, err - } - return NewConstValue(expr.GetId(), val), nil -} - -// constValue converts a proto Constant value to a ref.Val. -func (p *planner) constValue(c *exprpb.Constant) (ref.Val, error) { - switch c.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - return p.adapter.NativeToValue(c.GetBoolValue()), nil - case *exprpb.Constant_BytesValue: - return p.adapter.NativeToValue(c.GetBytesValue()), nil - case *exprpb.Constant_DoubleValue: - return p.adapter.NativeToValue(c.GetDoubleValue()), nil - case *exprpb.Constant_DurationValue: - return p.adapter.NativeToValue(c.GetDurationValue().AsDuration()), nil - case *exprpb.Constant_Int64Value: - return p.adapter.NativeToValue(c.GetInt64Value()), nil - case *exprpb.Constant_NullValue: - return p.adapter.NativeToValue(c.GetNullValue()), nil - case *exprpb.Constant_StringValue: - return p.adapter.NativeToValue(c.GetStringValue()), nil - case *exprpb.Constant_TimestampValue: - return p.adapter.NativeToValue(c.GetTimestampValue().AsTime()), nil - case *exprpb.Constant_Uint64Value: - return p.adapter.NativeToValue(c.GetUint64Value()), nil - } - return nil, fmt.Errorf("unknown constant type: %v", c) +func (p *planner) planConst(expr ast.Expr) (Interpretable, error) { + return NewConstValue(expr.ID(), expr.AsLiteral()), nil } // resolveTypeName takes a qualified string constructed at parse time, applies the proto @@ -687,17 +635,20 @@ func (p *planner) resolveTypeName(typeName string) (string, bool) { // - The target expression may only consist of ident and select expressions. // - The function is declared in the environment using its fully-qualified name. // - The fully-qualified function name matches the string serialized target value. -func (p *planner) resolveFunction(expr *exprpb.Expr) (*exprpb.Expr, string, string) { +func (p *planner) resolveFunction(expr ast.Expr) (ast.Expr, string, string) { // Note: similar logic exists within the `checker/checker.go`. If making changes here // please consider the impact on checker.go and consolidate implementations or mirror code // as appropriate. - call := expr.GetCallExpr() - target := call.GetTarget() - fnName := call.GetFunction() + call := expr.AsCall() + var target ast.Expr = nil + if call.IsMemberFunction() { + target = call.Target() + } + fnName := call.FunctionName() // Checked expressions always have a reference map entry, and _should_ have the fully qualified // function name as the fnName value. - oRef, hasOverload := p.refMap[expr.GetId()] + oRef, hasOverload := p.refMap[expr.ID()] if hasOverload { if len(oRef.OverloadIDs) == 1 { return target, fnName, oRef.OverloadIDs[0] @@ -771,16 +722,30 @@ func (p *planner) relativeAttr(id int64, eval Interpretable, opt bool) (Interpre // toQualifiedName converts an expression AST into a qualified name if possible, with a boolean // 'found' value that indicates if the conversion is successful. -func (p *planner) toQualifiedName(operand *exprpb.Expr) (string, bool) { +func (p *planner) toQualifiedName(operand ast.Expr) (string, bool) { // If the checker identified the expression as an attribute by the type-checker, then it can't // possibly be part of qualified name in a namespace. - _, isAttr := p.refMap[operand.GetId()] + _, isAttr := p.refMap[operand.ID()] if isAttr { return "", false } // Since functions cannot be both namespaced and receiver functions, if the operand is not an // qualified variable name, return the (possibly) qualified name given the expressions. - return containers.ToQualifiedName(operand) + switch operand.Kind() { + case ast.IdentKind: + id := operand.AsIdent() + return id, true + case ast.SelectKind: + sel := operand.AsSelect() + // Test only expressions are not valid as qualified names. + if sel.IsTestOnly() { + return "", false + } + if qual, found := p.toQualifiedName(sel.Operand()); found { + return qual + "." + sel.FieldName(), true + } + } + return "", false } func stripLeadingDot(name string) string { diff --git a/vendor/github.com/google/cel-go/interpreter/prune.go b/vendor/github.com/google/cel-go/interpreter/prune.go index b8834b1cb85..410d80dc43c 100644 --- a/vendor/github.com/google/cel-go/interpreter/prune.go +++ b/vendor/github.com/google/cel-go/interpreter/prune.go @@ -15,19 +15,18 @@ package interpreter import ( + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - structpb "google.golang.org/protobuf/types/known/structpb" ) type astPruner struct { - expr *exprpb.Expr - macroCalls map[int64]*exprpb.Expr + ast.ExprFactory + expr ast.Expr + macroCalls map[int64]ast.Expr state EvalState nextExprID int64 } @@ -67,84 +66,44 @@ type astPruner struct { // compiled and constant folded expressions, but is not willing to constant // fold(and thus cache results of) some external calls, then they can prepare // the overloads accordingly. -func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr { +func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *ast.AST { pruneState := NewEvalState() for _, id := range state.IDs() { v, _ := state.Value(id) pruneState.SetValue(id, v) } pruner := &astPruner{ - expr: expr, - macroCalls: macroCalls, - state: pruneState, - nextExprID: getMaxID(expr)} + ExprFactory: ast.NewExprFactory(), + expr: expr, + macroCalls: macroCalls, + state: pruneState, + nextExprID: getMaxID(expr)} newExpr, _ := pruner.maybePrune(expr) - return &exprpb.ParsedExpr{ - Expr: newExpr, - SourceInfo: &exprpb.SourceInfo{MacroCalls: pruner.macroCalls}, - } -} - -func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr { - return &exprpb.Expr{ - Id: id, - ExprKind: &exprpb.Expr_ConstExpr{ - ConstExpr: val, - }, + newInfo := ast.NewSourceInfo(nil) + for id, call := range pruner.macroCalls { + newInfo.SetMacroCall(id, call) } + return ast.NewAST(newExpr, newInfo) } -func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) { +func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) { switch v := val.(type) { - case types.Bool: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true - case types.Bytes: + case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint: p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true - case types.Double: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true + return p.NewLiteral(id, val), true case types.Duration: p.state.SetValue(id, val) - durationString := string(v.ConvertToType(types.StringType).(types.String)) - return &exprpb.Expr{ - Id: id, - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Function: overloads.TypeConvertDuration, - Args: []*exprpb.Expr{ - p.createLiteral(p.nextID(), - &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}), - }, - }, - }, - }, true - case types.Int: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true - case types.Uint: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true - case types.String: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true - case types.Null: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true + durationString := v.ConvertToType(types.StringType).(types.String) + return p.NewCall(id, overloads.TypeConvertDuration, p.NewLiteral(p.nextID(), durationString)), true + case types.Timestamp: + timestampString := v.ConvertToType(types.StringType).(types.String) + return p.NewCall(id, overloads.TypeConvertTimestamp, p.NewLiteral(p.nextID(), timestampString)), true } // Attempt to build a list literal. if list, isList := val.(traits.Lister); isList { sz := list.Size().(types.Int) - elemExprs := make([]*exprpb.Expr, sz) + elemExprs := make([]ast.Expr, sz) for i := types.Int(0); i < sz; i++ { elem := list.Get(i) if types.IsUnknownOrError(elem) { @@ -157,20 +116,13 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo elemExprs[i] = elemExpr } p.state.SetValue(id, val) - return &exprpb.Expr{ - Id: id, - ExprKind: &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - Elements: elemExprs, - }, - }, - }, true + return p.NewList(id, elemExprs, []int32{}), true } // Create a map literal if possible. if mp, isMap := val.(traits.Mapper); isMap { it := mp.Iterator() - entries := make([]*exprpb.Expr_CreateStruct_Entry, mp.Size().(types.Int)) + entries := make([]ast.EntryExpr, mp.Size().(types.Int)) i := 0 for it.HasNext() != types.False { key := it.Next() @@ -186,25 +138,12 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo if !ok { return nil, false } - entry := &exprpb.Expr_CreateStruct_Entry{ - Id: p.nextID(), - KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{ - MapKey: keyExpr, - }, - Value: valExpr, - } + entry := p.NewMapEntry(p.nextID(), keyExpr, valExpr, false) entries[i] = entry i++ } p.state.SetValue(id, val) - return &exprpb.Expr{ - Id: id, - ExprKind: &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{ - Entries: entries, - }, - }, - }, true + return p.NewMap(id, entries), true } // TODO(issues/377) To construct message literals, the type provider will need to support @@ -212,215 +151,206 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo return nil, false } -func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) { - elemVal, found := p.value(elem.GetId()) +func (p *astPruner) maybePruneOptional(elem ast.Expr) (ast.Expr, bool) { + elemVal, found := p.value(elem.ID()) if found && elemVal.Type() == types.OptionalType { opt := elemVal.(*types.Optional) if !opt.HasValue() { return nil, true } - if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned { + if newElem, pruned := p.maybeCreateLiteral(elem.ID(), opt.GetValue()); pruned { return newElem, true } } return elem, false } -func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { +func (p *astPruner) maybePruneIn(node ast.Expr) (ast.Expr, bool) { // elem in list - call := node.GetCallExpr() - val, exists := p.maybeValue(call.GetArgs()[1].GetId()) + call := node.AsCall() + val, exists := p.maybeValue(call.Args()[1].ID()) if !exists { return nil, false } if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { - return p.maybeCreateLiteral(node.GetId(), types.False) + return p.maybeCreateLiteral(node.ID(), types.False) } return nil, false } -func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) { - call := node.GetCallExpr() - arg := call.GetArgs()[0] - val, exists := p.maybeValue(arg.GetId()) +func (p *astPruner) maybePruneLogicalNot(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + arg := call.Args()[0] + val, exists := p.maybeValue(arg.ID()) if !exists { return nil, false } if b, ok := val.(types.Bool); ok { - return p.maybeCreateLiteral(node.GetId(), !b) + return p.maybeCreateLiteral(node.ID(), !b) } return nil, false } -func (p *astPruner) maybePruneOr(node *exprpb.Expr) (*exprpb.Expr, bool) { - call := node.GetCallExpr() +func (p *astPruner) maybePruneOr(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. - if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { + if v, exists := p.maybeValue(call.Args()[0].ID()); exists { if v == types.True { - return p.maybeCreateLiteral(node.GetId(), types.True) + return p.maybeCreateLiteral(node.ID(), types.True) } - return call.GetArgs()[1], true + return call.Args()[1], true } - if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { + if v, exists := p.maybeValue(call.Args()[1].ID()); exists { if v == types.True { - return p.maybeCreateLiteral(node.GetId(), types.True) + return p.maybeCreateLiteral(node.ID(), types.True) } - return call.GetArgs()[0], true + return call.Args()[0], true } return nil, false } -func (p *astPruner) maybePruneAnd(node *exprpb.Expr) (*exprpb.Expr, bool) { - call := node.GetCallExpr() +func (p *astPruner) maybePruneAnd(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. - if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { + if v, exists := p.maybeValue(call.Args()[0].ID()); exists { if v == types.False { - return p.maybeCreateLiteral(node.GetId(), types.False) + return p.maybeCreateLiteral(node.ID(), types.False) } - return call.GetArgs()[1], true + return call.Args()[1], true } - if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { + if v, exists := p.maybeValue(call.Args()[1].ID()); exists { if v == types.False { - return p.maybeCreateLiteral(node.GetId(), types.False) + return p.maybeCreateLiteral(node.ID(), types.False) } - return call.GetArgs()[0], true + return call.Args()[0], true } return nil, false } -func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) { - call := node.GetCallExpr() - cond, exists := p.maybeValue(call.GetArgs()[0].GetId()) +func (p *astPruner) maybePruneConditional(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + cond, exists := p.maybeValue(call.Args()[0].ID()) if !exists { return nil, false } if cond.Value().(bool) { - return call.GetArgs()[1], true + return call.Args()[1], true } - return call.GetArgs()[2], true + return call.Args()[2], true } -func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) { - if _, exists := p.value(node.GetId()); !exists { +func (p *astPruner) maybePruneFunction(node ast.Expr) (ast.Expr, bool) { + if _, exists := p.value(node.ID()); !exists { return nil, false } - call := node.GetCallExpr() - if call.Function == operators.LogicalOr { + call := node.AsCall() + if call.FunctionName() == operators.LogicalOr { return p.maybePruneOr(node) } - if call.Function == operators.LogicalAnd { + if call.FunctionName() == operators.LogicalAnd { return p.maybePruneAnd(node) } - if call.Function == operators.Conditional { + if call.FunctionName() == operators.Conditional { return p.maybePruneConditional(node) } - if call.Function == operators.In { + if call.FunctionName() == operators.In { return p.maybePruneIn(node) } - if call.Function == operators.LogicalNot { + if call.FunctionName() == operators.LogicalNot { return p.maybePruneLogicalNot(node) } return nil, false } -func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) { +func (p *astPruner) maybePrune(node ast.Expr) (ast.Expr, bool) { return p.prune(node) } -func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { +func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) { if node == nil { return node, false } - val, valueExists := p.maybeValue(node.GetId()) + val, valueExists := p.maybeValue(node.ID()) if valueExists { - if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok { - delete(p.macroCalls, node.GetId()) + if newNode, ok := p.maybeCreateLiteral(node.ID(), val); ok { + delete(p.macroCalls, node.ID()) return newNode, true } } - if macro, found := p.macroCalls[node.GetId()]; found { + if macro, found := p.macroCalls[node.ID()]; found { // Ensure that intermediate values for the comprehension are cleared during pruning - compre := node.GetComprehensionExpr() - if compre != nil { - visit(macro, clearIterVarVisitor(compre.IterVar, p.state)) + if node.Kind() == ast.ComprehensionKind { + compre := node.AsComprehension() + visit(macro, clearIterVarVisitor(compre.IterVar(), p.state)) } // prune the expression in terms of the macro call instead of the expanded form. if newMacro, pruned := p.prune(macro); pruned { - p.macroCalls[node.GetId()] = newMacro + p.macroCalls[node.ID()] = newMacro } } // We have either an unknown/error value, or something we don't want to // transform, or expression was not evaluated. If possible, drill down // more. - switch node.GetExprKind().(type) { - case *exprpb.Expr_SelectExpr: - if operand, pruned := p.maybePrune(node.GetSelectExpr().GetOperand()); pruned { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_SelectExpr{ - SelectExpr: &exprpb.Expr_Select{ - Operand: operand, - Field: node.GetSelectExpr().GetField(), - TestOnly: node.GetSelectExpr().GetTestOnly(), - }, - }, - }, true - } - case *exprpb.Expr_CallExpr: - var prunedCall bool - call := node.GetCallExpr() - args := call.GetArgs() - newArgs := make([]*exprpb.Expr, len(args)) - newCall := &exprpb.Expr_Call{ - Function: call.GetFunction(), - Target: call.GetTarget(), - Args: newArgs, - } - for i, arg := range args { - newArgs[i] = arg - if newArg, prunedArg := p.maybePrune(arg); prunedArg { - prunedCall = true - newArgs[i] = newArg + switch node.Kind() { + case ast.SelectKind: + sel := node.AsSelect() + if operand, isPruned := p.maybePrune(sel.Operand()); isPruned { + if sel.IsTestOnly() { + return p.NewPresenceTest(node.ID(), operand, sel.FieldName()), true } + return p.NewSelect(node.ID(), operand, sel.FieldName()), true } - if newTarget, prunedTarget := p.maybePrune(call.GetTarget()); prunedTarget { - prunedCall = true - newCall.Target = newTarget + case ast.CallKind: + argsPruned := false + call := node.AsCall() + args := call.Args() + newArgs := make([]ast.Expr, len(args)) + for i, a := range args { + newArgs[i] = a + if arg, isPruned := p.maybePrune(a); isPruned { + argsPruned = true + newArgs[i] = arg + } } - newNode := &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: newCall, - }, + if !call.IsMemberFunction() { + newCall := p.NewCall(node.ID(), call.FunctionName(), newArgs...) + if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { + return prunedCall, true + } + return newCall, argsPruned } - if newExpr, pruned := p.maybePruneFunction(newNode); pruned { - newExpr, _ = p.maybePrune(newExpr) - return newExpr, true + newTarget := call.Target() + targetPruned := false + if prunedTarget, isPruned := p.maybePrune(call.Target()); isPruned { + targetPruned = true + newTarget = prunedTarget } - if prunedCall { - return newNode, true + newCall := p.NewMemberCall(node.ID(), call.FunctionName(), newTarget, newArgs...) + if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { + return prunedCall, true } - case *exprpb.Expr_ListExpr: - elems := node.GetListExpr().GetElements() - optIndices := node.GetListExpr().GetOptionalIndices() + return newCall, targetPruned || argsPruned + case ast.ListKind: + l := node.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() optIndexMap := map[int32]bool{} for _, i := range optIndices { optIndexMap[i] = true } newOptIndexMap := make(map[int32]bool, len(optIndexMap)) - newElems := make([]*exprpb.Expr, 0, len(elems)) - var prunedList bool - + newElems := make([]ast.Expr, 0, len(elems)) + var listPruned bool prunedIdx := 0 for i, elem := range elems { _, isOpt := optIndexMap[int32(i)] if isOpt { newElem, pruned := p.maybePruneOptional(elem) if pruned { - prunedList = true + listPruned = true if newElem != nil { newElems = append(newElems, newElem) prunedIdx++ @@ -431,7 +361,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { } if newElem, prunedElem := p.maybePrune(elem); prunedElem { newElems = append(newElems, newElem) - prunedList = true + listPruned = true } else { newElems = append(newElems, elem) } @@ -443,76 +373,64 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { optIndices[idx] = i idx++ } - if prunedList { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - Elements: newElems, - OptionalIndices: optIndices, - }, - }, - }, true + if listPruned { + return p.NewList(node.ID(), newElems, optIndices), true } - case *exprpb.Expr_StructExpr: - var prunedStruct bool - entries := node.GetStructExpr().GetEntries() - messageType := node.GetStructExpr().GetMessageName() - newEntries := make([]*exprpb.Expr_CreateStruct_Entry, len(entries)) + case ast.MapKind: + var mapPruned bool + m := node.AsMap() + entries := m.Entries() + newEntries := make([]ast.EntryExpr, len(entries)) for i, entry := range entries { newEntries[i] = entry - newKey, prunedKey := p.maybePrune(entry.GetMapKey()) - newValue, prunedValue := p.maybePrune(entry.GetValue()) - if !prunedKey && !prunedValue { + e := entry.AsMapEntry() + newKey, keyPruned := p.maybePrune(e.Key()) + newValue, valuePruned := p.maybePrune(e.Value()) + if !keyPruned && !valuePruned { continue } - prunedStruct = true - newEntry := &exprpb.Expr_CreateStruct_Entry{ - Value: newValue, - } - if messageType != "" { - newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_FieldKey{ - FieldKey: entry.GetFieldKey(), - } - } else { - newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_MapKey{ - MapKey: newKey, - } - } - newEntry.OptionalEntry = entry.GetOptionalEntry() + mapPruned = true + newEntry := p.NewMapEntry(entry.ID(), newKey, newValue, e.IsOptional()) newEntries[i] = newEntry } - if prunedStruct { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{ - MessageName: messageType, - Entries: newEntries, - }, - }, - }, true + if mapPruned { + return p.NewMap(node.ID(), newEntries), true } - case *exprpb.Expr_ComprehensionExpr: - compre := node.GetComprehensionExpr() + case ast.StructKind: + var structPruned bool + obj := node.AsStruct() + fields := obj.Fields() + newFields := make([]ast.EntryExpr, len(fields)) + for i, field := range fields { + newFields[i] = field + f := field.AsStructField() + newValue, prunedValue := p.maybePrune(f.Value()) + if !prunedValue { + continue + } + structPruned = true + newEntry := p.NewStructField(field.ID(), f.Name(), newValue, f.IsOptional()) + newFields[i] = newEntry + } + if structPruned { + return p.NewStruct(node.ID(), obj.TypeName(), newFields), true + } + case ast.ComprehensionKind: + compre := node.AsComprehension() // Only the range of the comprehension is pruned since the state tracking only records // the last iteration of the comprehension and not each step in the evaluation which // means that the any residuals computed in between might be inaccurate. - if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_ComprehensionExpr{ - ComprehensionExpr: &exprpb.Expr_Comprehension{ - IterVar: compre.GetIterVar(), - IterRange: newRange, - AccuVar: compre.GetAccuVar(), - AccuInit: compre.GetAccuInit(), - LoopCondition: compre.GetLoopCondition(), - LoopStep: compre.GetLoopStep(), - Result: compre.GetResult(), - }, - }, - }, true + if newRange, pruned := p.maybePrune(compre.IterRange()); pruned { + return p.NewComprehension( + node.ID(), + newRange, + compre.IterVar(), + compre.AccuVar(), + compre.AccuInit(), + compre.LoopCondition(), + compre.LoopStep(), + compre.Result(), + ), true } } return node, false @@ -539,12 +457,12 @@ func (p *astPruner) nextID() int64 { type astVisitor struct { // visitEntry is called on every expr node, including those within a map/struct entry. - visitExpr func(expr *exprpb.Expr) + visitExpr func(expr ast.Expr) // visitEntry is called before entering the key, value of a map/struct entry. - visitEntry func(entry *exprpb.Expr_CreateStruct_Entry) + visitEntry func(entry ast.EntryExpr) } -func getMaxID(expr *exprpb.Expr) int64 { +func getMaxID(expr ast.Expr) int64 { maxID := int64(1) visit(expr, maxIDVisitor(&maxID)) return maxID @@ -552,10 +470,9 @@ func getMaxID(expr *exprpb.Expr) int64 { func clearIterVarVisitor(varName string, state EvalState) astVisitor { return astVisitor{ - visitExpr: func(e *exprpb.Expr) { - ident := e.GetIdentExpr() - if ident != nil && ident.GetName() == varName { - state.SetValue(e.GetId(), nil) + visitExpr: func(e ast.Expr) { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + state.SetValue(e.ID(), nil) } }, } @@ -563,56 +480,63 @@ func clearIterVarVisitor(varName string, state EvalState) astVisitor { func maxIDVisitor(maxID *int64) astVisitor { return astVisitor{ - visitExpr: func(e *exprpb.Expr) { - if e.GetId() >= *maxID { - *maxID = e.GetId() + 1 + visitExpr: func(e ast.Expr) { + if e.ID() >= *maxID { + *maxID = e.ID() + 1 } }, - visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) { - if e.GetId() >= *maxID { - *maxID = e.GetId() + 1 + visitEntry: func(e ast.EntryExpr) { + if e.ID() >= *maxID { + *maxID = e.ID() + 1 } }, } } -func visit(expr *exprpb.Expr, visitor astVisitor) { - exprs := []*exprpb.Expr{expr} +func visit(expr ast.Expr, visitor astVisitor) { + exprs := []ast.Expr{expr} for len(exprs) != 0 { e := exprs[0] if visitor.visitExpr != nil { visitor.visitExpr(e) } exprs = exprs[1:] - switch e.GetExprKind().(type) { - case *exprpb.Expr_SelectExpr: - exprs = append(exprs, e.GetSelectExpr().GetOperand()) - case *exprpb.Expr_CallExpr: - call := e.GetCallExpr() - if call.GetTarget() != nil { - exprs = append(exprs, call.GetTarget()) + switch e.Kind() { + case ast.SelectKind: + exprs = append(exprs, e.AsSelect().Operand()) + case ast.CallKind: + call := e.AsCall() + if call.Target() != nil { + exprs = append(exprs, call.Target()) } - exprs = append(exprs, call.GetArgs()...) - case *exprpb.Expr_ComprehensionExpr: - compre := e.GetComprehensionExpr() + exprs = append(exprs, call.Args()...) + case ast.ComprehensionKind: + compre := e.AsComprehension() exprs = append(exprs, - compre.GetIterRange(), - compre.GetAccuInit(), - compre.GetLoopCondition(), - compre.GetLoopStep(), - compre.GetResult()) - case *exprpb.Expr_ListExpr: - list := e.GetListExpr() - exprs = append(exprs, list.GetElements()...) - case *exprpb.Expr_StructExpr: - for _, entry := range e.GetStructExpr().GetEntries() { + compre.IterRange(), + compre.AccuInit(), + compre.LoopCondition(), + compre.LoopStep(), + compre.Result()) + case ast.ListKind: + list := e.AsList() + exprs = append(exprs, list.Elements()...) + case ast.MapKind: + for _, entry := range e.AsMap().Entries() { + e := entry.AsMapEntry() if visitor.visitEntry != nil { visitor.visitEntry(entry) } - if entry.GetMapKey() != nil { - exprs = append(exprs, entry.GetMapKey()) + exprs = append(exprs, e.Key()) + exprs = append(exprs, e.Value()) + } + case ast.StructKind: + for _, entry := range e.AsStruct().Fields() { + f := entry.AsStructField() + if visitor.visitEntry != nil { + visitor.visitEntry(entry) } - exprs = append(exprs, entry.GetValue()) + exprs = append(exprs, f.Value()) } } } diff --git a/vendor/github.com/google/cel-go/parser/BUILD.bazel b/vendor/github.com/google/cel-go/parser/BUILD.bazel index 67ecc955439..a4a09ad75da 100644 --- a/vendor/github.com/google/cel-go/parser/BUILD.bazel +++ b/vendor/github.com/google/cel-go/parser/BUILD.bazel @@ -20,8 +20,11 @@ go_library( visibility = ["//visibility:public"], deps = [ "//common:go_default_library", + "//common/ast:go_default_library", "//common/operators:go_default_library", "//common/runes:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", "//parser/gen:go_default_library", "@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", @@ -43,7 +46,9 @@ go_test( ":go_default_library", ], deps = [ + "//common/ast:go_default_library", "//common/debug:go_default_library", + "//common/types:go_default_library", "//parser/gen:go_default_library", "//test:go_default_library", "@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library", diff --git a/vendor/github.com/google/cel-go/parser/helper.go b/vendor/github.com/google/cel-go/parser/helper.go index a5f29e3d7ae..07656e139fa 100644 --- a/vendor/github.com/google/cel-go/parser/helper.go +++ b/vendor/github.com/google/cel-go/parser/helper.go @@ -20,281 +20,206 @@ import ( antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4" "github.com/google/cel-go/common" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" ) type parserHelper struct { - source common.Source - nextID int64 - positions map[int64]int32 - macroCalls map[int64]*exprpb.Expr + exprFactory ast.ExprFactory + source common.Source + sourceInfo *ast.SourceInfo + nextID int64 } -func newParserHelper(source common.Source) *parserHelper { +func newParserHelper(source common.Source, fac ast.ExprFactory) *parserHelper { return &parserHelper{ - source: source, - nextID: 1, - positions: make(map[int64]int32), - macroCalls: make(map[int64]*exprpb.Expr), + exprFactory: fac, + source: source, + sourceInfo: ast.NewSourceInfo(source), + nextID: 1, } } -func (p *parserHelper) getSourceInfo() *exprpb.SourceInfo { - return &exprpb.SourceInfo{ - Location: p.source.Description(), - Positions: p.positions, - LineOffsets: p.source.LineOffsets(), - MacroCalls: p.macroCalls} +func (p *parserHelper) getSourceInfo() *ast.SourceInfo { + return p.sourceInfo } -func (p *parserHelper) newLiteral(ctx any, value *exprpb.Constant) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_ConstExpr{ConstExpr: value} - return exprNode +func (p *parserHelper) newLiteral(ctx any, value ref.Val) ast.Expr { + return p.exprFactory.NewLiteral(p.newID(ctx), value) } -func (p *parserHelper) newLiteralBool(ctx any, value bool) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: value}}) +func (p *parserHelper) newLiteralBool(ctx any, value bool) ast.Expr { + return p.newLiteral(ctx, types.Bool(value)) } -func (p *parserHelper) newLiteralString(ctx any, value string) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: value}}) +func (p *parserHelper) newLiteralString(ctx any, value string) ast.Expr { + return p.newLiteral(ctx, types.String(value)) } -func (p *parserHelper) newLiteralBytes(ctx any, value []byte) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: value}}) +func (p *parserHelper) newLiteralBytes(ctx any, value []byte) ast.Expr { + return p.newLiteral(ctx, types.Bytes(value)) } -func (p *parserHelper) newLiteralInt(ctx any, value int64) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: value}}) +func (p *parserHelper) newLiteralInt(ctx any, value int64) ast.Expr { + return p.newLiteral(ctx, types.Int(value)) } -func (p *parserHelper) newLiteralUint(ctx any, value uint64) *exprpb.Expr { - return p.newLiteral(ctx, &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: value}}) +func (p *parserHelper) newLiteralUint(ctx any, value uint64) ast.Expr { + return p.newLiteral(ctx, types.Uint(value)) } -func (p *parserHelper) newLiteralDouble(ctx any, value float64) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: value}}) +func (p *parserHelper) newLiteralDouble(ctx any, value float64) ast.Expr { + return p.newLiteral(ctx, types.Double(value)) } -func (p *parserHelper) newIdent(ctx any, name string) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_IdentExpr{IdentExpr: &exprpb.Expr_Ident{Name: name}} - return exprNode +func (p *parserHelper) newIdent(ctx any, name string) ast.Expr { + return p.exprFactory.NewIdent(p.newID(ctx), name) } -func (p *parserHelper) newSelect(ctx any, operand *exprpb.Expr, field string) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_SelectExpr{ - SelectExpr: &exprpb.Expr_Select{Operand: operand, Field: field}} - return exprNode +func (p *parserHelper) newSelect(ctx any, operand ast.Expr, field string) ast.Expr { + return p.exprFactory.NewSelect(p.newID(ctx), operand, field) } -func (p *parserHelper) newPresenceTest(ctx any, operand *exprpb.Expr, field string) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_SelectExpr{ - SelectExpr: &exprpb.Expr_Select{Operand: operand, Field: field, TestOnly: true}} - return exprNode +func (p *parserHelper) newPresenceTest(ctx any, operand ast.Expr, field string) ast.Expr { + return p.exprFactory.NewPresenceTest(p.newID(ctx), operand, field) } -func (p *parserHelper) newGlobalCall(ctx any, function string, args ...*exprpb.Expr) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{Function: function, Args: args}} - return exprNode +func (p *parserHelper) newGlobalCall(ctx any, function string, args ...ast.Expr) ast.Expr { + return p.exprFactory.NewCall(p.newID(ctx), function, args...) } -func (p *parserHelper) newReceiverCall(ctx any, function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{Function: function, Target: target, Args: args}} - return exprNode +func (p *parserHelper) newReceiverCall(ctx any, function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return p.exprFactory.NewMemberCall(p.newID(ctx), function, target, args...) } -func (p *parserHelper) newList(ctx any, elements []*exprpb.Expr, optionals ...int32) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - Elements: elements, - OptionalIndices: optionals, - }} - return exprNode +func (p *parserHelper) newList(ctx any, elements []ast.Expr, optionals ...int32) ast.Expr { + return p.exprFactory.NewList(p.newID(ctx), elements, optionals) } -func (p *parserHelper) newMap(ctx any, entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{Entries: entries}} - return exprNode +func (p *parserHelper) newMap(ctx any, entries ...ast.EntryExpr) ast.Expr { + return p.exprFactory.NewMap(p.newID(ctx), entries) } -func (p *parserHelper) newMapEntry(entryID int64, key *exprpb.Expr, value *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { - return &exprpb.Expr_CreateStruct_Entry{ - Id: entryID, - KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{MapKey: key}, - Value: value, - OptionalEntry: optional, - } +func (p *parserHelper) newMapEntry(entryID int64, key ast.Expr, value ast.Expr, optional bool) ast.EntryExpr { + return p.exprFactory.NewMapEntry(entryID, key, value, optional) } -func (p *parserHelper) newObject(ctx any, typeName string, entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{ - MessageName: typeName, - Entries: entries, - }, - } - return exprNode +func (p *parserHelper) newObject(ctx any, typeName string, fields ...ast.EntryExpr) ast.Expr { + return p.exprFactory.NewStruct(p.newID(ctx), typeName, fields) } -func (p *parserHelper) newObjectField(fieldID int64, field string, value *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { - return &exprpb.Expr_CreateStruct_Entry{ - Id: fieldID, - KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{FieldKey: field}, - Value: value, - OptionalEntry: optional, - } +func (p *parserHelper) newObjectField(fieldID int64, field string, value ast.Expr, optional bool) ast.EntryExpr { + return p.exprFactory.NewStructField(fieldID, field, value, optional) } -func (p *parserHelper) newComprehension(ctx any, iterVar string, - iterRange *exprpb.Expr, +func (p *parserHelper) newComprehension(ctx any, + iterRange ast.Expr, + iterVar string, accuVar string, - accuInit *exprpb.Expr, - condition *exprpb.Expr, - step *exprpb.Expr, - result *exprpb.Expr) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_ComprehensionExpr{ - ComprehensionExpr: &exprpb.Expr_Comprehension{ - AccuVar: accuVar, - AccuInit: accuInit, - IterVar: iterVar, - IterRange: iterRange, - LoopCondition: condition, - LoopStep: step, - Result: result}} - return exprNode -} - -func (p *parserHelper) newExpr(ctx any) *exprpb.Expr { - id, isID := ctx.(int64) - if isID { - return &exprpb.Expr{Id: id} + accuInit ast.Expr, + condition ast.Expr, + step ast.Expr, + result ast.Expr) ast.Expr { + return p.exprFactory.NewComprehension( + p.newID(ctx), iterRange, iterVar, accuVar, accuInit, condition, step, result) +} + +func (p *parserHelper) newID(ctx any) int64 { + if id, isID := ctx.(int64); isID { + return id } - return &exprpb.Expr{Id: p.id(ctx)} + return p.id(ctx) +} + +func (p *parserHelper) newExpr(ctx any) ast.Expr { + return p.exprFactory.NewUnspecifiedExpr(p.newID(ctx)) } func (p *parserHelper) id(ctx any) int64 { - var location common.Location + var offset ast.OffsetRange switch c := ctx.(type) { case antlr.ParserRuleContext: - token := c.GetStart() - location = p.source.NewLocation(token.GetLine(), token.GetColumn()) + start, stop := c.GetStart(), c.GetStop() + if stop == nil { + stop = start + } + offset.Start = p.sourceInfo.ComputeOffset(int32(start.GetLine()), int32(start.GetColumn())) + offset.Stop = p.sourceInfo.ComputeOffset(int32(stop.GetLine()), int32(stop.GetColumn())) case antlr.Token: - token := c - location = p.source.NewLocation(token.GetLine(), token.GetColumn()) + offset.Start = p.sourceInfo.ComputeOffset(int32(c.GetLine()), int32(c.GetColumn())) + offset.Stop = offset.Start case common.Location: - location = c + offset.Start = p.sourceInfo.ComputeOffset(int32(c.Line()), int32(c.Column())) + offset.Stop = offset.Start + case ast.OffsetRange: + offset = c default: // This should only happen if the ctx is nil return -1 } id := p.nextID - p.positions[id], _ = p.source.LocationOffset(location) + p.sourceInfo.SetOffsetRange(id, offset) p.nextID++ return id } func (p *parserHelper) getLocation(id int64) common.Location { - characterOffset := p.positions[id] - location, _ := p.source.OffsetLocation(characterOffset) - return location + return p.sourceInfo.GetStartLocation(id) } // buildMacroCallArg iterates the expression and returns a new expression // where all macros have been replaced by their IDs in MacroCalls -func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr { - if _, found := p.macroCalls[expr.GetId()]; found { - return &exprpb.Expr{Id: expr.GetId()} +func (p *parserHelper) buildMacroCallArg(expr ast.Expr) ast.Expr { + if _, found := p.sourceInfo.GetMacroCall(expr.ID()); found { + return p.exprFactory.NewUnspecifiedExpr(expr.ID()) } - switch expr.GetExprKind().(type) { - case *exprpb.Expr_CallExpr: + switch expr.Kind() { + case ast.CallKind: // Iterate the AST from `expr` recursively looking for macros. Because we are at most // starting from the top level macro, this recursion is bounded by the size of the AST. This // means that the depth check on the AST during parsing will catch recursion overflows // before we get to here. - macroTarget := expr.GetCallExpr().GetTarget() - if macroTarget != nil { - macroTarget = p.buildMacroCallArg(macroTarget) - } - macroArgs := make([]*exprpb.Expr, len(expr.GetCallExpr().GetArgs())) - for index, arg := range expr.GetCallExpr().GetArgs() { + call := expr.AsCall() + macroArgs := make([]ast.Expr, len(call.Args())) + for index, arg := range call.Args() { macroArgs[index] = p.buildMacroCallArg(arg) } - return &exprpb.Expr{ - Id: expr.GetId(), - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Target: macroTarget, - Function: expr.GetCallExpr().GetFunction(), - Args: macroArgs, - }, - }, + if !call.IsMemberFunction() { + return p.exprFactory.NewCall(expr.ID(), call.FunctionName(), macroArgs...) } - case *exprpb.Expr_ListExpr: - listExpr := expr.GetListExpr() - macroListArgs := make([]*exprpb.Expr, len(listExpr.GetElements())) - for i, elem := range listExpr.GetElements() { + macroTarget := p.buildMacroCallArg(call.Target()) + return p.exprFactory.NewMemberCall(expr.ID(), call.FunctionName(), macroTarget, macroArgs...) + case ast.ListKind: + list := expr.AsList() + macroListArgs := make([]ast.Expr, list.Size()) + for i, elem := range list.Elements() { macroListArgs[i] = p.buildMacroCallArg(elem) } - return &exprpb.Expr{ - Id: expr.GetId(), - ExprKind: &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - Elements: macroListArgs, - OptionalIndices: listExpr.GetOptionalIndices(), - }, - }, - } + return p.exprFactory.NewList(expr.ID(), macroListArgs, list.OptionalIndices()) } - return expr } // addMacroCall adds the macro the the MacroCalls map in source info. If a macro has args/subargs/target // that are macros, their ID will be stored instead for later self-lookups. -func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) { - macroTarget := target - if target != nil { - if _, found := p.macroCalls[target.GetId()]; found { - macroTarget = &exprpb.Expr{Id: target.GetId()} - } else { - macroTarget = p.buildMacroCallArg(target) - } - } - - macroArgs := make([]*exprpb.Expr, len(args)) +func (p *parserHelper) addMacroCall(exprID int64, function string, target ast.Expr, args ...ast.Expr) { + macroArgs := make([]ast.Expr, len(args)) for index, arg := range args { macroArgs[index] = p.buildMacroCallArg(arg) } - - p.macroCalls[exprID] = &exprpb.Expr{ - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Target: macroTarget, - Function: function, - Args: macroArgs, - }, - }, + if target == nil { + p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewCall(0, function, macroArgs...)) + return } + macroTarget := target + if _, found := p.sourceInfo.GetMacroCall(target.ID()); found { + macroTarget = p.exprFactory.NewUnspecifiedExpr(target.ID()) + } else { + macroTarget = p.buildMacroCallArg(target) + } + p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewMemberCall(0, function, macroTarget, macroArgs...)) } // logicManager compacts logical trees into a more efficient structure which is semantically @@ -309,71 +234,71 @@ func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprp // controversial choice as it alters the traditional order of execution assumptions present in most // expressions. type logicManager struct { - helper *parserHelper + exprFactory ast.ExprFactory function string - terms []*exprpb.Expr + terms []ast.Expr ops []int64 variadicASTs bool } // newVariadicLogicManager creates a logic manager instance bound to a specific function and its first term. -func newVariadicLogicManager(h *parserHelper, function string, term *exprpb.Expr) *logicManager { +func newVariadicLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager { return &logicManager{ - helper: h, + exprFactory: fac, function: function, - terms: []*exprpb.Expr{term}, + terms: []ast.Expr{term}, ops: []int64{}, variadicASTs: true, } } // newBalancingLogicManager creates a logic manager instance bound to a specific function and its first term. -func newBalancingLogicManager(h *parserHelper, function string, term *exprpb.Expr) *logicManager { +func newBalancingLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager { return &logicManager{ - helper: h, + exprFactory: fac, function: function, - terms: []*exprpb.Expr{term}, + terms: []ast.Expr{term}, ops: []int64{}, variadicASTs: false, } } // addTerm adds an operation identifier and term to the set of terms to be balanced. -func (l *logicManager) addTerm(op int64, term *exprpb.Expr) { +func (l *logicManager) addTerm(op int64, term ast.Expr) { l.terms = append(l.terms, term) l.ops = append(l.ops, op) } // toExpr renders the logic graph into an Expr value, either balancing a tree of logical // operations or creating a variadic representation of the logical operator. -func (l *logicManager) toExpr() *exprpb.Expr { +func (l *logicManager) toExpr() ast.Expr { if len(l.terms) == 1 { return l.terms[0] } if l.variadicASTs { - return l.helper.newGlobalCall(l.ops[0], l.function, l.terms...) + return l.exprFactory.NewCall(l.ops[0], l.function, l.terms...) } return l.balancedTree(0, len(l.ops)-1) } // balancedTree recursively balances the terms provided to a commutative operator. -func (l *logicManager) balancedTree(lo, hi int) *exprpb.Expr { +func (l *logicManager) balancedTree(lo, hi int) ast.Expr { mid := (lo + hi + 1) / 2 - var left *exprpb.Expr + var left ast.Expr if mid == lo { left = l.terms[mid] } else { left = l.balancedTree(lo, mid-1) } - var right *exprpb.Expr + var right ast.Expr if mid == hi { right = l.terms[mid+1] } else { right = l.balancedTree(mid+1, hi) } - return l.helper.newGlobalCall(l.ops[mid], l.function, left, right) + return l.exprFactory.NewCall(l.ops[mid], l.function, left, right) } type exprHelper struct { @@ -387,202 +312,151 @@ func (e *exprHelper) nextMacroID() int64 { // Copy implements the ExprHelper interface method by producing a copy of the input Expr value // with a fresh set of numeric identifiers the Expr and all its descendants. -func (e *exprHelper) Copy(expr *exprpb.Expr) *exprpb.Expr { - copy := e.parserHelper.newExpr(e.parserHelper.getLocation(expr.GetId())) - switch expr.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - copy.ExprKind = &exprpb.Expr_ConstExpr{ConstExpr: expr.GetConstExpr()} - case *exprpb.Expr_IdentExpr: - copy.ExprKind = &exprpb.Expr_IdentExpr{IdentExpr: expr.GetIdentExpr()} - case *exprpb.Expr_SelectExpr: - op := expr.GetSelectExpr().GetOperand() - copy.ExprKind = &exprpb.Expr_SelectExpr{SelectExpr: &exprpb.Expr_Select{ - Operand: e.Copy(op), - Field: expr.GetSelectExpr().GetField(), - TestOnly: expr.GetSelectExpr().GetTestOnly(), - }} - case *exprpb.Expr_CallExpr: - call := expr.GetCallExpr() - target := call.GetTarget() - if target != nil { - target = e.Copy(target) +func (e *exprHelper) Copy(expr ast.Expr) ast.Expr { + offsetRange, _ := e.parserHelper.sourceInfo.GetOffsetRange(expr.ID()) + copyID := e.parserHelper.newID(offsetRange) + switch expr.Kind() { + case ast.LiteralKind: + return e.exprFactory.NewLiteral(copyID, expr.AsLiteral()) + case ast.IdentKind: + return e.exprFactory.NewIdent(copyID, expr.AsIdent()) + case ast.SelectKind: + sel := expr.AsSelect() + op := e.Copy(sel.Operand()) + if sel.IsTestOnly() { + return e.exprFactory.NewPresenceTest(copyID, op, sel.FieldName()) } - args := call.GetArgs() - argsCopy := make([]*exprpb.Expr, len(args)) + return e.exprFactory.NewSelect(copyID, op, sel.FieldName()) + case ast.CallKind: + call := expr.AsCall() + args := call.Args() + argsCopy := make([]ast.Expr, len(args)) for i, arg := range args { argsCopy[i] = e.Copy(arg) } - copy.ExprKind = &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Function: call.GetFunction(), - Target: target, - Args: argsCopy, - }, + if !call.IsMemberFunction() { + return e.exprFactory.NewCall(copyID, call.FunctionName(), argsCopy...) } - case *exprpb.Expr_ListExpr: - elems := expr.GetListExpr().GetElements() - elemsCopy := make([]*exprpb.Expr, len(elems)) + return e.exprFactory.NewMemberCall(copyID, call.FunctionName(), e.Copy(call.Target()), argsCopy...) + case ast.ListKind: + list := expr.AsList() + elems := list.Elements() + elemsCopy := make([]ast.Expr, len(elems)) for i, elem := range elems { elemsCopy[i] = e.Copy(elem) } - copy.ExprKind = &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{Elements: elemsCopy}, - } - case *exprpb.Expr_StructExpr: - entries := expr.GetStructExpr().GetEntries() - entriesCopy := make([]*exprpb.Expr_CreateStruct_Entry, len(entries)) - for i, entry := range entries { - entryCopy := &exprpb.Expr_CreateStruct_Entry{} - entryCopy.Id = e.nextMacroID() - switch entry.GetKeyKind().(type) { - case *exprpb.Expr_CreateStruct_Entry_FieldKey: - entryCopy.KeyKind = &exprpb.Expr_CreateStruct_Entry_FieldKey{ - FieldKey: entry.GetFieldKey(), - } - case *exprpb.Expr_CreateStruct_Entry_MapKey: - entryCopy.KeyKind = &exprpb.Expr_CreateStruct_Entry_MapKey{ - MapKey: e.Copy(entry.GetMapKey()), - } - } - entryCopy.Value = e.Copy(entry.GetValue()) - entriesCopy[i] = entryCopy + return e.exprFactory.NewList(copyID, elemsCopy, list.OptionalIndices()) + case ast.MapKind: + m := expr.AsMap() + entries := m.Entries() + entriesCopy := make([]ast.EntryExpr, len(entries)) + for i, en := range entries { + entry := en.AsMapEntry() + entryID := e.nextMacroID() + entriesCopy[i] = e.exprFactory.NewMapEntry(entryID, + e.Copy(entry.Key()), e.Copy(entry.Value()), entry.IsOptional()) } - copy.ExprKind = &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{ - MessageName: expr.GetStructExpr().GetMessageName(), - Entries: entriesCopy, - }, - } - case *exprpb.Expr_ComprehensionExpr: - iterRange := e.Copy(expr.GetComprehensionExpr().GetIterRange()) - accuInit := e.Copy(expr.GetComprehensionExpr().GetAccuInit()) - cond := e.Copy(expr.GetComprehensionExpr().GetLoopCondition()) - step := e.Copy(expr.GetComprehensionExpr().GetLoopStep()) - result := e.Copy(expr.GetComprehensionExpr().GetResult()) - copy.ExprKind = &exprpb.Expr_ComprehensionExpr{ - ComprehensionExpr: &exprpb.Expr_Comprehension{ - IterRange: iterRange, - IterVar: expr.GetComprehensionExpr().GetIterVar(), - AccuInit: accuInit, - AccuVar: expr.GetComprehensionExpr().GetAccuVar(), - LoopCondition: cond, - LoopStep: step, - Result: result, - }, + return e.exprFactory.NewMap(copyID, entriesCopy) + case ast.StructKind: + s := expr.AsStruct() + fields := s.Fields() + fieldsCopy := make([]ast.EntryExpr, len(fields)) + for i, f := range fields { + field := f.AsStructField() + fieldID := e.nextMacroID() + fieldsCopy[i] = e.exprFactory.NewStructField(fieldID, + field.Name(), e.Copy(field.Value()), field.IsOptional()) } + return e.exprFactory.NewStruct(copyID, s.TypeName(), fieldsCopy) + case ast.ComprehensionKind: + compre := expr.AsComprehension() + iterRange := e.Copy(compre.IterRange()) + accuInit := e.Copy(compre.AccuInit()) + cond := e.Copy(compre.LoopCondition()) + step := e.Copy(compre.LoopStep()) + result := e.Copy(compre.Result()) + return e.exprFactory.NewComprehension(copyID, + iterRange, compre.IterVar(), compre.AccuVar(), accuInit, cond, step, result) } - return copy -} - -// LiteralBool implements the ExprHelper interface method. -func (e *exprHelper) LiteralBool(value bool) *exprpb.Expr { - return e.parserHelper.newLiteralBool(e.nextMacroID(), value) -} - -// LiteralBytes implements the ExprHelper interface method. -func (e *exprHelper) LiteralBytes(value []byte) *exprpb.Expr { - return e.parserHelper.newLiteralBytes(e.nextMacroID(), value) -} - -// LiteralDouble implements the ExprHelper interface method. -func (e *exprHelper) LiteralDouble(value float64) *exprpb.Expr { - return e.parserHelper.newLiteralDouble(e.nextMacroID(), value) -} - -// LiteralInt implements the ExprHelper interface method. -func (e *exprHelper) LiteralInt(value int64) *exprpb.Expr { - return e.parserHelper.newLiteralInt(e.nextMacroID(), value) + return e.exprFactory.NewUnspecifiedExpr(copyID) } -// LiteralString implements the ExprHelper interface method. -func (e *exprHelper) LiteralString(value string) *exprpb.Expr { - return e.parserHelper.newLiteralString(e.nextMacroID(), value) -} - -// LiteralUint implements the ExprHelper interface method. -func (e *exprHelper) LiteralUint(value uint64) *exprpb.Expr { - return e.parserHelper.newLiteralUint(e.nextMacroID(), value) +// NewLiteral implements the ExprHelper interface method. +func (e *exprHelper) NewLiteral(value ref.Val) ast.Expr { + return e.exprFactory.NewLiteral(e.nextMacroID(), value) } // NewList implements the ExprHelper interface method. -func (e *exprHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr { - return e.parserHelper.newList(e.nextMacroID(), elems) +func (e *exprHelper) NewList(elems ...ast.Expr) ast.Expr { + return e.exprFactory.NewList(e.nextMacroID(), elems, []int32{}) } // NewMap implements the ExprHelper interface method. -func (e *exprHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { - return e.parserHelper.newMap(e.nextMacroID(), entries...) +func (e *exprHelper) NewMap(entries ...ast.EntryExpr) ast.Expr { + return e.exprFactory.NewMap(e.nextMacroID(), entries) } // NewMapEntry implements the ExprHelper interface method. -func (e *exprHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { - return e.parserHelper.newMapEntry(e.nextMacroID(), key, val, optional) +func (e *exprHelper) NewMapEntry(key ast.Expr, val ast.Expr, optional bool) ast.EntryExpr { + return e.exprFactory.NewMapEntry(e.nextMacroID(), key, val, optional) } -// NewObject implements the ExprHelper interface method. -func (e *exprHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { - return e.parserHelper.newObject(e.nextMacroID(), typeName, fieldInits...) +// NewStruct implements the ExprHelper interface method. +func (e *exprHelper) NewStruct(typeName string, fieldInits ...ast.EntryExpr) ast.Expr { + return e.exprFactory.NewStruct(e.nextMacroID(), typeName, fieldInits) } -// NewObjectFieldInit implements the ExprHelper interface method. -func (e *exprHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { - return e.parserHelper.newObjectField(e.nextMacroID(), field, init, optional) +// NewStructField implements the ExprHelper interface method. +func (e *exprHelper) NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr { + return e.exprFactory.NewStructField(e.nextMacroID(), field, init, optional) } -// Fold implements the ExprHelper interface method. -func (e *exprHelper) Fold(iterVar string, - iterRange *exprpb.Expr, +// NewComprehension implements the ExprHelper interface method. +func (e *exprHelper) NewComprehension( + iterRange ast.Expr, + iterVar string, accuVar string, - accuInit *exprpb.Expr, - condition *exprpb.Expr, - step *exprpb.Expr, - result *exprpb.Expr) *exprpb.Expr { - return e.parserHelper.newComprehension( - e.nextMacroID(), iterVar, iterRange, accuVar, accuInit, condition, step, result) + accuInit ast.Expr, + condition ast.Expr, + step ast.Expr, + result ast.Expr) ast.Expr { + return e.exprFactory.NewComprehension( + e.nextMacroID(), iterRange, iterVar, accuVar, accuInit, condition, step, result) } -// Ident implements the ExprHelper interface method. -func (e *exprHelper) Ident(name string) *exprpb.Expr { - return e.parserHelper.newIdent(e.nextMacroID(), name) +// NewIdent implements the ExprHelper interface method. +func (e *exprHelper) NewIdent(name string) ast.Expr { + return e.exprFactory.NewIdent(e.nextMacroID(), name) } -// AccuIdent implements the ExprHelper interface method. -func (e *exprHelper) AccuIdent() *exprpb.Expr { - return e.parserHelper.newIdent(e.nextMacroID(), AccumulatorName) +// NewAccuIdent implements the ExprHelper interface method. +func (e *exprHelper) NewAccuIdent() ast.Expr { + return e.exprFactory.NewAccuIdent(e.nextMacroID()) } -// GlobalCall implements the ExprHelper interface method. -func (e *exprHelper) GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr { - return e.parserHelper.newGlobalCall(e.nextMacroID(), function, args...) +// NewGlobalCall implements the ExprHelper interface method. +func (e *exprHelper) NewCall(function string, args ...ast.Expr) ast.Expr { + return e.exprFactory.NewCall(e.nextMacroID(), function, args...) } -// ReceiverCall implements the ExprHelper interface method. -func (e *exprHelper) ReceiverCall(function string, - target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr { - return e.parserHelper.newReceiverCall(e.nextMacroID(), function, target, args...) +// NewMemberCall implements the ExprHelper interface method. +func (e *exprHelper) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return e.exprFactory.NewMemberCall(e.nextMacroID(), function, target, args...) } -// PresenceTest implements the ExprHelper interface method. -func (e *exprHelper) PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr { - return e.parserHelper.newPresenceTest(e.nextMacroID(), operand, field) +// NewPresenceTest implements the ExprHelper interface method. +func (e *exprHelper) NewPresenceTest(operand ast.Expr, field string) ast.Expr { + return e.exprFactory.NewPresenceTest(e.nextMacroID(), operand, field) } -// Select implements the ExprHelper interface method. -func (e *exprHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr { - return e.parserHelper.newSelect(e.nextMacroID(), operand, field) +// NewSelect implements the ExprHelper interface method. +func (e *exprHelper) NewSelect(operand ast.Expr, field string) ast.Expr { + return e.exprFactory.NewSelect(e.nextMacroID(), operand, field) } // OffsetLocation implements the ExprHelper interface method. func (e *exprHelper) OffsetLocation(exprID int64) common.Location { - offset, found := e.parserHelper.positions[exprID] - if !found { - return common.NoLocation - } - location, found := e.parserHelper.source.OffsetLocation(offset) - if !found { - return common.NoLocation - } - return location + return e.parserHelper.sourceInfo.GetStartLocation(exprID) } // NewError associates an error message with a given expression id, populating the source offset location of the error if possible. diff --git a/vendor/github.com/google/cel-go/parser/macro.go b/vendor/github.com/google/cel-go/parser/macro.go index 6066e8ef4f8..5b1775bedb9 100644 --- a/vendor/github.com/google/cel-go/parser/macro.go +++ b/vendor/github.com/google/cel-go/parser/macro.go @@ -18,9 +18,10 @@ import ( "fmt" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" ) // NewGlobalMacro creates a Macro for a global function with the specified arg count. @@ -142,58 +143,38 @@ func makeVarArgMacroKey(name string, receiverStyle bool) string { // and produces as output an Expr ast node. // // Note: when the Macro.IsReceiverStyle() method returns true, the target argument will be nil. -type MacroExpander func(eh ExprHelper, - target *exprpb.Expr, - args []*exprpb.Expr) (*exprpb.Expr, *common.Error) +type MacroExpander func(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) -// ExprHelper assists with the manipulation of proto-based Expr values in a manner which is -// consistent with the source position and expression id generation code leveraged by both -// the parser and type-checker. +// ExprHelper assists with the creation of Expr values in a manner which is consistent +// the internal semantics and id generation behaviors of the parser and checker libraries. type ExprHelper interface { // Copy the input expression with a brand new set of identifiers. - Copy(*exprpb.Expr) *exprpb.Expr - - // LiteralBool creates an Expr value for a bool literal. - LiteralBool(value bool) *exprpb.Expr - - // LiteralBytes creates an Expr value for a byte literal. - LiteralBytes(value []byte) *exprpb.Expr - - // LiteralDouble creates an Expr value for double literal. - LiteralDouble(value float64) *exprpb.Expr + Copy(ast.Expr) ast.Expr - // LiteralInt creates an Expr value for an int literal. - LiteralInt(value int64) *exprpb.Expr + // Literal creates an Expr value for a scalar literal value. + NewLiteral(value ref.Val) ast.Expr - // LiteralString creates am Expr value for a string literal. - LiteralString(value string) *exprpb.Expr - - // LiteralUint creates an Expr value for a uint literal. - LiteralUint(value uint64) *exprpb.Expr - - // NewList creates a CreateList instruction where the list is comprised of the optional set - // of elements provided as arguments. - NewList(elems ...*exprpb.Expr) *exprpb.Expr + // NewList creates a list literal instruction with an optional set of elements. + NewList(elems ...ast.Expr) ast.Expr // NewMap creates a CreateStruct instruction for a map where the map is comprised of the // optional set of key, value entries. - NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + NewMap(entries ...ast.EntryExpr) ast.Expr // NewMapEntry creates a Map Entry for the key, value pair. - NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + NewMapEntry(key ast.Expr, val ast.Expr, optional bool) ast.EntryExpr - // NewObject creates a CreateStruct instruction for an object with a given type name and - // optional set of field initializers. - NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + // NewStruct creates a struct literal expression with an optional set of field initializers. + NewStruct(typeName string, fieldInits ...ast.EntryExpr) ast.Expr - // NewObjectFieldInit creates a new Object field initializer from the field name and value. - NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + // NewStructField creates a new struct field initializer from the field name and value. + NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr - // Fold creates a fold comprehension instruction. + // NewComprehension creates a new comprehension instruction. // - // - iterVar is the iteration variable name. // - iterRange represents the expression that resolves to a list or map where the elements or // keys (respectively) will be iterated over. + // - iterVar is the iteration variable name. // - accuVar is the accumulation variable name, typically parser.AccumulatorName. // - accuInit is the initial expression whose value will be set for the accuVar prior to // folding. @@ -204,31 +185,31 @@ type ExprHelper interface { // The accuVar should not shadow variable names that you would like to reference within the // environment in the step and condition expressions. Presently, the name __result__ is commonly // used by built-in macros but this may change in the future. - Fold(iterVar string, - iterRange *exprpb.Expr, + NewComprehension(iterRange ast.Expr, + iterVar string, accuVar string, - accuInit *exprpb.Expr, - condition *exprpb.Expr, - step *exprpb.Expr, - result *exprpb.Expr) *exprpb.Expr + accuInit ast.Expr, + condition ast.Expr, + step ast.Expr, + result ast.Expr) ast.Expr - // Ident creates an identifier Expr value. - Ident(name string) *exprpb.Expr + // NewIdent creates an identifier Expr value. + NewIdent(name string) ast.Expr - // AccuIdent returns an accumulator identifier for use with comprehension results. - AccuIdent() *exprpb.Expr + // NewAccuIdent returns an accumulator identifier for use with comprehension results. + NewAccuIdent() ast.Expr - // GlobalCall creates a function call Expr value for a global (free) function. - GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr + // NewCall creates a function call Expr value for a global (free) function. + NewCall(function string, args ...ast.Expr) ast.Expr - // ReceiverCall creates a function call Expr value for a receiver-style function. - ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr + // NewMemberCall creates a function call Expr value for a receiver-style function. + NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr - // PresenceTest creates a Select TestOnly Expr value for modelling has() semantics. - PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr + // NewPresenceTest creates a Select TestOnly Expr value for modelling has() semantics. + NewPresenceTest(operand ast.Expr, field string) ast.Expr - // Select create a field traversal Expr value. - Select(operand *exprpb.Expr, field string) *exprpb.Expr + // NewSelect create a field traversal Expr value. + NewSelect(operand ast.Expr, field string) ast.Expr // OffsetLocation returns the Location of the expression identifier. OffsetLocation(exprID int64) common.Location @@ -296,21 +277,21 @@ const ( // MakeAll expands the input call arguments into a comprehension that returns true if all of the // elements in the range match the predicate expressions: // .all(, ) -func MakeAll(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeAll(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { return makeQuantifier(quantifierAll, eh, target, args) } // MakeExists expands the input call arguments into a comprehension that returns true if any of the // elements in the range match the predicate expressions: // .exists(, ) -func MakeExists(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeExists(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { return makeQuantifier(quantifierExists, eh, target, args) } // MakeExistsOne expands the input call arguments into a comprehension that returns true if exactly // one of the elements in the range match the predicate expressions: // .exists_one(, ) -func MakeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeExistsOne(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { return makeQuantifier(quantifierExistsOne, eh, target, args) } @@ -324,14 +305,14 @@ func MakeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*ex // // In the second form only iterVar values which return true when provided to the predicate expression // are transformed. -func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeMap(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - return nil, eh.NewError(args[0].GetId(), "argument is not an identifier") + return nil, eh.NewError(args[0].ID(), "argument is not an identifier") } - var fn *exprpb.Expr - var filter *exprpb.Expr + var fn ast.Expr + var filter ast.Expr if len(args) == 3 { filter = args[1] @@ -341,84 +322,85 @@ func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.E fn = args[1] } - accuExpr := eh.Ident(AccumulatorName) + accuExpr := eh.NewAccuIdent() init := eh.NewList() - condition := eh.LiteralBool(true) - step := eh.GlobalCall(operators.Add, accuExpr, eh.NewList(fn)) + condition := eh.NewLiteral(types.True) + step := eh.NewCall(operators.Add, accuExpr, eh.NewList(fn)) if filter != nil { - step = eh.GlobalCall(operators.Conditional, filter, step, accuExpr) + step = eh.NewCall(operators.Conditional, filter, step, accuExpr) } - return eh.Fold(v, target, AccumulatorName, init, condition, step, accuExpr), nil + return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, accuExpr), nil } // MakeFilter expands the input call arguments into a comprehension which produces a list which contains // only elements which match the provided predicate expression: // .filter(, ) -func MakeFilter(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeFilter(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - return nil, eh.NewError(args[0].GetId(), "argument is not an identifier") + return nil, eh.NewError(args[0].ID(), "argument is not an identifier") } filter := args[1] - accuExpr := eh.Ident(AccumulatorName) + accuExpr := eh.NewAccuIdent() init := eh.NewList() - condition := eh.LiteralBool(true) - step := eh.GlobalCall(operators.Add, accuExpr, eh.NewList(args[0])) - step = eh.GlobalCall(operators.Conditional, filter, step, accuExpr) - return eh.Fold(v, target, AccumulatorName, init, condition, step, accuExpr), nil + condition := eh.NewLiteral(types.True) + step := eh.NewCall(operators.Add, accuExpr, eh.NewList(args[0])) + step = eh.NewCall(operators.Conditional, filter, step, accuExpr) + return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, accuExpr), nil } // MakeHas expands the input call arguments into a presence test, e.g. has(.field) -func MakeHas(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { - if s, ok := args[0].ExprKind.(*exprpb.Expr_SelectExpr); ok { - return eh.PresenceTest(s.SelectExpr.GetOperand(), s.SelectExpr.GetField()), nil +func MakeHas(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + if args[0].Kind() == ast.SelectKind { + s := args[0].AsSelect() + return eh.NewPresenceTest(s.Operand(), s.FieldName()), nil } - return nil, eh.NewError(args[0].GetId(), "invalid argument to has() macro") + return nil, eh.NewError(args[0].ID(), "invalid argument to has() macro") } -func makeQuantifier(kind quantifierKind, eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func makeQuantifier(kind quantifierKind, eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - return nil, eh.NewError(args[0].GetId(), "argument must be a simple name") + return nil, eh.NewError(args[0].ID(), "argument must be a simple name") } - var init *exprpb.Expr - var condition *exprpb.Expr - var step *exprpb.Expr - var result *exprpb.Expr + var init ast.Expr + var condition ast.Expr + var step ast.Expr + var result ast.Expr switch kind { case quantifierAll: - init = eh.LiteralBool(true) - condition = eh.GlobalCall(operators.NotStrictlyFalse, eh.AccuIdent()) - step = eh.GlobalCall(operators.LogicalAnd, eh.AccuIdent(), args[1]) - result = eh.AccuIdent() + init = eh.NewLiteral(types.True) + condition = eh.NewCall(operators.NotStrictlyFalse, eh.NewAccuIdent()) + step = eh.NewCall(operators.LogicalAnd, eh.NewAccuIdent(), args[1]) + result = eh.NewAccuIdent() case quantifierExists: - init = eh.LiteralBool(false) - condition = eh.GlobalCall( + init = eh.NewLiteral(types.False) + condition = eh.NewCall( operators.NotStrictlyFalse, - eh.GlobalCall(operators.LogicalNot, eh.AccuIdent())) - step = eh.GlobalCall(operators.LogicalOr, eh.AccuIdent(), args[1]) - result = eh.AccuIdent() + eh.NewCall(operators.LogicalNot, eh.NewAccuIdent())) + step = eh.NewCall(operators.LogicalOr, eh.NewAccuIdent(), args[1]) + result = eh.NewAccuIdent() case quantifierExistsOne: - zeroExpr := eh.LiteralInt(0) - oneExpr := eh.LiteralInt(1) + zeroExpr := eh.NewLiteral(types.Int(0)) + oneExpr := eh.NewLiteral(types.Int(1)) init = zeroExpr - condition = eh.LiteralBool(true) - step = eh.GlobalCall(operators.Conditional, args[1], - eh.GlobalCall(operators.Add, eh.AccuIdent(), oneExpr), eh.AccuIdent()) - result = eh.GlobalCall(operators.Equals, eh.AccuIdent(), oneExpr) + condition = eh.NewLiteral(types.True) + step = eh.NewCall(operators.Conditional, args[1], + eh.NewCall(operators.Add, eh.NewAccuIdent(), oneExpr), eh.NewAccuIdent()) + result = eh.NewCall(operators.Equals, eh.NewAccuIdent(), oneExpr) default: - return nil, eh.NewError(args[0].GetId(), fmt.Sprintf("unrecognized quantifier '%v'", kind)) + return nil, eh.NewError(args[0].ID(), fmt.Sprintf("unrecognized quantifier '%v'", kind)) } - return eh.Fold(v, target, AccumulatorName, init, condition, step, result), nil + return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, result), nil } -func extractIdent(e *exprpb.Expr) (string, bool) { - switch e.ExprKind.(type) { - case *exprpb.Expr_IdentExpr: - return e.GetIdentExpr().GetName(), true +func extractIdent(e ast.Expr) (string, bool) { + switch e.Kind() { + case ast.IdentKind: + return e.AsIdent(), true } return "", false } diff --git a/vendor/github.com/google/cel-go/parser/parser.go b/vendor/github.com/google/cel-go/parser/parser.go index 109326a9399..c408bd35531 100644 --- a/vendor/github.com/google/cel-go/parser/parser.go +++ b/vendor/github.com/google/cel-go/parser/parser.go @@ -26,12 +26,11 @@ import ( antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/runes" + "github.com/google/cel-go/common/types" "github.com/google/cel-go/parser/gen" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - structpb "google.golang.org/protobuf/types/known/structpb" ) // Parser encapsulates the context necessary to perform parsing for different expressions. @@ -88,11 +87,13 @@ func mustNewParser(opts ...Option) *Parser { } // Parse parses the expression represented by source and returns the result. -func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) { +func (p *Parser) Parse(source common.Source) (*ast.AST, *common.Errors) { errs := common.NewErrors(source) + fac := ast.NewExprFactory() impl := parser{ errors: &parseErrors{errs}, - helper: newParserHelper(source), + exprFactory: fac, + helper: newParserHelper(source, fac), macros: p.macros, maxRecursionDepth: p.maxRecursionDepth, errorReportingLimit: p.errorReportingLimit, @@ -106,18 +107,15 @@ func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors if !ok { buf = runes.NewBuffer(source.Content()) } - var e *exprpb.Expr + var out ast.Expr if buf.Len() > p.expressionSizeCodePointLimit { - e = impl.reportError(common.NoLocation, + out = impl.reportError(common.NoLocation, "expression code point size exceeds limit: size: %d, limit %d", buf.Len(), p.expressionSizeCodePointLimit) } else { - e = impl.parse(buf, source.Description()) + out = impl.parse(buf, source.Description()) } - return &exprpb.ParsedExpr{ - Expr: e, - SourceInfo: impl.helper.getSourceInfo(), - }, errs + return ast.NewAST(out, impl.helper.getSourceInfo()), errs } // reservedIds are not legal to use as variables. We exclude them post-parse, as they *are* valid @@ -150,7 +148,7 @@ var reservedIds = map[string]struct{}{ // This function calls ParseWithMacros with AllMacros. // // Deprecated: Use NewParser().Parse() instead. -func Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) { +func Parse(source common.Source) (*ast.AST, *common.Errors) { return mustNewParser(Macros(AllMacros...)).Parse(source) } @@ -287,6 +285,7 @@ var _ antlr.ErrorStrategy = &recoveryLimitErrorStrategy{} type parser struct { gen.BaseCELVisitor errors *parseErrors + exprFactory ast.ExprFactory helper *parserHelper macros map[string]Macro recursionDepth int @@ -320,7 +319,7 @@ var ( } ) -func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr { +func (p *parser) parse(expr runes.Buffer, desc string) ast.Expr { // TODO: get rid of these pools once https://github.com/antlr/antlr4/pull/3571 is in a release lexer := lexerPool.Get().(*gen.CELLexer) prsr := parserPool.Get().(*gen.CELParser) @@ -373,7 +372,7 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr { } }() - return p.Visit(prsr.Start()).(*exprpb.Expr) + return p.Visit(prsr.Start()).(ast.Expr) } // Visitor implementations. @@ -470,26 +469,26 @@ func (p *parser) VisitStart(ctx *gen.StartContext) any { // Visit a parse tree produced by CELParser#expr. func (p *parser) VisitExpr(ctx *gen.ExprContext) any { - result := p.Visit(ctx.GetE()).(*exprpb.Expr) + result := p.Visit(ctx.GetE()).(ast.Expr) if ctx.GetOp() == nil { return result } opID := p.helper.id(ctx.GetOp()) - ifTrue := p.Visit(ctx.GetE1()).(*exprpb.Expr) - ifFalse := p.Visit(ctx.GetE2()).(*exprpb.Expr) + ifTrue := p.Visit(ctx.GetE1()).(ast.Expr) + ifFalse := p.Visit(ctx.GetE2()).(ast.Expr) return p.globalCallOrMacro(opID, operators.Conditional, result, ifTrue, ifFalse) } // Visit a parse tree produced by CELParser#conditionalOr. func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any { - result := p.Visit(ctx.GetE()).(*exprpb.Expr) + result := p.Visit(ctx.GetE()).(ast.Expr) l := p.newLogicManager(operators.LogicalOr, result) rest := ctx.GetE1() for i, op := range ctx.GetOps() { if i >= len(rest) { return p.reportError(ctx, "unexpected character, wanted '||'") } - next := p.Visit(rest[i]).(*exprpb.Expr) + next := p.Visit(rest[i]).(ast.Expr) opID := p.helper.id(op) l.addTerm(opID, next) } @@ -498,14 +497,14 @@ func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any { // Visit a parse tree produced by CELParser#conditionalAnd. func (p *parser) VisitConditionalAnd(ctx *gen.ConditionalAndContext) any { - result := p.Visit(ctx.GetE()).(*exprpb.Expr) + result := p.Visit(ctx.GetE()).(ast.Expr) l := p.newLogicManager(operators.LogicalAnd, result) rest := ctx.GetE1() for i, op := range ctx.GetOps() { if i >= len(rest) { return p.reportError(ctx, "unexpected character, wanted '&&'") } - next := p.Visit(rest[i]).(*exprpb.Expr) + next := p.Visit(rest[i]).(ast.Expr) opID := p.helper.id(op) l.addTerm(opID, next) } @@ -519,9 +518,9 @@ func (p *parser) VisitRelation(ctx *gen.RelationContext) any { opText = ctx.GetOp().GetText() } if op, found := operators.Find(opText); found { - lhs := p.Visit(ctx.Relation(0)).(*exprpb.Expr) + lhs := p.Visit(ctx.Relation(0)).(ast.Expr) opID := p.helper.id(ctx.GetOp()) - rhs := p.Visit(ctx.Relation(1)).(*exprpb.Expr) + rhs := p.Visit(ctx.Relation(1)).(ast.Expr) return p.globalCallOrMacro(opID, op, lhs, rhs) } return p.reportError(ctx, "operator not found") @@ -534,9 +533,9 @@ func (p *parser) VisitCalc(ctx *gen.CalcContext) any { opText = ctx.GetOp().GetText() } if op, found := operators.Find(opText); found { - lhs := p.Visit(ctx.Calc(0)).(*exprpb.Expr) + lhs := p.Visit(ctx.Calc(0)).(ast.Expr) opID := p.helper.id(ctx.GetOp()) - rhs := p.Visit(ctx.Calc(1)).(*exprpb.Expr) + rhs := p.Visit(ctx.Calc(1)).(ast.Expr) return p.globalCallOrMacro(opID, op, lhs, rhs) } return p.reportError(ctx, "operator not found") @@ -552,7 +551,7 @@ func (p *parser) VisitLogicalNot(ctx *gen.LogicalNotContext) any { return p.Visit(ctx.Member()) } opID := p.helper.id(ctx.GetOps()[0]) - target := p.Visit(ctx.Member()).(*exprpb.Expr) + target := p.Visit(ctx.Member()).(ast.Expr) return p.globalCallOrMacro(opID, operators.LogicalNot, target) } @@ -561,13 +560,13 @@ func (p *parser) VisitNegate(ctx *gen.NegateContext) any { return p.Visit(ctx.Member()) } opID := p.helper.id(ctx.GetOps()[0]) - target := p.Visit(ctx.Member()).(*exprpb.Expr) + target := p.Visit(ctx.Member()).(ast.Expr) return p.globalCallOrMacro(opID, operators.Negate, target) } // VisitSelect visits a parse tree produced by CELParser#Select. func (p *parser) VisitSelect(ctx *gen.SelectContext) any { - operand := p.Visit(ctx.Member()).(*exprpb.Expr) + operand := p.Visit(ctx.Member()).(ast.Expr) // Handle the error case where no valid identifier is specified. if ctx.GetId() == nil || ctx.GetOp() == nil { return p.helper.newExpr(ctx) @@ -588,7 +587,7 @@ func (p *parser) VisitSelect(ctx *gen.SelectContext) any { // VisitMemberCall visits a parse tree produced by CELParser#MemberCall. func (p *parser) VisitMemberCall(ctx *gen.MemberCallContext) any { - operand := p.Visit(ctx.Member()).(*exprpb.Expr) + operand := p.Visit(ctx.Member()).(ast.Expr) // Handle the error case where no valid identifier is specified. if ctx.GetId() == nil { return p.helper.newExpr(ctx) @@ -600,13 +599,13 @@ func (p *parser) VisitMemberCall(ctx *gen.MemberCallContext) any { // Visit a parse tree produced by CELParser#Index. func (p *parser) VisitIndex(ctx *gen.IndexContext) any { - target := p.Visit(ctx.Member()).(*exprpb.Expr) + target := p.Visit(ctx.Member()).(ast.Expr) // Handle the error case where no valid identifier is specified. if ctx.GetOp() == nil { return p.helper.newExpr(ctx) } opID := p.helper.id(ctx.GetOp()) - index := p.Visit(ctx.GetIndex()).(*exprpb.Expr) + index := p.Visit(ctx.GetIndex()).(ast.Expr) operator := operators.Index if ctx.GetOpt() != nil { if !p.enableOptionalSyntax { @@ -630,7 +629,7 @@ func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) any { messageName = "." + messageName } objID := p.helper.id(ctx.GetOp()) - entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry) + entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]ast.EntryExpr) return p.helper.newObject(objID, messageName, entries...) } @@ -638,16 +637,16 @@ func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) any { func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext) any { if ctx == nil || ctx.GetFields() == nil { // This is the result of a syntax error handled elswhere, return empty. - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } - result := make([]*exprpb.Expr_CreateStruct_Entry, len(ctx.GetFields())) + result := make([]ast.EntryExpr, len(ctx.GetFields())) cols := ctx.GetCols() vals := ctx.GetValues() for i, f := range ctx.GetFields() { if i >= len(cols) || i >= len(vals) { // This is the result of a syntax error detected elsewhere. - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } initID := p.helper.id(cols[i]) optField := f.(*gen.OptFieldContext) @@ -659,10 +658,10 @@ func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext // The field may be empty due to a prior error. id := optField.IDENTIFIER() if id == nil { - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } fieldName := id.GetText() - value := p.Visit(vals[i]).(*exprpb.Expr) + value := p.Visit(vals[i]).(ast.Expr) field := p.helper.newObjectField(initID, fieldName, value, optional) result[i] = field } @@ -702,9 +701,9 @@ func (p *parser) VisitCreateList(ctx *gen.CreateListContext) any { // Visit a parse tree produced by CELParser#CreateStruct. func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) any { structID := p.helper.id(ctx.GetOp()) - entries := []*exprpb.Expr_CreateStruct_Entry{} + entries := []ast.EntryExpr{} if ctx.GetEntries() != nil { - entries = p.Visit(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry) + entries = p.Visit(ctx.GetEntries()).([]ast.EntryExpr) } return p.helper.newMap(structID, entries...) } @@ -713,17 +712,17 @@ func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) any { func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) any { if ctx == nil || ctx.GetKeys() == nil { // This is the result of a syntax error handled elswhere, return empty. - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } - result := make([]*exprpb.Expr_CreateStruct_Entry, len(ctx.GetCols())) + result := make([]ast.EntryExpr, len(ctx.GetCols())) keys := ctx.GetKeys() vals := ctx.GetValues() for i, col := range ctx.GetCols() { colID := p.helper.id(col) if i >= len(keys) || i >= len(vals) { // This is the result of a syntax error detected elsewhere. - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } optKey := keys[i] optional := optKey.GetOpt() != nil @@ -731,8 +730,8 @@ func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) any p.reportError(optKey, "unsupported syntax '?'") continue } - key := p.Visit(optKey.GetE()).(*exprpb.Expr) - value := p.Visit(vals[i]).(*exprpb.Expr) + key := p.Visit(optKey.GetE()).(ast.Expr) + value := p.Visit(vals[i]).(ast.Expr) entry := p.helper.newMapEntry(colID, key, value, optional) result[i] = entry } @@ -812,30 +811,27 @@ func (p *parser) VisitBoolFalse(ctx *gen.BoolFalseContext) any { // Visit a parse tree produced by CELParser#Null. func (p *parser) VisitNull(ctx *gen.NullContext) any { - return p.helper.newLiteral(ctx, - &exprpb.Constant{ - ConstantKind: &exprpb.Constant_NullValue{ - NullValue: structpb.NullValue_NULL_VALUE}}) + return p.helper.exprFactory.NewLiteral(p.helper.newID(ctx), types.NullValue) } -func (p *parser) visitExprList(ctx gen.IExprListContext) []*exprpb.Expr { +func (p *parser) visitExprList(ctx gen.IExprListContext) []ast.Expr { if ctx == nil { - return []*exprpb.Expr{} + return []ast.Expr{} } return p.visitSlice(ctx.GetE()) } -func (p *parser) visitListInit(ctx gen.IListInitContext) ([]*exprpb.Expr, []int32) { +func (p *parser) visitListInit(ctx gen.IListInitContext) ([]ast.Expr, []int32) { if ctx == nil { - return []*exprpb.Expr{}, []int32{} + return []ast.Expr{}, []int32{} } elements := ctx.GetElems() - result := make([]*exprpb.Expr, len(elements)) + result := make([]ast.Expr, len(elements)) optionals := []int32{} for i, e := range elements { - ex := p.Visit(e.GetE()).(*exprpb.Expr) + ex := p.Visit(e.GetE()).(ast.Expr) if ex == nil { - return []*exprpb.Expr{}, []int32{} + return []ast.Expr{}, []int32{} } result[i] = ex if e.GetOpt() != nil { @@ -849,13 +845,13 @@ func (p *parser) visitListInit(ctx gen.IListInitContext) ([]*exprpb.Expr, []int3 return result, optionals } -func (p *parser) visitSlice(expressions []gen.IExprContext) []*exprpb.Expr { +func (p *parser) visitSlice(expressions []gen.IExprContext) []ast.Expr { if expressions == nil { - return []*exprpb.Expr{} + return []ast.Expr{} } - result := make([]*exprpb.Expr, len(expressions)) + result := make([]ast.Expr, len(expressions)) for i, e := range expressions { - ex := p.Visit(e).(*exprpb.Expr) + ex := p.Visit(e).(ast.Expr) result[i] = ex } return result @@ -870,24 +866,24 @@ func (p *parser) unquote(ctx any, value string, isBytes bool) string { return text } -func (p *parser) newLogicManager(function string, term *exprpb.Expr) *logicManager { +func (p *parser) newLogicManager(function string, term ast.Expr) *logicManager { if p.enableVariadicOperatorASTs { - return newVariadicLogicManager(p.helper, function, term) + return newVariadicLogicManager(p.exprFactory, function, term) } - return newBalancingLogicManager(p.helper, function, term) + return newBalancingLogicManager(p.exprFactory, function, term) } -func (p *parser) reportError(ctx any, format string, args ...any) *exprpb.Expr { +func (p *parser) reportError(ctx any, format string, args ...any) ast.Expr { var location common.Location err := p.helper.newExpr(ctx) switch c := ctx.(type) { case common.Location: location = c case antlr.Token, antlr.ParserRuleContext: - location = p.helper.getLocation(err.GetId()) + location = p.helper.getLocation(err.ID()) } // Provide arguments to the report error. - p.errors.reportErrorAtID(err.GetId(), location, format, args...) + p.errors.reportErrorAtID(err.ID(), location, format, args...) return err } @@ -924,21 +920,21 @@ func (p *parser) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DF // Intentional } -func (p *parser) globalCallOrMacro(exprID int64, function string, args ...*exprpb.Expr) *exprpb.Expr { +func (p *parser) globalCallOrMacro(exprID int64, function string, args ...ast.Expr) ast.Expr { if expr, found := p.expandMacro(exprID, function, nil, args...); found { return expr } return p.helper.newGlobalCall(exprID, function, args...) } -func (p *parser) receiverCallOrMacro(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr { +func (p *parser) receiverCallOrMacro(exprID int64, function string, target ast.Expr, args ...ast.Expr) ast.Expr { if expr, found := p.expandMacro(exprID, function, target, args...); found { return expr } return p.helper.newReceiverCall(exprID, function, target, args...) } -func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) (*exprpb.Expr, bool) { +func (p *parser) expandMacro(exprID int64, function string, target ast.Expr, args ...ast.Expr) (ast.Expr, bool) { macro, found := p.macros[makeMacroKey(function, len(args), target != nil)] if !found { macro, found = p.macros[makeVarArgMacroKey(function, target != nil)] @@ -964,7 +960,7 @@ func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr, return nil, false } if p.populateMacroCalls { - p.helper.addMacroCall(expr.GetId(), function, target, args...) + p.helper.addMacroCall(expr.ID(), function, target, args...) } return expr, true } diff --git a/vendor/github.com/google/cel-go/parser/unparser.go b/vendor/github.com/google/cel-go/parser/unparser.go index c3c40a0dd36..91cf7294475 100644 --- a/vendor/github.com/google/cel-go/parser/unparser.go +++ b/vendor/github.com/google/cel-go/parser/unparser.go @@ -20,9 +20,9 @@ import ( "strconv" "strings" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/types" ) // Unparse takes an input expression and source position information and generates a human-readable @@ -39,7 +39,7 @@ import ( // // This function optionally takes in one or more UnparserOption to alter the unparsing behavior, such as // performing word wrapping on expressions. -func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo, opts ...UnparserOption) (string, error) { +func Unparse(expr ast.Expr, info *ast.SourceInfo, opts ...UnparserOption) (string, error) { unparserOpts := &unparserOption{ wrapOnColumn: defaultWrapOnColumn, wrapAfterColumnLimit: defaultWrapAfterColumnLimit, @@ -68,12 +68,12 @@ func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo, opts ...UnparserOption) // unparser visits an expression to reconstruct a human-readable string from an AST. type unparser struct { str strings.Builder - info *exprpb.SourceInfo + info *ast.SourceInfo options *unparserOption lastWrappedIndex int } -func (un *unparser) visit(expr *exprpb.Expr) error { +func (un *unparser) visit(expr ast.Expr) error { if expr == nil { return errors.New("unsupported expression") } @@ -81,27 +81,29 @@ func (un *unparser) visit(expr *exprpb.Expr) error { if visited || err != nil { return err } - switch expr.GetExprKind().(type) { - case *exprpb.Expr_CallExpr: + switch expr.Kind() { + case ast.CallKind: return un.visitCall(expr) - case *exprpb.Expr_ConstExpr: + case ast.LiteralKind: return un.visitConst(expr) - case *exprpb.Expr_IdentExpr: + case ast.IdentKind: return un.visitIdent(expr) - case *exprpb.Expr_ListExpr: + case ast.ListKind: return un.visitList(expr) - case *exprpb.Expr_SelectExpr: + case ast.MapKind: + return un.visitStructMap(expr) + case ast.SelectKind: return un.visitSelect(expr) - case *exprpb.Expr_StructExpr: - return un.visitStruct(expr) + case ast.StructKind: + return un.visitStructMsg(expr) default: return fmt.Errorf("unsupported expression: %v", expr) } } -func (un *unparser) visitCall(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - fun := c.GetFunction() +func (un *unparser) visitCall(expr ast.Expr) error { + c := expr.AsCall() + fun := c.FunctionName() switch fun { // ternary operator case operators.Conditional: @@ -141,10 +143,10 @@ func (un *unparser) visitCall(expr *exprpb.Expr) error { } } -func (un *unparser) visitCallBinary(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - fun := c.GetFunction() - args := c.GetArgs() +func (un *unparser) visitCallBinary(expr ast.Expr) error { + c := expr.AsCall() + fun := c.FunctionName() + args := c.Args() lhs := args[0] // add parens if the current operator is lower precedence than the lhs expr operator. lhsParen := isComplexOperatorWithRespectTo(fun, lhs) @@ -168,9 +170,9 @@ func (un *unparser) visitCallBinary(expr *exprpb.Expr) error { return un.visitMaybeNested(rhs, rhsParen) } -func (un *unparser) visitCallConditional(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - args := c.GetArgs() +func (un *unparser) visitCallConditional(expr ast.Expr) error { + c := expr.AsCall() + args := c.Args() // add parens if operand is a conditional itself. nested := isSamePrecedence(operators.Conditional, args[0]) || isComplexOperator(args[0]) @@ -196,13 +198,13 @@ func (un *unparser) visitCallConditional(expr *exprpb.Expr) error { return un.visitMaybeNested(args[2], nested) } -func (un *unparser) visitCallFunc(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - fun := c.GetFunction() - args := c.GetArgs() - if c.GetTarget() != nil { - nested := isBinaryOrTernaryOperator(c.GetTarget()) - err := un.visitMaybeNested(c.GetTarget(), nested) +func (un *unparser) visitCallFunc(expr ast.Expr) error { + c := expr.AsCall() + fun := c.FunctionName() + args := c.Args() + if c.IsMemberFunction() { + nested := isBinaryOrTernaryOperator(c.Target()) + err := un.visitMaybeNested(c.Target(), nested) if err != nil { return err } @@ -223,17 +225,17 @@ func (un *unparser) visitCallFunc(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitCallIndex(expr *exprpb.Expr) error { +func (un *unparser) visitCallIndex(expr ast.Expr) error { return un.visitCallIndexInternal(expr, "[") } -func (un *unparser) visitCallOptIndex(expr *exprpb.Expr) error { +func (un *unparser) visitCallOptIndex(expr ast.Expr) error { return un.visitCallIndexInternal(expr, "[?") } -func (un *unparser) visitCallIndexInternal(expr *exprpb.Expr, op string) error { - c := expr.GetCallExpr() - args := c.GetArgs() +func (un *unparser) visitCallIndexInternal(expr ast.Expr, op string) error { + c := expr.AsCall() + args := c.Args() nested := isBinaryOrTernaryOperator(args[0]) err := un.visitMaybeNested(args[0], nested) if err != nil { @@ -248,10 +250,10 @@ func (un *unparser) visitCallIndexInternal(expr *exprpb.Expr, op string) error { return nil } -func (un *unparser) visitCallUnary(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - fun := c.GetFunction() - args := c.GetArgs() +func (un *unparser) visitCallUnary(expr ast.Expr) error { + c := expr.AsCall() + fun := c.FunctionName() + args := c.Args() unmangled, found := operators.FindReverse(fun) if !found { return fmt.Errorf("cannot unmangle operator: %s", fun) @@ -261,35 +263,34 @@ func (un *unparser) visitCallUnary(expr *exprpb.Expr) error { return un.visitMaybeNested(args[0], nested) } -func (un *unparser) visitConst(expr *exprpb.Expr) error { - c := expr.GetConstExpr() - switch c.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - un.str.WriteString(strconv.FormatBool(c.GetBoolValue())) - case *exprpb.Constant_BytesValue: +func (un *unparser) visitConst(expr ast.Expr) error { + val := expr.AsLiteral() + switch val := val.(type) { + case types.Bool: + un.str.WriteString(strconv.FormatBool(bool(val))) + case types.Bytes: // bytes constants are surrounded with b"" - b := c.GetBytesValue() un.str.WriteString(`b"`) - un.str.WriteString(bytesToOctets(b)) + un.str.WriteString(bytesToOctets([]byte(val))) un.str.WriteString(`"`) - case *exprpb.Constant_DoubleValue: + case types.Double: // represent the float using the minimum required digits - d := strconv.FormatFloat(c.GetDoubleValue(), 'g', -1, 64) + d := strconv.FormatFloat(float64(val), 'g', -1, 64) un.str.WriteString(d) if !strings.Contains(d, ".") { un.str.WriteString(".0") } - case *exprpb.Constant_Int64Value: - i := strconv.FormatInt(c.GetInt64Value(), 10) + case types.Int: + i := strconv.FormatInt(int64(val), 10) un.str.WriteString(i) - case *exprpb.Constant_NullValue: + case types.Null: un.str.WriteString("null") - case *exprpb.Constant_StringValue: + case types.String: // strings will be double quoted with quotes escaped. - un.str.WriteString(strconv.Quote(c.GetStringValue())) - case *exprpb.Constant_Uint64Value: + un.str.WriteString(strconv.Quote(string(val))) + case types.Uint: // uint literals have a 'u' suffix. - ui := strconv.FormatUint(c.GetUint64Value(), 10) + ui := strconv.FormatUint(uint64(val), 10) un.str.WriteString(ui) un.str.WriteString("u") default: @@ -298,16 +299,16 @@ func (un *unparser) visitConst(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitIdent(expr *exprpb.Expr) error { - un.str.WriteString(expr.GetIdentExpr().GetName()) +func (un *unparser) visitIdent(expr ast.Expr) error { + un.str.WriteString(expr.AsIdent()) return nil } -func (un *unparser) visitList(expr *exprpb.Expr) error { - l := expr.GetListExpr() - elems := l.GetElements() +func (un *unparser) visitList(expr ast.Expr) error { + l := expr.AsList() + elems := l.Elements() optIndices := make(map[int]bool, len(elems)) - for _, idx := range l.GetOptionalIndices() { + for _, idx := range l.OptionalIndices() { optIndices[int(idx)] = true } un.str.WriteString("[") @@ -327,20 +328,20 @@ func (un *unparser) visitList(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitOptSelect(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - args := c.GetArgs() +func (un *unparser) visitOptSelect(expr ast.Expr) error { + c := expr.AsCall() + args := c.Args() operand := args[0] - field := args[1].GetConstExpr().GetStringValue() - return un.visitSelectInternal(operand, false, ".?", field) + field := args[1].AsLiteral().(types.String) + return un.visitSelectInternal(operand, false, ".?", string(field)) } -func (un *unparser) visitSelect(expr *exprpb.Expr) error { - sel := expr.GetSelectExpr() - return un.visitSelectInternal(sel.GetOperand(), sel.GetTestOnly(), ".", sel.GetField()) +func (un *unparser) visitSelect(expr ast.Expr) error { + sel := expr.AsSelect() + return un.visitSelectInternal(sel.Operand(), sel.IsTestOnly(), ".", sel.FieldName()) } -func (un *unparser) visitSelectInternal(operand *exprpb.Expr, testOnly bool, op string, field string) error { +func (un *unparser) visitSelectInternal(operand ast.Expr, testOnly bool, op string, field string) error { // handle the case when the select expression was generated by the has() macro. if testOnly { un.str.WriteString("has(") @@ -358,34 +359,25 @@ func (un *unparser) visitSelectInternal(operand *exprpb.Expr, testOnly bool, op return nil } -func (un *unparser) visitStruct(expr *exprpb.Expr) error { - s := expr.GetStructExpr() - // If the message name is non-empty, then this should be treated as message construction. - if s.GetMessageName() != "" { - return un.visitStructMsg(expr) - } - // Otherwise, build a map. - return un.visitStructMap(expr) -} - -func (un *unparser) visitStructMsg(expr *exprpb.Expr) error { - m := expr.GetStructExpr() - entries := m.GetEntries() - un.str.WriteString(m.GetMessageName()) +func (un *unparser) visitStructMsg(expr ast.Expr) error { + m := expr.AsStruct() + fields := m.Fields() + un.str.WriteString(m.TypeName()) un.str.WriteString("{") - for i, entry := range entries { - f := entry.GetFieldKey() - if entry.GetOptionalEntry() { + for i, f := range fields { + field := f.AsStructField() + f := field.Name() + if field.IsOptional() { un.str.WriteString("?") } un.str.WriteString(f) un.str.WriteString(": ") - v := entry.GetValue() + v := field.Value() err := un.visit(v) if err != nil { return err } - if i < len(entries)-1 { + if i < len(fields)-1 { un.str.WriteString(", ") } } @@ -393,13 +385,14 @@ func (un *unparser) visitStructMsg(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitStructMap(expr *exprpb.Expr) error { - m := expr.GetStructExpr() - entries := m.GetEntries() +func (un *unparser) visitStructMap(expr ast.Expr) error { + m := expr.AsMap() + entries := m.Entries() un.str.WriteString("{") - for i, entry := range entries { - k := entry.GetMapKey() - if entry.GetOptionalEntry() { + for i, e := range entries { + entry := e.AsMapEntry() + k := entry.Key() + if entry.IsOptional() { un.str.WriteString("?") } err := un.visit(k) @@ -407,7 +400,7 @@ func (un *unparser) visitStructMap(expr *exprpb.Expr) error { return err } un.str.WriteString(": ") - v := entry.GetValue() + v := entry.Value() err = un.visit(v) if err != nil { return err @@ -420,16 +413,15 @@ func (un *unparser) visitStructMap(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitMaybeMacroCall(expr *exprpb.Expr) (bool, error) { - macroCalls := un.info.GetMacroCalls() - call, found := macroCalls[expr.GetId()] +func (un *unparser) visitMaybeMacroCall(expr ast.Expr) (bool, error) { + call, found := un.info.GetMacroCall(expr.ID()) if !found { return false, nil } return true, un.visit(call) } -func (un *unparser) visitMaybeNested(expr *exprpb.Expr, nested bool) error { +func (un *unparser) visitMaybeNested(expr ast.Expr, nested bool) error { if nested { un.str.WriteString("(") } @@ -453,12 +445,12 @@ func isLeftRecursive(op string) bool { // precedence of the (possible) operation represented in the input Expr. // // If the expr is not a Call, the result is false. -func isSamePrecedence(op string, expr *exprpb.Expr) bool { - if expr.GetCallExpr() == nil { +func isSamePrecedence(op string, expr ast.Expr) bool { + if expr.Kind() != ast.CallKind { return false } - c := expr.GetCallExpr() - other := c.GetFunction() + c := expr.AsCall() + other := c.FunctionName() return operators.Precedence(op) == operators.Precedence(other) } @@ -466,16 +458,16 @@ func isSamePrecedence(op string, expr *exprpb.Expr) bool { // than the (possible) operation represented in the input Expr. // // If the expr is not a Call, the result is false. -func isLowerPrecedence(op string, expr *exprpb.Expr) bool { - c := expr.GetCallExpr() - other := c.GetFunction() +func isLowerPrecedence(op string, expr ast.Expr) bool { + c := expr.AsCall() + other := c.FunctionName() return operators.Precedence(op) < operators.Precedence(other) } // Indicates whether the expr is a complex operator, i.e., a call expression // with 2 or more arguments. -func isComplexOperator(expr *exprpb.Expr) bool { - if expr.GetCallExpr() != nil && len(expr.GetCallExpr().GetArgs()) >= 2 { +func isComplexOperator(expr ast.Expr) bool { + if expr.Kind() == ast.CallKind && len(expr.AsCall().Args()) >= 2 { return true } return false @@ -484,19 +476,19 @@ func isComplexOperator(expr *exprpb.Expr) bool { // Indicates whether it is a complex operation compared to another. // expr is *not* considered complex if it is not a call expression or has // less than two arguments, or if it has a higher precedence than op. -func isComplexOperatorWithRespectTo(op string, expr *exprpb.Expr) bool { - if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 { +func isComplexOperatorWithRespectTo(op string, expr ast.Expr) bool { + if expr.Kind() != ast.CallKind || len(expr.AsCall().Args()) < 2 { return false } return isLowerPrecedence(op, expr) } // Indicate whether this is a binary or ternary operator. -func isBinaryOrTernaryOperator(expr *exprpb.Expr) bool { - if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 { +func isBinaryOrTernaryOperator(expr ast.Expr) bool { + if expr.Kind() != ast.CallKind || len(expr.AsCall().Args()) < 2 { return false } - _, isBinaryOp := operators.FindReverseBinaryOperator(expr.GetCallExpr().GetFunction()) + _, isBinaryOp := operators.FindReverseBinaryOperator(expr.AsCall().FunctionName()) return isBinaryOp || isSamePrecedence(operators.Conditional, expr) } diff --git a/vendor/modules.txt b/vendor/modules.txt index 340e452b7f7..4857fc5cc30 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -388,7 +388,7 @@ github.com/golang/protobuf/ptypes/duration github.com/golang/protobuf/ptypes/struct github.com/golang/protobuf/ptypes/timestamp github.com/golang/protobuf/ptypes/wrappers -# github.com/google/cel-go v0.17.6 +# github.com/google/cel-go v0.18.0 ## explicit; go 1.18 github.com/google/cel-go/cel github.com/google/cel-go/checker @@ -895,7 +895,7 @@ google.golang.org/appengine/urlfetch # google.golang.org/genproto v0.0.0-20230803162519-f966b187b2e5 ## explicit; go 1.19 google.golang.org/genproto/internal -# google.golang.org/genproto/googleapis/api v0.0.0-20230726155614-23370e0ffb3e +# google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 ## explicit; go 1.19 google.golang.org/genproto/googleapis/api google.golang.org/genproto/googleapis/api/annotations