Skip to content

Commit

Permalink
chore(deps): Upgrade CEL (cerbos#2412)
Browse files Browse the repository at this point in the history
Implements cerbos#2379, except optional
types, which will be in a separate PR. ePDP changes are coming, too.

---------

Signed-off-by: Dennis Buduev <dbuduev@users.noreply.github.com>
Co-authored-by: Charith Ellawala <charithe@users.noreply.github.com>
  • Loading branch information
dbuduev and charithe authored Jan 8, 2025
1 parent d0c26dd commit 5559df2
Show file tree
Hide file tree
Showing 10 changed files with 666 additions and 68 deletions.
46 changes: 38 additions & 8 deletions docs/modules/policies/pages/conditions.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ NOTE: The IP address functions are Cerbos-specific extensions to CEL.
"attr": {
"id": "125",
"teams": ["design", "communications", "product", "commercial"],
"limits": {
"design": 10,
"product": 25
},
"clients": {
"acme": {"active": true},
"bb inc": {"active": true}
Expand All @@ -373,17 +377,28 @@ NOTE: The IP address functions are Cerbos-specific extensions to CEL.
| Operator/Function | Description | Example
| + | Concatenates lists | P.attr.teams + ["design", "engineering"]
| [] | Index into a list or a map | P.attr.teams[0] == "design" && P.attr.clients["acme"]["active"] == true
| all | Check whether all elements in a list match the predicate | P.attr.teams.all(t, size(t) > 3)
| except | Produces the set difference of two lists | P.attr.teams.except(["design", "engineering"]) == ["communications", "product", "commercial"]
| exists | Check whether at least one element matching the predicate exists | P.attr.teams.exists(t, t.startsWith("comm"))
| exists_one | Check that only one element matching the predicate exists | P.attr.teams.exists_one(t, t.startsWith("comm")) == false
| filter | Filter a list using the predicate | size(P.attr.teams.filter(t, t.matches("^comm"))) == 2
| all | Check whether all elements in a list match the predicate. | P.attr.teams.all(t, size(t) > 3) && [1, 2, 3].all(i, j, i < j)
| distinct | Returns the distinct elements of a list | [1, 2, 2, 3, 3, 3].distinct() == [1, 2, 3]
| except | Produces the set difference of two lists | P.attr.teams.except(["design", "engineering"]) == ["communications", "product", "commercial"]
| exists | Check whether at least one element matching the predicate exists in a list or map. | P.attr.teams.exists(t, t.startsWith("comm")) && P.attr.limits.exists(k, v, k == "design" && v > 0)
| exists_one | Check that only one element matching the predicate exists. | P.attr.teams.exists_one(t, t.startsWith("comm")) == false && P.attr.limits.exists_one(k, v, k == "design" && v > 0) == false
| filter | Filter a list using the predicate. | size(P.attr.teams.filter(t, t.matches("^comm"))) == 2
| flatten | Flattens a list. If an optional depth is provided, the list is flattened to the specified level | [1,2,[],[],[3,4]].flatten() == [1, 2, 3, 4] && [1,[2,[3,[4]]]].flatten(2) == [1, 2, 3, [4]]
| hasIntersection| Checks whether the lists have at least one common element | hasIntersection(["design", "engineering"], P.attr.teams)
| in | Check whether the given element is contained in the list or map | ("design" in P.attr.teams) && ("acme" in P.attr.clients)
| intersect| Produces the set intersection of two lists | intersect(["design", "engineering"], P.attr.teams) == ["design"]
| isSubset| Checks whether the list is a subset of another list | ["design", "engineering"].isSubset(P.attr.teams) == false
| isSubset | Checks whether the list is a subset of another list | ["design", "engineering"].isSubset(P.attr.teams) == false
| lists.range | Returns a list of integers from 0 to n-1 | lists.range(5) == [0, 1, 2, 3, 4]
| map | Transform each element in a list | "DESIGN" in P.attr.teams.map(t, t.upperAscii())
| reverse | Returns the elements of a list in reverse order | [5, 3, 1, 2].reverse() == [2, 1, 3, 5]
| size | Number of elements in a list or map | size(P.attr.teams) == 4 && size(P.attr.clients) == 2
| slice | Returns a new sub-list using the indexes provided | [1,2,3,4].slice(1, 3) == [2, 3]
| sort | Sorts a list with comparable elements | [3, 2, 1].sort() == [1, 2, 3]
| sortBy | Sorts a list by a key value, i.e., the order is determined by the result of an expression applied to each element of the list | [{ "name": "foo", "score": 0 },{ "name": "bar", "score": -10 },{ "name": "baz", "score": 1000 }].sortBy(e, e.score).map(e, e.name) == ["bar", "foo", "baz"]
| transformList | Converts a map or a list into a list value. The output expression determines the contents of the output list. Elements in the list may optionally be filtered | [1, 2, 3].transformList(i, v, i > 0, 2 * v) == [4, 6] && +
[1, 2, 3].transformList(i, v, 2 * v) == [2, 4, 6]
| transformMap | Converts a map or a list into a map value. The key remains unchanged and only the value is changed. | [1, 2, 3].transformMap(i, v, i > 0, 2 * v) == {1: 4, 2: 6}
| transformMapEntry | Converts a map or a list into a map value; however, this transform expects the entry expression be a map literal. Elements in the map may optionally be filtered | {'greeting': 'hello'}.transformMapEntry(k, v, {v: k}) == {'hello': 'greeting'}
|===


Expand All @@ -393,8 +408,23 @@ NOTE: The IP address functions are Cerbos-specific extensions to CEL.
[%header,cols=".^1m,.^2,4m",grid=rows]
|===
| Function | Description | Example
| math.greatest | Get the greatest valued number present in the arguments | math.greatest([1, 3, 5]) == 5
| math.least | Get the least valued number present in the arguments | math.least([1, 3, 5]) == 1
| math.abs | Returns the absolute value of the numeric type provided as input | math.abs(1.2) == 1.2 && math.abs(-2) == 2
| math.bitAnd | Performs a bitwise-AND operation over two int or uint values | math.bitAnd(3u, 2u) == 2u && math.bitAnd(3, 5) == 3 && math.bitAnd(-3, -5) == -7
| math.bitNot | Function which accepts a single int or uint and performs a bitwise-NOT ones-complement of the given binary value | math.bitNot(1) == -1 && math.bitNot(-1) == 0 && math.bitNot(0u) == 18446744073709551615u
| math.bitOr | Performs a bitwise-OR operation over two int or uint values | math.bitOr(1u, 2u) == 3u && math.bitOr(-2, -4) == -2
| math.bitShiftLeft | Perform a left shift of bits on the first parameter, by the amount of bits specified in the second parameter. The first parameter is either a uint or an int. The second parameter must be an int | math.bitShiftLeft(1, 2) == 4 && math.bitShiftLeft(-1, 2) == -4 && math.bitShiftLeft(1u, 2) == 4u && math.bitShiftLeft(1u, 200) == 0u
| math.bitShiftRight | Perform a right shift of bits on the first parameter, by the amount of bits specified in the second parameter. The first parameter is either a uint or an int. The second parameter must be an int | math.bitShiftRight(1024, 2) == 256 && math.bitShiftRight(1024u, 2) == 256u && math.bitShiftLeft(1024u, 64) == 0u
| math.bitXor | Performs a bitwise-XOR operation over two int or uint values | math.bitXor(3u, 5u) == 6u && math.bitXor(1, 3) == 2
| math.ceil | Compute the ceiling of a double value | math.ceil(1.2) == 2.0 && math.ceil(-1.2) == -1.0
| math.floor | Compute the floor of a double value | math.floor(1.2) == 1.0 && math.floor(-1.2) == -2.0
| math.greatest | Get the greatest valued number present in the arguments | math.greatest([1, 3, 5]) == 5 && math.greatest(1, 3, 5) == 5
| math.isFinite | Returns true if the value is a finite number | !math.isFinite(0.0/0.0) && math.isFinite(1.2)
| math.isInf | Returns true if the input double value is -Inf or +Inf | math.isInf(1.0/0.0) && !math.isInf(1.2)
| math.isNaN | Returns true if the input double value is NaN, false otherwise | math.isNaN(0.0/0.0) && !math.isNaN(1.2)
| math.least | Get the least valued number present in the arguments | math.least([1, 3, 5]) == 1 && math.least(1, 3, 5) == 1
| math.round | Rounds the double value to the nearest whole number with ties rounding away from zero, e.g. 1.5 -> 2.0, -1.5 -> -2.0 | math.round(1.2) == 1.0 && math.round(1.5) == 2.0 && math.round(-1.5) == -2.0
| math.sign | Returns the sign of the numeric type, either -1, 0, 1 | math.sign(1.2) == 1 && math.sign(-2) == -1 && math.sign(0) == 0
| math.trunc | Truncates the fractional portion of the double value | math.trunc(1.2) == 1.0 && math.trunc(-1.2) == -1.0
|===


Expand Down
3 changes: 3 additions & 0 deletions internal/conditions/cel.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ var (
}

StdEnvOptions = []cel.EnvOption{
ext.TwoVarComprehensions(),
cel.CrossTypeNumericComparisons(true),
cel.Types(&enginev1.Request{}, &enginev1.Request_Principal{}, &enginev1.Request_Resource{}, &enginev1.Runtime{}),
cel.Declarations(StdEnvDecls...),
ext.Lists(),
ext.Bindings(),
ext.Strings(),
ext.Encoders(),
ext.Math(),
Expand Down
126 changes: 84 additions & 42 deletions internal/engine/planner/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ const (
Index = "index"
All = "all"
Filter = "filter"
TransformMap = "transformMap"
TransformMapEntry = "transformMapEntry"
TransformList = "transformList"
Exists = "exists"
ExistsOne = "exists_one"
Map = "map"
Expand Down Expand Up @@ -158,7 +161,7 @@ func replaceVarsGen(e *exprpb.Expr, f replaceVarsFunc) (output *exprpb.Expr, err

// This functions wraps references to known resource attributes in an `id` function call, which simply returns its argument.
// E.g. Replace R.attr.field1 with id(R.attr.field1) iif R.attr.field1 is passed in the request to the Query Planner API.
// This trick is necessary evaluate expression like `P.attr.struct1[R.attr.field1]`, otherwise CEL tries to use `R.attr.field1`
// This trick is necessary to evaluate expression like `P.attr.struct1[R.attr.field1]`, otherwise CEL tries to use `R.attr.field1`
// as a qualifier for `P.attr.struct1` and produces the error https://github.com/cerbos/cerbos/issues/1340
func replaceResourceVals(e *exprpb.Expr, vals map[string]*structpb.Value) (output *exprpb.Expr, err error) {
return replaceVarsGen(e, func(ex *exprpb.Expr) (output *exprpb.Expr, matched bool, err error) {
Expand Down Expand Up @@ -312,7 +315,7 @@ func mkListExpr(elems []*exprpb.Expr) *exprpb.Expr {
}
}

func mkExprOpExpr(op string, args ...*enginev1.PlanResourcesFilter_Expression_Operand) *enginev1.PlanResourcesFilter_Expression_Operand_Expression {
func mkExprOpExpr(op string, args ...*enginev1.PlanResourcesFilter_Expression_Operand) *exprOpExpr {
return &enginev1.PlanResourcesFilter_Expression_Operand_Expression{
Expression: &enginev1.PlanResourcesFilter_Expression{Operator: op, Operands: args},
}
Expand All @@ -322,31 +325,32 @@ func buildExpr(expr *exprpb.Expr, acc *enginev1.PlanResourcesFilter_Expression_O
return buildExprImpl(expr, acc, nil)
}

type (
exprOp = enginev1.PlanResourcesFilter_Expression_Operand
exprOpExpr = enginev1.PlanResourcesFilter_Expression_Operand_Expression
exprOpValue = enginev1.PlanResourcesFilter_Expression_Operand_Value
exprOpVar = enginev1.PlanResourcesFilter_Expression_Operand_Variable
)

func buildExprImpl(cur *exprpb.Expr, acc *enginev1.PlanResourcesFilter_Expression_Operand, parent *exprpb.Expr) error {
type (
ExprOp = enginev1.PlanResourcesFilter_Expression_Operand
ExprOpExpr = enginev1.PlanResourcesFilter_Expression_Operand_Expression
ExprOpValue = enginev1.PlanResourcesFilter_Expression_Operand_Value
ExprOpVar = enginev1.PlanResourcesFilter_Expression_Operand_Variable
)
switch expr := cur.ExprKind.(type) {
case *exprpb.Expr_CallExpr:
fn, _ := opFromCLE(expr.CallExpr.Function)
var offset int
if expr.CallExpr.Target != nil {
offset++
}
operands := make([]*ExprOp, len(expr.CallExpr.Args)+offset)
operands := make([]*exprOp, len(expr.CallExpr.Args)+offset)
if expr.CallExpr.Target != nil {
operands[0] = new(ExprOp)
operands[0] = new(exprOp)
err := buildExprImpl(expr.CallExpr.Target, operands[0], cur)
if err != nil {
return err
}
}

for i, arg := range expr.CallExpr.Args {
operands[i+offset] = new(ExprOp)
operands[i+offset] = new(exprOp)
err := buildExprImpl(arg, operands[i+offset], cur)
if err != nil {
return err
Expand All @@ -358,12 +362,12 @@ func buildExprImpl(cur *exprpb.Expr, acc *enginev1.PlanResourcesFilter_Expressio
if err != nil {
return err
}
acc.Node = &ExprOpValue{Value: value}
acc.Node = &exprOpValue{Value: value}
case *exprpb.Expr_IdentExpr:
acc.Node = &ExprOpVar{Variable: expr.IdentExpr.Name}
acc.Node = &exprOpVar{Variable: expr.IdentExpr.Name}
case *exprpb.Expr_SelectExpr:
if expr.SelectExpr.TestOnly {
acc.Node = &ExprOpValue{Value: structpb.NewBoolValue(true)}
acc.Node = &exprOpValue{Value: structpb.NewBoolValue(true)}
break
}
var names []string
Expand Down Expand Up @@ -391,14 +395,14 @@ func buildExprImpl(cur *exprpb.Expr, acc *enginev1.PlanResourcesFilter_Expressio
}
}
// This is a compound "a.b.c" variable
acc.Node = &ExprOpVar{Variable: sb.String()}
acc.Node = &exprOpVar{Variable: sb.String()}
} else {
op := new(ExprOp)
op := new(exprOp)
err := buildExprImpl(expr.SelectExpr.Operand, op, cur)
if err != nil {
return err
}
acc.Node = mkExprOpExpr(GetField, op, &ExprOp{Node: &ExprOpVar{Variable: expr.SelectExpr.Field}})
acc.Node = mkExprOpExpr(GetField, op, &exprOp{Node: &exprOpVar{Variable: expr.SelectExpr.Field}})
}
case *exprpb.Expr_ListExpr:
x := expr.ListExpr
Expand All @@ -417,12 +421,12 @@ func buildExprImpl(cur *exprpb.Expr, acc *enginev1.PlanResourcesFilter_Expressio
}
listValue.Values[i] = value
}
acc.Node = &ExprOpValue{Value: structpb.NewListValue(&listValue)}
acc.Node = &exprOpValue{Value: structpb.NewListValue(&listValue)}
} else {
// list of expressions
operands := make([]*ExprOp, len(x.Elements))
operands := make([]*exprOp, len(x.Elements))
for i := range operands {
operands[i] = new(ExprOp)
operands[i] = new(exprOp)
err := buildExprImpl(x.Elements[i], operands[i], cur)
if err != nil {
return err
Expand All @@ -445,23 +449,23 @@ func buildExprImpl(cur *exprpb.Expr, acc *enginev1.PlanResourcesFilter_Expressio
return nil
}
}
operands := make([]*ExprOp, len(x.Entries))
operands := make([]*exprOp, len(x.Entries))
for i, entry := range x.Entries {
k, v := new(ExprOp), new(ExprOp)
k, v := new(exprOp), new(exprOp)
switch entry := entry.KeyKind.(type) {
case *exprpb.Expr_CreateStruct_Entry_MapKey:
err := buildExprImpl(entry.MapKey, k, cur)
if err != nil {
return err
}
case *exprpb.Expr_CreateStruct_Entry_FieldKey:
k.Node = &ExprOpValue{Value: structpb.NewStringValue(entry.FieldKey)}
k.Node = &exprOpValue{Value: structpb.NewStringValue(entry.FieldKey)}
}
err := buildExprImpl(entry.Value, v, cur)
if err != nil {
return err
}
operands[i] = new(ExprOp)
operands[i] = new(exprOp)
operands[i].Node = mkExprOpExpr(SetField, k, v)
}
acc.Node = mkExprOpExpr(Struct, operands...)
Expand All @@ -470,34 +474,68 @@ func buildExprImpl(cur *exprpb.Expr, acc *enginev1.PlanResourcesFilter_Expressio
if err != nil {
return err
}
iterRange := lambdaAst.iterRange
if x, ok := iterRange.ExprKind.(*exprpb.Expr_StructExpr); ok {
iterRange = mkListExpr(structKeys(x.StructExpr))
}
lambda := new(ExprOp)
err = buildExprImpl(lambdaAst.lambdaExpr, lambda, cur)
if err != nil {
return err
}
if _, ok := lambda.Node.(*ExprOpExpr); !ok {
if _, ok := lambda.Node.(*ExprOpVar); !ok {
return fmt.Errorf("expected expression or variable, got %T", lambda.Node)
}
}
op := new(ExprOp)
err = buildExprImpl(iterRange, op, cur)
acc.Node, err = lambdaAst.mkNode(cur)
if err != nil {
return err
}

acc.Node = mkExprOpExpr(lambdaAst.operator, op, &ExprOp{Node: mkExprOpExpr(Lambda, lambda, &ExprOp{Node: &ExprOpVar{Variable: lambdaAst.iterVar}})})
default:
return fmt.Errorf("buildExprImpl: unsupported expression: %v", expr)
}

return nil
}

func (lambdaAst *lambdaAST) mkNode(cur *exprpb.Expr) (*exprOpExpr, error) {
lambda, err := buildLambdaExprOp(lambdaAst.expr, cur)
if err != nil {
return nil, err
}
lambda2, err := buildLambdaExprOp(lambdaAst.expr2, cur)
if err != nil {
return nil, err
}
lambdaArgs := []*exprOp{lambda}
if lambda2 != nil {
lambdaArgs = append(lambdaArgs, lambda2)
}
lambdaArgs = append(lambdaArgs, &exprOp{Node: &exprOpVar{Variable: lambdaAst.iterVar}})
if lambdaAst.iterVar2 != "" {
lambdaArgs = append(lambdaArgs, &exprOp{Node: &exprOpVar{Variable: lambdaAst.iterVar2}})
}
target, err := lambdaAst.buildIterRangeOp(cur)
if err != nil {
return nil, err
}
return mkExprOpExpr(lambdaAst.operator, target, &exprOp{Node: mkExprOpExpr(Lambda, lambdaArgs...)}), nil
}

func (lambdaAst *lambdaAST) buildIterRangeOp(cur *exprpb.Expr) (*exprOp, error) {
ir := lambdaAst.iterRange
if x, ok := ir.ExprKind.(*exprpb.Expr_StructExpr); ok && !canOperateOnStruct(lambdaAst.operator) {
ir = mkListExpr(structKeys(x.StructExpr))
}
irExprOp := new(exprOp)
err := buildExprImpl(ir, irExprOp, cur)
return irExprOp, err
}

func buildLambdaExprOp(expr, cur *exprpb.Expr) (*exprOp, error) {
if expr == nil {
return nil, nil
}
lambda := new(exprOp)
err := buildExprImpl(expr, lambda, cur)
if err != nil {
return nil, err
}
if _, ok := lambda.Node.(*exprOpExpr); !ok {
if _, ok = lambda.Node.(*exprOpVar); !ok {
return nil, fmt.Errorf("expected expression or variable, got %T", lambda.Node)
}
}
return lambda, nil
}

func visitConst(c *exprpb.Constant) (*structpb.Value, error) {
switch v := c.ConstantKind.(type) {
case *exprpb.Constant_BoolValue:
Expand Down Expand Up @@ -810,3 +848,7 @@ func filterExprOpExprToString(b *strings.Builder, expr *enginev1.PlanResourcesFi

b.WriteString(")")
}

func canOperateOnStruct(op string) bool {
return op == TransformMap || op == TransformMapEntry || op == TransformList
}
Loading

0 comments on commit 5559df2

Please sign in to comment.