Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Constraint::Plus stores an ExtensionSet, which is a BTreeSet #636

Merged
merged 11 commits into from
Nov 6, 2023
24 changes: 12 additions & 12 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! system (outside the `types` module), which also parses nested [`OpDef`]s.

use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;

Expand Down Expand Up @@ -301,18 +301,13 @@ pub enum ExtensionBuildError {
}

/// A set of extensions identified by their unique [`ExtensionId`].
#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExtensionSet(HashSet<ExtensionId>);
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExtensionSet(BTreeSet<ExtensionId>);

impl ExtensionSet {
/// Creates a new empty extension set.
pub fn new() -> Self {
Self(HashSet::new())
}

/// Creates a new extension set from some extensions.
pub fn new_from_extensions(extensions: impl Into<HashSet<ExtensionId>>) -> Self {
Self(extensions.into())
pub const fn new() -> Self {
Self(BTreeSet::new())
}

/// Adds a extension to the set.
Expand Down Expand Up @@ -350,13 +345,18 @@ impl ExtensionSet {

/// The things in other which are in not in self
pub fn missing_from(&self, other: &Self) -> Self {
ExtensionSet(HashSet::from_iter(other.0.difference(&self.0).cloned()))
ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
}

/// Iterate over the contained ExtensionIds
pub fn iter(&self) -> impl Iterator<Item = &ExtensionId> {
self.0.iter()
}

/// True if this set contains no [ExtensionId]s
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

impl Display for ExtensionSet {
Expand All @@ -367,6 +367,6 @@ impl Display for ExtensionSet {

impl FromIterator<ExtensionId> for ExtensionSet {
fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
Self(HashSet::from_iter(iter))
Self(BTreeSet::from_iter(iter))
}
}
82 changes: 35 additions & 47 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
//! depend on these open variables, then the validation check for extensions
//! will succeed regardless of what the variable is instantiated to.

use super::{ExtensionId, ExtensionSet};
use super::ExtensionSet;
use crate::{
hugr::views::HugrView,
ops::{OpTag, OpTrait},
Expand Down Expand Up @@ -65,8 +65,8 @@ impl Meta {
enum Constraint {
/// A variable has the same value as another variable
Equal(Meta),
/// Variable extends the value of another by one extension
Plus(ExtensionId, Meta),
/// Variable extends the value of another by a set of extensions
Plus(ExtensionSet, Meta),
}

#[derive(Debug, Clone, PartialEq, Error)]
Expand Down Expand Up @@ -230,26 +230,6 @@ impl UnificationContext {
self.solved.get(&self.resolve(*m))
}

/// Convert an extension *set* difference in terms of a sequence of fresh
/// metas with `Plus` constraints which each add only one extension req.
fn gen_union_constraint(&mut self, input: Meta, output: Meta, delta: ExtensionSet) {
let mut last_meta = input;
// Create fresh metavariables with `Plus` constraints for
// each extension that should be added by the node
// Hence a extension delta [A, B] would lead to
// > ma = fresh_meta()
// > add_constraint(ma, Plus(a, input)
// > mb = fresh_meta()
// > add_constraint(mb, Plus(b, ma)
// > add_constraint(output, Equal(mb))
for r in delta.0.into_iter() {
let curr_meta = self.fresh_meta();
self.add_constraint(curr_meta, Constraint::Plus(r, last_meta));
last_meta = curr_meta;
}
self.add_constraint(output, Constraint::Equal(last_meta));
}

/// Return the metavariable corresponding to the given location on the
/// graph, either by making a new meta, or looking it up
fn make_or_get_meta(&mut self, node: Node, dir: Direction) -> Meta {
Expand Down Expand Up @@ -311,11 +291,13 @@ impl UnificationContext {
match node_type.signature() {
// Input extensions are open
None => {
self.gen_union_constraint(
m_input,
m_output,
node_type.op_signature().extension_reqs,
);
let delta = node_type.op_signature().extension_reqs;
let c = if delta.is_empty() {
Constraint::Equal(m_input)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative would be to store this with the empty ExtensionSet, and then we would only be storing one kind of Constraint (so, a struct, not an enum - might be syntactically easier to work with!). Then we'd be testing ExtensionSet::is_empty instead of matching on Constraint::Plus vs ::Equals...

} else {
Constraint::Plus(delta, m_input)
};
self.add_constraint(m_output, c);
if matches!(
node_type.tag(),
OpTag::Alias | OpTag::Function | OpTag::FuncDefn
Expand Down Expand Up @@ -510,8 +492,7 @@ impl UnificationContext {
// to a set which already contained it.
Constraint::Plus(r, other_meta) => {
if let Some(rs) = self.get_solution(other_meta) {
let mut rrs = rs.clone();
rrs.insert(r);
let rrs = rs.clone().union(r);
match self.get_solution(&meta) {
// Let's check that this is right?
Some(rs) => {
Expand Down Expand Up @@ -664,19 +645,19 @@ impl UnificationContext {
// Handle the case where the constraints for `m` contain a self
// reference, i.e. "m = Plus(E, m)", in which case the variable
// should be instantiated to E rather than the empty set.
let solution =
ExtensionSet::from_iter(self.get_constraints(&m).unwrap().iter().filter_map(
|c| match c {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => {
Some(x.clone())
}
_ => None,
},
));
let solution = self
.get_constraints(&m)
.unwrap()
.iter()
.filter_map(|c| match c {
// If `m` has been merged, [`self.variables`] entry
// will have already been updated to the merged
// value by [`self.merge_equal_metas`] so we don't
// need to worry about resolving it.
Constraint::Plus(x, other_m) if m == self.resolve(*other_m) => Some(x),
_ => None,
})
.fold(ExtensionSet::new(), ExtensionSet::union);
self.add_solution(m, solution);
}
}
Expand All @@ -690,6 +671,7 @@ mod test {

use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::extension::ExtensionId;
use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet};
use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType};
use crate::macros::const_extension_ids;
Expand Down Expand Up @@ -807,8 +789,14 @@ mod test {

ctx.solved.insert(metas[2], ExtensionSet::singleton(&A));
ctx.add_constraint(metas[1], Constraint::Equal(metas[2]));
ctx.add_constraint(metas[0], Constraint::Plus(B, metas[2]));
ctx.add_constraint(metas[4], Constraint::Plus(C, metas[0]));
ctx.add_constraint(
metas[0],
Constraint::Plus(ExtensionSet::singleton(&B), metas[2]),
);
ctx.add_constraint(
metas[4],
Constraint::Plus(ExtensionSet::singleton(&C), metas[0]),
);
ctx.add_constraint(metas[3], Constraint::Equal(metas[4]));
ctx.add_constraint(metas[5], Constraint::Equal(metas[0]));
ctx.main_loop()?;
Expand Down Expand Up @@ -881,8 +869,8 @@ mod test {
.insert((NodeIndex::new(4).into(), Direction::Incoming), ab);
ctx.variables.insert(a);
ctx.variables.insert(b);
ctx.add_constraint(ab, Constraint::Plus(A, b));
ctx.add_constraint(ab, Constraint::Plus(B, a));
ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&A), b));
ctx.add_constraint(ab, Constraint::Plus(ExtensionSet::singleton(&B), a));
let solution = ctx.main_loop()?;
// We'll only find concrete solutions for the Incoming extension reqs of
// the main node created by `Hugr::default`
Expand Down
6 changes: 2 additions & 4 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
//! Conversions between integer and floating-point values.

use std::collections::HashSet;

use crate::{
extension::{ExtensionId, ExtensionSet, SignatureError},
type_row,
Expand Down Expand Up @@ -39,10 +37,10 @@ fn itof_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
pub fn extension() -> Extension {
let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::new_from_extensions(HashSet::from_iter(vec![
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
])),
]),
);

extension
Expand Down