Skip to content

Commit

Permalink
wip: nested closures
Browse files Browse the repository at this point in the history
  • Loading branch information
divarvel committed Jan 3, 2024
1 parent 0f6c3ad commit 0d029b0
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 42 deletions.
3 changes: 3 additions & 0 deletions biscuit-auth/examples/testcases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,9 @@ fn expressions_v5(target: &str, root: &KeyPair, test: bool) -> TestResult {
check if [1,2,3].any($p -> $p > 2);
//any
check if ![1,2,3].any($p -> $p > 3);
// nested closures
check if [1,2,3].any($p -> $p > 1 && [3,4,5].any($q -> $p == $q));
"#
)
.build_with_rng(&root, SymbolTable::default(), &mut rng)
Expand Down
138 changes: 96 additions & 42 deletions biscuit-auth/src/datalog/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,39 +94,30 @@ impl Binary {
mut right: Vec<Op>,
params: &[u32],
ops: &mut Vec<Op>,
values: &HashMap<u32, Term>,
values: &mut HashMap<u32, Term>,
symbols: &mut TemporarySymbolTable,
) -> Result<Term, error::Expression> {
println!("Recursing, values before: {values:?}");
match (self, left, params) {
(Binary::LazyOr, Term::Bool(true), []) => Ok(Term::Bool(true)),
(Binary::LazyOr, Term::Bool(false), []) => {
ops.push(Op::Binary(Binary::Or));
right.reverse();
for op in right {
ops.push(op);
}
Ok(Term::Bool(false))
let e = Expression { ops: right.clone() };
e.evaluate(values, symbols)
}
(Binary::LazyAnd, Term::Bool(false), []) => Ok(Term::Bool(false)),
(Binary::LazyAnd, Term::Bool(true), []) => {
ops.push(Op::Binary(Binary::And));
for op in right {
ops.push(op);
}
Ok(Term::Bool(true))
let e = Expression { ops: right.clone() };
e.evaluate(values, symbols)
}
(Binary::All, Term::Set(set_values), [param]) => {
for value in set_values.iter() {
let ops = right
.clone()
.iter()
.map(|op| match op {
Op::Value(Term::Variable(v)) if v == param => Op::Value(value.clone()),
_ => op.clone(),
})
.collect::<Vec<_>>();
let e = Expression { ops };
match e.evaluate(values, symbols)? {
values.insert(*param, value.clone());
println!("Recursing, values during: {values:?}");
let e = Expression { ops: right.clone() };
let result = e.evaluate(values, symbols);
values.remove(param);
println!("Recursing, values after: {values:?}");
match result? {
Term::Bool(true) => {}
Term::Bool(false) => return Ok(Term::Bool(false)),
_ => return Err(error::Expression::InvalidType),
Expand All @@ -136,16 +127,13 @@ impl Binary {
}
(Binary::Any, Term::Set(set_values), [param]) => {
for value in set_values.iter() {
let ops = right
.clone()
.iter()
.map(|op| match op {
Op::Value(Term::Variable(v)) if v == param => Op::Value(value.clone()),
_ => op.clone(),
})
.collect::<Vec<_>>();
let e = Expression { ops };
match e.evaluate(values, symbols)? {
values.insert(*param, value.clone());
println!("Recursing, values during: {values:?}");
let e = Expression { ops: right.clone() };
let result = e.evaluate(values, symbols);
values.remove(param);
println!("Recursing, values after: {values:?}");
match result? {
Term::Bool(false) => {}
Term::Bool(true) => return Ok(Term::Bool(true)),
_ => return Err(error::Expression::InvalidType),
Expand Down Expand Up @@ -336,10 +324,12 @@ impl Expression {
let mut ops = self.ops.clone();
ops.reverse();

println!("-- begin -- {values:?}");
while let Some(op) = ops.pop() {
// println!("ops: {ops:?}");
// println!("op: {:?}\t| stack: {:?}", op, stack);
println!("ops: {ops:?}");
println!("op: {:?}\t| stack: {:?}", op, stack);

let opop = op.clone();
match op {
Op::Value(Term::Variable(i)) => match values.get(&i) {
Some(term) => stack.push(StackElem::Term(term.clone())),
Expand All @@ -354,7 +344,7 @@ impl Expression {
stack.push(StackElem::Term(unary.evaluate(term, symbols)?))
}
_ => {
//println!("expected a value on the stack");
println!("expected a value on the stack");
return Err(error::Expression::InvalidStack);
}
},
Expand All @@ -366,12 +356,26 @@ impl Expression {
(
Some(StackElem::Closure(params, right_ops)),
Some(StackElem::Term(left_term)),
) => stack.push(StackElem::Term(binary.evaluate_with_closure(
left_term, right_ops, &params, &mut ops, values, symbols,
)?)),
) => {
let mut values = values.clone();
stack.push(StackElem::Term(binary.evaluate_with_closure(
left_term,
right_ops,
&params,
&mut ops,
&mut values,
symbols,
)?))
}

_ => {
//println!("expected two values on the stack");
e => {
println!(
"while evaluating {}",
self.print(&SymbolTable::new()).unwrap()
);
println!("while evaluating {opop:?}");
println!("with context {values:?}");
println!("expected two values on the stack, got: {e:?}");
return Err(error::Expression::InvalidStack);
}
},
Expand All @@ -380,14 +384,18 @@ impl Expression {
}
}
}
//println!("stack: {stack:?}");
println!("-- end {stack:?}");

if stack.len() == 1 {
match stack.remove(0) {
StackElem::Term(t) => Ok(t),
_ => Err(error::Expression::InvalidStack),
_ => {
println!("expected a term on the stack after evaluation");
Err(error::Expression::InvalidStack)
}
}
} else {
println!("expected one value the stack after evaluation");
Err(error::Expression::InvalidStack)
}
}
Expand Down Expand Up @@ -673,4 +681,50 @@ mod tests {
let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap();
assert_eq!(res1, Term::Bool(true));
}

#[test]
fn nested_closures() {
let mut symbols = SymbolTable::new();
let p = symbols.insert("p") as u32;
let q = symbols.insert("q") as u32;
let mut tmp_symbols = TemporarySymbolTable::new(&symbols);

let ops1 = vec![
Op::Value(Term::Set(
[Term::Integer(1), Term::Integer(2), Term::Integer(3)].into(),
)),
Op::Closure(
vec![p],
vec![
Op::Value(Term::Variable(p)),
Op::Value(Term::Integer(1)),
Op::Binary(Binary::GreaterThan),
Op::Closure(
vec![],
vec![
Op::Value(Term::Set(
[Term::Integer(3), Term::Integer(4), Term::Integer(5)].into(),
)),
Op::Closure(
vec![q],
vec![
Op::Value(Term::Variable(p)),
Op::Value(Term::Variable(q)),
Op::Binary(Binary::Equal),
],
),
Op::Binary(Binary::Any),
],
),
Op::Binary(Binary::LazyAnd),
],
),
Op::Binary(Binary::Any),
];
let e1 = Expression { ops: ops1 };
println!("{}", e1.print(&symbols).unwrap());

let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap();
assert_eq!(res1, Term::Bool(true));
}
}

0 comments on commit 0d029b0

Please sign in to comment.