Skip to content

Commit

Permalink
refactor!: flatten CustomOp in to OpType (#1429)
Browse files Browse the repository at this point in the history
In new workflow most of the time we just want to work with extension
ops, and opaque are just used for serialisation. To that end flatten
`CustomOp` out. If there are `OpaqueOp`s in the HUGR it means the
extension was not available to resolve in to. Will be disallowed in
follow up PR.

Relates to #1362 


BREAKING CHANGE: `CustomOp` removed, `OpType` now contains `ExtensionOp`
and `OpaqueOp` directly. `CustomOpError` renamed to`OpaqueOpError`.
  • Loading branch information
ss2165 committed Aug 14, 2024
1 parent e59464b commit 8e8bba5
Show file tree
Hide file tree
Showing 27 changed files with 554 additions and 722 deletions.
6 changes: 3 additions & 3 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ mod test {
DataflowSubContainer,
},
extension::prelude::BOOL_T,
ops::{custom::OpaqueOp, CustomOp},
ops::custom::OpaqueOp,
type_row,
types::Signature,
};
Expand Down Expand Up @@ -297,13 +297,13 @@ mod test {
#[test]
fn with_nonlinear_and_outputs() {
let missing_ext: ExtensionId = "MissingExt".try_into().unwrap();
let my_custom_op = CustomOp::new_opaque(OpaqueOp::new(
let my_custom_op = OpaqueOp::new(
missing_ext.clone(),
"MyOp",
"unknown op".to_string(),
vec![],
Signature::new(vec![QB, NAT], vec![QB]),
));
);
let build_res = build_main(
Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
.with_extension_delta(ExtensionSet::from_iter([
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ impl CustomConcrete for OpaqueOp {
type Identifier = OpName;

fn def_name(&self) -> &OpName {
self.name()
self.op_name()
}

fn type_args(&self) -> &[TypeArg] {
Expand Down
12 changes: 5 additions & 7 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ pub enum LowerFunc {
FixedHugr {
/// The extensions required by the [`Hugr`]
extensions: ExtensionSet,
/// The [`Hugr`] to be used to replace [CustomOp]s matching the parent
/// The [`Hugr`] to be used to replace [ExtensionOp]s matching the parent
/// [OpDef]
///
/// [CustomOp]: crate::ops::CustomOp
/// [ExtensionOp]: crate::ops::ExtensionOp
hugr: Hugr,
},
/// Custom binary function that can (fallibly) compute a Hugr
Expand Down Expand Up @@ -495,7 +495,7 @@ pub(super) mod test {
use crate::extension::prelude::USIZE_T;
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY};
use crate::ops::{CustomOp, OpName};
use crate::ops::OpName;
use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME};
use crate::types::type_param::{TypeArgError, TypeParam};
use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
Expand Down Expand Up @@ -615,10 +615,8 @@ pub(super) mod test {
Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: USIZE_T }])?);
let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
let rev = dfg.add_dataflow_op(
CustomOp::new_extension(
e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], &reg)
.unwrap(),
),
e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], &reg)
.unwrap(),
dfg.input_wires(),
)?;
dfg.finish_hugr_with_outputs(rev.outputs(), &reg)?;
Expand Down
5 changes: 2 additions & 3 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use lazy_static::lazy_static;

use crate::ops::constant::{CustomCheckFailure, ValueName};
use crate::ops::{CustomOp, OpName};
use crate::ops::{ExtensionOp, OpName};
use crate::types::{FuncValueType, SumType, TypeName};
use crate::{
extension::{ExtensionId, TypeDefBound},
Expand Down Expand Up @@ -200,7 +200,7 @@ pub const NEW_ARRAY_OP_ID: OpName = OpName::new_inline("new_array");
pub const PANIC_OP_ID: OpName = OpName::new_inline("panic");

/// Initialize a new array op of element type `element_ty` of length `size`
pub fn new_array_op(element_ty: Type, size: u64) -> CustomOp {
pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp {
PRELUDE
.instantiate_extension_op(
&NEW_ARRAY_OP_ID,
Expand All @@ -211,7 +211,6 @@ pub fn new_array_op(element_ty: Type, size: u64) -> CustomOp {
&PRELUDE_REGISTRY,
)
.unwrap()
.into()
}

/// Name of the string type.
Expand Down
22 changes: 8 additions & 14 deletions hugr-core/src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

use strum::IntoEnumIterator;

use crate::ops::{CustomOp, OpName, OpNameRef};
use crate::ops::{ExtensionOp, OpName, OpNameRef};
use crate::{
ops::{custom::ExtensionOp, NamedOp, OpType},
ops::{NamedOp, OpType},
types::TypeArg,
Extension,
};
Expand Down Expand Up @@ -87,14 +87,11 @@ pub trait MakeOpDef: NamedOp {
}

/// If the definition can be loaded from a string, load from an [ExtensionOp].
fn from_op(custom_op: &CustomOp) -> Result<Self, OpLoadError>
fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized + std::str::FromStr,
{
match custom_op {
CustomOp::Extension(ext) => Self::from_extension_op(ext),
CustomOp::Opaque(opaque) => try_from_name(opaque.name(), opaque.extension()),
}
Self::from_extension_op(ext_op)
}
}

Expand All @@ -112,15 +109,12 @@ pub trait HasDef: MakeExtensionOp {
/// Associated [HasConcrete] type.
type Def: HasConcrete<Concrete = Self> + std::str::FromStr;

/// Load the operation from a [CustomOp].
fn from_op(custom_op: &CustomOp) -> Result<Self, OpLoadError>
/// Load the operation from a [ExtensionOp].
fn from_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
match custom_op {
CustomOp::Extension(ext) => Self::from_extension_op(ext),
CustomOp::Opaque(opaque) => Self::Def::from_op(custom_op)?.instantiate(opaque.args()),
}
Self::from_extension_op(ext_op)
}
}

Expand All @@ -137,7 +131,7 @@ pub trait MakeExtensionOp: NamedOp {
where
Self: Sized,
{
let ext: &ExtensionOp = op.as_custom_op()?.as_extension_op()?;
let ext: &ExtensionOp = op.as_extension_op()?;
Self::from_extension_op(ext).ok()
}

Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ mod test {
}
fn extension_ops(h: &impl HugrView) -> Vec<Node> {
h.nodes()
.filter(|n| matches!(h.get_optype(*n), OpType::CustomOp(_)))
.filter(|n| matches!(h.get_optype(*n), OpType::ExtensionOp(_)))
.collect()
}

Expand Down
22 changes: 6 additions & 16 deletions hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ mod test {
use crate::hugr::internal::HugrMutInternals;
use crate::hugr::rewrite::replace::WhichHugr;
use crate::hugr::{HugrMut, Rewrite};
use crate::ops::custom::{CustomOp, OpaqueOp};
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG};
Expand All @@ -477,14 +477,12 @@ mod test {
.instantiate([TypeArg::Type { ty: USIZE_T }])
.unwrap(),
);
let pop: CustomOp = collections::EXTENSION
let pop: ExtensionOp = collections::EXTENSION
.instantiate_extension_op("pop", [TypeArg::Type { ty: USIZE_T }], &reg)
.unwrap()
.into();
let push: CustomOp = collections::EXTENSION
.unwrap();
let push: ExtensionOp = collections::EXTENSION
.instantiate_extension_op("push", [TypeArg::Type { ty: USIZE_T }], &reg)
.unwrap()
.into();
.unwrap();
let just_list = TypeRow::from(vec![listy.clone()]);
let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]);

Expand Down Expand Up @@ -643,15 +641,7 @@ mod test {
fn test_invalid() {
let unknown_ext: ExtensionId = "unknown_ext".try_into().unwrap();
let utou = Signature::new_endo(vec![USIZE_T]);
let mk_op = |s| {
CustomOp::new_opaque(OpaqueOp::new(
unknown_ext.clone(),
s,
String::new(),
vec![],
utou.clone(),
))
};
let mk_op = |s| OpaqueOp::new(unknown_ext.clone(), s, String::new(), vec![], utou.clone());
let mut h = DFGBuilder::new(
Signature::new(type_row![USIZE_T, BOOL_T], type_row![USIZE_T])
.with_extension_delta(unknown_ext.clone()),
Expand Down
23 changes: 20 additions & 3 deletions hugr-core/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,14 @@ pub fn check_hugr(lhs: &Hugr, rhs: &Hugr) {
let new_op = rhs.get_optype(node);
let old_op = h_canon.get_optype(node);
if !new_op.is_const() {
assert_eq!(new_op, old_op);
match (new_op, old_op) {
(OpType::ExtensionOp(ext), OpType::OpaqueOp(opaque))
| (OpType::OpaqueOp(opaque), OpType::ExtensionOp(ext)) => {
let ext_opaque: OpaqueOp = ext.clone().into();
assert_eq!(ext_opaque, opaque.clone());
}
_ => assert_eq!(new_op, old_op),
}
}
}

Expand Down Expand Up @@ -533,7 +540,7 @@ fn std_extensions_valid() {
mod proptest {
use super::check_testing_roundtrip;
use super::{NodeSer, SimpleOpDef};
use crate::ops::{OpType, Value};
use crate::ops::{OpType, OpaqueOp, Value};
use crate::types::{PolyFuncTypeRV, Type};
use proptest::prelude::*;

Expand All @@ -545,7 +552,17 @@ mod proptest {
(0..i32::MAX as usize).prop_map(|x| portgraph::NodeIndex::new(x).into()),
any::<OpType>(),
)
.prop_map(|(parent, op)| NodeSer { parent, op })
.prop_map(|(parent, op)| {
if let OpType::ExtensionOp(ext_op) = op {
let opaque: OpaqueOp = ext_op.into();
NodeSer {
parent,
op: opaque.into(),
}
} else {
NodeSer { parent, op }
}
})
.boxed()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
},
{
"parent": 0,
"op": "CustomOp",
"op": "Extension",
"extension": "logic",
"name": "And",
"description": "logical 'and'",
Expand Down
57 changes: 23 additions & 34 deletions hugr-core/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use thiserror::Error;
use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED};

use crate::ops::constant::ConstTypeError;
use crate::ops::custom::{resolve_opaque_op, CustomOp, CustomOpError};
use crate::ops::custom::{resolve_opaque_op, ExtensionOp, OpaqueOpError};
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::{FuncDefn, OpParent, OpTag, OpTrait, OpType, ValidateOp};
use crate::types::type_param::TypeParam;
Expand Down Expand Up @@ -566,37 +566,26 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
) -> Result<(), ValidationError> {
let op_type = self.hugr.get_optype(node);
// The op_type must be defined only in terms of type variables defined outside the node
// TODO consider turning this match into a trait method?

let validate_ext = |ext_op: &ExtensionOp| -> Result<(), ValidationError> {
// Check TypeArgs are valid, and if we can, fit the declared TypeParams
ext_op
.def()
.validate_args(ext_op.args(), self.extension_registry, var_decls)
.map_err(|cause| ValidationError::SignatureError { node, cause })
};
match op_type {
OpType::CustomOp(op) => {
OpType::ExtensionOp(ext_op) => validate_ext(ext_op)?,
OpType::OpaqueOp(opaque) => {
// Try to resolve serialized names to actual OpDefs in Extensions.
let temp: CustomOp;
let resolved = match op {
CustomOp::Opaque(opaque) => {
// If resolve_extension_ops has been called first, this would always return Ok(None)
match resolve_opaque_op(node, opaque, self.extension_registry)? {
Some(exten) => {
temp = CustomOp::new_extension(exten);
&temp
}
None => op,
}
}
CustomOp::Extension(_) => op,
};
// Check TypeArgs are valid, and if we can, fit the declared TypeParams
match resolved {
CustomOp::Extension(exten) => exten
.def()
.validate_args(exten.args(), self.extension_registry, var_decls)
.map_err(|cause| ValidationError::SignatureError { node, cause })?,
CustomOp::Opaque(opaque) => {
// Best effort. Just check TypeArgs are valid in themselves, allowing any of them
// to contain type vars (we don't know how many are binary params, so accept if in doubt)
for arg in opaque.args() {
arg.validate(self.extension_registry, var_decls)
.map_err(|cause| ValidationError::SignatureError { node, cause })?;
}
if let Some(ext_op) = resolve_opaque_op(node, opaque, self.extension_registry)? {
validate_ext(&ext_op)?;
} else {
// Best effort. Just check TypeArgs are valid in themselves, allowing any of them
// to contain type vars (we don't know how many are binary params, so accept if in doubt)
for arg in opaque.args() {
arg.validate(self.extension_registry, var_decls)
.map_err(|cause| ValidationError::SignatureError { node, cause })?;
}
}
}
Expand Down Expand Up @@ -748,12 +737,12 @@ pub enum ValidationError {
#[source]
cause: SignatureError,
},
/// Error in a [CustomOp] serialized as an [Opaque].
/// Error in a [ExtensionOp] serialized as an [Opaque].
///
/// [CustomOp]: crate::ops::CustomOp
/// [Opaque]: crate::ops::CustomOp::Opaque
/// [ExtensionOp]: crate::ops::ExtensionOp
/// [Opaque]: crate::ops::OpaqueOp
#[error(transparent)]
CustomOpError(#[from] CustomOpError),
OpaqueOpError(#[from] OpaqueOpError),
/// A [Const] contained a [Value] of unexpected [Type].
///
/// [Const]: crate::ops::Const
Expand Down
12 changes: 8 additions & 4 deletions hugr-core/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use enum_dispatch::enum_dispatch;

pub use constant::{Const, Value};
pub use controlflow::{BasicBlock, Case, Conditional, DataflowBlock, ExitBlock, TailLoop, CFG};
pub use custom::CustomOp;
pub use custom::{ExtensionOp, OpaqueOp};
pub use dataflow::{
Call, CallIndirect, DataflowOpTrait, DataflowParent, Input, LoadConstant, LoadFunction, Output,
DFG,
Expand Down Expand Up @@ -53,7 +53,10 @@ pub enum OpType {
LoadConstant,
LoadFunction,
DFG,
CustomOp,
#[serde(skip_deserializing, rename = "Extension")]
ExtensionOp,
#[serde(rename = "Extension")]
OpaqueOp,
Noop,
MakeTuple,
UnpackTuple,
Expand Down Expand Up @@ -112,7 +115,7 @@ impl_op_ref_try_into!(CallIndirect);
impl_op_ref_try_into!(LoadConstant);
impl_op_ref_try_into!(LoadFunction);
impl_op_ref_try_into!(DFG, dfg);
impl_op_ref_try_into!(CustomOp);
impl_op_ref_try_into!(ExtensionOp);
impl_op_ref_try_into!(Noop);
impl_op_ref_try_into!(MakeTuple);
impl_op_ref_try_into!(UnpackTuple);
Expand Down Expand Up @@ -427,7 +430,8 @@ impl OpParent for Call {}
impl OpParent for CallIndirect {}
impl OpParent for LoadConstant {}
impl OpParent for LoadFunction {}
impl OpParent for CustomOp {}
impl OpParent for ExtensionOp {}
impl OpParent for OpaqueOp {}
impl OpParent for Noop {}
impl OpParent for MakeTuple {}
impl OpParent for UnpackTuple {}
Expand Down
Loading

0 comments on commit 8e8bba5

Please sign in to comment.