Skip to content

Commit

Permalink
Fix tests...
Browse files Browse the repository at this point in the history
* DFContext reinstate fn hugr(), drop AsRef requirement (fixes StackOverflow)
* test_tail_loop_iterates_twice: use tail_loop_builder_exts, fix from #1332(?)
* Fix only-one-DataflowContext asserts using Arc::ptr_eq
  • Loading branch information
acl-cqc committed Aug 6, 2024
1 parent 8fed0cd commit a244c4b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
18 changes: 10 additions & 8 deletions hugr-passes/src/const_fold2/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ mod utils;
use context::DataflowContext;
pub use utils::{TailLoopTermination, ValueRow, IO, PV};

pub trait DFContext: AsRef<Hugr> + Clone + Eq + Hash + std::ops::Deref<Target = Hugr> {}
pub trait DFContext: Clone + Eq + Hash + std::ops::Deref<Target = Hugr> {
fn hugr(&self) -> &impl HugrView;
}

ascent::ascent! {
// The trait-indirection layer here means we can just write 'C' but in practice ATM
Expand All @@ -34,9 +36,9 @@ ascent::ascent! {

node(c, n) <-- context(c), for n in c.nodes();

in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c, *n);
in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c.hugr(), *n);

out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c, *n);
out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c.hugr(), *n);

parent_of_node(c, parent, child) <--
node(c, child), if let Some(parent) = c.get_parent(*child);
Expand All @@ -55,8 +57,8 @@ ascent::ascent! {
out_wire_value(c, m, op, v);


node_in_value_row(c, n, utils::bottom_row(c, *n)) <-- node(c, n);
node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v);
node_in_value_row(c, n, utils::bottom_row(c.hugr(), *n)) <-- node(c, n);
node_in_value_row(c, n, utils::singleton_in_row(c.hugr(), n, p, v.clone())) <-- in_wire_value(c, n, p, v);


// Per node-type rules
Expand All @@ -67,7 +69,7 @@ ascent::ascent! {
relation load_constant_node(C, Node);
load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant();

out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c, *n)) <--
out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c.hugr(), *n)) <--
load_constant_node(c, n);


Expand Down Expand Up @@ -116,7 +118,7 @@ ascent::ascent! {
if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0
if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(),
let variant_len = tailloop.just_inputs.len(),
for (out_p, v) in out_in_row.iter(c, *out_n).flat_map(
for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map(
|(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v)
);

Expand All @@ -127,7 +129,7 @@ ascent::ascent! {
if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1
if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(),
let variant_len = tailloop.just_outputs.len(),
for (out_p, v) in out_in_row.iter(c, *out_n).flat_map(
for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map(
|(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v)
);

Expand Down
19 changes: 8 additions & 11 deletions hugr-passes/src/const_fold2/datalog/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::Arc;

use hugr_core::hugr::internal::HugrInternals;
use hugr_core::{Hugr, HugrView};

use super::DFContext;
Expand All @@ -25,13 +24,13 @@ impl<H: HugrView> Clone for DataflowContext<H> {
}

impl<H: HugrView> Hash for DataflowContext<H> {
fn hash<I: Hasher>(&self, state: &mut I) {}
fn hash<I: Hasher>(&self, _state: &mut I) {}
}

impl<H: HugrView> PartialEq for DataflowContext<H> {
fn eq(&self, other: &Self) -> bool {
// Any AscentProgram should have only one DataflowContext
assert_eq!(self as *const _, other as *const _);
// Any AscentProgram should have only one DataflowContext (maybe cloned)
assert!(Arc::ptr_eq(&self.0, &other.0));
true
}
}
Expand All @@ -40,8 +39,8 @@ impl<H: HugrView> Eq for DataflowContext<H> {}

impl<H: HugrView> PartialOrd for DataflowContext<H> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
// Any AscentProgram should have only one DataflowContext
assert_eq!(self as *const _, other as *const _);
// Any AscentProgram should have only one DataflowContext (maybe cloned)
assert!(Arc::ptr_eq(&self.0, &other.0));
Some(std::cmp::Ordering::Equal)
}
}
Expand All @@ -54,10 +53,8 @@ impl<H: HugrView> Deref for DataflowContext<H> {
}
}

impl<H: HugrView> AsRef<Hugr> for DataflowContext<H> {
fn as_ref(&self) -> &Hugr {
self.base_hugr()
impl<H: HugrView> DFContext for DataflowContext<H> {
fn hugr(&self) -> &impl HugrView {
self.0.as_ref()
}
}

impl<H: HugrView> DFContext for DataflowContext<H> {}
9 changes: 7 additions & 2 deletions hugr-passes/src/const_fold2/datalog/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ fn test_tail_loop_iterates_twice() {
// let r_w = builder
// .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap());
let tlb = builder
.tail_loop_builder([], [(BOOL_T, false_w), (BOOL_T, true_w)], vec![].into())
.tail_loop_builder_exts(
[],
[(BOOL_T, false_w), (BOOL_T, true_w)],
vec![].into(),
ExtensionSet::new(),
)
.unwrap();
assert_eq!(
tlb.loop_signature().unwrap().dataflow_signature().unwrap(),
Expand All @@ -157,7 +162,7 @@ fn test_tail_loop_iterates_twice() {
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
// TODO once we can do conditionals put these wires inside `just_outputs` and
// we should be able to propagate their values
let [o_w1, o_w2, _] = tail_loop.outputs_arr();
let [o_w1, o_w2] = tail_loop.outputs_arr();

let mut machine = Machine::new();
machine.run_hugr(&hugr);
Expand Down

0 comments on commit a244c4b

Please sign in to comment.