Skip to content

Commit

Permalink
remove FoldOutput
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 24, 2023
1 parent 8ee49da commit 520de7c
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 56 deletions.
21 changes: 4 additions & 17 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

use crate::{
extension::ConstFoldResult,
ops::{custom::ExternalOp, Const, LeafOp, OpType},
values::Value,
IncomingPort, OutgoingPort,
ops::{Const, OpType},
IncomingPort,
};

/// For a given op and consts, attempt to evaluate the op.
Expand All @@ -18,7 +17,7 @@ pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldRes
#[cfg(test)]
mod test {
use crate::{
extension::{ExtensionRegistry, FoldOutput, PRELUDE, PRELUDE_REGISTRY},
extension::{ExtensionRegistry, PRELUDE},
ops::LeafOp,
std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES},
types::TypeArg,
Expand Down Expand Up @@ -57,18 +56,6 @@ mod test {
let add_op: OpType = u64_add().into();
let out = fold_const(&add_op, &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), FoldOutput::Value(Box::new(i2c(c))))]);
}

#[test]
// a = a + 0
fn test_zero_add() {
for in_port in [0, 1] {
let other_in = 1 - in_port;
let consts = vec![(in_port.into(), i2c(0))];
let add_op: OpType = u64_add().into();
let out = fold_const(&add_op, &consts).unwrap();
assert_eq!(&out[..], &[(0.into(), FoldOutput::Input(other_in.into()))]);
}
assert_eq!(&out[..], &[(0.into(), i2c(c))]);
}
}
2 changes: 1 addition & 1 deletion src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub use type_def::{TypeDef, TypeDefBound};
mod const_fold;
pub mod prelude;
pub mod validate;
pub use const_fold::{ConstFold, ConstFoldResult, FoldOutput};
pub use const_fold::{ConstFold, ConstFoldResult};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

/// Extension Registries store extensions to be looked up e.g. during validation.
Expand Down
19 changes: 1 addition & 18 deletions src/extension/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,9 @@ use crate::types::TypeArg;

use crate::OutgoingPort;

use crate::IncomingPort;

use crate::ops;
use derive_more::From;

#[derive(From, Clone, PartialEq, Debug)]
pub enum FoldOutput {
/// Value from port can be replaced with a constant
Value(Box<ops::Const>),
/// Value from port corresponds to one of the incoming values.
Input(IncomingPort),
}

impl From<ops::Const> for FoldOutput {
fn from(value: ops::Const) -> Self {
Self::Value(Box::new(value))
}
}

pub type ConstFoldResult = Option<Vec<(OutgoingPort, FoldOutput)>>;
pub type ConstFoldResult = Option<Vec<(OutgoingPort, ops::Const)>>;

pub trait ConstFold: Send + Sync {
fn fold(
Expand Down
2 changes: 1 addition & 1 deletion src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::{

use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FunctionType, PolyFuncType};
use crate::{ops, Hugr, IncomingPort};
use crate::Hugr;

/// Trait necessary for binary computations of OpDef signature
pub trait CustomSignatureFunc: Send + Sync {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, S
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrView, NodeType};
use crate::types::{type_param::TypeArg, FunctionType};
use crate::{ops, Hugr, IncomingPort, Node, OutgoingPort};
use crate::{ops, Hugr, IncomingPort, Node};

use super::tag::OpTag;
use super::{LeafOp, OpTrait, OpType};
Expand Down
19 changes: 3 additions & 16 deletions src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

use super::int_types::{get_log_width, int_type_var, ConstIntU, INT_TYPES, LOG_WIDTH_TYPE_PARAM};
use crate::extension::prelude::{sum_with_error, BOOL_T};
use crate::extension::{ConstFoldResult, CustomValidator, FoldOutput, ValidateJustArgs};
use crate::extension::{ConstFoldResult, CustomValidator, ValidateJustArgs};
use crate::types::{FunctionType, PolyFuncType};
use crate::utils::collect_array;
use crate::values::Value;

use crate::{
extension::{ExtensionId, ExtensionSet, SignatureError},
types::{type_param::TypeArg, Type, TypeRow},
Expand Down Expand Up @@ -72,22 +72,10 @@ fn idivmod_sig() -> PolyFuncType {
int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)])
}

fn zero(width: u8) -> ops::Const {
ops::Const::new(
ConstIntU::new(width, 0).unwrap().into(),
INT_TYPES[5].to_owned(),
)
.unwrap()
}

fn iadd_fold(consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult {
// TODO get width from const
let width = 5;
match consts {
[(p, c)] if c == &zero(width) => {
let other_port: IncomingPort = if &IncomingPort::from(0) == p { 1 } else { 0 }.into();
Some(vec![(0.into(), other_port.into())])
}
[(_, c1), (_, c2)] => {
let [c1, c2]: [&ConstIntU; 2] = [c1, c2].map(|c| c.get_custom_value().unwrap());

Expand All @@ -99,8 +87,7 @@ fn iadd_fold(consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult {
.into(),
INT_TYPES[5].to_owned(),
)
.unwrap()
.into(),
.unwrap(),
)])
}

Expand Down
4 changes: 2 additions & 2 deletions src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use downcast_rs::{impl_downcast, Downcast};
use smol_str::SmolStr;

use crate::macros::impl_box_clone;
use crate::ops::OpType;
use crate::{Hugr, HugrView, IncomingPort, OutgoingPort};

use crate::{Hugr, HugrView};

use crate::types::{CustomCheckFailure, CustomType};

Expand Down

0 comments on commit 520de7c

Please sign in to comment.