diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index ebfbec14f9..32e0ef0e9c 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -305,7 +305,7 @@ def interpret_operator_eqn(self, eqn: jax.core.JaxprEqn) -> None: *wires, *invals, op=eqn.primitive.name, - params_len=len(eqn.invars) - eqn.params["n_wires"], + ctrl_value_len=0, **kwargs, ) diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py index a720ac97ae..1a5823f8f8 100644 --- a/frontend/catalyst/jax_extras/tracing.py +++ b/frontend/catalyst/jax_extras/tracing.py @@ -969,9 +969,9 @@ def bind_flexible_primitive(primitive, flexible_args: dict[str, Any], *dyn_args, The flexible_args is a dictionary containing the flexible arguments. These are the arguments that can either be static or dynamic. This method - will bind a flexible argument as static only if it is an integer, float, or boolean - literal. In the static case, the binded primitive's param name is the flexible arg's key, - and the jaxpr param value is the flexible arg's value. + will bind a flexible argument as static only if it is a single or a list of only integer, float, + or boolean literals. In the static case, the binded primitive's param name is the flexible arg's + key, and the jaxpr param value is the flexible arg's value. If a flexible argument is received as a tracer, it will be binded dynamically with the flexible arg's value. @@ -984,7 +984,12 @@ def bind_flexible_primitive(primitive, flexible_args: dict[str, Any], *dyn_args, for flex_arg_name, flex_arg_value in flexible_args.items(): if type(flex_arg_value) in static_literal_pool: - static_args |= {flex_arg_name: flex_arg_value} + static_args[flex_arg_name] = flex_arg_value + elif isinstance(flex_arg_value, list): + if flex_arg_value and all(type(arg) in static_literal_pool for arg in flex_arg_value): + static_args[flex_arg_name] = flex_arg_value + else: + dyn_args += (*flex_arg_value,) else: dyn_args += (flex_arg_value,) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 3032bd09d3..9f58dff29b 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -86,6 +86,7 @@ SetBasisStateOp, SetStateOp, StateOp, + StaticCustomOp, TensorOp, VarianceOp, ) @@ -972,13 +973,17 @@ def _gphase_lowering( # @qinst_p.def_abstract_eval def _qinst_abstract_eval( - *qubits_or_params, op=None, qubits_len=0, params_len=0, ctrl_len=0, adjoint=False + *qubits_or_params, + op=None, + qubits_len=0, + ctrl_len=0, + ctrl_value_len=0, + adjoint=False, + static_params=None, ): # The signature here is: (using * to denote zero or more) - # qubits*, params*, ctrl_qubits*, ctrl_values* - qubits = qubits_or_params[:qubits_len] - ctrl_qubits = qubits_or_params[-2 * ctrl_len : -ctrl_len] - all_qubits = qubits + ctrl_qubits + # qubits*, ctrl_qubits*, ctrl_values*, params* + all_qubits = qubits_or_params[: qubits_len + ctrl_len] for idx in range(qubits_len + ctrl_len): qubit = all_qubits[idx] assert isinstance(qubit, AbstractQbit) @@ -996,17 +1001,19 @@ def _qinst_lowering( *qubits_or_params, op=None, qubits_len=0, - params_len=0, ctrl_len=0, + ctrl_value_len=0, adjoint=False, + static_params=None, ): + assert ctrl_value_len == ctrl_len, "Control values must be the same length as control qubits" ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True qubits = qubits_or_params[:qubits_len] - params = qubits_or_params[qubits_len : qubits_len + params_len] - ctrl_qubits = qubits_or_params[qubits_len + params_len : qubits_len + params_len + ctrl_len] - ctrl_values = qubits_or_params[qubits_len + params_len + ctrl_len :] + ctrl_qubits = qubits_or_params[qubits_len : qubits_len + ctrl_len] + ctrl_values = qubits_or_params[qubits_len + ctrl_len : qubits_len + ctrl_len + ctrl_value_len] + params = qubits_or_params[qubits_len + ctrl_len + ctrl_value_len :] for qubit in qubits: assert ir.OpaqueType.isinstance(qubit.type) @@ -1029,13 +1036,31 @@ def _qinst_lowering( p = TensorExtractOp(ir.IntegerType.get_signless(1), v, []).result ctrl_values_i1.append(p) + params_attr = ( + None + if not static_params + else ir.DenseF64ArrayAttr.get([ir.FloatAttr.get_f64(val) for val in static_params]) + ) + if len(float_params) > 0: + assert ( + params_attr is None + ), "Static parameters are not allowed when having dynamic parameters" + name_attr = ir.StringAttr.get(op) name_str = str(name_attr) name_str = name_str.replace('"', "") if name_str == "MultiRZ": - assert len(float_params) == 1, "MultiRZ takes one float parameter" - float_param = float_params[0] + assert len(float_params) <= 1, "MultiRZ takes at most one dynamic float parameter" + assert ( + not static_params or len(static_params) <= 1 + ), "MultiRZ takes at most one static float parameter" + # TODO: Add support for MultiRZ with static params + float_param = ( + TensorExtractOp(ir.F64Type.get(), mlir.ir_constant(static_params[0]), []) + if len(float_params) == 0 + else float_params[0] + ) return MultiRZOp( out_qubits=[qubit.type for qubit in qubits], out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], @@ -1045,7 +1070,17 @@ def _qinst_lowering( in_ctrl_values=ctrl_values_i1, adjoint=adjoint, ).results - + if params_attr: + return StaticCustomOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + static_params=params_attr, + in_qubits=qubits, + gate_name=name_attr, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results return CustomOp( out_qubits=[qubit.type for qubit in qubits], out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 3e7f2fea29..2eb4dc7396 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -691,12 +691,14 @@ def bind_native_operation(qrp, op, controlled_wires, controlled_values, adjoint= else: qubits = qrp.extract(op.wires) controlled_qubits = qrp.extract(controlled_wires) - qubits2 = qinst_p.bind( - *[*qubits, *op.parameters, *controlled_qubits, *controlled_values], + qubits2 = bind_flexible_primitive( + qinst_p, + {"static_params": op.parameters}, + *[*qubits, *controlled_qubits, *controlled_values], op=op.name, qubits_len=len(qubits), - params_len=len(op.parameters), ctrl_len=len(controlled_qubits), + ctrl_value_len=len(controlled_values), adjoint=adjoint, ) qrp.insert(op.wires, qubits2[: len(qubits)]) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index fc2c5efcbc..4336010600 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -169,6 +169,8 @@ def get_enforce_runtime_invariants_stage(_options: CompileOptions) -> List[str]: "split-multiple-tapes", # Run the transform sequence defined in the MLIR module "builtin.module(apply-transform-sequence)", + # Lower the static custom ops to regular custom ops with dynamic parameters. + "static-custom-lowering", # Nested modules are something that will be used in the future # for making device specific transformations. # Since at the moment, nothing in the runtime is using them diff --git a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py index 0dcc8dd992..adc6545473 100644 --- a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py +++ b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py @@ -428,10 +428,17 @@ def change_instruction(ctx, eqn): op = params["op"] cuda_inst_name = from_catalyst_to_cuda[op] qubits_len = params["qubits_len"] + static_params = params.get("static_params") # Now, we can map to the correct op # For now just assume rx - cuda_inst(ctx.kernel, *qubits_or_params, inst=cuda_inst_name, qubits_len=qubits_len) + cuda_inst( + ctx.kernel, + *qubits_or_params, + inst=cuda_inst_name, + qubits_len=qubits_len, + static_params=static_params, + ) # Finally determine how many are qubits. qubits = qubits_or_params[:qubits_len] diff --git a/frontend/catalyst/third_party/cuda/primitives/__init__.py b/frontend/catalyst/third_party/cuda/primitives/__init__.py index e1679d3bcb..5c00100308 100644 --- a/frontend/catalyst/third_party/cuda/primitives/__init__.py +++ b/frontend/catalyst/third_party/cuda/primitives/__init__.py @@ -301,27 +301,33 @@ def make_primitive_for_gate(): kernel_gate_p = jax.core.Primitive("kernel_inst") kernel_gate_p.multiple_results = True - def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1): + def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None): """Convenience. Quantum operations in CUDA-quantum return no values. But JAXPR expects return values. We can just say that multiple_results = True and return an empty tuple. """ - kernel_gate_p.bind(kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len) + kernel_gate_p.bind( + kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len, static_params=static_params + ) return tuple() @kernel_gate_p.def_impl - def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1): + def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None): """Concrete implementation.""" assert inst and qubits_len > 0 + if static_params is None: + static_params = [] method = getattr(cudaq.Kernel, inst) targets = qubits_or_params[:qubits_len] params = qubits_or_params[qubits_len:] + if not params: + params = static_params method(kernel, *params, *targets) return tuple() @kernel_gate_p.def_abstract_eval - def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1): + def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1, static_params=None): """Abstract evaluation.""" return tuple() diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 04263d33ec..522a453013 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -133,10 +133,8 @@ def test_decompose_s(): @qml.qnode(dev) # CHECK-LABEL: public @jit_decompose_s def decompose_s(): - # CHECK-NOT: name="S" - # CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64 # CHECK-NOT: name = "S" - # CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]]) + # CHECK: {{%.+}} = quantum.static_custom "PhaseShift" [1.570796e+00] # CHECK-NOT: name = "S" qml.S(wires=0) return measure(wires=0) diff --git a/frontend/test/lit/test_if_else.py b/frontend/test/lit/test_if_else.py index 0aec735804..bd586c7c62 100644 --- a/frontend/test/lit/test_if_else.py +++ b/frontend/test/lit/test_if_else.py @@ -94,7 +94,7 @@ def circuit_single_gate(n: int): # CHECK: [[b6:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t6]] # CHECK: [[qreg_out1:%.+]] = scf.if [[b6]] # CHECK-DAG: [[q4:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out]] - # CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.custom "RX"({{%.+}}) [[q4]] + # CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.static_custom "RX" [3.140000e+00] [[q4]] # CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q5]] # CHECK: scf.yield [[qreg_3]] # CHECK: else diff --git a/frontend/test/lit/test_measurements.py b/frontend/test/lit/test_measurements.py index eed5161843..21e4e300a0 100644 --- a/frontend/test/lit/test_measurements.py +++ b/frontend/test/lit/test_measurements.py @@ -38,7 +38,7 @@ def sample1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = quantum.custom "RZ" + # COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliZ] @@ -54,7 +54,7 @@ def sample2(x: float, y: float): qml.RX(x, wires=0) # COM: CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = quantum.custom "RZ" + # COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs1:%.+]] = quantum.namedobs [[q1]][ PauliX] @@ -77,7 +77,7 @@ def sample3(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]] @@ -145,7 +145,7 @@ def test_sample_dynamic(shots: int): def counts1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = quantum.custom "RZ" + # COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliZ] @@ -160,7 +160,7 @@ def counts2(x: float, y: float): qml.RX(x, wires=0) # COM: CHECK: [[q1:%.+]] = "quantum.custom"({{%.+}}, {{%.+}}) {gate_name = "RY" qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = "quantum.custom"({{%.+}}, {{%.+}}) {gate_name = "RZ" + # COM: CHECK: [[q0:%.+]] = "quantum.static_custom"({{%.+}}, {{%.+}}) {gate_name = "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs1:%.+]] = "quantum.namedobs"([[q1]]) {type = #quantum} @@ -183,7 +183,7 @@ def counts3(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]] @@ -228,7 +228,7 @@ def test_counts_dynamic(shots: int): def expval1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -247,7 +247,7 @@ def expval2(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) # CHECK: [[p1:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -304,7 +304,7 @@ def expval5(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) B = np.array( @@ -334,7 +334,7 @@ def expval5(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) coeffs = np.array([0.2, -0.543]) @@ -426,7 +426,7 @@ def expval9(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) # CHECK: [[p1:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -448,7 +448,7 @@ def expval10(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) B = np.array( @@ -476,7 +476,7 @@ def expval10(x: float, y: float): def var1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -519,7 +519,7 @@ def probs1(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # qml.probs() # unsupported by PennyLane @@ -540,7 +540,7 @@ def state1(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # qml.state(wires=[0]) # unsupported by PennyLane diff --git a/frontend/test/lit/test_quantum_control.py b/frontend/test/lit/test_quantum_control.py index 3cd794dc9e..730692c855 100644 --- a/frontend/test/lit/test_quantum_control.py +++ b/frontend/test/lit/test_quantum_control.py @@ -98,7 +98,7 @@ def test_native_controlled_custom(): @qml.qnode(dev) # CHECK-LABEL: public @jit_native_controlled def native_controlled(): - # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot" + # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.static_custom "Rot" # CHECK-SAME: ctrls # CHECK-SAME: ctrlvals(%true, %true) qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2]) diff --git a/frontend/test/lit/test_static_circuit.py b/frontend/test/lit/test_static_circuit.py new file mode 100644 index 0000000000..40de0c68d3 --- /dev/null +++ b/frontend/test/lit/test_static_circuit.py @@ -0,0 +1,68 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# RUN: %PYTHON %s | FileCheck %s + +""" +Test quantum circuits with static (knonw at compile time) specifications. +""" + +import pennylane as qml + +from catalyst import qjit + + +def test_static_params(): + """Test operations with static params.""" + + @qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=4)) + def circuit(): + x = 3.14 + y = 0.6 + qml.Rot(x, y, x + y, wires=0) + + qml.RX(x, wires=0) + qml.RY(y, wires=1) + qml.RZ(x, wires=2) + + qml.IsingXX(x, wires=[0, 1]) + qml.IsingXX(y, wires=[1, 2]) + qml.IsingZZ(x, wires=[0, 1]) + + qml.CRX(x, wires=[0, 1]) + qml.CRY(x, wires=[0, 1]) + qml.CRZ(x, wires=[0, 1]) + + return qml.state() + + print(circuit.mlir) + + +# CHECK-LABEL: public @jit_circuit +# CHECK: %[[REG:.*]] = quantum.alloc( 4) : !quantum.reg +# CHECK: %[[BIT1:.*]] = quantum.extract %[[REG]][ 0] : !quantum.reg -> !quantum.bit +# CHECK: %[[ROT:.*]] = quantum.static_custom "Rot" +# CHECK: %[[RX:.*]] = quantum.static_custom "RX" +# CHECK: %[[BIT1:.*]] = quantum.extract %[[REG]][ 1] +# CHECK: %[[RY1:.*]] = quantum.static_custom "RY" +# CHECK: %[[XX1:.*]] = quantum.static_custom "IsingXX" +# CHECK: %[[BIT2:.*]] = quantum.extract %[[REG]][ 2] +# CHECK: %[[RZ:.*]] = quantum.static_custom "RZ" +# CHECK: %[[XX2:.*]] = quantum.static_custom "IsingXX" +# CHECK: %[[ZZ:.*]] = quantum.static_custom "IsingZZ" +# CHECK: %[[CRX:.*]] = quantum.static_custom "CRX" +# CHECK: %[[CRY:.*]] = quantum.static_custom "CRY" +# CHECK: %[[CRZ:.*]] = quantum.static_custom "CRZ" +test_static_params() diff --git a/mlir/include/Quantum/IR/QuantumInterfaces.td b/mlir/include/Quantum/IR/QuantumInterfaces.td index 145e47ae46..dc1390c636 100644 --- a/mlir/include/Quantum/IR/QuantumInterfaces.td +++ b/mlir/include/Quantum/IR/QuantumInterfaces.td @@ -164,6 +164,23 @@ def QuantumGate : OpInterface<"QuantumGate", [QuantumOperation]> { }]; } +def StaticGate : OpInterface<"StaticGate", [QuantumGate]> { + let description = [{ + This interface provides a generic way to interact with quantum + instructions with static parameters (known at compile time). These parameters + are specified by a set of constant literals in the form of an array attribute. + }]; + + let cppNamespace = "::catalyst::quantum"; + + let methods = [ + InterfaceMethod< + "Return all operands which are considered gate parameters.", + "mlir::DenseF64ArrayAttr", "getAllParams" + >, + ]; +} + def ParametrizedGate : OpInterface<"ParametrizedGate", [QuantumGate]> { let description = [{ This interface provides a generic way to interact with parametrized diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 246b23b92f..7c8fa3d4f2 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -457,6 +457,45 @@ def CustomOp : UnitaryGate_Op<"custom", [DifferentiableGate, NoMemoryEffect, let hasCanonicalizeMethod = 1; } +def StaticCustomOp : UnitaryGate_Op<"static_custom", [NoMemoryEffect, + AttrSizedOperandSegments, + AttrSizedResultSegments]> { + let summary = "A generic quantum gate with static parameters in form of a DenseF64ArrayAttr."; + let description = [{ + This operation represents a quantum gate with parameters defined statically as a + DenseF64ArrayAttr, rather than passed dynamically as operands. This is useful for gates + with parameters known at compile-time. + }]; + + let arguments = (ins + DenseF64ArrayAttr:$static_params, + Variadic:$in_qubits, + StrAttr:$gate_name, + OptionalAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $gate_name $static_params $in_qubits attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) + (`ctrls` type($out_ctrl_qubits)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + llvm::ArrayRef getAllParams() { + return getStaticParams(); + } + }]; + let hasCanonicalizeMethod = 1; +} + + def GlobalPhaseOp : UnitaryGate_Op<"gphase", [DifferentiableGate, AttrSizedOperandSegments]> { let summary = "Global Phase."; let description = [{ diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h index 65a930bb17..d43490d7cd 100644 --- a/mlir/include/Quantum/Transforms/Passes.h +++ b/mlir/include/Quantum/Transforms/Passes.h @@ -30,5 +30,6 @@ std::unique_ptr createAnnotateFunctionPass(); std::unique_ptr createSplitMultipleTapesPass(); std::unique_ptr createMergeRotationsPass(); std::unique_ptr createIonsDecompositionPass(); +std::unique_ptr createStaticCustomLoweringPass(); } // namespace catalyst diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index 21ddcc0a82..5a5325adb3 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -88,6 +88,13 @@ def SplitMultipleTapesPass : Pass<"split-multiple-tapes"> { let constructor = "catalyst::createSplitMultipleTapesPass()"; } +def StaticCustomLoweringPass : Pass<"static-custom-lowering"> { + let summary = "Lower static custom ops to regular custom op with dynamic parameters."; + + let constructor = "catalyst::createStaticCustomLoweringPass()"; +} + + // ----- Quantum circuit transformation passes begin ----- // // For example, automatic compiler peephole opts, etc. diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h index 1daf201018..cdf638624e 100644 --- a/mlir/include/Quantum/Transforms/Patterns.h +++ b/mlir/include/Quantum/Transforms/Patterns.h @@ -28,6 +28,7 @@ void populateAdjointPatterns(mlir::RewritePatternSet &); void populateSelfInversePatterns(mlir::RewritePatternSet &); void populateMergeRotationsPatterns(mlir::RewritePatternSet &); void populateIonsDecompositionPatterns(mlir::RewritePatternSet &); +void populateStaticCustomPatterns(mlir::RewritePatternSet &); } // namespace quantum } // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 20624898aa..eb5a79feae 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -51,6 +51,7 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createMergeRotationsPass); mlir::registerPass(catalyst::createScatterLoweringPass); mlir::registerPass(catalyst::createSplitMultipleTapesPass); + mlir::registerPass(catalyst::createStaticCustomLoweringPass); mlir::registerPass(catalyst::createTestPass); mlir::registerPass(catalyst::createIonsDecompositionPass); mlir::registerPass(catalyst::createQuantumToIonPass); diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index c9702ae555..ff2c571fb7 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -31,6 +31,7 @@ void createEnforceRuntimeInvariantsPipeline(OpPassManager &pm) { pm.addPass(catalyst::createSplitMultipleTapesPass()); pm.addPass(catalyst::createApplyTransformSequencePass()); + pm.addPass(catalyst::createStaticCustomLoweringPass()); pm.addPass(catalyst::createInlineNestedModulePass()); } void createHloLoweringPipeline(OpPassManager &pm) diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 695ec45efd..3c38a384d5 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -40,30 +40,64 @@ static const mlir::StringSet<> hermitianOps = {"Hadamard", "PauliX", "PauliY", " "CY", "CZ", "SWAP", "Toffoli"}; static const mlir::StringSet<> rotationsOps = {"RX", "RY", "RZ", "PhaseShift", "CRX", "CRY", "CRZ", "ControlledPhaseShift"}; -LogicalResult CustomOp::canonicalize(CustomOp op, mlir::PatternRewriter &rewriter) + +LogicalResult StaticCustomOp::canonicalize(StaticCustomOp op, mlir::PatternRewriter &rewriter) { - if (op.getAdjoint()) { - auto name = op.getGateName(); - if (hermitianOps.contains(name)) { - op.setAdjoint(false); - return success(); - } - else if (rotationsOps.contains(name)) { - auto params = op.getParams(); - SmallVector paramsNeg; - for (auto param : params) { - auto paramNeg = rewriter.create(op.getLoc(), param); - paramsNeg.push_back(paramNeg); - } - - rewriter.replaceOpWithNewOp( - op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramsNeg, - op.getInQubits(), name, nullptr, op.getInCtrlQubits(), op.getInCtrlValues()); + if (!op.getAdjoint()) { + return failure(); + } + auto name = op.getGateName(); - return success(); + if (hermitianOps.contains(name)) { + rewriter.modifyOpInPlace(op, [&op]() { op.setAdjoint(false); }); + return success(); + } + + if (rotationsOps.contains(name)) { + auto params = op.getStaticParams(); + SmallVector paramsNeg; + for (auto param : params) { + auto paramNeg = -1 * param; + paramsNeg.push_back(paramNeg); } + + rewriter.replaceOpWithNewOp( + op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), + rewriter.getDenseF64ArrayAttr(paramsNeg), op.getInQubits(), name, nullptr, + op.getInCtrlQubits(), op.getInCtrlValues()); + + return success(); + } + + return failure(); +} + +LogicalResult CustomOp::canonicalize(CustomOp op, mlir::PatternRewriter &rewriter) +{ + if (!op.getAdjoint()) { return failure(); } + auto name = op.getGateName(); + + if (hermitianOps.contains(name)) { + rewriter.modifyOpInPlace(op, [&op]() { op.setAdjoint(false); }); + return success(); + } + + if (rotationsOps.contains(name)) { + auto params = op.getParams(); + SmallVector paramsNeg; + for (auto param : params) { + auto paramNeg = rewriter.create(op.getLoc(), param); + paramsNeg.push_back(paramNeg); + } + + rewriter.replaceOpWithNewOp( + op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramsNeg, + op.getInQubits(), name, nullptr, op.getInCtrlQubits(), op.getInCtrlValues()); + + return success(); + } return failure(); } diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index fcbff39b76..2483d4618d 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -17,6 +17,8 @@ file(GLOB SRC MergeRotationsPatterns.cpp ions_decompositions.cpp IonsDecompositionPatterns.cpp + static_custom_lowering.cpp + StaticCustomPatterns.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 5d49eff575..eb9665d828 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -49,7 +49,7 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern vpga(op); if (!vpga.getVerifierResult()) { return failure(); } @@ -67,6 +67,41 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + /// We simplify consecutive Hermitian quantum gates by removing them. + /// Hermitian gates are self-inverse and applying the same gate twice in succession + /// cancels out the effect. This pattern rewrites such redundant operations by + /// replacing the operation with its "grandparent" operation in the quantum circuit. + mlir::LogicalResult matchAndRewrite(StaticCustomOp op, + mlir::PatternRewriter &rewriter) const override + { + LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); + + StringRef OpGateName = op.getGateName(); + if (!HermitianOps.contains(OpGateName)) { + return failure(); + } + + VerifyParentGateAndNameAnalysis vpga(op); + if (!vpga.getVerifierResult()) { + return failure(); + } + + // Replace uses + ValueRange InQubits = op.getInQubits(); + auto parentOp = cast(InQubits[0].getDefiningOp()); + + // TODO: it would make more sense for getQubitOperands() + // to return ValueRange, like the other getters + std::vector originalQubits = parentOp.getQubitOperands(); + + rewriter.replaceOp(op, originalQubits); + return success(); + } +}; + template struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -74,8 +109,8 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { bool verifyParentGateParams(OpType op, OpType parentOp) const { // Verify that the parent gate has the same parameters - ValueRange opParams = op.getAllParams(); - ValueRange parentOpParams = parentOp.getAllParams(); + auto opParams = op.getAllParams(); + auto parentOpParams = parentOp.getAllParams(); if (opParams.size() != parentOpParams.size()) { return false; @@ -109,7 +144,13 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); if (isa(op)) { - VerifyParentGateAndNameAnalysis vpga(cast(op)); + VerifyParentGateAndNameAnalysis vpga(cast(op)); + if (!vpga.getVerifierResult()) { + return failure(); + } + } + else if (isa(op)) { + VerifyParentGateAndNameAnalysis vpga(cast(op)); if (!vpga.getVerifierResult()) { return failure(); } @@ -151,12 +192,14 @@ namespace quantum { void populateSelfInversePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext(), 1); patterns.add(patterns.getContext(), 1); // TODO: better organize the quantum dialect // There is an interface `QuantumGate` for all the unitary gate operations, // but interfaces cannot be accepted by pattern matchers, since pattern // matchers require the target operations to have concrete names in the IR. + patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index d449773e1b..c63e4461cf 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -44,7 +44,7 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { ValueRange inQubits = op.getInQubits(); auto parentOp = dyn_cast_or_null(inQubits[0].getDefiningOp()); - VerifyParentGateAndNameAnalysis vpga(op); + VerifyParentGateAndNameAnalysis vpga(op); if (!vpga.getVerifierResult()) { return failure(); } @@ -73,6 +73,47 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { } }; +struct MergeRotationsStaticRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(StaticCustomOp op, + mlir::PatternRewriter &rewriter) const override + { + LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); + auto loc = op.getLoc(); + StringRef opGateName = op.getGateName(); + if (!rotationsSet.contains(opGateName)) + return failure(); + ValueRange inQubits = op.getInQubits(); + auto parentOp = dyn_cast_or_null(inQubits[0].getDefiningOp()); + + VerifyParentGateAndNameAnalysis vpga(op); + if (!vpga.getVerifierResult()) { + return failure(); + } + + TypeRange outQubitsTypes = op.getOutQubits().getTypes(); + TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); + ValueRange parentInQubits = parentOp.getInQubits(); + ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); + ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); + + auto parentParams = parentOp.getStaticParams(); + auto params = op.getStaticParams(); + SmallVector sumParams; + for (auto [param, parentParam] : llvm::zip(params, parentParams)) { + sumParams.push_back(parentParam + param); + }; + auto mergeOp = rewriter.create( + loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, parentInQubits, opGateName, nullptr, + parentInCtrlQubits, parentInCtrlValues); + + op.replaceAllUsesWith(mergeOp); + + return success(); + } +}; + struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -117,6 +158,7 @@ namespace quantum { void populateMergeRotationsPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext(), 1); patterns.add(patterns.getContext(), 1); patterns.add(patterns.getContext(), 1); } diff --git a/mlir/lib/Quantum/Transforms/StaticCustomPatterns.cpp b/mlir/lib/Quantum/Transforms/StaticCustomPatterns.cpp new file mode 100644 index 0000000000..64d8e49529 --- /dev/null +++ b/mlir/lib/Quantum/Transforms/StaticCustomPatterns.cpp @@ -0,0 +1,62 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "static-custom" + +#include "Quantum/IR/QuantumOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" + +using llvm::dbgs; +using namespace mlir; +using namespace catalyst::quantum; + +namespace { + +struct LowerStaticCustomOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(StaticCustomOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + + { + LLVM_DEBUG(dbgs() << "Lowering the following static custom operation:\n" << op << "\n"); + SmallVector paramValues; + auto staticParams = op.getStaticParams(); + for (auto param : staticParams) { + auto constant = rewriter.create(op.getLoc(), rewriter.getF64Type(), + rewriter.getF64FloatAttr(param)); + paramValues.push_back(constant); + } + + rewriter.replaceOpWithNewOp(op, op.getGateName(), op.getInQubits(), + op.getInCtrlQubits(), op.getInCtrlValues(), + paramValues, op.getAdjointFlag()); + return success(); + } +}; + +} // namespace + +namespace catalyst { +namespace quantum { + +void populateStaticCustomPatterns(RewritePatternSet &patterns) +{ + patterns.add(patterns.getContext(), 1); +} + +} // namespace quantum +} // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp index 6630fee909..3de93d87b0 100644 --- a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp +++ b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp @@ -134,15 +134,15 @@ template class VerifyParentGateAnalysis { } }; -class VerifyParentGateAndNameAnalysis : public VerifyParentGateAnalysis { +template +class VerifyParentGateAndNameAnalysis : public VerifyParentGateAnalysis { // If OpType is quantum.custom, also verify that parent gate has the // same gate name. public: - VerifyParentGateAndNameAnalysis(quantum::CustomOp gate) - : VerifyParentGateAnalysis(gate) + VerifyParentGateAndNameAnalysis(OpType gate) : VerifyParentGateAnalysis(gate) { ValueRange inQubits = gate.getInQubits(); - auto parentGate = dyn_cast_or_null(inQubits[0].getDefiningOp()); + auto parentGate = dyn_cast_or_null(inQubits[0].getDefiningOp()); if (!parentGate) { this->setVerifierResult(false); @@ -156,7 +156,7 @@ class VerifyParentGateAndNameAnalysis : public VerifyParentGateAnalysis { Operation *module = getOperation(); RewritePatternSet patternsCanonicalization(&getContext()); + catalyst::quantum::StaticCustomOp::getCanonicalizationPatterns(patternsCanonicalization, + &getContext()); catalyst::quantum::CustomOp::getCanonicalizationPatterns(patternsCanonicalization, &getContext()); catalyst::quantum::MultiRZOp::getCanonicalizationPatterns(patternsCanonicalization, diff --git a/mlir/lib/Quantum/Transforms/static_custom_lowering.cpp b/mlir/lib/Quantum/Transforms/static_custom_lowering.cpp new file mode 100644 index 0000000000..2e55614acb --- /dev/null +++ b/mlir/lib/Quantum/Transforms/static_custom_lowering.cpp @@ -0,0 +1,68 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "static-costum" + +#include "Catalyst/IR/CatalystDialect.h" +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; +using namespace mlir; +using namespace catalyst::quantum; + +namespace catalyst { + +namespace quantum { + +#define GEN_PASS_DEF_STATICCUSTOMLOWERINGPASS +#define GEN_PASS_DECL_STATICCUSTOMLOWERINGPASS +#include "Quantum/Transforms/Passes.h.inc" + +struct StaticCustomLoweringPass : impl::StaticCustomLoweringPassBase { + using StaticCustomLoweringPassBase::StaticCustomLoweringPassBase; + void runOnOperation() final + { + LLVM_DEBUG(dbgs() << "static custom op lowering pass" + << "\n"); + auto module = getOperation(); + auto &context = getContext(); + RewritePatternSet patterns(&context); + ConversionTarget target(context); + + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalOp(); + + populateStaticCustomPatterns(patterns); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace quantum + +std::unique_ptr createStaticCustomLoweringPass() +{ + return std::make_unique(); +} + +} // namespace catalyst diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index d7362bb0b7..d3f70c4250 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -323,6 +323,31 @@ func.func @test_chained_self_inverse(%arg0: tensor) -> !quantum.bit { // ----- +// test quantum.static_custom with static parameters + +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse() -> !quantum.bit { + // CHECK: quantum.alloc + // CHECK: [[IN:%.+]] = quantum.extract + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + %out_qubits = quantum.static_custom "RX" [2.000000e-01] %1 : !quantum.bit + %out_qubits_1 = quantum.static_custom "RX" [2.000000e-01] %out_qubits {adjoint} : !quantum.bit + + + %out_qubits_2 = quantum.static_custom "RX" [2.000000e-01] %out_qubits_1 {adjoint} : !quantum.bit + %out_qubits_3 = quantum.static_custom "RX" [2.000000e-01] %out_qubits_2 : !quantum.bit + + // CHECK-NOT: quantum.static_custom + // CHECK: return [[IN]] + return %out_qubits_3 : !quantum.bit +} + + +// ----- + + // test quantum.custom labeled both with adjoints // CHECK-LABEL: test_chained_self_inverse diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index d6cd91ba0e..2d15ac8511 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -296,6 +296,22 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit { // ----- +func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit { + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[ret:%.+]] = quantum.static_custom "RX" [-5.000000e-01] [[qubit]] : !quantum.bit + %2 = quantum.static_custom "RX" [2.000000e-01] %1 {adjoint}: !quantum.bit + %3 = quantum.static_custom "RX" [3.000000e-01] %2 {adjoint}: !quantum.bit + + // CHECK: return [[ret]] + return %3 : !quantum.bit +} + +// ----- + + func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit diff --git a/mlir/test/Quantum/StaticCustomTest.mlir b/mlir/test/Quantum/StaticCustomTest.mlir new file mode 100644 index 0000000000..6b84b664a3 --- /dev/null +++ b/mlir/test/Quantum/StaticCustomTest.mlir @@ -0,0 +1,29 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt --static-custom-lowering --split-input-file -verify-diagnostics %s | FileCheck %s + +func.func public @circuit() -> !quantum.bit { + // CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[sum1:%.+]] = arith.constant 2.000000e-01 : f64 + // CHECK: [[ret1:%.+]] = quantum.custom "RX"([[sum1]]) [[qubit]] : !quantum.bit + // CHECK: [[sum2:%.+]] = arith.constant 1.000000e-01 : f64 + // CHECK: [[ret2:%.+]] = quantum.custom "RY"([[sum2]]) [[ret1]] : !quantum.bit + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits1 = quantum.static_custom "RX" [2.000000e-01] %1 : !quantum.bit + %out_qubits2 = quantum.static_custom "RY" [1.000000e-01] %out_qubits1 : !quantum.bit + return %out_qubits2 : !quantum.bit +}