Skip to content

Commit

Permalink
fix: cleanup and set timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws committed Aug 15, 2024
1 parent 58b7ad0 commit b8312b7
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 51 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ on:
jobs:
build:
runs-on: ${{ matrix.os }}
timeout-minutes: 10
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ test = [
"pytest-rerunfailures",
"pytest-timeout",
"pytest-xdist",
"qiskit-braket-provider",
"qiskit==1.1.2",
"qiskit-braket-provider==0.4.1",
"qiskit-algorithms",
"sphinx",
"sphinx-rtd-theme",
Expand Down
13 changes: 13 additions & 0 deletions src/braket/simulator_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from braket.ir.openqasm import Program as OpenQASMProgram

from braket.simulator_v2.density_matrix_simulator_v2 import ( # noqa: F401
DensityMatrixSimulatorV2,
)
Expand All @@ -7,3 +9,14 @@
)

from ._version import __version__ # noqa: F401

payload = OpenQASMProgram(
source="""
OPENQASM 3.0;
qubit[1] q;
h q[0];
#pragma braket result state_vector
"""
)
StateVectorSimulatorV2().run_openqasm(payload)
StateVectorSimulatorV2().run_multiple([payload, payload])
79 changes: 36 additions & 43 deletions src/braket/simulator_v2/base_simulator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,6 @@ def __init__(self, device):
def initialize_simulation(self, **kwargs):
return

def _openqasm_to_jl(self, openqasm_ir: OpenQASMProgram):
# convert to the Julia OpenQasmProgram type for dispatch
jl_braket_schema_header = jl.BraketSimulator.braketSchemaHeader(
jl.convert(jl.String, openqasm_ir.braketSchemaHeader.name),
jl.convert(jl.String, openqasm_ir.braketSchemaHeader.version),
)
if openqasm_ir.inputs:
jl_inputs = jl.Dict[jl.String, jl.Any](
jl.Pair(
jl.convert(jl.String, input_key),
(
jl.convert(jl.String, input_val)
if isinstance(input_val, str)
else jl.convert(jl.Number, input_val)
),
)
for (input_key, input_val) in openqasm_ir.inputs.items()
)
else:
jl_inputs = jl.Dict[jl.String, jl.Float64]()
jl_source = jl.convert(jl.String, openqasm_ir.source)
return jl.BraketSimulator.OpenQasmProgram(
jl_braket_schema_header,
jl_source,
jl_inputs,
)

def _ir_list_to_jl(self, payloads: list[OpenQASMProgram], shots: int):
return [self._openqasm_to_jl(ir) for ir in payloads]

def run_openqasm(
self,
openqasm_ir: OpenQASMProgram,
Expand All @@ -71,9 +41,18 @@ def run_openqasm(
are requested when shots>0.
"""
try:
jl_ir = self._openqasm_to_jl(openqasm_ir)
jl_shots = jl.convert(jl.Int, shots)
jl_result = jl.simulate(self._device, [jl_ir], jl_shots)[0]
jl_shots = shots
jl_inputs = (
jl.Dict[jl.String, jl.Any](
jl.Pair(jl.convert(jl.String, k), jl.convert(jl.Any, v))
for (k, v) in openqasm_ir.inputs.items()
)
if openqasm_ir.inputs
else jl.Dict[jl.String, jl.Any]()
)
jl_result = jl.BraketSimulator.simulate._jl_call_nogil(
self._device, openqasm_ir.source, jl_inputs, jl_shots
)
except JuliaError as e:
_handle_julia_error(e)

Expand All @@ -92,8 +71,8 @@ def run_multiple(
programs: Sequence[OpenQASMProgram],
max_parallel: Optional[int] = -1,
shots: Optional[int] = 0,
inputs: Optional[Union[dict, Sequence[dict]]] = None,
): # -> list[GateModelTaskResult]:
inputs: Optional[Union[dict, Sequence[dict]]] = {},
) -> list[GateModelTaskResult]:
"""
Run the tasks specified by the given IR programs.
Extra arguments will contain any additional information necessary to run the tasks,
Expand All @@ -107,15 +86,28 @@ def run_multiple(
the result of the ith program.
"""
try:
julia_irs = self._ir_list_to_jl(programs, shots)
julia_inputs = (
jl.Dict[jl.String, jl.Float64]() if inputs is None else inputs
)
jl_results = jl.simulate(
irs = jl.Vector[jl.String]()
is_single_input = isinstance(inputs, dict) or len(inputs) == 1
py_inputs = {}
if (is_single_input and isinstance(inputs, dict)) or not is_single_input:
py_inputs = [inputs.copy() for _ in range(len(programs))]
elif is_single_input and not isinstance(inputs, dict):
py_inputs = [inputs[0].copy() for _ in range(len(programs))]
else:
py_inputs = inputs
jl_inputs = jl.Vector[jl.Dict[jl.String, jl.Any]]()
for p_ix, program in enumerate(programs):
irs.append(program.source)
if program.inputs:
jl_inputs.append(program.inputs | py_inputs[p_ix])
else:
jl_inputs.append(py_inputs[p_ix])

jl_results = jl.BraketSimulator.simulate._jl_call_nogil(
self._device,
julia_irs,
jl.convert(jl.Int, shots),
inputs=julia_inputs,
irs,
jl_inputs,
shots,
max_parallel=jl.convert(jl.Int, max_parallel),
)

Expand Down Expand Up @@ -178,6 +170,7 @@ def reconstruct_complex(v):

def _handle_julia_error(julia_error: JuliaError):
try:
print(julia_error)
python_exception = getattr(julia_error.exception, "alternate_type", None)
if python_exception is None:
error = julia_error
Expand Down
2 changes: 0 additions & 2 deletions test/unit_tests/braket/simulator_v2/test_qiskit_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,5 @@ def vqe():
@pytest.mark.timeout(10)
def test_qiskit_vqe(H2_op, vqe):
# Find the ground state
print("Computing VQE", flush=True)
result = vqe.compute_minimum_eigenvalue(H2_op)
print("Done computing VQE", flush=True)
assert result.eigenvalue < 0.0
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@

import cmath
import re

# import re
import sys
from collections import Counter, namedtuple
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import pytest
Expand Down Expand Up @@ -1131,3 +1130,52 @@ def test_run_multiple():
assert np.allclose(results[0].resultTypes[0].value, np.array([1, 1]) / np.sqrt(2))
assert np.allclose(results[1].resultTypes[0].value, np.array([1, 0]))
assert np.allclose(results[2].resultTypes[0].value, np.array([0, 1]))


@pytest.mark.timeout(10)
def test_run_single_executor():
payload = OpenQASMProgram(
source="""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
h q[0];
#pragma braket result state_vector
"""
)
pool = ThreadPoolExecutor(2)
fs = {
pool.submit(StateVectorSimulator().run_openqasm, payload): ix
for ix in range(10)
}
for future in as_completed(fs):
results = future.result()
assert np.allclose(results.resultTypes[0].value, np.array([1, 1]) / np.sqrt(2))


@pytest.mark.timeout(10)
def test_run_multiple_executor():
payloads = [
OpenQASMProgram(
source=f"""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
{gate} q[0];
#pragma braket result state_vector
"""
)
for gate in ["h", "z", "x"]
]
pool = ThreadPoolExecutor(2)
fs = {
pool.submit(StateVectorSimulator().run_multiple, payloads): ix
for ix in range(10)
}
for future in as_completed(fs):
results = future.result()
assert np.allclose(
results[0].resultTypes[0].value, np.array([1, 1]) / np.sqrt(2)
)
assert np.allclose(results[1].resultTypes[0].value, np.array([1, 0]))
assert np.allclose(results[2].resultTypes[0].value, np.array([0, 1]))
3 changes: 0 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ envlist = linters,docs,unit-tests

[testenv:unit-tests]
basepython = python3
setenv =
JULIA_PKG_USE_CLI_GIT=true
JULIA_CONDAPKG_BACKEND="Null"
# {posargs} contains additional arguments specified when invoking tox. e.g. tox -- -s -k test_foo.py
deps =
{[test-deps]deps}
Expand Down

0 comments on commit b8312b7

Please sign in to comment.