Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change: force precompilation at startup #44

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 28 additions & 31 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,29 +55,24 @@ def run_sim_batch(oq3_prog, sim, shots):
return


device_ids = ("braket_sv", "braket_sv_v2", "braket_dm", "braket_dm_v2")
device_ids = ("sv", "dm")

generators = (ghz, qft)
generators = ("ghz", "qft")


@pytest.mark.parametrize("device_id", device_ids)
@pytest.mark.parametrize("nq", n_qubits)
@pytest.mark.parametrize("exact_results", exact_shots_results)
@pytest.mark.parametrize("circuit", generators)
def test_exact_shots(benchmark, device_id, nq, exact_results, circuit):
if device_id in ("braket_dm_v2", "braket_dm") and (
exact_results in ("state_vector",) or nq > 10
):
pytest.skip()
if (
device_id in ("braket_sv",)
and exact_results in ("density_matrix q[0], q[1]",)
and nq >= 17
):
if device_id == "dm" and (exact_results in ("state_vector",) or nq > 10):
pytest.skip()
result_type = exact_results
oq3_prog = Program(source=circuit(nq, result_type))
sim = LocalSimulator(device_id)
if circuit == "qft":
oq3_prog = Program(source=qft(nq, result_type))
elif circuit == "ghz":
oq3_prog = Program(source=ghz(nq, result_type))
sim = LocalSimulator(f"braket_{device_id}_v2")
benchmark.pedantic(run_sim, args=(oq3_prog, sim, 0), iterations=5, warmup_rounds=1)


Expand All @@ -89,17 +84,16 @@ def test_exact_shots(benchmark, device_id, nq, exact_results, circuit):
def test_exact_shots_batched(
benchmark, device_id, nq, batch_size, exact_results, circuit
):
if device_id in ("braket_dm_v2", "braket_dm") and (
exact_results in ("state_vector,") or nq >= 5
):
pytest.skip()
if nq >= 10:
if (
device_id == "dm" and (exact_results in ("state_vector,") or nq >= 5)
) or nq >= 15:
pytest.skip()
# skip all for now as this is very expensive
pytest.skip()
result_type = exact_results
oq3_prog = [Program(source=circuit(nq, result_type)) for _ in range(batch_size)]
sim = LocalSimulator(device_id)
if circuit == "qft":
oq3_prog = [Program(source=qft(nq, result_type)) for _ in range(batch_size)]
elif circuit == "ghz":
oq3_prog = [Program(source=ghz(nq, result_type)) for _ in range(batch_size)]
sim = LocalSimulator(f"braket_{device_id}_v2")
benchmark.pedantic(
run_sim_batch, args=(oq3_prog, sim, 0), iterations=5, warmup_rounds=1
)
Expand All @@ -114,11 +108,14 @@ def test_exact_shots_batched(
@pytest.mark.parametrize("nonzero_shots_results", nonzero_shots_results)
@pytest.mark.parametrize("circuit", generators)
def test_nonzero_shots(benchmark, device_id, nq, shots, nonzero_shots_results, circuit):
if device_id in ("braket_dm_v2", "braket_dm") and nq > 10:
if device_id in ("dm",) and nq > 10:
pytest.skip()
result_type = nonzero_shots_results
oq3_prog = Program(source=circuit(nq, result_type))
sim = LocalSimulator(device_id)
if circuit == "qft":
oq3_prog = Program(source=qft(nq, result_type))
elif circuit == "ghz":
oq3_prog = Program(source=ghz(nq, result_type))
sim = LocalSimulator(f"braket_{device_id}_v2")
benchmark.pedantic(
run_sim, args=(oq3_prog, sim, shots), iterations=5, warmup_rounds=1
)
Expand All @@ -134,17 +131,17 @@ def test_nonzero_shots(benchmark, device_id, nq, shots, nonzero_shots_results, c
def test_nonzero_shots_batched(
benchmark, device_id, nq, batch_size, shots, nonzero_shots_results, circuit
):
if device_id in ("braket_dm_v2", "braket_dm") and nq >= 5:
if device_id in ("dm") and nq >= 5:
pytest.skip()
if nq >= 10:
pytest.skip()

# skip all for now as this is very expensive
pytest.skip()

result_type = nonzero_shots_results
oq3_prog = [Program(source=circuit(nq, result_type)) for _ in range(batch_size)]
sim = LocalSimulator(device_id)
if circuit == "qft":
oq3_prog = [Program(source=qft(nq, result_type)) for _ in range(batch_size)]
elif circuit == "ghz":
oq3_prog = [Program(source=ghz(nq, result_type)) for _ in range(batch_size)]
sim = LocalSimulator(f"braket_{device_id}_v2")
benchmark.pedantic(
run_sim_batch, args=(oq3_prog, sim, shots), iterations=5, warmup_rounds=1
)
Expand Down
83 changes: 83 additions & 0 deletions benchmark/pl_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import networkx as nx
import numpy as np
import pennylane as qml
import pytest

# always the same for repeatability
np.random.seed(0x1C2C6D66)
n_qubits = range(3, 16)
n_layers = range(1, 5)
shots = (100,)


def make_wide_tapes(nq: int, nl: int, shots: int):
gamma = 0.2
alpha = 0.4
p = 0.5
seed = 42
g = nx.erdos_renyi_graph(nq, p=p, seed=seed)
cost_h, mixer_h = qml.qaoa.max_clique(g, constrained=False)
ops = [qml.Hadamard(i) for i in range(nq)]
for layer in range(nl):
cl_op = qml.templates.ApproxTimeEvolution(cost_h, gamma, 1)
ops.append(cl_op)
ml_op = qml.templates.ApproxTimeEvolution(mixer_h, alpha, 1)
ops.append(ml_op)

measurements = [qml.expval(o) for (c, o) in zip(cost_h.coeffs, cost_h.ops)]
tapes = [qml.tape.QuantumTape(ops, measurements, shots=shots)]
wider_tapes = [t.expand(depth=5) for t in tapes]
return wider_tapes


def make_qiskit_tapes(nq: int, nl: int, shots: int):
wide_tapes = make_wide_tapes(nq, nl, shots)
qiskit_sim = qml.device(
"qiskit.aer",
backend="aer_simulator_statevector",
wires=nq,
shots=shots,
statevector_parallel_threshold=8,
)
qiskit_tapes = qiskit_sim.compile_circuits(wide_tapes)
return qiskit_tapes


@pytest.mark.parametrize("shots", shots)
@pytest.mark.parametrize("n_layers", n_layers)
@pytest.mark.parametrize("nq", n_qubits)
def test_sim_aer(benchmark, shots, n_layers, nq):
tapes = make_qiskit_tapes(nq, n_layers, shots)
sim = qml.device(
"qiskit.aer", backend="aer_simulator_statevector", wires=nq, shots=shots
)
benchmark.pedantic(sim.execute, args=(tapes,), iterations=5, warmup_rounds=1)


@pytest.mark.parametrize("shots", shots)
@pytest.mark.parametrize("n_layers", n_layers)
@pytest.mark.parametrize("nq", n_qubits)
def test_sim_v2(benchmark, shots, n_layers, nq):
tapes = make_wide_tapes(nq, n_layers, shots)
sim = qml.device(
"braket.local.qubit", backend="braket_sv_v2", wires=nq, shots=shots
)
benchmark.pedantic(sim.execute, args=(tapes,), iterations=5, warmup_rounds=1)


@pytest.mark.parametrize("shots", shots)
@pytest.mark.parametrize("n_layers", n_layers)
@pytest.mark.parametrize("nq", n_qubits)
def test_sim_v1(benchmark, shots, n_layers, nq):
tapes = make_wide_tapes(nq, n_layers, shots)
sim = qml.device("braket.local.qubit", backend="braket_sv", wires=nq, shots=shots)
benchmark.pedantic(sim.execute, args=(tapes,), iterations=5, warmup_rounds=1)


@pytest.mark.parametrize("shots", shots)
@pytest.mark.parametrize("n_layers", n_layers)
@pytest.mark.parametrize("nq", n_qubits)
def test_sim_lightning(benchmark, shots, n_layers, nq):
tapes = make_wide_tapes(nq, n_layers, shots)
sim = qml.device("lightning.qubit", wires=nq, shots=shots)
benchmark.pedantic(sim.execute, args=(tapes,), iterations=5, warmup_rounds=1)
4 changes: 2 additions & 2 deletions src/braket/juliapkg.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"julia": "1.10",
"julia": "1.11",
"packages": {
"BraketSimulator": {
"uuid": "76d27892-9a0b-406c-98e4-7c178e9b3dff",
"version": "0.0.5"
"version": "0.0.7"
},
"JSON3": {
"uuid": "0f8b85d8-7281-11e9-16c2-39a750bddbf1",
Expand Down
66 changes: 59 additions & 7 deletions src/braket/simulator_v2/base_simulator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setup_julia():
import sys

# don't reimport if we don't have to
if "juliacall" in sys.modules:
if "juliacall" in sys.modules and hasattr(sys.modules["juliacall"], "Main"):
os.environ["PYTHON_JULIACALL_HANDLE_SIGNALS"] = "yes"
return
else:
Expand All @@ -37,26 +37,67 @@ def setup_julia():
):
os.environ[k] = os.environ.get(k, default)

import juliacall
from juliacall import Main as jl

jl = juliacall.Main
# These are used at simulator class instantiation to trigger
# precompilation of Julia methods which may be invalidated
# or uncacheable. Total time for this should be <1s.
jl.seval("using BraketSimulator, JSON3")
stock_oq3 = """
exact_sv_oq3 = """
OPENQASM 3.0;
input float p;
qubit[2] q;
h q[0];
cphaseshift(1.5707963267948966) q[1], q[0];
rx(1.5707963267948966) q[0];
ry(1.5707963267948966) q[0];
rz(p) q[0];
rz(p) q[0];
ry(1) q[1];
rx(0) q[1];
rz(2) q[1];
cnot q;
#pragma braket noise bit_flip(0.1) q[0]
#pragma braket result variance y(q[0])
#pragma braket result expectation y(q[0])
#pragma braket result expectation y(q[0]) @ z(q[1])
#pragma braket result expectation z(q[0]) @ z(q[1])
#pragma braket result density_matrix q[0], q[1]
#pragma braket result probability
"""
jl.BraketSimulator.simulate("braket_dm_v2", stock_oq3, "{}", 0)
inexact_sv_oq3 = """
OPENQASM 3.0;
input float p;
qubit[9] q;
h q;
#pragma braket result variance y(q[0])
#pragma braket result expectation z(q[1])
#pragma braket result expectation z(q[1]) @ z(q[2])
#pragma braket result expectation x(q[3]) @ x(q[4])
#pragma braket result expectation y(q[5]) @ y(q[6])
#pragma braket result expectation h(q[7]) @ h(q[8])
"""
stock_dm_oq3 = """
OPENQASM 3.0;
input float p;
qubit[2] q;
h q[0];
#pragma braket noise bit_flip(0.1) q[0]
#pragma braket noise phase_flip(0.1) q[0]
#pragma braket result variance y(q[0])
#pragma braket result expectation y(q[0])
#pragma braket result density_matrix q[0], q[1]
"""
jl.BraketSimulator.simulate("braket_sv_v2", exact_sv_oq3, '{"p": 1.57}', 0)
jl.BraketSimulator.simulate("braket_sv_v2", inexact_sv_oq3, '{"p": 1.57}', 100)
jl.BraketSimulator.simulate("braket_dm_v2", stock_dm_oq3, '{"p": 1.57}', 0)
return


def setup_pool():
# We use a multiprocessing Pool with one worker
# in order to bypass the Python GIL. This protects us
# when the simulator is used from a non-main thread from another
# Python module, as occurs in the Qiskit-Braket plugin.
global __JULIA_POOL__
__JULIA_POOL__ = Pool(processes=1)
__JULIA_POOL__.apply(setup_julia)
Expand All @@ -65,6 +106,11 @@ def setup_pool():
return


# large arrays are extremely expensive to transfer among Python
# processes because they are pickle'd. For large arrays like for
# StateVector, DensityMatrix, or Probability result types, we
# instead do an mmap to disk, which is dramatically faster. For
# smaller objects this isn't helpful.
def _handle_mmaped_result(raw_result, mmap_paths, obj_lengths):
result = GateModelTaskResult(**raw_result)
if mmap_paths:
Expand All @@ -91,6 +137,8 @@ def _handle_mmaped_result(raw_result, mmap_paths, obj_lengths):
class BaseLocalSimulatorV2(BaseLocalSimulator):
def __init__(self, device: str):
global __JULIA_POOL__
# if the pool is already set up, no need
# to do anything
if __JULIA_POOL__ is None:
setup_pool()
self._device = device
Expand Down Expand Up @@ -120,10 +168,13 @@ def run_openqasm(
are requested when shots>0.
"""
global __JULIA_POOL__

# pass inputs and source as strings to avoid pickling a dict
inputs_dict = json.dumps(openqasm_ir.inputs) if openqasm_ir.inputs else "{}"
try:
jl_result = __JULIA_POOL__.apply(
translate_and_run,
[self._device, openqasm_ir, shots],
[self._device, openqasm_ir.source, inputs_dict, shots],
)
except Exception as e:
_handle_julia_error(e)
Expand All @@ -134,6 +185,7 @@ def run_openqasm(

# attach the result types
if not shots:
# have to convert the types of array result types to what the BDK expects
result = _result_value_to_ndarray(result)
else:
result.resultTypes = [rt.type for rt in result.resultTypes]
Expand Down
16 changes: 5 additions & 11 deletions src/braket/simulator_v2/julia_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

def _handle_julia_error(error):
# in case juliacall isn't loaded
print(error)
if type(error).__name__ == "JuliaError":
python_exception = getattr(error.exception, "alternate_type", None)
if python_exception is None:
Expand All @@ -27,23 +26,18 @@ def _handle_julia_error(error):


def translate_and_run(
device_id: str, openqasm_ir: OpenQASMProgram, shots: int = 0
device_id: str, openqasm_source: str, openqasm_inputs: str, shots: int = 0
) -> str:
jl = sys.modules["juliacall"].Main
jl.GC.enable(False)
jl_inputs = json.dumps(openqasm_ir.inputs) if openqasm_ir.inputs else "{}"
jl = getattr(sys.modules["juliacall"], "Main")
try:
result = jl.BraketSimulator.simulate(
device_id,
openqasm_ir.source,
jl_inputs,
openqasm_source,
openqasm_inputs,
shots,
)

except Exception as e:
_handle_julia_error(e)
finally:
jl.GC.enable(True)

return result

Expand All @@ -55,7 +49,7 @@ def translate_and_run_multiple(
inputs: Optional[Union[dict, Sequence[dict]]] = None,
) -> List[str]:
inputs = inputs or {}
jl = sys.modules["juliacall"].Main
jl = getattr(sys.modules["juliacall"], "Main")
irs = [program.source for program in programs]
py_inputs = {}
if len(inputs) > 1 or isinstance(inputs, dict):
Expand Down
Loading