From 48edd026813d60b2cb9fa082fc82b2867871c4ec Mon Sep 17 00:00:00 2001 From: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> Date: Thu, 19 Dec 2024 14:17:33 -0500 Subject: [PATCH] Added Static CustomOp with lowering to regular custom Op (#1387) **Context:** Currently, the classical parameters of quantum gates are compiled dynamically even if they are known at compile time. We would like the IR to be extended to support literal values as opposed to SSA Values for such static parameters. **Description of the Change:** This is achieved by adding a new type of gate called StaticCustomOp which is similar to CustomOp except that it uses a DenseF64ArrayAttr which lists all the literal values for parameters in square bracket. For example the following IR: ``` %c = llvm.mlir.constant(2.000000e-01 : f64) %result = quantum.custom "RX"(%c) %qubit ``` would change to: ``` %result = quantum.static_custom "RX" [2.000000e-01] %qubit ``` The static custom ops are then lowered to the regular custom ops after apply-transform-sequence pass for the rest of the compilation process. Currently we have **not** supported Multirz, and GlobalShift. **Benefits:** More flexibility in case there is a lack of support for dynamic quantum circuits. **Possible Drawbacks:** **Related GitHub Issues:** Static parameters [sc-73581] --------- Co-authored-by: Erick Ochoa Lopez --- frontend/catalyst/from_plxpr.py | 2 +- frontend/catalyst/jax_extras/tracing.py | 13 ++-- frontend/catalyst/jax_primitives.py | 59 +++++++++++---- frontend/catalyst/jax_tracer.py | 8 ++- frontend/catalyst/pipelines.py | 2 + .../cuda/catalyst_to_cuda_interpreter.py | 9 ++- .../third_party/cuda/primitives/__init__.py | 14 ++-- frontend/test/lit/test_decomposition.py | 4 +- frontend/test/lit/test_if_else.py | 2 +- frontend/test/lit/test_measurements.py | 30 ++++---- frontend/test/lit/test_quantum_control.py | 2 +- frontend/test/lit/test_static_circuit.py | 68 ++++++++++++++++++ mlir/include/Quantum/IR/QuantumInterfaces.td | 17 +++++ mlir/include/Quantum/IR/QuantumOps.td | 39 ++++++++++ mlir/include/Quantum/Transforms/Passes.h | 1 + mlir/include/Quantum/Transforms/Passes.td | 7 ++ mlir/include/Quantum/Transforms/Patterns.h | 1 + .../Catalyst/Transforms/RegisterAllPasses.cpp | 1 + mlir/lib/Driver/Pipelines.cpp | 1 + mlir/lib/Quantum/IR/QuantumOps.cpp | 72 ++++++++++++++----- mlir/lib/Quantum/Transforms/CMakeLists.txt | 2 + .../Transforms/ChainedSelfInversePatterns.cpp | 51 +++++++++++-- .../Transforms/MergeRotationsPatterns.cpp | 44 +++++++++++- .../Transforms/StaticCustomPatterns.cpp | 62 ++++++++++++++++ .../Transforms/VerifyParentGateAnalysis.hpp | 10 +-- .../lib/Quantum/Transforms/merge_rotation.cpp | 2 + .../Transforms/static_custom_lowering.cpp | 68 ++++++++++++++++++ mlir/test/Quantum/ChainedSelfInverseTest.mlir | 25 +++++++ mlir/test/Quantum/MergeRotationsTest.mlir | 16 +++++ mlir/test/Quantum/StaticCustomTest.mlir | 29 ++++++++ 30 files changed, 587 insertions(+), 74 deletions(-) create mode 100644 frontend/test/lit/test_static_circuit.py create mode 100644 mlir/lib/Quantum/Transforms/StaticCustomPatterns.cpp create mode 100644 mlir/lib/Quantum/Transforms/static_custom_lowering.cpp create mode 100644 mlir/test/Quantum/StaticCustomTest.mlir diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index ebfbec14f9..32e0ef0e9c 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -305,7 +305,7 @@ def interpret_operator_eqn(self, eqn: jax.core.JaxprEqn) -> None: *wires, *invals, op=eqn.primitive.name, - params_len=len(eqn.invars) - eqn.params["n_wires"], + ctrl_value_len=0, **kwargs, ) diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py index a720ac97ae..1a5823f8f8 100644 --- a/frontend/catalyst/jax_extras/tracing.py +++ b/frontend/catalyst/jax_extras/tracing.py @@ -969,9 +969,9 @@ def bind_flexible_primitive(primitive, flexible_args: dict[str, Any], *dyn_args, The flexible_args is a dictionary containing the flexible arguments. These are the arguments that can either be static or dynamic. This method - will bind a flexible argument as static only if it is an integer, float, or boolean - literal. In the static case, the binded primitive's param name is the flexible arg's key, - and the jaxpr param value is the flexible arg's value. + will bind a flexible argument as static only if it is a single or a list of only integer, float, + or boolean literals. In the static case, the binded primitive's param name is the flexible arg's + key, and the jaxpr param value is the flexible arg's value. If a flexible argument is received as a tracer, it will be binded dynamically with the flexible arg's value. @@ -984,7 +984,12 @@ def bind_flexible_primitive(primitive, flexible_args: dict[str, Any], *dyn_args, for flex_arg_name, flex_arg_value in flexible_args.items(): if type(flex_arg_value) in static_literal_pool: - static_args |= {flex_arg_name: flex_arg_value} + static_args[flex_arg_name] = flex_arg_value + elif isinstance(flex_arg_value, list): + if flex_arg_value and all(type(arg) in static_literal_pool for arg in flex_arg_value): + static_args[flex_arg_name] = flex_arg_value + else: + dyn_args += (*flex_arg_value,) else: dyn_args += (flex_arg_value,) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 3032bd09d3..9f58dff29b 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -86,6 +86,7 @@ SetBasisStateOp, SetStateOp, StateOp, + StaticCustomOp, TensorOp, VarianceOp, ) @@ -972,13 +973,17 @@ def _gphase_lowering( # @qinst_p.def_abstract_eval def _qinst_abstract_eval( - *qubits_or_params, op=None, qubits_len=0, params_len=0, ctrl_len=0, adjoint=False + *qubits_or_params, + op=None, + qubits_len=0, + ctrl_len=0, + ctrl_value_len=0, + adjoint=False, + static_params=None, ): # The signature here is: (using * to denote zero or more) - # qubits*, params*, ctrl_qubits*, ctrl_values* - qubits = qubits_or_params[:qubits_len] - ctrl_qubits = qubits_or_params[-2 * ctrl_len : -ctrl_len] - all_qubits = qubits + ctrl_qubits + # qubits*, ctrl_qubits*, ctrl_values*, params* + all_qubits = qubits_or_params[: qubits_len + ctrl_len] for idx in range(qubits_len + ctrl_len): qubit = all_qubits[idx] assert isinstance(qubit, AbstractQbit) @@ -996,17 +1001,19 @@ def _qinst_lowering( *qubits_or_params, op=None, qubits_len=0, - params_len=0, ctrl_len=0, + ctrl_value_len=0, adjoint=False, + static_params=None, ): + assert ctrl_value_len == ctrl_len, "Control values must be the same length as control qubits" ctx = jax_ctx.module_context.context ctx.allow_unregistered_dialects = True qubits = qubits_or_params[:qubits_len] - params = qubits_or_params[qubits_len : qubits_len + params_len] - ctrl_qubits = qubits_or_params[qubits_len + params_len : qubits_len + params_len + ctrl_len] - ctrl_values = qubits_or_params[qubits_len + params_len + ctrl_len :] + ctrl_qubits = qubits_or_params[qubits_len : qubits_len + ctrl_len] + ctrl_values = qubits_or_params[qubits_len + ctrl_len : qubits_len + ctrl_len + ctrl_value_len] + params = qubits_or_params[qubits_len + ctrl_len + ctrl_value_len :] for qubit in qubits: assert ir.OpaqueType.isinstance(qubit.type) @@ -1029,13 +1036,31 @@ def _qinst_lowering( p = TensorExtractOp(ir.IntegerType.get_signless(1), v, []).result ctrl_values_i1.append(p) + params_attr = ( + None + if not static_params + else ir.DenseF64ArrayAttr.get([ir.FloatAttr.get_f64(val) for val in static_params]) + ) + if len(float_params) > 0: + assert ( + params_attr is None + ), "Static parameters are not allowed when having dynamic parameters" + name_attr = ir.StringAttr.get(op) name_str = str(name_attr) name_str = name_str.replace('"', "") if name_str == "MultiRZ": - assert len(float_params) == 1, "MultiRZ takes one float parameter" - float_param = float_params[0] + assert len(float_params) <= 1, "MultiRZ takes at most one dynamic float parameter" + assert ( + not static_params or len(static_params) <= 1 + ), "MultiRZ takes at most one static float parameter" + # TODO: Add support for MultiRZ with static params + float_param = ( + TensorExtractOp(ir.F64Type.get(), mlir.ir_constant(static_params[0]), []) + if len(float_params) == 0 + else float_params[0] + ) return MultiRZOp( out_qubits=[qubit.type for qubit in qubits], out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], @@ -1045,7 +1070,17 @@ def _qinst_lowering( in_ctrl_values=ctrl_values_i1, adjoint=adjoint, ).results - + if params_attr: + return StaticCustomOp( + out_qubits=[qubit.type for qubit in qubits], + out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], + static_params=params_attr, + in_qubits=qubits, + gate_name=name_attr, + in_ctrl_qubits=ctrl_qubits, + in_ctrl_values=ctrl_values_i1, + adjoint=adjoint, + ).results return CustomOp( out_qubits=[qubit.type for qubit in qubits], out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits], diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 3e7f2fea29..2eb4dc7396 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -691,12 +691,14 @@ def bind_native_operation(qrp, op, controlled_wires, controlled_values, adjoint= else: qubits = qrp.extract(op.wires) controlled_qubits = qrp.extract(controlled_wires) - qubits2 = qinst_p.bind( - *[*qubits, *op.parameters, *controlled_qubits, *controlled_values], + qubits2 = bind_flexible_primitive( + qinst_p, + {"static_params": op.parameters}, + *[*qubits, *controlled_qubits, *controlled_values], op=op.name, qubits_len=len(qubits), - params_len=len(op.parameters), ctrl_len=len(controlled_qubits), + ctrl_value_len=len(controlled_values), adjoint=adjoint, ) qrp.insert(op.wires, qubits2[: len(qubits)]) diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index fc2c5efcbc..4336010600 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -169,6 +169,8 @@ def get_enforce_runtime_invariants_stage(_options: CompileOptions) -> List[str]: "split-multiple-tapes", # Run the transform sequence defined in the MLIR module "builtin.module(apply-transform-sequence)", + # Lower the static custom ops to regular custom ops with dynamic parameters. + "static-custom-lowering", # Nested modules are something that will be used in the future # for making device specific transformations. # Since at the moment, nothing in the runtime is using them diff --git a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py index 0dcc8dd992..adc6545473 100644 --- a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py +++ b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py @@ -428,10 +428,17 @@ def change_instruction(ctx, eqn): op = params["op"] cuda_inst_name = from_catalyst_to_cuda[op] qubits_len = params["qubits_len"] + static_params = params.get("static_params") # Now, we can map to the correct op # For now just assume rx - cuda_inst(ctx.kernel, *qubits_or_params, inst=cuda_inst_name, qubits_len=qubits_len) + cuda_inst( + ctx.kernel, + *qubits_or_params, + inst=cuda_inst_name, + qubits_len=qubits_len, + static_params=static_params, + ) # Finally determine how many are qubits. qubits = qubits_or_params[:qubits_len] diff --git a/frontend/catalyst/third_party/cuda/primitives/__init__.py b/frontend/catalyst/third_party/cuda/primitives/__init__.py index e1679d3bcb..5c00100308 100644 --- a/frontend/catalyst/third_party/cuda/primitives/__init__.py +++ b/frontend/catalyst/third_party/cuda/primitives/__init__.py @@ -301,27 +301,33 @@ def make_primitive_for_gate(): kernel_gate_p = jax.core.Primitive("kernel_inst") kernel_gate_p.multiple_results = True - def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1): + def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None): """Convenience. Quantum operations in CUDA-quantum return no values. But JAXPR expects return values. We can just say that multiple_results = True and return an empty tuple. """ - kernel_gate_p.bind(kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len) + kernel_gate_p.bind( + kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len, static_params=static_params + ) return tuple() @kernel_gate_p.def_impl - def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1): + def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None): """Concrete implementation.""" assert inst and qubits_len > 0 + if static_params is None: + static_params = [] method = getattr(cudaq.Kernel, inst) targets = qubits_or_params[:qubits_len] params = qubits_or_params[qubits_len:] + if not params: + params = static_params method(kernel, *params, *targets) return tuple() @kernel_gate_p.def_abstract_eval - def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1): + def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1, static_params=None): """Abstract evaluation.""" return tuple() diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 04263d33ec..522a453013 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -133,10 +133,8 @@ def test_decompose_s(): @qml.qnode(dev) # CHECK-LABEL: public @jit_decompose_s def decompose_s(): - # CHECK-NOT: name="S" - # CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64 # CHECK-NOT: name = "S" - # CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]]) + # CHECK: {{%.+}} = quantum.static_custom "PhaseShift" [1.570796e+00] # CHECK-NOT: name = "S" qml.S(wires=0) return measure(wires=0) diff --git a/frontend/test/lit/test_if_else.py b/frontend/test/lit/test_if_else.py index 0aec735804..bd586c7c62 100644 --- a/frontend/test/lit/test_if_else.py +++ b/frontend/test/lit/test_if_else.py @@ -94,7 +94,7 @@ def circuit_single_gate(n: int): # CHECK: [[b6:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t6]] # CHECK: [[qreg_out1:%.+]] = scf.if [[b6]] # CHECK-DAG: [[q4:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out]] - # CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.custom "RX"({{%.+}}) [[q4]] + # CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.static_custom "RX" [3.140000e+00] [[q4]] # CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q5]] # CHECK: scf.yield [[qreg_3]] # CHECK: else diff --git a/frontend/test/lit/test_measurements.py b/frontend/test/lit/test_measurements.py index eed5161843..21e4e300a0 100644 --- a/frontend/test/lit/test_measurements.py +++ b/frontend/test/lit/test_measurements.py @@ -38,7 +38,7 @@ def sample1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = quantum.custom "RZ" + # COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliZ] @@ -54,7 +54,7 @@ def sample2(x: float, y: float): qml.RX(x, wires=0) # COM: CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = quantum.custom "RZ" + # COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs1:%.+]] = quantum.namedobs [[q1]][ PauliX] @@ -77,7 +77,7 @@ def sample3(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]] @@ -145,7 +145,7 @@ def test_sample_dynamic(shots: int): def counts1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = quantum.custom "RZ" + # COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliZ] @@ -160,7 +160,7 @@ def counts2(x: float, y: float): qml.RX(x, wires=0) # COM: CHECK: [[q1:%.+]] = "quantum.custom"({{%.+}}, {{%.+}}) {gate_name = "RY" qml.RY(y, wires=1) - # COM: CHECK: [[q0:%.+]] = "quantum.custom"({{%.+}}, {{%.+}}) {gate_name = "RZ" + # COM: CHECK: [[q0:%.+]] = "quantum.static_custom"({{%.+}}, {{%.+}}) {gate_name = "RZ" qml.RZ(0.1, wires=0) # COM: CHECK: [[obs1:%.+]] = "quantum.namedobs"([[q1]]) {type = #quantum} @@ -183,7 +183,7 @@ def counts3(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]] @@ -228,7 +228,7 @@ def test_counts_dynamic(shots: int): def expval1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -247,7 +247,7 @@ def expval2(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) # CHECK: [[p1:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -304,7 +304,7 @@ def expval5(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) B = np.array( @@ -334,7 +334,7 @@ def expval5(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) coeffs = np.array([0.2, -0.543]) @@ -426,7 +426,7 @@ def expval9(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) # CHECK: [[p1:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -448,7 +448,7 @@ def expval10(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q2:%.+]] = quantum.custom "RZ" + # CHECK: [[q2:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=2) B = np.array( @@ -476,7 +476,7 @@ def expval10(x: float, y: float): def var1(x: float, y: float): qml.RX(x, wires=0) qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliX] @@ -519,7 +519,7 @@ def probs1(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # qml.probs() # unsupported by PennyLane @@ -540,7 +540,7 @@ def state1(x: float, y: float): qml.RX(x, wires=0) # CHECK: [[q1:%.+]] = quantum.custom "RY" qml.RY(y, wires=1) - # CHECK: [[q0:%.+]] = quantum.custom "RZ" + # CHECK: [[q0:%.+]] = quantum.static_custom "RZ" qml.RZ(0.1, wires=0) # qml.state(wires=[0]) # unsupported by PennyLane diff --git a/frontend/test/lit/test_quantum_control.py b/frontend/test/lit/test_quantum_control.py index 3cd794dc9e..730692c855 100644 --- a/frontend/test/lit/test_quantum_control.py +++ b/frontend/test/lit/test_quantum_control.py @@ -98,7 +98,7 @@ def test_native_controlled_custom(): @qml.qnode(dev) # CHECK-LABEL: public @jit_native_controlled def native_controlled(): - # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot" + # CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.static_custom "Rot" # CHECK-SAME: ctrls # CHECK-SAME: ctrlvals(%true, %true) qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2]) diff --git a/frontend/test/lit/test_static_circuit.py b/frontend/test/lit/test_static_circuit.py new file mode 100644 index 0000000000..40de0c68d3 --- /dev/null +++ b/frontend/test/lit/test_static_circuit.py @@ -0,0 +1,68 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# RUN: %PYTHON %s | FileCheck %s + +""" +Test quantum circuits with static (knonw at compile time) specifications. +""" + +import pennylane as qml + +from catalyst import qjit + + +def test_static_params(): + """Test operations with static params.""" + + @qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=4)) + def circuit(): + x = 3.14 + y = 0.6 + qml.Rot(x, y, x + y, wires=0) + + qml.RX(x, wires=0) + qml.RY(y, wires=1) + qml.RZ(x, wires=2) + + qml.IsingXX(x, wires=[0, 1]) + qml.IsingXX(y, wires=[1, 2]) + qml.IsingZZ(x, wires=[0, 1]) + + qml.CRX(x, wires=[0, 1]) + qml.CRY(x, wires=[0, 1]) + qml.CRZ(x, wires=[0, 1]) + + return qml.state() + + print(circuit.mlir) + + +# CHECK-LABEL: public @jit_circuit +# CHECK: %[[REG:.*]] = quantum.alloc( 4) : !quantum.reg +# CHECK: %[[BIT1:.*]] = quantum.extract %[[REG]][ 0] : !quantum.reg -> !quantum.bit +# CHECK: %[[ROT:.*]] = quantum.static_custom "Rot" +# CHECK: %[[RX:.*]] = quantum.static_custom "RX" +# CHECK: %[[BIT1:.*]] = quantum.extract %[[REG]][ 1] +# CHECK: %[[RY1:.*]] = quantum.static_custom "RY" +# CHECK: %[[XX1:.*]] = quantum.static_custom "IsingXX" +# CHECK: %[[BIT2:.*]] = quantum.extract %[[REG]][ 2] +# CHECK: %[[RZ:.*]] = quantum.static_custom "RZ" +# CHECK: %[[XX2:.*]] = quantum.static_custom "IsingXX" +# CHECK: %[[ZZ:.*]] = quantum.static_custom "IsingZZ" +# CHECK: %[[CRX:.*]] = quantum.static_custom "CRX" +# CHECK: %[[CRY:.*]] = quantum.static_custom "CRY" +# CHECK: %[[CRZ:.*]] = quantum.static_custom "CRZ" +test_static_params() diff --git a/mlir/include/Quantum/IR/QuantumInterfaces.td b/mlir/include/Quantum/IR/QuantumInterfaces.td index 145e47ae46..dc1390c636 100644 --- a/mlir/include/Quantum/IR/QuantumInterfaces.td +++ b/mlir/include/Quantum/IR/QuantumInterfaces.td @@ -164,6 +164,23 @@ def QuantumGate : OpInterface<"QuantumGate", [QuantumOperation]> { }]; } +def StaticGate : OpInterface<"StaticGate", [QuantumGate]> { + let description = [{ + This interface provides a generic way to interact with quantum + instructions with static parameters (known at compile time). These parameters + are specified by a set of constant literals in the form of an array attribute. + }]; + + let cppNamespace = "::catalyst::quantum"; + + let methods = [ + InterfaceMethod< + "Return all operands which are considered gate parameters.", + "mlir::DenseF64ArrayAttr", "getAllParams" + >, + ]; +} + def ParametrizedGate : OpInterface<"ParametrizedGate", [QuantumGate]> { let description = [{ This interface provides a generic way to interact with parametrized diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 246b23b92f..7c8fa3d4f2 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -457,6 +457,45 @@ def CustomOp : UnitaryGate_Op<"custom", [DifferentiableGate, NoMemoryEffect, let hasCanonicalizeMethod = 1; } +def StaticCustomOp : UnitaryGate_Op<"static_custom", [NoMemoryEffect, + AttrSizedOperandSegments, + AttrSizedResultSegments]> { + let summary = "A generic quantum gate with static parameters in form of a DenseF64ArrayAttr."; + let description = [{ + This operation represents a quantum gate with parameters defined statically as a + DenseF64ArrayAttr, rather than passed dynamically as operands. This is useful for gates + with parameters known at compile-time. + }]; + + let arguments = (ins + DenseF64ArrayAttr:$static_params, + Variadic:$in_qubits, + StrAttr:$gate_name, + OptionalAttr:$adjoint, + Variadic:$in_ctrl_qubits, + Variadic:$in_ctrl_values + ); + + let results = (outs + Variadic:$out_qubits, + Variadic:$out_ctrl_qubits + ); + + let assemblyFormat = [{ + $gate_name $static_params $in_qubits attr-dict ( `ctrls` `(` $in_ctrl_qubits^ `)` )? + ( `ctrlvals` `(` $in_ctrl_values^ `)` )? `:` type($out_qubits) + (`ctrls` type($out_ctrl_qubits)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + llvm::ArrayRef getAllParams() { + return getStaticParams(); + } + }]; + let hasCanonicalizeMethod = 1; +} + + def GlobalPhaseOp : UnitaryGate_Op<"gphase", [DifferentiableGate, AttrSizedOperandSegments]> { let summary = "Global Phase."; let description = [{ diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h index 65a930bb17..d43490d7cd 100644 --- a/mlir/include/Quantum/Transforms/Passes.h +++ b/mlir/include/Quantum/Transforms/Passes.h @@ -30,5 +30,6 @@ std::unique_ptr createAnnotateFunctionPass(); std::unique_ptr createSplitMultipleTapesPass(); std::unique_ptr createMergeRotationsPass(); std::unique_ptr createIonsDecompositionPass(); +std::unique_ptr createStaticCustomLoweringPass(); } // namespace catalyst diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index 21ddcc0a82..5a5325adb3 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -88,6 +88,13 @@ def SplitMultipleTapesPass : Pass<"split-multiple-tapes"> { let constructor = "catalyst::createSplitMultipleTapesPass()"; } +def StaticCustomLoweringPass : Pass<"static-custom-lowering"> { + let summary = "Lower static custom ops to regular custom op with dynamic parameters."; + + let constructor = "catalyst::createStaticCustomLoweringPass()"; +} + + // ----- Quantum circuit transformation passes begin ----- // // For example, automatic compiler peephole opts, etc. diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h index 1daf201018..cdf638624e 100644 --- a/mlir/include/Quantum/Transforms/Patterns.h +++ b/mlir/include/Quantum/Transforms/Patterns.h @@ -28,6 +28,7 @@ void populateAdjointPatterns(mlir::RewritePatternSet &); void populateSelfInversePatterns(mlir::RewritePatternSet &); void populateMergeRotationsPatterns(mlir::RewritePatternSet &); void populateIonsDecompositionPatterns(mlir::RewritePatternSet &); +void populateStaticCustomPatterns(mlir::RewritePatternSet &); } // namespace quantum } // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 20624898aa..eb5a79feae 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -51,6 +51,7 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createMergeRotationsPass); mlir::registerPass(catalyst::createScatterLoweringPass); mlir::registerPass(catalyst::createSplitMultipleTapesPass); + mlir::registerPass(catalyst::createStaticCustomLoweringPass); mlir::registerPass(catalyst::createTestPass); mlir::registerPass(catalyst::createIonsDecompositionPass); mlir::registerPass(catalyst::createQuantumToIonPass); diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index c9702ae555..ff2c571fb7 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -31,6 +31,7 @@ void createEnforceRuntimeInvariantsPipeline(OpPassManager &pm) { pm.addPass(catalyst::createSplitMultipleTapesPass()); pm.addPass(catalyst::createApplyTransformSequencePass()); + pm.addPass(catalyst::createStaticCustomLoweringPass()); pm.addPass(catalyst::createInlineNestedModulePass()); } void createHloLoweringPipeline(OpPassManager &pm) diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 695ec45efd..3c38a384d5 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -40,30 +40,64 @@ static const mlir::StringSet<> hermitianOps = {"Hadamard", "PauliX", "PauliY", " "CY", "CZ", "SWAP", "Toffoli"}; static const mlir::StringSet<> rotationsOps = {"RX", "RY", "RZ", "PhaseShift", "CRX", "CRY", "CRZ", "ControlledPhaseShift"}; -LogicalResult CustomOp::canonicalize(CustomOp op, mlir::PatternRewriter &rewriter) + +LogicalResult StaticCustomOp::canonicalize(StaticCustomOp op, mlir::PatternRewriter &rewriter) { - if (op.getAdjoint()) { - auto name = op.getGateName(); - if (hermitianOps.contains(name)) { - op.setAdjoint(false); - return success(); - } - else if (rotationsOps.contains(name)) { - auto params = op.getParams(); - SmallVector paramsNeg; - for (auto param : params) { - auto paramNeg = rewriter.create(op.getLoc(), param); - paramsNeg.push_back(paramNeg); - } - - rewriter.replaceOpWithNewOp( - op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramsNeg, - op.getInQubits(), name, nullptr, op.getInCtrlQubits(), op.getInCtrlValues()); + if (!op.getAdjoint()) { + return failure(); + } + auto name = op.getGateName(); - return success(); + if (hermitianOps.contains(name)) { + rewriter.modifyOpInPlace(op, [&op]() { op.setAdjoint(false); }); + return success(); + } + + if (rotationsOps.contains(name)) { + auto params = op.getStaticParams(); + SmallVector paramsNeg; + for (auto param : params) { + auto paramNeg = -1 * param; + paramsNeg.push_back(paramNeg); } + + rewriter.replaceOpWithNewOp( + op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), + rewriter.getDenseF64ArrayAttr(paramsNeg), op.getInQubits(), name, nullptr, + op.getInCtrlQubits(), op.getInCtrlValues()); + + return success(); + } + + return failure(); +} + +LogicalResult CustomOp::canonicalize(CustomOp op, mlir::PatternRewriter &rewriter) +{ + if (!op.getAdjoint()) { return failure(); } + auto name = op.getGateName(); + + if (hermitianOps.contains(name)) { + rewriter.modifyOpInPlace(op, [&op]() { op.setAdjoint(false); }); + return success(); + } + + if (rotationsOps.contains(name)) { + auto params = op.getParams(); + SmallVector paramsNeg; + for (auto param : params) { + auto paramNeg = rewriter.create(op.getLoc(), param); + paramsNeg.push_back(paramNeg); + } + + rewriter.replaceOpWithNewOp( + op, op.getOutQubits().getTypes(), op.getOutCtrlQubits().getTypes(), paramsNeg, + op.getInQubits(), name, nullptr, op.getInCtrlQubits(), op.getInCtrlValues()); + + return success(); + } return failure(); } diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index fcbff39b76..2483d4618d 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -17,6 +17,8 @@ file(GLOB SRC MergeRotationsPatterns.cpp ions_decompositions.cpp IonsDecompositionPatterns.cpp + static_custom_lowering.cpp + StaticCustomPatterns.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) diff --git a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp index 5d49eff575..eb9665d828 100644 --- a/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ChainedSelfInversePatterns.cpp @@ -49,7 +49,7 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern vpga(op); if (!vpga.getVerifierResult()) { return failure(); } @@ -67,6 +67,41 @@ struct ChainedNamedHermitianOpRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + /// We simplify consecutive Hermitian quantum gates by removing them. + /// Hermitian gates are self-inverse and applying the same gate twice in succession + /// cancels out the effect. This pattern rewrites such redundant operations by + /// replacing the operation with its "grandparent" operation in the quantum circuit. + mlir::LogicalResult matchAndRewrite(StaticCustomOp op, + mlir::PatternRewriter &rewriter) const override + { + LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); + + StringRef OpGateName = op.getGateName(); + if (!HermitianOps.contains(OpGateName)) { + return failure(); + } + + VerifyParentGateAndNameAnalysis vpga(op); + if (!vpga.getVerifierResult()) { + return failure(); + } + + // Replace uses + ValueRange InQubits = op.getInQubits(); + auto parentOp = cast(InQubits[0].getDefiningOp()); + + // TODO: it would make more sense for getQubitOperands() + // to return ValueRange, like the other getters + std::vector originalQubits = parentOp.getQubitOperands(); + + rewriter.replaceOp(op, originalQubits); + return success(); + } +}; + template struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -74,8 +109,8 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { bool verifyParentGateParams(OpType op, OpType parentOp) const { // Verify that the parent gate has the same parameters - ValueRange opParams = op.getAllParams(); - ValueRange parentOpParams = parentOp.getAllParams(); + auto opParams = op.getAllParams(); + auto parentOpParams = parentOp.getAllParams(); if (opParams.size() != parentOpParams.size()) { return false; @@ -109,7 +144,13 @@ struct ChainedUUadjOpRewritePattern : public mlir::OpRewritePattern { LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); if (isa(op)) { - VerifyParentGateAndNameAnalysis vpga(cast(op)); + VerifyParentGateAndNameAnalysis vpga(cast(op)); + if (!vpga.getVerifierResult()) { + return failure(); + } + } + else if (isa(op)) { + VerifyParentGateAndNameAnalysis vpga(cast(op)); if (!vpga.getVerifierResult()) { return failure(); } @@ -151,12 +192,14 @@ namespace quantum { void populateSelfInversePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext(), 1); patterns.add(patterns.getContext(), 1); // TODO: better organize the quantum dialect // There is an interface `QuantumGate` for all the unitary gate operations, // but interfaces cannot be accepted by pattern matchers, since pattern // matchers require the target operations to have concrete names in the IR. + patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); patterns.add>(patterns.getContext(), 1); diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index d449773e1b..c63e4461cf 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -44,7 +44,7 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { ValueRange inQubits = op.getInQubits(); auto parentOp = dyn_cast_or_null(inQubits[0].getDefiningOp()); - VerifyParentGateAndNameAnalysis vpga(op); + VerifyParentGateAndNameAnalysis vpga(op); if (!vpga.getVerifierResult()) { return failure(); } @@ -73,6 +73,47 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { } }; +struct MergeRotationsStaticRewritePattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(StaticCustomOp op, + mlir::PatternRewriter &rewriter) const override + { + LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n"); + auto loc = op.getLoc(); + StringRef opGateName = op.getGateName(); + if (!rotationsSet.contains(opGateName)) + return failure(); + ValueRange inQubits = op.getInQubits(); + auto parentOp = dyn_cast_or_null(inQubits[0].getDefiningOp()); + + VerifyParentGateAndNameAnalysis vpga(op); + if (!vpga.getVerifierResult()) { + return failure(); + } + + TypeRange outQubitsTypes = op.getOutQubits().getTypes(); + TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); + ValueRange parentInQubits = parentOp.getInQubits(); + ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); + ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); + + auto parentParams = parentOp.getStaticParams(); + auto params = op.getStaticParams(); + SmallVector sumParams; + for (auto [param, parentParam] : llvm::zip(params, parentParams)) { + sumParams.push_back(parentParam + param); + }; + auto mergeOp = rewriter.create( + loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, parentInQubits, opGateName, nullptr, + parentInCtrlQubits, parentInCtrlValues); + + op.replaceAllUsesWith(mergeOp); + + return success(); + } +}; + struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -117,6 +158,7 @@ namespace quantum { void populateMergeRotationsPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext(), 1); patterns.add(patterns.getContext(), 1); patterns.add(patterns.getContext(), 1); } diff --git a/mlir/lib/Quantum/Transforms/StaticCustomPatterns.cpp b/mlir/lib/Quantum/Transforms/StaticCustomPatterns.cpp new file mode 100644 index 0000000000..64d8e49529 --- /dev/null +++ b/mlir/lib/Quantum/Transforms/StaticCustomPatterns.cpp @@ -0,0 +1,62 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "static-custom" + +#include "Quantum/IR/QuantumOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" + +using llvm::dbgs; +using namespace mlir; +using namespace catalyst::quantum; + +namespace { + +struct LowerStaticCustomOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(StaticCustomOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + + { + LLVM_DEBUG(dbgs() << "Lowering the following static custom operation:\n" << op << "\n"); + SmallVector paramValues; + auto staticParams = op.getStaticParams(); + for (auto param : staticParams) { + auto constant = rewriter.create(op.getLoc(), rewriter.getF64Type(), + rewriter.getF64FloatAttr(param)); + paramValues.push_back(constant); + } + + rewriter.replaceOpWithNewOp(op, op.getGateName(), op.getInQubits(), + op.getInCtrlQubits(), op.getInCtrlValues(), + paramValues, op.getAdjointFlag()); + return success(); + } +}; + +} // namespace + +namespace catalyst { +namespace quantum { + +void populateStaticCustomPatterns(RewritePatternSet &patterns) +{ + patterns.add(patterns.getContext(), 1); +} + +} // namespace quantum +} // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp index 6630fee909..3de93d87b0 100644 --- a/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp +++ b/mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp @@ -134,15 +134,15 @@ template class VerifyParentGateAnalysis { } }; -class VerifyParentGateAndNameAnalysis : public VerifyParentGateAnalysis { +template +class VerifyParentGateAndNameAnalysis : public VerifyParentGateAnalysis { // If OpType is quantum.custom, also verify that parent gate has the // same gate name. public: - VerifyParentGateAndNameAnalysis(quantum::CustomOp gate) - : VerifyParentGateAnalysis(gate) + VerifyParentGateAndNameAnalysis(OpType gate) : VerifyParentGateAnalysis(gate) { ValueRange inQubits = gate.getInQubits(); - auto parentGate = dyn_cast_or_null(inQubits[0].getDefiningOp()); + auto parentGate = dyn_cast_or_null(inQubits[0].getDefiningOp()); if (!parentGate) { this->setVerifierResult(false); @@ -156,7 +156,7 @@ class VerifyParentGateAndNameAnalysis : public VerifyParentGateAnalysis { Operation *module = getOperation(); RewritePatternSet patternsCanonicalization(&getContext()); + catalyst::quantum::StaticCustomOp::getCanonicalizationPatterns(patternsCanonicalization, + &getContext()); catalyst::quantum::CustomOp::getCanonicalizationPatterns(patternsCanonicalization, &getContext()); catalyst::quantum::MultiRZOp::getCanonicalizationPatterns(patternsCanonicalization, diff --git a/mlir/lib/Quantum/Transforms/static_custom_lowering.cpp b/mlir/lib/Quantum/Transforms/static_custom_lowering.cpp new file mode 100644 index 0000000000..2e55614acb --- /dev/null +++ b/mlir/lib/Quantum/Transforms/static_custom_lowering.cpp @@ -0,0 +1,68 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "static-costum" + +#include "Catalyst/IR/CatalystDialect.h" +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; +using namespace mlir; +using namespace catalyst::quantum; + +namespace catalyst { + +namespace quantum { + +#define GEN_PASS_DEF_STATICCUSTOMLOWERINGPASS +#define GEN_PASS_DECL_STATICCUSTOMLOWERINGPASS +#include "Quantum/Transforms/Passes.h.inc" + +struct StaticCustomLoweringPass : impl::StaticCustomLoweringPassBase { + using StaticCustomLoweringPassBase::StaticCustomLoweringPassBase; + void runOnOperation() final + { + LLVM_DEBUG(dbgs() << "static custom op lowering pass" + << "\n"); + auto module = getOperation(); + auto &context = getContext(); + RewritePatternSet patterns(&context); + ConversionTarget target(context); + + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalOp(); + + populateStaticCustomPatterns(patterns); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace quantum + +std::unique_ptr createStaticCustomLoweringPass() +{ + return std::make_unique(); +} + +} // namespace catalyst diff --git a/mlir/test/Quantum/ChainedSelfInverseTest.mlir b/mlir/test/Quantum/ChainedSelfInverseTest.mlir index d7362bb0b7..d3f70c4250 100644 --- a/mlir/test/Quantum/ChainedSelfInverseTest.mlir +++ b/mlir/test/Quantum/ChainedSelfInverseTest.mlir @@ -323,6 +323,31 @@ func.func @test_chained_self_inverse(%arg0: tensor) -> !quantum.bit { // ----- +// test quantum.static_custom with static parameters + +// CHECK-LABEL: test_chained_self_inverse +func.func @test_chained_self_inverse() -> !quantum.bit { + // CHECK: quantum.alloc + // CHECK: [[IN:%.+]] = quantum.extract + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + %out_qubits = quantum.static_custom "RX" [2.000000e-01] %1 : !quantum.bit + %out_qubits_1 = quantum.static_custom "RX" [2.000000e-01] %out_qubits {adjoint} : !quantum.bit + + + %out_qubits_2 = quantum.static_custom "RX" [2.000000e-01] %out_qubits_1 {adjoint} : !quantum.bit + %out_qubits_3 = quantum.static_custom "RX" [2.000000e-01] %out_qubits_2 : !quantum.bit + + // CHECK-NOT: quantum.static_custom + // CHECK: return [[IN]] + return %out_qubits_3 : !quantum.bit +} + + +// ----- + + // test quantum.custom labeled both with adjoints // CHECK-LABEL: test_chained_self_inverse diff --git a/mlir/test/Quantum/MergeRotationsTest.mlir b/mlir/test/Quantum/MergeRotationsTest.mlir index d6cd91ba0e..2d15ac8511 100644 --- a/mlir/test/Quantum/MergeRotationsTest.mlir +++ b/mlir/test/Quantum/MergeRotationsTest.mlir @@ -296,6 +296,22 @@ func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit { // ----- +func.func @test_merge_rotations(%arg0: f64, %arg1: f64) -> !quantum.bit { + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[ret:%.+]] = quantum.static_custom "RX" [-5.000000e-01] [[qubit]] : !quantum.bit + %2 = quantum.static_custom "RX" [2.000000e-01] %1 {adjoint}: !quantum.bit + %3 = quantum.static_custom "RX" [3.000000e-01] %2 {adjoint}: !quantum.bit + + // CHECK: return [[ret]] + return %3 : !quantum.bit +} + +// ----- + + func.func @test_merge_rotations(%arg0: f64, %arg1: f64, %arg2: f64) -> (!quantum.bit, !quantum.bit) { // CHECK: [[reg:%.+]] = quantum.alloc( 2) : !quantum.reg // CHECK: [[qubit1:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit diff --git a/mlir/test/Quantum/StaticCustomTest.mlir b/mlir/test/Quantum/StaticCustomTest.mlir new file mode 100644 index 0000000000..6b84b664a3 --- /dev/null +++ b/mlir/test/Quantum/StaticCustomTest.mlir @@ -0,0 +1,29 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt --static-custom-lowering --split-input-file -verify-diagnostics %s | FileCheck %s + +func.func public @circuit() -> !quantum.bit { + // CHECK: [[reg:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[qubit:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[sum1:%.+]] = arith.constant 2.000000e-01 : f64 + // CHECK: [[ret1:%.+]] = quantum.custom "RX"([[sum1]]) [[qubit]] : !quantum.bit + // CHECK: [[sum2:%.+]] = arith.constant 1.000000e-01 : f64 + // CHECK: [[ret2:%.+]] = quantum.custom "RY"([[sum2]]) [[ret1]] : !quantum.bit + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits1 = quantum.static_custom "RX" [2.000000e-01] %1 : !quantum.bit + %out_qubits2 = quantum.static_custom "RY" [1.000000e-01] %out_qubits1 : !quantum.bit + return %out_qubits2 : !quantum.bit +}