Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: StateVector shouldn't be a supported pragma for DM simulator #25

Merged
merged 8 commits into from
Jul 6, 2024
Merged
7 changes: 3 additions & 4 deletions src/braket/simulator_v2/base_simulator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Sequence
from typing import Any, Optional, Union

import juliacall
import numpy as np
from braket.default_simulator.result_types import TargetedResultType
from braket.default_simulator.simulator import BaseLocalSimulator
Expand Down Expand Up @@ -100,7 +99,7 @@ def run_jaqcd(
translated_ir, qubit_count = self._jaqcd_to_jl(circuit_ir, shots)
try:
r = jl.simulate(self._device, translated_ir, qubit_count, shots)
except juliacall.JuliaError as e:
except JuliaError as e:
_handle_julia_error(e)
r.additionalMetadata.action = circuit_ir
r = _result_value_to_ndarray(r)
Expand Down Expand Up @@ -158,7 +157,7 @@ def run_openqasm(
"""
try:
r = jl.simulate(self._device, self._openqasm_to_jl(openqasm_ir), shots)
except juliacall.JuliaError as e:
except JuliaError as e:
_handle_julia_error(e)
r.additionalMetadata.action = openqasm_ir
# attach the result types
Expand Down Expand Up @@ -209,7 +208,7 @@ def run_multiple(
shots=shots,
inputs=inputs,
)
except juliacall.JuliaError as e:
except JuliaError as e:
_handle_julia_error(e)

for r_ix, result in enumerate(results):
Expand Down
1 change: 0 additions & 1 deletion src/braket/simulator_v2/density_matrix_simulator_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def properties(self) -> GateModelSimulatorDeviceCapabilities:
],
"supportedPragmas": [
"braket_unitary_matrix",
"braket_result_type_state_vector",
"braket_result_type_density_matrix",
"braket_result_type_sample",
"braket_result_type_expectation",
Expand Down
50 changes: 21 additions & 29 deletions src/braket/simulator_v2/julia_import.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,34 @@
import os
import sys
import warnings

import juliacall

# Check if JuliaCall is already loaded, and if so, warn the user
# about the relevant environment variables. If not loaded,
# set up sensible defaults.
if "juliacall" in sys.modules:
# Required to avoid segfaults (https://juliapy.github.io/PythonCall.jl/dev/faq/)
if os.environ.get("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes") != "yes":
warnings.warn(
"`juliacall` module has already been imported. "
+ "Make sure that you have set the environment variable "
+ "`PYTHON_JULIACALL_HANDLE_SIGNALS=yes` to avoid segfaults. "
"`PYTHON_JULIACALL_HANDLE_SIGNALS` environment variable "
+ "is set to something other than 'yes' or ''. "
+ "You will experience segfaults if running with Julia multithreading."
)
else:
# Required to avoid segfaults (https://juliapy.github.io/PythonCall.jl/dev/faq/)
if os.environ.get("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes") != "yes":
warnings.warn(
"`PYTHON_JULIACALL_HANDLE_SIGNALS` environment variable "
+ "is set to something other than 'yes' or ''. "
+ "You will experience segfaults if running with Julia multithreading."
)

if os.environ.get("PYTHON_JULIACALL_THREADS", "auto") != "auto":
warnings.warn(
"`PYTHON_JULIACALL_THREADS` environment variable is set to "
+ "something other than `auto`, so `amazon-braket-simulator-v2` "
+ "was not able to set it."
)

for k, default in (
("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"),
("PYTHON_JULIACALL_THREADS", "auto"),
("PYTHON_JULIACALL_OPTLEVEL", "3"),
# let the user's Conda/Pip handle installing things
("JULIA_CONDAPKG_BACKEND", "Null"),
):
os.environ[k] = os.environ.get(k, default)
if os.environ.get("PYTHON_JULIACALL_THREADS", "auto") != "auto":
warnings.warn(
"`PYTHON_JULIACALL_THREADS` environment variable is set to "
+ "something other than `auto`, so `amazon-braket-simulator-v2` "
+ "was not able to set it."
)

import juliacall
for k, default in (
("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"),
("PYTHON_JULIACALL_THREADS", "auto"),
("PYTHON_JULIACALL_OPTLEVEL", "3"),
# let the user's Conda/Pip handle installing things
("JULIA_CONDAPKG_BACKEND", "Null"),
):
os.environ[k] = os.environ.get(k, default)

jl = juliacall.Base.Module()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def test_properties():
],
"supportedPragmas": [
"braket_unitary_matrix",
"braket_result_type_state_vector",
"braket_result_type_density_matrix",
"braket_result_type_sample",
"braket_result_type_expectation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def test_simulator_run_grcs_16(grcs_16_qubit, batch_size):
if isinstance(grcs_16_qubit.circuit_ir, JaqcdProgram):
result = simulator.run(
grcs_16_qubit.circuit_ir,
qubit_count=16,
shots=0,
batch_size=batch_size,
)
Expand All @@ -102,9 +101,7 @@ def test_simulator_run_bell_pair(bell_ir, batch_size, caplog):
simulator = StateVectorSimulator()
shots_count = 10000
if isinstance(bell_ir, JaqcdProgram):
result = simulator.run(
bell_ir, qubit_count=2, shots=shots_count, batch_size=batch_size
)
result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size)
else:
result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size)

Expand Down Expand Up @@ -729,7 +726,6 @@ def test_simulator_identity(caplog):
if isinstance(program, JaqcdProgram):
result = simulator.run(
program,
qubit_count=2,
shots=shots_count,
)
else:
Expand All @@ -756,7 +752,7 @@ def test_simulator_instructions_not_supported(circuit_noise):
)
with pytest.raises(TypeError, match=no_noise):
if isinstance(circuit_noise, JaqcdProgram):
simulator.run(circuit_noise, qubit_count=2, shots=0)
simulator.run(circuit_noise, shots=0)
else:
simulator.run(circuit_noise, shots=0)

Expand All @@ -765,7 +761,7 @@ def test_simulator_run_no_results_no_shots(bell_ir):
simulator = StateVectorSimulator()
with pytest.raises(ValueError):
if isinstance(bell_ir, JaqcdProgram):
simulator.run(bell_ir, qubit_count=2, shots=0)
simulator.run(bell_ir, shots=0)
else:
simulator.run(bell_ir, shots=0)

Expand All @@ -788,7 +784,7 @@ def test_simulator_run_amplitude_shots():
"""
)
with pytest.raises(ValueError):
simulator.run(jaqcd, qubit_count=2, shots=100)
simulator.run(jaqcd, shots=100)
with pytest.raises(ValueError):
simulator.run(qasm, shots=100)

Expand Down Expand Up @@ -838,7 +834,7 @@ def test_simulator_run_statevector_shots():
"""
)
with pytest.raises(ValueError):
simulator.run(jaqcd, qubit_count=2, shots=100)
simulator.run(jaqcd, shots=100)
with pytest.raises(ValueError):
simulator.run(qasm, shots=100)

Expand Down Expand Up @@ -871,7 +867,7 @@ def test_simulator_run_result_types_shots(caplog):
"""
)
shots_count = 100
jaqcd_result = simulator.run(jaqcd, qubit_count=2, shots=shots_count)
jaqcd_result = simulator.run(jaqcd, shots=shots_count)
qasm_result = simulator.run(qasm, shots=shots_count)
for result in jaqcd_result, qasm_result:
assert all([len(measurement) == 2] for measurement in result.measurements)
Expand Down Expand Up @@ -911,7 +907,7 @@ def test_simulator_run_result_types_shots_basis_rotation_gates(caplog):
"""
)
shots_count = 1000
jaqcd_result = simulator.run(jaqcd, qubit_count=2, shots=shots_count)
jaqcd_result = simulator.run(jaqcd, shots=shots_count)
qasm_result = simulator.run(qasm, shots=shots_count)
for result in jaqcd_result, qasm_result:
assert all([len(measurement) == 2] for measurement in result.measurements)
Expand Down Expand Up @@ -941,7 +937,7 @@ def test_simulator_run_result_types_shots_basis_rotation_gates_value_error():
)
)
shots_count = 1000
simulator.run(ir, qubit_count=2, shots=shots_count)
simulator.run(ir, shots=shots_count)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1031,7 +1027,7 @@ def test_simulator_run_observable_references_invalid_qubit(ir, qubit_count):
shots_count = 0
if isinstance(ir, JaqcdProgram):
with pytest.raises(ValueError):
simulator.run(ir, qubit_count=qubit_count, shots=shots_count)
simulator.run(ir, shots=shots_count)
else:
# index error since you're indexing from a logical qubit
with pytest.raises(IndexError):
Expand All @@ -1046,7 +1042,7 @@ def test_simulator_bell_pair_result_types(
simulator = StateVectorSimulator()
ir = bell_ir_with_result(targets)
if isinstance(ir, JaqcdProgram):
result = simulator.run(ir, qubit_count=2, shots=0, batch_size=batch_size)
result = simulator.run(ir, shots=0, batch_size=batch_size)
else:
result = simulator.run(ir, shots=0, batch_size=batch_size)
assert len(result.resultTypes) == 2
Expand Down Expand Up @@ -1082,7 +1078,7 @@ def test_simulator_fails_samples_0_shots():
"""
)
with pytest.raises(ValueError):
simulator.run(jaqcd, qubit_count=1, shots=0)
simulator.run(jaqcd, shots=0)
with pytest.raises(ValueError):
simulator.run(qasm, shots=0)

Expand Down Expand Up @@ -1161,7 +1157,7 @@ def test_simulator_valid_observables(result_types, expected):
}
)
)
result = simulator.run(prog, qubit_count=2, shots=0)
result = simulator.run(prog, shots=0)
for i in range(len(result_types)):
assert np.allclose(result.resultTypes[i].value, expected[i])

Expand Down Expand Up @@ -1482,7 +1478,7 @@ def test_simulator_analytic_value_type(jaqcd_string, oq3_pragma, jaqcd_type):
#pragma braket result {oq3_pragma}
"""
)
result = simulator.run(jaqcd, qubit_count=2, shots=0)
result = simulator.run(jaqcd, shots=0)
assert result.resultTypes[0].type == jaqcd_type
assert isinstance(result.resultTypes[0].value, np.ndarray)
result = simulator.run(qasm, shots=0)
Expand Down Expand Up @@ -1568,12 +1564,29 @@ def test_noncontiguous_qubits_jaqcd_multiple_targets():
"results": [{"type": "expectation", "observable": ["z"], "targets": [4]}],
}
prg = JaqcdProgram.parse_raw(json.dumps(jaqcd_program))
result = StateVectorSimulator().run(prg, qubit_count=2, shots=0)
result = StateVectorSimulator().run(prg, shots=0)

assert result.measuredQubits == [0, 1]
assert result.resultTypes[0].value == -1


def test_run_multiple_single_circuit():
payload = [
OpenQASMProgram(
source="""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
h q[0];
#pragma braket result state_vector
"""
)
]
simulator = StateVectorSimulator()
results = simulator.run_multiple(payload, shots=0)
assert np.allclose(results[0].resultTypes[0].value, np.array([1, 1]) / np.sqrt(2))


def test_run_multiple():
payloads = [
OpenQASMProgram(
Expand Down
Loading