diff --git a/benchmark/pl_benchmark.py b/benchmark/pl_benchmark.py index 3d69261..93a1179 100644 --- a/benchmark/pl_benchmark.py +++ b/benchmark/pl_benchmark.py @@ -48,7 +48,9 @@ def make_qiskit_tapes(nq: int, nl: int, shots: int): @pytest.mark.parametrize("nq", n_qubits) def test_sim_aer(benchmark, shots, n_layers, nq): tapes = make_qiskit_tapes(nq, n_layers, shots) - sim = qml.device("qiskit.aer", backend="aer_simulator_statevector", wires=nq, shots=shots) + sim = qml.device( + "qiskit.aer", backend="aer_simulator_statevector", wires=nq, shots=shots + ) benchmark.pedantic(sim.execute, args=(tapes,), iterations=5, warmup_rounds=1) diff --git a/src/braket/simulator_v2/base_simulator_v2.py b/src/braket/simulator_v2/base_simulator_v2.py index 2e299ba..9fbb580 100644 --- a/src/braket/simulator_v2/base_simulator_v2.py +++ b/src/braket/simulator_v2/base_simulator_v2.py @@ -39,6 +39,9 @@ def setup_julia(): from juliacall import Main as jl + # These are used at simulator class instantiation to trigger + # precompilation of Julia methods which may be invalidated + # or uncacheable. Total time for this should be <1s. jl.seval("using BraketSimulator, JSON3") exact_sv_oq3 = """ OPENQASM 3.0; @@ -91,6 +94,10 @@ def setup_julia(): def setup_pool(): + # We use a multiprocessing Pool with one worker + # in order to bypass the Python GIL. This protects us + # when the simulator is used from a non-main thread from another + # Python module, as occurs in the Qiskit-Braket plugin. global __JULIA_POOL__ __JULIA_POOL__ = Pool(processes=1) __JULIA_POOL__.apply(setup_julia) @@ -99,6 +106,11 @@ def setup_pool(): return +# large arrays are extremely expensive to transfer among Python +# processes because they are pickle'd. For large arrays like for +# StateVector, DensityMatrix, or Probability result types, we +# instead do an mmap to disk, which is dramatically faster. For +# smaller objects this isn't helpful. def _handle_mmaped_result(raw_result, mmap_paths, obj_lengths): result = GateModelTaskResult(**raw_result) if mmap_paths: @@ -125,6 +137,8 @@ def _handle_mmaped_result(raw_result, mmap_paths, obj_lengths): class BaseLocalSimulatorV2(BaseLocalSimulator): def __init__(self, device: str): global __JULIA_POOL__ + # if the pool is already set up, no need + # to do anything if __JULIA_POOL__ is None: setup_pool() self._device = device @@ -155,6 +169,7 @@ def run_openqasm( """ global __JULIA_POOL__ + # pass inputs and source as strings to avoid pickling a dict inputs_dict = json.dumps(openqasm_ir.inputs) if openqasm_ir.inputs else "{}" try: jl_result = __JULIA_POOL__.apply( @@ -170,6 +185,7 @@ def run_openqasm( # attach the result types if not shots: + # have to convert the types of array result types to what the BDK expects result = _result_value_to_ndarray(result) else: result.resultTypes = [rt.type for rt in result.resultTypes] diff --git a/src/braket/simulator_v2/julia_workers.py b/src/braket/simulator_v2/julia_workers.py index b0efabf..ba329d9 100644 --- a/src/braket/simulator_v2/julia_workers.py +++ b/src/braket/simulator_v2/julia_workers.py @@ -8,7 +8,6 @@ def _handle_julia_error(error): # in case juliacall isn't loaded - print(error) if type(error).__name__ == "JuliaError": python_exception = getattr(error.exception, "alternate_type", None) if python_exception is None: