diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 0c5ddb55b..594a8019e 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -1,35 +1,27 @@ //! Prelude extension - available in all contexts, defining common types, //! operations and constants. +use itertools::Itertools; use lazy_static::lazy_static; -use crate::extension::simple_op::MakeOpDef; -use crate::ops::constant::{CustomCheckFailure, ValueName}; -use crate::ops::{ExtensionOp, OpName}; -use crate::types::{FuncValueType, SumType, TypeName, TypeRV}; -use crate::{ - extension::{ExtensionId, TypeDefBound}, - ops::constant::CustomConst, - type_row, - types::{ - type_param::{TypeArg, TypeParam}, - CustomType, PolyFuncTypeRV, Signature, Type, TypeBound, - }, - Extension, +use crate::extension::const_fold::fold_out_row; +use crate::extension::simple_op::{ + try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; +use crate::extension::{ + ConstFold, ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDefBound, +}; +use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName}; +use crate::ops::{ExtensionOp, NamedOp, OpName, Value}; +use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::{ + CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeBound, + TypeName, TypeRV, TypeRow, TypeRowRV, +}; +use crate::utils::sorted_consts; +use crate::{type_row, Extension}; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; -use crate::{ - extension::{ - const_fold::fold_out_row, - simple_op::{try_from_name, MakeExtensionOp, MakeRegisteredOp, OpLoadError}, - ConstFold, ExtensionSet, OpDef, SignatureError, SignatureFunc, - }, - ops::{NamedOp, Value}, - types::{PolyFuncType, TypeRow}, - utils::sorted_consts, -}; - use super::{ExtensionRegistry, SignatureFromArgs}; struct ArrayOpCustom; @@ -255,8 +247,102 @@ pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE); pub const ERROR_TYPE_NAME: TypeName = TypeName::new_inline("error"); /// Return a Sum type with the first variant as the given type and the second an Error. -pub fn sum_with_error(ty: Type) -> SumType { - SumType::new([ty, ERROR_TYPE]) +pub fn sum_with_error(ty: impl Into) -> SumType { + either_type(ty, ERROR_TYPE) +} + +/// An optional type, i.e. a Sum type with the first variant as the given type and the second as an empty tuple. +#[inline] +pub fn option_type(ty: impl Into) -> SumType { + either_type(ty, TypeRow::new()) +} + +/// An "either" type, i.e. a Sum type with a "left" and a "right" variant. +/// +/// When used as a fallible value, the "left" variant represents a successful computation, +/// and the "right" variant represents a failure. +#[inline] +pub fn either_type(ty_ok: impl Into, ty_err: impl Into) -> SumType { + SumType::new([ty_ok.into(), ty_err.into()]) +} + +/// A constant optional value with a given value. +/// +/// See [option_type]. +pub fn const_some(value: Value) -> Value { + const_some_tuple([value]) +} + +/// A constant optional value with a row of values. +/// +/// For single values, use [const_some]. +/// +/// See [option_type]. +pub fn const_some_tuple(values: impl IntoIterator) -> Value { + const_left_tuple(values, TypeRow::new()) +} + +/// A constant optional value with no value. +/// +/// See [option_type]. +pub fn const_none(ty: impl Into) -> Value { + const_right_tuple(ty, []) +} + +/// A constant Either value with a left variant. +/// +/// In fallible computations, this represents a successful result. +/// +/// See [either_type]. +pub fn const_left(value: Value, ty_right: impl Into) -> Value { + const_left_tuple([value], ty_right) +} + +/// A constant Either value with a row of left values. +/// +/// In fallible computations, this represents a successful result. +/// +/// See [either_type]. +pub fn const_left_tuple( + values: impl IntoIterator, + ty_right: impl Into, +) -> Value { + let values = values.into_iter().collect_vec(); + let types: TypeRowRV = values + .iter() + .map(|v| TypeRV::from(v.get_type())) + .collect_vec() + .into(); + let typ = either_type(types, ty_right); + Value::sum(0, values, typ).unwrap() +} + +/// A constant Either value with a right variant. +/// +/// In fallible computations, this represents a failure. +/// +/// See [either_type]. +pub fn const_right(ty_left: impl Into, value: Value) -> Value { + const_right_tuple(ty_left, [value]) +} + +/// A constant Either value with a row of right values. +/// +/// In fallible computations, this represents a failure. +/// +/// See [either_type]. +pub fn const_right_tuple( + ty_left: impl Into, + values: impl IntoIterator, +) -> Value { + let values = values.into_iter().collect_vec(); + let types: TypeRowRV = values + .iter() + .map(|v| TypeRV::from(v.get_type())) + .collect_vec() + .into(); + let typ = either_type(ty_left, types); + Value::sum(1, values, typ).unwrap() } #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] @@ -806,6 +892,8 @@ impl MakeRegisteredOp for Lift { #[cfg(test)] mod test { + use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; + use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use crate::{ builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr}, utils::test_quantum_extension::cx_gate, @@ -897,6 +985,35 @@ mod test { b.finish_prelude_hugr_with_outputs(out.outputs()).unwrap(); } + #[test] + fn test_option() { + let typ: Type = option_type(BOOL_T).into(); + let const_val1 = const_some(Value::true_val()); + let const_val2 = const_none(BOOL_T); + + let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap(); + + let some = b.add_load_value(const_val1); + let none = b.add_load_value(const_val2); + + b.finish_prelude_hugr_with_outputs([some, none]).unwrap(); + } + + #[test] + fn test_result() { + let typ: Type = either_type(BOOL_T, FLOAT64_TYPE).into(); + let const_bool = const_left(Value::true_val(), FLOAT64_TYPE); + let const_float = const_right(BOOL_T, ConstF64::new(0.5).into()); + + let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap(); + + let bool = b.add_load_value(const_bool); + let float = b.add_load_value(const_float); + + b.finish_hugr_with_outputs([bool, float], &FLOAT_OPS_REGISTRY) + .unwrap(); + } + #[test] /// test the prelude error type and panic op. fn test_error_type() { diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index c6544b5bf..045198b74 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -9,6 +9,8 @@ from hugr.utils import ser_it if TYPE_CHECKING: + from collections.abc import Iterable + from hugr import ext @@ -303,6 +305,52 @@ def __repr__(self) -> str: return f"Tuple{tuple(self.variant_rows[0])}" +@dataclass(eq=False) +class Option(Sum): + """Optional tuple of elements. + + Instances of this type correspond to :class:`Sum` with two variants. + The first variant is the tuple of elements, the second is empty. + """ + + def __init__(self, *tys: Type): + self.variant_rows = [list(tys), []] + + def __repr__(self) -> str: + return f"Option({', '.join(map(repr, self.variant_rows[0]))})" + + +@dataclass(eq=False) +class Either(Sum): + """Two-variant tuple of elements. + + Instances of this type correspond to :class:`Sum` with a Left and a Right variant. + + In fallible contexts, the Left variant is used to represent success, and the + Right variant is used to represent failure. + + Example: + >>> either = Either([Bool, Bool], [Bool]) + >>> either + Either(left=[Bool, Bool], right=[Bool]) + >>> str(either) + 'Either((Bool, Bool), Bool)' + """ + + def __init__(self, left: Iterable[Type], right: Iterable[Type]): + self.variant_rows = [list(left), list(right)] + + def __repr__(self) -> str: # pragma: no cover + left, right = self.variant_rows + return f"Either(left={left}, right={right})" + + def __str__(self) -> str: + left, right = self.variant_rows + left_str = left[0] if len(left) == 1 else tuple(left) + right_str = right[0] if len(right) == 1 else tuple(right) + return f"Either({left_str}, {right_str})" + + @dataclass(frozen=True) class Variable(Type): """A type variable with a given bound, identified by index.""" diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index 210b19dff..7757636b1 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -11,6 +11,8 @@ from hugr.utils import ser_it if TYPE_CHECKING: + from collections.abc import Iterable + from hugr.hugr import Hugr @@ -149,6 +151,133 @@ def __repr__(self) -> str: return f"Tuple({', '.join(map(repr, self.vals))})" +@dataclass +class Some(Sum): + """Optional tuple of value, containing a list of values. + + Example: + >>> some = Some(TRUE, FALSE) + >>> some + Some(TRUE, FALSE) + >>> str(some) + 'Some(TRUE, FALSE)' + >>> some.type_() + Option(Bool, Bool) + + """ + + #: The values of this tuple. + vals: list[Value] + + def __init__(self, *vals: Value): + val_list = list(vals) + super().__init__( + tag=0, typ=tys.Option(*(v.type_() for v in val_list)), vals=val_list + ) + + def __repr__(self) -> str: + return f"Some({', '.join(map(repr, self.vals))})" + + +@dataclass +class None_(Sum): + """Optional tuple of value, containing no values. + + Example: + >>> none = None_(tys.Bool) + >>> none + None(Bool) + >>> str(none) + 'None' + >>> none.type_() + Option(Bool) + + """ + + def __init__(self, *types: tys.Type): + super().__init__(tag=1, typ=tys.Option(*types), vals=[]) + + def __repr__(self) -> str: + return f"None({', '.join(map(repr, self.typ.variant_rows[0]))})" + + def __str__(self) -> str: + return "None" + + +@dataclass +class Left(Sum): + """Left variant of a :class:`tys.Either` type, containing a list of values. + + In fallible contexts, this represents the success variant. + + Example: + >>> left = Left([TRUE, FALSE], [tys.Bool]) + >>> left + Left(vals=[TRUE, FALSE], right_typ=[Bool]) + >>> str(left) + 'Left(TRUE, FALSE)' + >>> str(left.type_()) + 'Either((Bool, Bool), Bool)' + """ + + #: The values of this tuple. + vals: list[Value] + + def __init__(self, vals: Iterable[Value], right_typ: Iterable[tys.Type]): + val_list = list(vals) + super().__init__( + tag=0, + typ=tys.Either([v.type_() for v in val_list], right_typ), + vals=val_list, + ) + + def __repr__(self) -> str: + _, right_typ = self.typ.variant_rows + return f"Left(vals={self.vals}, right_typ={list(right_typ)})" + + def __str__(self) -> str: + vals_str = ", ".join(map(str, self.vals)) + return f"Left({vals_str})" + + +@dataclass +class Right(Sum): + """Right variant of a :class:`tys.Either` type, containing a list of values. + + In fallible contexts, this represents the failure variant. + + Internally a :class:`Sum` with two variant rows. + + Example: + >>> right = Right([tys.Bool, tys.Bool, tys.Bool], [TRUE, FALSE]) + >>> right + Right(left_typ=[Bool, Bool, Bool], vals=[TRUE, FALSE]) + >>> str(right) + 'Right(TRUE, FALSE)' + >>> str(right.type_()) + 'Either((Bool, Bool, Bool), (Bool, Bool))' + """ + + #: The values of this tuple. + vals: list[Value] + + def __init__(self, left_typ: Iterable[tys.Type], vals: Iterable[Value]): + val_list = list(vals) + super().__init__( + tag=1, + typ=tys.Either(left_typ, [v.type_() for v in val_list]), + vals=val_list, + ) + + def __repr__(self) -> str: + left_typ, _ = self.typ.variant_rows + return f"Right(left_typ={list(left_typ)}, vals={self.vals})" + + def __str__(self) -> str: + vals_str = ", ".join(map(str, self.vals)) + return f"Right({vals_str})" + + @dataclass class Function(Value): """Higher order function value, defined by a :class:`Hugr `."""