Skip to content

Commit

Permalink
topdown: Adding scope component to virtual cache key to capture pre-e…
Browse files Browse the repository at this point in the history
…val ref unification

Fixes: open-policy-agent#6926
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Aug 23, 2024
1 parent c3867a3 commit fd9cdb7
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 6 deletions.
125 changes: 119 additions & 6 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,15 @@ type evalVirtualPartialCacheHint struct {
full bool
}

func (h *evalVirtualPartialCacheHint) keyWithoutScope() ast.Ref {
if h.key != nil {
if _, ok := h.key[len(h.key)-1].Value.(vcKeyScope); ok {
return h.key[:len(h.key)-1]
}
}
return h.key
}

func (e evalVirtualPartial) eval(iter unifyIterator) error {

unknown := e.e.unknown(e.ref[:e.pos+1], e.bindings)
Expand Down Expand Up @@ -2485,7 +2494,7 @@ func (e evalVirtualPartial) evalEachRule(iter unifyIterator, unknown bool) error
}

if hint.key != nil {
if v, err := result.Value.Find(hint.key[e.pos+1:]); err == nil && v != nil {
if v, err := result.Value.Find(hint.keyWithoutScope()[e.pos+1:]); err == nil && v != nil {
e.e.virtualCache.Put(hint.key, ast.NewTerm(v))
}
}
Expand Down Expand Up @@ -2832,6 +2841,8 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac
plugged := e.bindings.Plug(e.ref[e.pos+1])

if _, ok := plugged.Value.(ast.Var); ok {
// Note: we might have additional opportunity to optimize here, if we consider that ground values
// right of e.pos could create a smaller eval "scope" through ref bi-unification before evaluating rules.
hint.full = true
hint.key = e.plugged[:e.pos+1]
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
Expand All @@ -2840,19 +2851,45 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac

m := maxRefLength(e.ir.Rules, len(e.ref))

scoping := false
hintKeyEnd := 0
for i := e.pos + 1; i < m; i++ {
plugged = e.bindings.Plug(e.ref[i])

if !plugged.IsGround() {
break
if plugged.IsGround() && !scoping {
hintKeyEnd = i
hint.key = append(e.plugged[:i], plugged)
} else {
scoping = true
hl := len(hint.key)
if hl == 0 {
break
}
if scope, ok := hint.key[hl-1].Value.(vcKeyScope); ok {
scope.Ref = append(scope.Ref, plugged)
hint.key[len(hint.key)-1] = ast.NewTerm(scope)
} else {
scope = vcKeyScope{}
scope.Ref = append(scope.Ref, plugged)
hint.key = append(hint.key, ast.NewTerm(scope))
}
}

hint.key = append(e.plugged[:i], plugged)

if cached, _ := e.e.virtualCache.Get(hint.key); cached != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
hint.hit = true
return hint, e.evalTerm(iter, i+1, cached, e.bindings)
return hint, e.evalTerm(iter, hintKeyEnd+1, cached, e.bindings)
}
}

if hl := len(hint.key); hl > 0 {
if scope, ok := hint.key[hl-1].Value.(vcKeyScope); ok {
scope = scope.reduce()
if scope.empty() {
hint.key = hint.key[:hl-1]
} else {
hint.key[hl-1].Value = scope
}
}
}

Expand All @@ -2861,6 +2898,82 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac
return hint, nil
}

// vcKeyScope represents the scoping that pre-rule-eval ref unification imposes on a virtual cache entry.
type vcKeyScope struct {
ast.Ref
}

func (q vcKeyScope) Compare(other ast.Value) int {
if q2, ok := other.(vcKeyScope); ok {
r1 := q.Ref
r2 := q2.Ref
if len(r1) != len(r2) {
return -1
}

for i := range r1 {
_, v1IsVar := r1[i].Value.(ast.Var)
_, v2IsVar := r2[i].Value.(ast.Var)
if !v1IsVar && !v2IsVar && r1[i].Value.Compare(r2[i].Value) != 0 {
return -1
}
}

return 0
}
return 1
}

func (q vcKeyScope) Find(_ ast.Ref) (ast.Value, error) {
return nil, nil
}

func (q vcKeyScope) Hash() int {
var hash int
for _, v := range q.Ref {
if _, ok := v.Value.(ast.Var); ok {
// all vars are equal
hash += 1
} else {
hash += v.Value.Hash()
}
}
return hash
}

func (q vcKeyScope) IsGround() bool {
return false
}

func (q vcKeyScope) String() string {
buf := make([]string, 0, len(q.Ref))
for _, t := range q.Ref {
if _, ok := t.Value.(ast.Var); ok {
buf = append(buf, "_")
} else {
buf = append(buf, t.String())
}
}
return fmt.Sprintf("<%s>", strings.Join(buf, ","))
}

// reduce removes vars from the tail of the ref.
func (q vcKeyScope) reduce() vcKeyScope {
ref := q.Ref.Copy()
var i int
for i = len(q.Ref) - 1; i >= 0; i-- {
if _, ok := q.Ref[i].Value.(ast.Var); !ok {
break
}
}
ref = ref[:i+1]
return vcKeyScope{ref}
}

func (q vcKeyScope) empty() bool {
return len(q.Ref) == 0
}

func getNestedObject(ref ast.Ref, rootObj *ast.Object, b *bindings, l *ast.Location) (*ast.Object, error) {
current := rootObj
for _, term := range ref {
Expand Down
81 changes: 81 additions & 0 deletions topdown/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,87 @@ func TestTopdownVirtualCache(t *testing.T) {
miss: 3, // 'data.test.p = true' + 'data.test.q[[y, 1]] = z' + 'data.test.q = x'
exp: 1,
},
{
note: "partial object, ref-head, ref with unification scope",
module: `package test
import rego.v1
a[x][y][z] := x + y + z if {
some x in [1, 2]
some y in [3, 4]
some z in [5, 6]
}
p if {
x := a[1][_][5] # miss, cache key: data.test.a[1][<_,5>]
some foo
y := a[1][foo][5] # hit, cache key: data.test.a[1][<_,5>]
x == y
}`,
query: `data.test.p = x`,
hit: 1, // data.test.a[1][_][5]
miss: 2, // data.test.p + data.test.a[1][_][5]
},
{
note: "partial object, ref-head, ref with unification scope, diverging key scope",
module: `package test
import rego.v1
a[x][y][z] := x + y + z if {
some x in [1, 2]
some y in [3, 4]
some z in [5, 6]
}
p if {
x := a[1][_][5] # miss, cache key: data.test.a[1][<_,5>]
y := a[1][_][6] # miss, cache key: data.test.a[1][<_,6>]
z := a[1][_][5] # hit, cache key: data.test.a[1][<_,5>]
x != y
x == z
}`,
query: `data.test.p = x`,
hit: 1, // data.test.a[1][_][5]
miss: 3, // data.test.p + data.test.a[1][_][5] + data.test.a[1][_][6]
},
{
note: "partial object, ref-head, ref with unification scope, trailing vars don't contribute to key scope",
module: `package test
import rego.v1
a[x][y][z][x] := x + y + z if {
some x in [1, 2]
some y in [3, 4]
some z in [5, 6]
}
p if {
x := a[1][_][5][_] # miss, cache key: data.test.a[1][<_,5>]
y := a[1][_][5] # hit, cache key: data.test.a[1][<_,5>]
x == y[_]
}`,
query: `data.test.p = x`,
hit: 1, // data.test.a[1][_][5]
miss: 2, // data.test.p + data.test.a[1][_][5]
},
{
// Regression test for https://github.com/open-policy-agent/opa/issues/6926
note: "partial object, ref-head, leaf set, ref with unification scope",
module: `package p
import rego.v1
obj.sub[x][x] contains x if some x in ["one", "two"]
obj[x][x] contains x if x := "whatever"
main contains x if {
[1 | obj.sub[_].one[_]] # miss, cache key: data.p.obj.sub[<_,one>]
x := obj.sub[_][_][_] # miss, cache key: data.p.obj.sub
}`,
query: `data.p.main = x`,
hit: 0,
miss: 3, // data.p.main + data.p.obj.sub[<_,one>] + data.p.obj.sub
},
}

for _, tc := range tests {
Expand Down

0 comments on commit fd9cdb7

Please sign in to comment.