Skip to content

Commit

Permalink
fix: StateVector shouldn't be a supported pragma for DM simulator (#25)
Browse files Browse the repository at this point in the history

Co-authored-by: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com>
  • Loading branch information
kshyatt-aws and rmshaffer authored Jul 6, 2024
1 parent d5e51c6 commit 38395a7
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 53 deletions.
7 changes: 3 additions & 4 deletions src/braket/simulator_v2/base_simulator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Sequence
from typing import Any, Optional, Union

import juliacall
import numpy as np
from braket.default_simulator.result_types import TargetedResultType
from braket.default_simulator.simulator import BaseLocalSimulator
Expand Down Expand Up @@ -100,7 +99,7 @@ def run_jaqcd(
translated_ir, qubit_count = self._jaqcd_to_jl(circuit_ir, shots)
try:
r = jl.simulate(self._device, translated_ir, qubit_count, shots)
except juliacall.JuliaError as e:
except JuliaError as e:
_handle_julia_error(e)
r.additionalMetadata.action = circuit_ir
r = _result_value_to_ndarray(r)
Expand Down Expand Up @@ -158,7 +157,7 @@ def run_openqasm(
"""
try:
r = jl.simulate(self._device, self._openqasm_to_jl(openqasm_ir), shots)
except juliacall.JuliaError as e:
except JuliaError as e:
_handle_julia_error(e)
r.additionalMetadata.action = openqasm_ir
# attach the result types
Expand Down Expand Up @@ -209,7 +208,7 @@ def run_multiple(
shots=shots,
inputs=inputs,
)
except juliacall.JuliaError as e:
except JuliaError as e:
_handle_julia_error(e)

for r_ix, result in enumerate(results):
Expand Down
1 change: 0 additions & 1 deletion src/braket/simulator_v2/density_matrix_simulator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def properties(self) -> GateModelSimulatorDeviceCapabilities:
],
"supportedPragmas": [
"braket_unitary_matrix",
"braket_result_type_state_vector",
"braket_result_type_density_matrix",
"braket_result_type_sample",
"braket_result_type_expectation",
Expand Down
50 changes: 21 additions & 29 deletions src/braket/simulator_v2/julia_import.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,34 @@
import os
import sys
import warnings

import juliacall

# Check if JuliaCall is already loaded, and if so, warn the user
# about the relevant environment variables. If not loaded,
# set up sensible defaults.
if "juliacall" in sys.modules:
# Required to avoid segfaults (https://juliapy.github.io/PythonCall.jl/dev/faq/)
if os.environ.get("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes") != "yes":
warnings.warn(
"`juliacall` module has already been imported. "
+ "Make sure that you have set the environment variable "
+ "`PYTHON_JULIACALL_HANDLE_SIGNALS=yes` to avoid segfaults. "
"`PYTHON_JULIACALL_HANDLE_SIGNALS` environment variable "
+ "is set to something other than 'yes' or ''. "
+ "You will experience segfaults if running with Julia multithreading."
)
else:
# Required to avoid segfaults (https://juliapy.github.io/PythonCall.jl/dev/faq/)
if os.environ.get("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes") != "yes":
warnings.warn(
"`PYTHON_JULIACALL_HANDLE_SIGNALS` environment variable "
+ "is set to something other than 'yes' or ''. "
+ "You will experience segfaults if running with Julia multithreading."
)

if os.environ.get("PYTHON_JULIACALL_THREADS", "auto") != "auto":
warnings.warn(
"`PYTHON_JULIACALL_THREADS` environment variable is set to "
+ "something other than `auto`, so `amazon-braket-simulator-v2` "
+ "was not able to set it."
)

for k, default in (
("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"),
("PYTHON_JULIACALL_THREADS", "auto"),
("PYTHON_JULIACALL_OPTLEVEL", "3"),
# let the user's Conda/Pip handle installing things
("JULIA_CONDAPKG_BACKEND", "Null"),
):
os.environ[k] = os.environ.get(k, default)
if os.environ.get("PYTHON_JULIACALL_THREADS", "auto") != "auto":
warnings.warn(
"`PYTHON_JULIACALL_THREADS` environment variable is set to "
+ "something other than `auto`, so `amazon-braket-simulator-v2` "
+ "was not able to set it."
)

import juliacall
for k, default in (
("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"),
("PYTHON_JULIACALL_THREADS", "auto"),
("PYTHON_JULIACALL_OPTLEVEL", "3"),
# let the user's Conda/Pip handle installing things
("JULIA_CONDAPKG_BACKEND", "Null"),
):
os.environ[k] = os.environ.get(k, default)

jl = juliacall.Base.Module()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def test_properties():
],
"supportedPragmas": [
"braket_unitary_matrix",
"braket_result_type_state_vector",
"braket_result_type_density_matrix",
"braket_result_type_sample",
"braket_result_type_expectation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def test_simulator_run_grcs_16(grcs_16_qubit, batch_size):
if isinstance(grcs_16_qubit.circuit_ir, JaqcdProgram):
result = simulator.run(
grcs_16_qubit.circuit_ir,
qubit_count=16,
shots=0,
batch_size=batch_size,
)
Expand All @@ -102,9 +101,7 @@ def test_simulator_run_bell_pair(bell_ir, batch_size, caplog):
simulator = StateVectorSimulator()
shots_count = 10000
if isinstance(bell_ir, JaqcdProgram):
result = simulator.run(
bell_ir, qubit_count=2, shots=shots_count, batch_size=batch_size
)
result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size)
else:
result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size)

Expand Down Expand Up @@ -729,7 +726,6 @@ def test_simulator_identity(caplog):
if isinstance(program, JaqcdProgram):
result = simulator.run(
program,
qubit_count=2,
shots=shots_count,
)
else:
Expand All @@ -756,7 +752,7 @@ def test_simulator_instructions_not_supported(circuit_noise):
)
with pytest.raises(TypeError, match=no_noise):
if isinstance(circuit_noise, JaqcdProgram):
simulator.run(circuit_noise, qubit_count=2, shots=0)
simulator.run(circuit_noise, shots=0)
else:
simulator.run(circuit_noise, shots=0)

Expand All @@ -765,7 +761,7 @@ def test_simulator_run_no_results_no_shots(bell_ir):
simulator = StateVectorSimulator()
with pytest.raises(ValueError):
if isinstance(bell_ir, JaqcdProgram):
simulator.run(bell_ir, qubit_count=2, shots=0)
simulator.run(bell_ir, shots=0)
else:
simulator.run(bell_ir, shots=0)

Expand All @@ -788,7 +784,7 @@ def test_simulator_run_amplitude_shots():
"""
)
with pytest.raises(ValueError):
simulator.run(jaqcd, qubit_count=2, shots=100)
simulator.run(jaqcd, shots=100)
with pytest.raises(ValueError):
simulator.run(qasm, shots=100)

Expand Down Expand Up @@ -838,7 +834,7 @@ def test_simulator_run_statevector_shots():
"""
)
with pytest.raises(ValueError):
simulator.run(jaqcd, qubit_count=2, shots=100)
simulator.run(jaqcd, shots=100)
with pytest.raises(ValueError):
simulator.run(qasm, shots=100)

Expand Down Expand Up @@ -871,7 +867,7 @@ def test_simulator_run_result_types_shots(caplog):
"""
)
shots_count = 100
jaqcd_result = simulator.run(jaqcd, qubit_count=2, shots=shots_count)
jaqcd_result = simulator.run(jaqcd, shots=shots_count)
qasm_result = simulator.run(qasm, shots=shots_count)
for result in jaqcd_result, qasm_result:
assert all([len(measurement) == 2] for measurement in result.measurements)
Expand Down Expand Up @@ -911,7 +907,7 @@ def test_simulator_run_result_types_shots_basis_rotation_gates(caplog):
"""
)
shots_count = 1000
jaqcd_result = simulator.run(jaqcd, qubit_count=2, shots=shots_count)
jaqcd_result = simulator.run(jaqcd, shots=shots_count)
qasm_result = simulator.run(qasm, shots=shots_count)
for result in jaqcd_result, qasm_result:
assert all([len(measurement) == 2] for measurement in result.measurements)
Expand Down Expand Up @@ -941,7 +937,7 @@ def test_simulator_run_result_types_shots_basis_rotation_gates_value_error():
)
)
shots_count = 1000
simulator.run(ir, qubit_count=2, shots=shots_count)
simulator.run(ir, shots=shots_count)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1031,7 +1027,7 @@ def test_simulator_run_observable_references_invalid_qubit(ir, qubit_count):
shots_count = 0
if isinstance(ir, JaqcdProgram):
with pytest.raises(ValueError):
simulator.run(ir, qubit_count=qubit_count, shots=shots_count)
simulator.run(ir, shots=shots_count)
else:
# index error since you're indexing from a logical qubit
with pytest.raises(IndexError):
Expand All @@ -1046,7 +1042,7 @@ def test_simulator_bell_pair_result_types(
simulator = StateVectorSimulator()
ir = bell_ir_with_result(targets)
if isinstance(ir, JaqcdProgram):
result = simulator.run(ir, qubit_count=2, shots=0, batch_size=batch_size)
result = simulator.run(ir, shots=0, batch_size=batch_size)
else:
result = simulator.run(ir, shots=0, batch_size=batch_size)
assert len(result.resultTypes) == 2
Expand Down Expand Up @@ -1082,7 +1078,7 @@ def test_simulator_fails_samples_0_shots():
"""
)
with pytest.raises(ValueError):
simulator.run(jaqcd, qubit_count=1, shots=0)
simulator.run(jaqcd, shots=0)
with pytest.raises(ValueError):
simulator.run(qasm, shots=0)

Expand Down Expand Up @@ -1161,7 +1157,7 @@ def test_simulator_valid_observables(result_types, expected):
}
)
)
result = simulator.run(prog, qubit_count=2, shots=0)
result = simulator.run(prog, shots=0)
for i in range(len(result_types)):
assert np.allclose(result.resultTypes[i].value, expected[i])

Expand Down Expand Up @@ -1482,7 +1478,7 @@ def test_simulator_analytic_value_type(jaqcd_string, oq3_pragma, jaqcd_type):
#pragma braket result {oq3_pragma}
"""
)
result = simulator.run(jaqcd, qubit_count=2, shots=0)
result = simulator.run(jaqcd, shots=0)
assert result.resultTypes[0].type == jaqcd_type
assert isinstance(result.resultTypes[0].value, np.ndarray)
result = simulator.run(qasm, shots=0)
Expand Down Expand Up @@ -1568,12 +1564,29 @@ def test_noncontiguous_qubits_jaqcd_multiple_targets():
"results": [{"type": "expectation", "observable": ["z"], "targets": [4]}],
}
prg = JaqcdProgram.parse_raw(json.dumps(jaqcd_program))
result = StateVectorSimulator().run(prg, qubit_count=2, shots=0)
result = StateVectorSimulator().run(prg, shots=0)

assert result.measuredQubits == [0, 1]
assert result.resultTypes[0].value == -1


def test_run_multiple_single_circuit():
payload = [
OpenQASMProgram(
source="""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
h q[0];
#pragma braket result state_vector
"""
)
]
simulator = StateVectorSimulator()
results = simulator.run_multiple(payload, shots=0)
assert np.allclose(results[0].resultTypes[0].value, np.array([1, 1]) / np.sqrt(2))


def test_run_multiple():
payloads = [
OpenQASMProgram(
Expand Down

0 comments on commit 38395a7

Please sign in to comment.