Skip to content

Commit

Permalink
Cuda support for static params
Browse files Browse the repository at this point in the history
  • Loading branch information
erick-xanadu authored and mehrdad2m committed Dec 17, 2024
1 parent a96f730 commit 8bb0129
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,11 @@ def change_instruction(ctx, eqn):
op = params["op"]
cuda_inst_name = from_catalyst_to_cuda[op]
qubits_len = params["qubits_len"]
static_params = params.get("static_params")

# Now, we can map to the correct op
# For now just assume rx
cuda_inst(ctx.kernel, *qubits_or_params, inst=cuda_inst_name, qubits_len=qubits_len)
cuda_inst(ctx.kernel, *qubits_or_params, inst=cuda_inst_name, qubits_len=qubits_len, static_params=static_params)

Check notice on line 435 in frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py#L435

Line too long (117/100) (line-too-long)

# Finally determine how many are qubits.
qubits = qubits_or_params[:qubits_len]
Expand Down
12 changes: 8 additions & 4 deletions frontend/catalyst/third_party/cuda/primitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,27 +301,31 @@ def make_primitive_for_gate():
kernel_gate_p = jax.core.Primitive("kernel_inst")
kernel_gate_p.multiple_results = True

def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1):
def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None):
"""Convenience.
Quantum operations in CUDA-quantum return no values. But JAXPR expects return values.
We can just say that multiple_results = True and return an empty tuple.
"""
kernel_gate_p.bind(kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len)
kernel_gate_p.bind(kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len, static_params=static_params)
return tuple()

@kernel_gate_p.def_impl
def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1):
def gate_impl(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params=None):
"""Concrete implementation."""
assert inst and qubits_len > 0
if static_params is None:
static_params = []
method = getattr(cudaq.Kernel, inst)
targets = qubits_or_params[:qubits_len]
params = qubits_or_params[qubits_len:]
if not params:
params = static_params
method(kernel, *params, *targets)
return tuple()

@kernel_gate_p.def_abstract_eval
def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1):
def gate_abs(_kernel, *_qubits_or_params, inst=None, qubits_len=-1, static_params=None):
"""Abstract evaluation."""
return tuple()

Expand Down

0 comments on commit 8bb0129

Please sign in to comment.