Skip to content

Commit

Permalink
integer addition tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 24, 2023
1 parent b84766b commit 8ee49da
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 59 deletions.
26 changes: 18 additions & 8 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
//! Constant folding routines.

use crate::{
extension::ConstFoldResult,
ops::{custom::ExternalOp, Const, LeafOp, OpType},
values::Value,
IncomingPort, OutgoingPort,
};

/// For a given op and consts, attempt to evaluate the op.
pub fn fold_const(
op: &OpType,
consts: &[(IncomingPort, Const)],
) -> Option<Vec<(OutgoingPort, Const)>> {
pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
let op = op.as_leaf_op()?;
let ext_op = op.as_extension_op()?;

Expand All @@ -20,7 +18,7 @@ pub fn fold_const(
#[cfg(test)]
mod test {
use crate::{
extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY},
extension::{ExtensionRegistry, FoldOutput, PRELUDE, PRELUDE_REGISTRY},
ops::LeafOp,
std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES},
types::TypeArg,
Expand Down Expand Up @@ -53,12 +51,24 @@ mod test {
#[case(0, 0, 0)]
#[case(0, 1, 1)]
#[case(23, 435, 458)]
// c = a && b
fn test_and(#[case] a: u64, #[case] b: u64, #[case] c: u64) {
// c = a + b
fn test_add(#[case] a: u64, #[case] b: u64, #[case] c: u64) {
let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))];
let add_op: OpType = u64_add().into();
let out = fold_const(&add_op, &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), i2c(c))]);
assert_eq!(&out[..], &[(0.into(), FoldOutput::Value(Box::new(i2c(c))))]);
}

#[test]
// a = a + 0
fn test_zero_add() {
for in_port in [0, 1] {
let other_in = 1 - in_port;
let consts = vec![(in_port.into(), i2c(0))];
let add_op: OpType = u64_add().into();
let out = fold_const(&add_op, &consts).unwrap();
assert_eq!(&out[..], &[(0.into(), FoldOutput::Input(other_in.into()))]);
}
}
}
3 changes: 2 additions & 1 deletion src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ pub use op_def::{
};
mod type_def;
pub use type_def::{TypeDef, TypeDefBound};
mod const_fold;
pub mod prelude;
pub mod validate;

pub use const_fold::{ConstFold, ConstFoldResult, FoldOutput};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

/// Extension Registries store extensions to be looked up e.g. during validation.
Expand Down
61 changes: 61 additions & 0 deletions src/extension/const_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use std::fmt::Formatter;

use std::fmt::Debug;

use crate::types::TypeArg;

use crate::OutgoingPort;

use crate::IncomingPort;

use crate::ops;
use derive_more::From;

#[derive(From, Clone, PartialEq, Debug)]
pub enum FoldOutput {
/// Value from port can be replaced with a constant
Value(Box<ops::Const>),
/// Value from port corresponds to one of the incoming values.
Input(IncomingPort),
}

impl From<ops::Const> for FoldOutput {
fn from(value: ops::Const) -> Self {
Self::Value(Box::new(value))
}
}

pub type ConstFoldResult = Option<Vec<(OutgoingPort, FoldOutput)>>;

pub trait ConstFold: Send + Sync {
fn fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Const)],
) -> ConstFoldResult;
}

impl Debug for Box<dyn ConstFold> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "<custom constant folding>")
}
}

impl Default for Box<dyn ConstFold> {
fn default() -> Self {
Box::new(|&_: &_| None)
}
}

impl<T> ConstFold for T
where
T: Fn(&[(crate::IncomingPort, crate::ops::Const)]) -> ConstFoldResult + Send + Sync,
{
fn fold(
&self,
_type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Const)],
) -> ConstFoldResult {
self(consts)
}
}
45 changes: 4 additions & 41 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ use std::sync::Arc;
use smol_str::SmolStr;

use super::{
Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError,
ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry,
ExtensionSet, SignatureError,
};

use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FunctionType, PolyFuncType};
use crate::Hugr;
use crate::{ops, Hugr, IncomingPort};

/// Trait necessary for binary computations of OpDef signature
pub trait CustomSignatureFunc: Send + Sync {
Expand Down Expand Up @@ -245,44 +246,6 @@ impl Debug for LowerFunc {
}
}

type ConstFoldResult = Option<Vec<(crate::OutgoingPort, crate::ops::Const)>>;
pub trait ConstFold: Send + Sync {
fn fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Const)],
) -> ConstFoldResult;
}

impl Debug for Box<dyn ConstFold> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "<custom constant folding>")
}
}

impl Default for Box<dyn ConstFold> {
fn default() -> Self {
Box::new(|&_: &_| None)
}
}

impl<T> ConstFold for T
where
T: Fn(
&[(crate::IncomingPort, crate::ops::Const)],
) -> Option<Vec<(crate::OutgoingPort, crate::ops::Const)>>
+ Send
+ Sync,
{
fn fold(
&self,
_type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Const)],
) -> ConstFoldResult {
self(consts)
}
}

/// Serializable definition for dynamically loaded operations.
///
/// TODO: Define a way to construct new OpDef's from a serialized definition.
Expand Down Expand Up @@ -441,7 +404,7 @@ impl OpDef {
self.misc.insert(k.to_string(), v)
}

pub fn add_constant_folding(&mut self, fold: impl ConstFold + 'static) {
pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
self.constant_folder = Box::new(fold)
}

Expand Down
7 changes: 2 additions & 5 deletions src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use smol_str::SmolStr;
use std::sync::Arc;
use thiserror::Error;

use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError};
use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrView, NodeType};
use crate::types::{type_param::TypeArg, FunctionType};
Expand Down Expand Up @@ -128,10 +128,7 @@ impl ExtensionOp {
self.def.as_ref()
}

pub fn constant_fold(
&self,
consts: &[(IncomingPort, ops::Const)],
) -> Option<Vec<(OutgoingPort, ops::Const)>> {
pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult {
self.def().constant_fold(self.args(), consts)
}
}
Expand Down
46 changes: 42 additions & 4 deletions src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
//! Basic integer operations.

use super::int_types::{get_log_width, int_type_var, LOG_WIDTH_TYPE_PARAM};
use super::int_types::{get_log_width, int_type_var, ConstIntU, INT_TYPES, LOG_WIDTH_TYPE_PARAM};
use crate::extension::prelude::{sum_with_error, BOOL_T};
use crate::extension::{CustomValidator, ValidateJustArgs};
use crate::type_row;
use crate::extension::{ConstFoldResult, CustomValidator, FoldOutput, ValidateJustArgs};
use crate::types::{FunctionType, PolyFuncType};
use crate::utils::collect_array;
use crate::values::Value;
use crate::{
extension::{ExtensionId, ExtensionSet, SignatureError},
types::{type_param::TypeArg, Type, TypeRow},
Extension,
};
use crate::{ops, type_row, IncomingPort};

use lazy_static::lazy_static;

Expand Down Expand Up @@ -71,6 +72,42 @@ fn idivmod_sig() -> PolyFuncType {
int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)])
}

fn zero(width: u8) -> ops::Const {
ops::Const::new(
ConstIntU::new(width, 0).unwrap().into(),
INT_TYPES[5].to_owned(),
)
.unwrap()
}

fn iadd_fold(consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult {
// TODO get width from const
let width = 5;
match consts {
[(p, c)] if c == &zero(width) => {
let other_port: IncomingPort = if &IncomingPort::from(0) == p { 1 } else { 0 }.into();
Some(vec![(0.into(), other_port.into())])
}
[(_, c1), (_, c2)] => {
let [c1, c2]: [&ConstIntU; 2] = [c1, c2].map(|c| c.get_custom_value().unwrap());

Some(vec![(
0.into(),
ops::Const::new(
ConstIntU::new(width, c1.value() + c2.value())
.unwrap()
.into(),
INT_TYPES[5].to_owned(),
)
.unwrap()
.into(),
)])
}

_ => None,
}
}

/// Extension for basic integer operations.
fn extension() -> Extension {
let itob_sig = int_polytype(1, vec![int_type_var(0)], type_row![BOOL_T]);
Expand Down Expand Up @@ -246,13 +283,14 @@ fn extension() -> Extension {
ibinop_sig(),
)
.unwrap();
extension
let iadd = extension
.add_op(
"iadd".into(),
"addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(),
ibinop_sig(),
)
.unwrap();
iadd.set_constant_folder(iadd_fold);
extension
.add_op(
"isub".into(),
Expand Down

0 comments on commit 8ee49da

Please sign in to comment.