Skip to content

Commit

Permalink
refactor!: use enum op traits for floats + conversions (#755)
Browse files Browse the repository at this point in the history
BREAKING CHANGES: extension() function replaced with EXTENSION static
ref for float_ops and conversions
  • Loading branch information
ss2165 committed Dec 21, 2023
1 parent 268f120 commit 0bcab0a
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 159 deletions.
2 changes: 1 addition & 1 deletion src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ mod test {
use super::*;

fn test_registry() -> ExtensionRegistry {
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::extension()]).unwrap()
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap()
}

#[test]
Expand Down
164 changes: 116 additions & 48 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,131 @@
//! Conversions between integer and floating-point values.

use smol_str::SmolStr;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::{
extension::{prelude::sum_with_error, ExtensionId, ExtensionSet},
extension::{
prelude::sum_with_error,
simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError},
ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc,
},
ops::{custom::ExtensionOp, OpName},
type_row,
types::{FunctionType, PolyFuncType},
types::{FunctionType, PolyFuncType, TypeArg},
Extension,
};

use super::int_types::int_tv;
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
use lazy_static::lazy_static;

/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");

/// Extension for basic arithmetic operations.
pub fn extension() -> Extension {
let ftoi_sig = PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]),
);

let itof_sig = PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]),
);

let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
]),
);
extension
.add_op(
"trunc_u".into(),
"float to unsigned int".to_owned(),
ftoi_sig.clone(),
)
.unwrap();
extension
.add_op("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig)
.unwrap();
extension
.add_op(
"convert_u".into(),
"unsigned int to float".to_owned(),
itof_sig.clone(),
)
.unwrap();
extension
.add_op(
"convert_s".into(),
"signed int to float".to_owned(),
itof_sig,
)
.unwrap();

extension
/// Extensiop for conversions between floats and integers.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
#[allow(missing_docs, non_camel_case_types)]
pub enum ConvertOpDef {
trunc_u,
trunc_s,
convert_u,
convert_s,
}

impl MakeOpDef for ConvertOpDef {
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
crate::extension::simple_op::try_from_name(op_def.name())
}

fn signature(&self) -> SignatureFunc {
use ConvertOpDef::*;
match self {
trunc_s | trunc_u => PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]),
),

convert_s | convert_u => PolyFuncType::new(
vec![LOG_WIDTH_TYPE_PARAM],
FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]),
),
}
.into()
}

fn description(&self) -> String {
use ConvertOpDef::*;
match self {
trunc_u => "float to unsigned int",
trunc_s => "float to signed int",
convert_u => "unsigned int to float",
convert_s => "signed int to float",
}
.to_string()
}
}

/// Concrete convert operation with integer width set.
#[derive(Debug, Clone, PartialEq)]
pub struct ConvertOpType {
def: ConvertOpDef,
width: u64,
}

impl OpName for ConvertOpType {
fn name(&self) -> SmolStr {
self.def.name()
}
}

impl MakeExtensionOp for ConvertOpType {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def = ConvertOpDef::from_def(ext_op.def())?;
let width = match *ext_op.args() {
[TypeArg::BoundedNat { n }] => n,
_ => return Err(SignatureError::InvalidTypeArgs.into()),
};
Ok(Self { def, width })
}

fn type_args(&self) -> Vec<crate::types::TypeArg> {
vec![TypeArg::BoundedNat { n: self.width }]
}
}

lazy_static! {
/// Extension for conversions between integers and floats.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
]),
);

ConvertOpDef::load_all_ops(&mut extension).unwrap();

extension
};

/// Registry of extensions required to validate integer operations.
pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
super::int_types::EXTENSION.to_owned(),
super::float_types::EXTENSION.to_owned(),
EXTENSION.to_owned(),
])
.unwrap();
}

impl MakeRegisteredOp for ConvertOpType {
fn extension_id(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
&CONVERT_OPS_REGISTRY
}
}

#[cfg(test)]
Expand All @@ -66,7 +134,7 @@ mod test {

#[test]
fn test_conversions_extension() {
let r = extension();
let r = &EXTENSION;
assert_eq!(r.name() as &str, "arithmetic.conversions");
assert_eq!(r.types().count(), 0);
for (name, _) in r.operations() {
Expand Down
Loading

0 comments on commit 0bcab0a

Please sign in to comment.