From 84d5fd29a4fb56fa155d2d80df497fd34f416a9c Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Mon, 8 Jul 2024 11:38:54 -0700 Subject: [PATCH 1/2] change: adding check for main thread --- src/braket/simulator_v2/base_simulator_v2.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/braket/simulator_v2/base_simulator_v2.py b/src/braket/simulator_v2/base_simulator_v2.py index 213f229..63afd9f 100644 --- a/src/braket/simulator_v2/base_simulator_v2.py +++ b/src/braket/simulator_v2/base_simulator_v2.py @@ -1,5 +1,6 @@ import sys import warnings +import threading from collections.abc import Sequence from typing import Any, Optional, Union @@ -92,6 +93,7 @@ def run_jaqcd( as a result type when shots=0. Or, if StateVector and Amplitude result types are requested when shots>0. """ + _validate_thread() if qubit_count is not None: warnings.warn( f"qubit_count is deprecated for {type(self).__name__} and can be set to None" @@ -155,6 +157,7 @@ def run_openqasm( as a result type when shots=0. Or, if StateVector and Amplitude result types are requested when shots>0. """ + _validate_thread() try: r = jl.simulate(self._device, self._openqasm_to_jl(openqasm_ir), shots) except JuliaError as e: @@ -200,6 +203,7 @@ def run_multiple( list[GateModelTaskResult]: A list of result objects, with the ith object being the result of the ith program. """ + _validate_thread() try: results = jl.simulate( self._device, @@ -245,6 +249,14 @@ def _validate_jaqcd(self, circuit_ir, qubit_count: int, shots: int): ) +def _validate_thread(): + if threading.current_thread() is not threading.main_thread(): + raise RuntimeError( + "Simulations must be run from the Main thread. " + "For multiple simulations, please use run_batch() instead." + ) + + def _result_value_to_ndarray( task_result: GateModelTaskResult, ) -> GateModelTaskResult: From 9a0d640d03776409ebb0dfecf61e60d5425586d2 Mon Sep 17 00:00:00 2001 From: Milan Krneta Date: Tue, 9 Jul 2024 13:36:23 -0700 Subject: [PATCH 2/2] adding tests --- src/braket/simulator_v2/base_simulator_v2.py | 2 +- .../simulator_v2/test_density_matrix_simulator_v2.py | 11 +++++++++++ .../simulator_v2/test_state_vector_simulator_v2.py | 11 +++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/braket/simulator_v2/base_simulator_v2.py b/src/braket/simulator_v2/base_simulator_v2.py index 63afd9f..8b876fe 100644 --- a/src/braket/simulator_v2/base_simulator_v2.py +++ b/src/braket/simulator_v2/base_simulator_v2.py @@ -1,6 +1,6 @@ import sys -import warnings import threading +import warnings from collections.abc import Sequence from typing import Any, Optional, Union diff --git a/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator_v2.py b/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator_v2.py index f823965..6c56ee0 100644 --- a/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator_v2.py +++ b/test/unit_tests/braket/simulator_v2/test_density_matrix_simulator_v2.py @@ -15,6 +15,7 @@ import json import sys from collections import Counter, namedtuple +from unittest.mock import patch import numpy as np import pytest @@ -930,3 +931,13 @@ def test_kraus_noise(): result = device.run(program) probabilities = result.resultTypes[0].value assert np.allclose(probabilities, [0.18, 0, 0.82, 0]) + + +@patch("braket.simulator_v2.base_simulator_v2.threading") +def test_threading(mock_threading): + program = OpenQASMProgram(source="""OPENQASM 3.0;""") + simulator = DensityMatrixSimulator() + with pytest.raises( + RuntimeError, match="Simulations must be run from the Main thread.*" + ): + simulator.run(program, shots=0) diff --git a/test/unit_tests/braket/simulator_v2/test_state_vector_simulator_v2.py b/test/unit_tests/braket/simulator_v2/test_state_vector_simulator_v2.py index 6afd193..ee2735d 100644 --- a/test/unit_tests/braket/simulator_v2/test_state_vector_simulator_v2.py +++ b/test/unit_tests/braket/simulator_v2/test_state_vector_simulator_v2.py @@ -18,6 +18,7 @@ # import re import sys from collections import Counter, namedtuple +from unittest.mock import patch import numpy as np import pytest @@ -1605,3 +1606,13 @@ def test_run_multiple(): assert np.allclose(results[0].resultTypes[0].value, np.array([1, 1]) / np.sqrt(2)) assert np.allclose(results[1].resultTypes[0].value, np.array([1, 0])) assert np.allclose(results[2].resultTypes[0].value, np.array([0, 1])) + + +@patch("braket.simulator_v2.base_simulator_v2.threading") +def test_threading(mock_threading): + program = OpenQASMProgram(source="""OPENQASM 3.0;""") + simulator = StateVectorSimulator() + with pytest.raises( + RuntimeError, match="Simulations must be run from the Main thread.*" + ): + simulator.run(program, shots=0)