From 511ea57967f543ab8f2f45e690f1127e41ba95df Mon Sep 17 00:00:00 2001 From: Mehrdad Malekmohammadi Date: Tue, 17 Dec 2024 17:42:07 -0500 Subject: [PATCH] make format --- .../third_party/cuda/catalyst_to_cuda_interpreter.py | 8 +++++++- frontend/catalyst/third_party/cuda/primitives/__init__.py | 4 +++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py index 2d96839314..adc6545473 100644 --- a/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py +++ b/frontend/catalyst/third_party/cuda/catalyst_to_cuda_interpreter.py @@ -432,7 +432,13 @@ def change_instruction(ctx, eqn): # Now, we can map to the correct op # For now just assume rx - cuda_inst(ctx.kernel, *qubits_or_params, inst=cuda_inst_name, qubits_len=qubits_len, static_params=static_params) + cuda_inst( + ctx.kernel, + *qubits_or_params, + inst=cuda_inst_name, + qubits_len=qubits_len, + static_params=static_params, + ) # Finally determine how many are qubits. qubits = qubits_or_params[:qubits_len] diff --git a/frontend/catalyst/third_party/cuda/primitives/__init__.py b/frontend/catalyst/third_party/cuda/primitives/__init__.py index b542b88ea7..5c00100308 100644 --- a/frontend/catalyst/third_party/cuda/primitives/__init__.py +++ b/frontend/catalyst/third_party/cuda/primitives/__init__.py @@ -307,7 +307,9 @@ def gate_func(kernel, *qubits_or_params, inst=None, qubits_len=-1, static_params Quantum operations in CUDA-quantum return no values. But JAXPR expects return values. We can just say that multiple_results = True and return an empty tuple. """ - kernel_gate_p.bind(kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len, static_params=static_params) + kernel_gate_p.bind( + kernel, *qubits_or_params, inst=inst, qubits_len=qubits_len, static_params=static_params + ) return tuple() @kernel_gate_p.def_impl