Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Move int conversions to conversions ext, add to/from usize #1490

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions hugr-core/src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,15 +351,13 @@ pub enum OpaqueOpError {
#[cfg(test)]
mod test {

use crate::std_extensions::arithmetic::conversions::{self, CONVERT_OPS_REGISTRY};
use crate::{
extension::{
prelude::{BOOL_T, QB_T, USIZE_T},
SignatureFunc,
},
std_extensions::arithmetic::{
int_ops::{self, INT_OPS_REGISTRY},
int_types::INT_TYPES,
},
std_extensions::arithmetic::int_types::INT_TYPES,
types::FuncValueType,
Extension,
};
Expand Down Expand Up @@ -387,10 +385,10 @@ mod test {

#[test]
fn resolve_opaque_op() {
let registry = &INT_OPS_REGISTRY;
let registry = &CONVERT_OPS_REGISTRY;
let i0 = &INT_TYPES[0];
let opaque = OpaqueOp::new(
int_ops::EXTENSION_ID,
conversions::EXTENSION_ID,
"itobool",
"description".into(),
vec![],
Expand Down
165 changes: 132 additions & 33 deletions hugr-core/src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::extension::prelude::{BOOL_T, STRING_TYPE, USIZE_T};
use crate::extension::simple_op::{HasConcrete, HasDef};
use crate::ops::OpName;
use crate::std_extensions::arithmetic::int_ops::int_polytype;
use crate::std_extensions::arithmetic::int_types::int_type;
use crate::{
extension::{
prelude::sum_with_error,
Expand All @@ -12,12 +16,12 @@ use crate::{
},
ops::{custom::ExtensionOp, NamedOp},
type_row,
types::{FuncValueType, PolyFuncTypeRV, TypeArg, TypeRV},
types::{TypeArg, TypeRV},
Extension,
};

use super::int_types::int_tv;
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
use super::float_types::FLOAT64_TYPE;
use super::int_types::{get_log_width, int_tv};
use lazy_static::lazy_static;
mod const_fold;
/// The extension identifier.
Expand All @@ -34,6 +38,12 @@ pub enum ConvertOpDef {
trunc_s,
convert_u,
convert_s,
itobool,
ifrombool,
itostring_u,
itostring_s,
itousize,
ifromusize,
}

impl MakeOpDef for ConvertOpDef {
Expand All @@ -47,18 +57,19 @@ impl MakeOpDef for ConvertOpDef {

fn signature(&self) -> SignatureFunc {
use ConvertOpDef::*;
PolyFuncTypeRV::new(
vec![LOG_WIDTH_TYPE_PARAM],
match self {
trunc_s | trunc_u => FuncValueType::new(
type_row![FLOAT64_TYPE],
TypeRV::from(sum_with_error(int_tv(0))),
),
convert_s | convert_u => {
FuncValueType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE])
}
},
)
match self {
trunc_s | trunc_u => int_polytype(
1,
type_row![FLOAT64_TYPE],
TypeRV::from(sum_with_error(int_tv(0))),
),
convert_s | convert_u => int_polytype(1, vec![int_tv(0)], type_row![FLOAT64_TYPE]),
itobool => int_polytype(0, vec![int_type(0)], vec![BOOL_T]),
ifrombool => int_polytype(0, vec![BOOL_T], vec![int_type(0)]),
itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![STRING_TYPE]),
itousize => int_polytype(0, vec![int_type(6)], vec![USIZE_T]),
ifromusize => int_polytype(0, vec![USIZE_T], vec![int_type(6)]),
}
.into()
}

Expand All @@ -69,6 +80,12 @@ impl MakeOpDef for ConvertOpDef {
trunc_s => "float to signed int",
convert_u => "unsigned int to float",
convert_s => "signed int to float",
itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
itostring_s => "convert a signed integer to its string representation",
itostring_u => "convert an unsigned integer to its string representation",
itousize => "convert a 64b unsigned integer to its usize representation",
ifromusize => "convert a usize to a 64b unsigned integer",
}
.to_string()
}
Expand All @@ -79,19 +96,32 @@ impl MakeOpDef for ConvertOpDef {
}

impl ConvertOpDef {
/// Initialise a conversion op with an integer log width type argument.
/// Initialize a [ConvertOpType] from a [ConvertOpDef] which requires no
/// integer widths set.
pub fn without_log_width(self) -> ConvertOpType {
ConvertOpType {
def: self,
log_widths: vec![],
}
}
/// Initialize a [ConvertOpType] from a [ConvertOpDef] which requires one
/// integer width set.
pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
ConvertOpType {
def: self,
log_width,
log_widths: vec![log_width],
}
}
}
/// Concrete convert operation with integer log width set.
#[derive(Debug, Clone, PartialEq)]
pub struct ConvertOpType {
def: ConvertOpDef,
log_width: u8,
/// The kind of conversion op.
pub def: ConvertOpDef,
/// The integer width parameters of the conversion op. These are interpreted
/// differently, depending on `def`. The integer types in the inputs and
/// outputs of the op will have [int_type]s of these widths.
pub log_widths: Vec<u8>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if its only ever 0, 1, shouldn't this be option?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently only 0 or 1. But yes, I just made the field private so it should be ok to change to an option.

}

impl NamedOp for ConvertOpType {
Expand All @@ -103,20 +133,11 @@ impl NamedOp for ConvertOpType {
impl MakeExtensionOp for ConvertOpType {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def = ConvertOpDef::from_def(ext_op.def())?;
let log_width: u64 = match *ext_op.args() {
[TypeArg::BoundedNat { n }] => n,
_ => return Err(SignatureError::InvalidTypeArgs.into()),
};
Ok(Self {
def,
log_width: u8::try_from(log_width).unwrap(),
})
def.instantiate(ext_op.args())
}

fn type_args(&self) -> Vec<crate::types::TypeArg> {
vec![TypeArg::BoundedNat {
n: self.log_width as u64,
}]
fn type_args(&self) -> Vec<TypeArg> {
self.log_widths.iter().map(|&n| (n as u64).into()).collect()
}
}

Expand Down Expand Up @@ -157,17 +178,95 @@ impl MakeRegisteredOp for ConvertOpType {
}
}

impl HasConcrete for ConvertOpDef {
type Concrete = ConvertOpType;

fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
let log_widths: Vec<u8> = type_args
.iter()
.map(|a| get_log_width(a).map_err(|_| SignatureError::InvalidTypeArgs))
.collect::<Result<_, _>>()?;
Ok(ConvertOpType {
def: *self,
log_widths,
})
}
}

impl HasDef for ConvertOpType {
type Def = ConvertOpDef;
}

#[cfg(test)]
mod test {
use rstest::rstest;

use crate::extension::prelude::ConstUsize;
use crate::ops::Value;
use crate::std_extensions::arithmetic::int_types::ConstInt;
use crate::IncomingPort;

use super::*;

#[test]
fn test_conversions_extension() {
let r = &EXTENSION;
assert_eq!(r.name() as &str, "arithmetic.conversions");
assert_eq!(r.types().count(), 0);
for (name, _) in r.operations() {
assert!(name.as_str().starts_with("convert") || name.as_str().starts_with("trunc"));
}

#[test]
fn test_conversions() {
// Initialization with an invalid number of type arguments should fail.
assert!(
ConvertOpDef::itobool
.with_log_width(1)
.to_extension_op()
.is_none(),
"type arguments invalid"
);

// This should work
let o = ConvertOpDef::itobool.without_log_width();
let ext_op: ExtensionOp = o.clone().to_extension_op().unwrap();

assert_eq!(ConvertOpType::from_op(&ext_op).unwrap(), o);
assert_eq!(
ConvertOpDef::from_op(&ext_op).unwrap(),
ConvertOpDef::itobool
);
}

#[rstest]
#[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])]
#[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])]
#[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])]
#[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])]
#[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])]
#[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])]
fn convert_fold(
#[case] op: ConvertOpType,
#[case] inputs: &[Value],
#[case] outputs: &[Value],
) {
use crate::ops::Value;

let consts: Vec<(IncomingPort, Value)> = inputs
.iter()
.enumerate()
.map(|(i, v)| (i.into(), v.clone()))
.collect();

let res = op
.to_extension_op()
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, expected) in outputs.iter().enumerate() {
let res_val: &Value = &res.get(i).unwrap().1;

assert_eq!(res_val, expected);
}
}
}
Loading
Loading