Skip to content

Commit

Permalink
Added Static CustomOp with lowering to regular custom Op (#1387)
Browse files Browse the repository at this point in the history
**Context:**
Currently, the classical parameters of quantum gates are compiled
dynamically even if they are known at compile time. We would like the IR
to be extended to support literal values as opposed to SSA Values for
such static parameters.

**Description of the Change:**
This is achieved by adding a new type of gate called StaticCustomOp
which is similar to CustomOp except that it uses a DenseF64ArrayAttr
which lists all the literal values for parameters in square bracket. For
example the following IR:
```
%c = llvm.mlir.constant(2.000000e-01 : f64)
%result = quantum.custom "RX"(%c) %qubit
```
would change to:
```
%result = quantum.static_custom "RX" [2.000000e-01] %qubit
```
The static custom ops are then lowered to the regular custom ops after
apply-transform-sequence pass for the rest of the compilation process.
Currently we have **not** supported Multirz, and GlobalShift.

**Benefits:**
More flexibility in case there is a lack of support for dynamic quantum
circuits.

**Possible Drawbacks:**

**Related GitHub Issues:**
Static parameters [sc-73581]

---------

Co-authored-by: Erick Ochoa Lopez <erick.ochoalopez@xanadu.ai>
  • Loading branch information
mehrdad2m and erick-xanadu authored Dec 19, 2024
1 parent 11e4e4f commit 48edd02
Show file tree
Hide file tree
Showing 30 changed files with 587 additions and 74 deletions.
2 changes: 1 addition & 1 deletion frontend/catalyst/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def interpret_operator_eqn(self, eqn: jax.core.JaxprEqn) -> None:
*wires,
*invals,
op=eqn.primitive.name,
params_len=len(eqn.invars) - eqn.params["n_wires"],
ctrl_value_len=0,
**kwargs,
)

Expand Down
13 changes: 9 additions & 4 deletions frontend/catalyst/jax_extras/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,9 +969,9 @@ def bind_flexible_primitive(primitive, flexible_args: dict[str, Any], *dyn_args,
The flexible_args is a dictionary containing the flexible arguments.
These are the arguments that can either be static or dynamic. This method
will bind a flexible argument as static only if it is an integer, float, or boolean
literal. In the static case, the binded primitive's param name is the flexible arg's key,
and the jaxpr param value is the flexible arg's value.
will bind a flexible argument as static only if it is a single or a list of only integer, float,
or boolean literals. In the static case, the binded primitive's param name is the flexible arg's
key, and the jaxpr param value is the flexible arg's value.
If a flexible argument is received as a tracer, it will be binded dynamically with
the flexible arg's value.
Expand All @@ -984,7 +984,12 @@ def bind_flexible_primitive(primitive, flexible_args: dict[str, Any], *dyn_args,

for flex_arg_name, flex_arg_value in flexible_args.items():
if type(flex_arg_value) in static_literal_pool:
static_args |= {flex_arg_name: flex_arg_value}
static_args[flex_arg_name] = flex_arg_value
elif isinstance(flex_arg_value, list):
if flex_arg_value and all(type(arg) in static_literal_pool for arg in flex_arg_value):
static_args[flex_arg_name] = flex_arg_value
else:
dyn_args += (*flex_arg_value,)
else:
dyn_args += (flex_arg_value,)

Expand Down
59 changes: 47 additions & 12 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
SetBasisStateOp,
SetStateOp,
StateOp,
StaticCustomOp,
TensorOp,
VarianceOp,
)
Expand Down Expand Up @@ -972,13 +973,17 @@ def _gphase_lowering(
#
@qinst_p.def_abstract_eval
def _qinst_abstract_eval(
*qubits_or_params, op=None, qubits_len=0, params_len=0, ctrl_len=0, adjoint=False
*qubits_or_params,
op=None,
qubits_len=0,
ctrl_len=0,
ctrl_value_len=0,
adjoint=False,
static_params=None,
):
# The signature here is: (using * to denote zero or more)
# qubits*, params*, ctrl_qubits*, ctrl_values*
qubits = qubits_or_params[:qubits_len]
ctrl_qubits = qubits_or_params[-2 * ctrl_len : -ctrl_len]
all_qubits = qubits + ctrl_qubits
# qubits*, ctrl_qubits*, ctrl_values*, params*
all_qubits = qubits_or_params[: qubits_len + ctrl_len]
for idx in range(qubits_len + ctrl_len):
qubit = all_qubits[idx]
assert isinstance(qubit, AbstractQbit)
Expand All @@ -996,17 +1001,19 @@ def _qinst_lowering(
*qubits_or_params,
op=None,
qubits_len=0,
params_len=0,
ctrl_len=0,
ctrl_value_len=0,
adjoint=False,
static_params=None,
):
assert ctrl_value_len == ctrl_len, "Control values must be the same length as control qubits"
ctx = jax_ctx.module_context.context
ctx.allow_unregistered_dialects = True

qubits = qubits_or_params[:qubits_len]
params = qubits_or_params[qubits_len : qubits_len + params_len]
ctrl_qubits = qubits_or_params[qubits_len + params_len : qubits_len + params_len + ctrl_len]
ctrl_values = qubits_or_params[qubits_len + params_len + ctrl_len :]
ctrl_qubits = qubits_or_params[qubits_len : qubits_len + ctrl_len]
ctrl_values = qubits_or_params[qubits_len + ctrl_len : qubits_len + ctrl_len + ctrl_value_len]
params = qubits_or_params[qubits_len + ctrl_len + ctrl_value_len :]

for qubit in qubits:
assert ir.OpaqueType.isinstance(qubit.type)
Expand All @@ -1029,13 +1036,31 @@ def _qinst_lowering(
p = TensorExtractOp(ir.IntegerType.get_signless(1), v, []).result
ctrl_values_i1.append(p)

params_attr = (
None
if not static_params
else ir.DenseF64ArrayAttr.get([ir.FloatAttr.get_f64(val) for val in static_params])
)
if len(float_params) > 0:
assert (
params_attr is None
), "Static parameters are not allowed when having dynamic parameters"

name_attr = ir.StringAttr.get(op)
name_str = str(name_attr)
name_str = name_str.replace('"', "")

if name_str == "MultiRZ":
assert len(float_params) == 1, "MultiRZ takes one float parameter"
float_param = float_params[0]
assert len(float_params) <= 1, "MultiRZ takes at most one dynamic float parameter"
assert (
not static_params or len(static_params) <= 1
), "MultiRZ takes at most one static float parameter"
# TODO: Add support for MultiRZ with static params
float_param = (
TensorExtractOp(ir.F64Type.get(), mlir.ir_constant(static_params[0]), [])
if len(float_params) == 0
else float_params[0]
)
return MultiRZOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
Expand All @@ -1045,7 +1070,17 @@ def _qinst_lowering(
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if params_attr:
return StaticCustomOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
static_params=params_attr,
in_qubits=qubits,
gate_name=name_attr,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results
return CustomOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
Expand Down
8 changes: 5 additions & 3 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,14 @@ def bind_native_operation(qrp, op, controlled_wires, controlled_values, adjoint=
else:
qubits = qrp.extract(op.wires)
controlled_qubits = qrp.extract(controlled_wires)
qubits2 = qinst_p.bind(
*[*qubits, *op.parameters, *controlled_qubits, *controlled_values],
qubits2 = bind_flexible_primitive(
qinst_p,
{"static_params": op.parameters},
*[*qubits, *controlled_qubits, *controlled_values],
op=op.name,
qubits_len=len(qubits),
params_len=len(op.parameters),
ctrl_len=len(controlled_qubits),
ctrl_value_len=len(controlled_values),
adjoint=adjoint,
)
qrp.insert(op.wires, qubits2[: len(qubits)])
Expand Down
2 changes: 2 additions & 0 deletions frontend/catalyst/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def get_enforce_runtime_invariants_stage(_options: CompileOptions) -> List[str]:
"split-multiple-tapes",
# Run the transform sequence defined in the MLIR module
"builtin.module(apply-transform-sequence)",
# Lower the static custom ops to regular custom ops with dynamic parameters.
"static-custom-lowering",
# Nested modules are something that will be used in the future
# for making device specific transformations.
# Since at the moment, nothing in the runtime is using them
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,17 @@ def change_instruction(ctx, eqn):
op = params["op"]
cuda_inst_name = from_catalyst_to_cuda[op]
qubits_len = params["qubits_len"]
static_params = params.get("static_params")

# Now, we can map to the correct op
# For now just assume rx
cuda_inst(ctx.kernel, *qubits_or_params, inst=cuda_inst_name, qubits_len=qubits_len)
cuda_inst(
ctx.kernel,
*qubits_or_params,
inst=cuda_inst_name,
qubits_len=qubits_len,
static_params=static_params,
)

# Finally determine how many are qubits.
qubits = qubits_or_params[:qubits_len]
Expand Down
14 changes: 10 additions & 4 deletions frontend/catalyst/third_party/cuda/primitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,27 +301,33 @@ def make_primitive_for_gate():
kernel_gate_p = jax.core.Primitive("kernel_inst")
kernel_gate_p.multiple_results = True

def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1):
def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None):
"""Convenience.
Quantum operations in CUDA-quantum return no values. But JAXPR expects return values.
We can just say that multiple_results = True and return an empty tuple.
"""
kernel_gate_p.bind(kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len)
kernel_gate_p.bind(
kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len, static_params=static_params
)
return tuple()

@kernel_gate_p.def_impl
def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1):
def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None):
"""Concrete implementation."""
assert inst and qubits_len > 0
if static_params is None:
static_params = []
method = getattr(cudaq.Kernel, inst)
targets = qubits_or_params[:qubits_len]
params = qubits_or_params[qubits_len:]
if not params:
params = static_params
method(kernel, *params, *targets)
return tuple()

@kernel_gate_p.def_abstract_eval
def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1):
def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1, static_params=None):
"""Abstract evaluation."""
return tuple()

Expand Down
4 changes: 1 addition & 3 deletions frontend/test/lit/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,8 @@ def test_decompose_s():
@qml.qnode(dev)
# CHECK-LABEL: public @jit_decompose_s
def decompose_s():
# CHECK-NOT: name="S"
# CHECK: [[pi_div_2:%.+]] = arith.constant 1.57079{{.+}} : f64
# CHECK-NOT: name = "S"
# CHECK: {{%.+}} = quantum.custom "PhaseShift"([[pi_div_2]])
# CHECK: {{%.+}} = quantum.static_custom "PhaseShift" [1.570796e+00]
# CHECK-NOT: name = "S"
qml.S(wires=0)
return measure(wires=0)
Expand Down
2 changes: 1 addition & 1 deletion frontend/test/lit/test_if_else.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def circuit_single_gate(n: int):
# CHECK: [[b6:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t6]]
# CHECK: [[qreg_out1:%.+]] = scf.if [[b6]]
# CHECK-DAG: [[q4:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out]]
# CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.custom "RX"({{%.+}}) [[q4]]
# CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.static_custom "RX" [3.140000e+00] [[q4]]
# CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q5]]
# CHECK: scf.yield [[qreg_3]]
# CHECK: else
Expand Down
30 changes: 15 additions & 15 deletions frontend/test/lit/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
def sample1(x: float, y: float):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
# COM: CHECK: [[q0:%.+]] = quantum.custom "RZ"
# COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=0)

# COM: CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliZ]
Expand All @@ -54,7 +54,7 @@ def sample2(x: float, y: float):
qml.RX(x, wires=0)
# COM: CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# COM: CHECK: [[q0:%.+]] = quantum.custom "RZ"
# COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=0)

# COM: CHECK: [[obs1:%.+]] = quantum.namedobs [[q1]][ PauliX]
Expand All @@ -77,7 +77,7 @@ def sample3(x: float, y: float):
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# CHECK: [[q0:%.+]] = quantum.custom "RZ"
# CHECK: [[q0:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=0)

# CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]]
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_sample_dynamic(shots: int):
def counts1(x: float, y: float):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
# COM: CHECK: [[q0:%.+]] = quantum.custom "RZ"
# COM: CHECK: [[q0:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=0)

# COM: CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliZ]
Expand All @@ -160,7 +160,7 @@ def counts2(x: float, y: float):
qml.RX(x, wires=0)
# COM: CHECK: [[q1:%.+]] = "quantum.custom"({{%.+}}, {{%.+}}) {gate_name = "RY"
qml.RY(y, wires=1)
# COM: CHECK: [[q0:%.+]] = "quantum.custom"({{%.+}}, {{%.+}}) {gate_name = "RZ"
# COM: CHECK: [[q0:%.+]] = "quantum.static_custom"({{%.+}}, {{%.+}}) {gate_name = "RZ"
qml.RZ(0.1, wires=0)

# COM: CHECK: [[obs1:%.+]] = "quantum.namedobs"([[q1]]) {type = #quantum<named_observable PauliX>}
Expand All @@ -183,7 +183,7 @@ def counts3(x: float, y: float):
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# CHECK: [[q0:%.+]] = quantum.custom "RZ"
# CHECK: [[q0:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=0)

# CHECK: [[obs:%.+]] = quantum.compbasis [[q0]], [[q1]]
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_counts_dynamic(shots: int):
def expval1(x: float, y: float):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
# CHECK: [[q0:%.+]] = quantum.custom "RZ"
# CHECK: [[q0:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=0)

# CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliX]
Expand All @@ -247,7 +247,7 @@ def expval2(x: float, y: float):
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# CHECK: [[q2:%.+]] = quantum.custom "RZ"
# CHECK: [[q2:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=2)

# CHECK: [[p1:%.+]] = quantum.namedobs [[q0]][ PauliX]
Expand Down Expand Up @@ -304,7 +304,7 @@ def expval5(x: float, y: float):
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# CHECK: [[q2:%.+]] = quantum.custom "RZ"
# CHECK: [[q2:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=2)

B = np.array(
Expand Down Expand Up @@ -334,7 +334,7 @@ def expval5(x: float, y: float):
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# CHECK: [[q2:%.+]] = quantum.custom "RZ"
# CHECK: [[q2:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=2)

coeffs = np.array([0.2, -0.543])
Expand Down Expand Up @@ -426,7 +426,7 @@ def expval9(x: float, y: float):
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# CHECK: [[q2:%.+]] = quantum.custom "RZ"
# CHECK: [[q2:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=2)

# CHECK: [[p1:%.+]] = quantum.namedobs [[q0]][ PauliX]
Expand All @@ -448,7 +448,7 @@ def expval10(x: float, y: float):
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# CHECK: [[q2:%.+]] = quantum.custom "RZ"
# CHECK: [[q2:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=2)

B = np.array(
Expand Down Expand Up @@ -476,7 +476,7 @@ def expval10(x: float, y: float):
def var1(x: float, y: float):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
# CHECK: [[q0:%.+]] = quantum.custom "RZ"
# CHECK: [[q0:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=0)

# CHECK: [[obs:%.+]] = quantum.namedobs [[q0]][ PauliX]
Expand Down Expand Up @@ -519,7 +519,7 @@ def probs1(x: float, y: float):
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# CHECK: [[q0:%.+]] = quantum.custom "RZ"
# CHECK: [[q0:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=0)

# qml.probs() # unsupported by PennyLane
Expand All @@ -540,7 +540,7 @@ def state1(x: float, y: float):
qml.RX(x, wires=0)
# CHECK: [[q1:%.+]] = quantum.custom "RY"
qml.RY(y, wires=1)
# CHECK: [[q0:%.+]] = quantum.custom "RZ"
# CHECK: [[q0:%.+]] = quantum.static_custom "RZ"
qml.RZ(0.1, wires=0)

# qml.state(wires=[0]) # unsupported by PennyLane
Expand Down
2 changes: 1 addition & 1 deletion frontend/test/lit/test_quantum_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_native_controlled_custom():
@qml.qnode(dev)
# CHECK-LABEL: public @jit_native_controlled
def native_controlled():
# CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.custom "Rot"
# CHECK: [[out:%.+]], [[out_ctrl:%.+]]:2 = quantum.static_custom "Rot"
# CHECK-SAME: ctrls
# CHECK-SAME: ctrlvals(%true, %true)
qml.ctrl(qml.Rot(0.3, 0.4, 0.5, wires=[0]), control=[1, 2])
Expand Down
Loading

0 comments on commit 48edd02

Please sign in to comment.