From e93e551fef7f2e6a0060da36f7ace6ee044e268a Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Thu, 27 Jun 2024 15:01:22 -0700 Subject: [PATCH] feat: Introduce `run_multiple` method (#264) Introduces `run_multiple` method to `BraketSimulator` to allow backends to leverage their own batching implementations. Will next publish SDK PR to make use of this interface. --- src/braket/default_simulator/simulator.py | 9 ++-- src/braket/simulator/braket_simulator.py | 41 ++++++++++++++++++- .../test_density_matrix_simulator.py | 24 ++++++++++- .../test_state_vector_simulator.py | 22 +++++++++- 4 files changed, 88 insertions(+), 8 deletions(-) diff --git a/src/braket/default_simulator/simulator.py b/src/braket/default_simulator/simulator.py index 7fad664b..f97cda99 100644 --- a/src/braket/default_simulator/simulator.py +++ b/src/braket/default_simulator/simulator.py @@ -129,7 +129,6 @@ def run( Args: circuit_ir (Union[OpenQASMProgram, JaqcdProgram]): Circuit specification. - qubit_count (int, jaqcd-only): Number of qubits. shots (int, optional): The number of shots to simulate. Default is 0, which performs a full analytical simulation. batch_size (int, optional): The size of the circuit partitions to contract, @@ -632,7 +631,7 @@ def run_openqasm( def run_jaqcd( self, circuit_ir: JaqcdProgram, - qubit_count: int, + qubit_count: Any = None, shots: int = 0, *, batch_size: int = 1, @@ -642,7 +641,7 @@ def run_jaqcd( Args: circuit_ir (Program): ir representation of a braket circuit specifying the instructions to execute. - qubit_count (int): Unused parameter; in signature for backwards-compatibility + qubit_count (Any): Unused parameter; in signature for backwards-compatibility shots (int): The number of times to run the circuit. batch_size (int): The size of the circuit partitions to contract, if applying multiple gates at a time is desired; see `StateVectorSimulation`. @@ -657,6 +656,10 @@ def run_jaqcd( as a result type when shots=0. Or, if StateVector and Amplitude result types are requested when shots>0. """ + if qubit_count is not None: + warnings.warn( + f"qubit_count is deprecated for {type(self).__name__} and can be set to None" + ) self._validate_ir_results_compatibility( circuit_ir.results, device_action_type=DeviceActionType.JAQCD, diff --git a/src/braket/simulator/braket_simulator.py b/src/braket/simulator/braket_simulator.py index d3a32227..19569181 100644 --- a/src/braket/simulator/braket_simulator.py +++ b/src/braket/simulator/braket_simulator.py @@ -12,7 +12,10 @@ # language governing permissions and limitations under the License. from abc import ABC, abstractmethod -from typing import Union +from collections.abc import Sequence +from multiprocessing import Pool +from os import cpu_count +from typing import Optional, Union from braket.device_schema import DeviceCapabilities from braket.ir.ahs import Program as AHSProgram @@ -49,7 +52,7 @@ def run( Run the task specified by the given IR. Extra arguments will contain any additional information necessary to run the task, - such as number of qubits. + such as the extra parameters for AHS simulations. Args: ir (Union[OQ3Program, AHSProgram, JaqcdProgram]): The IR representation of the program @@ -59,6 +62,40 @@ def run( representing the results of the simulation. """ + def run_multiple( + self, + programs: Sequence[Union[OQ3Program, AHSProgram, JaqcdProgram]], + max_parallel: Optional[int] = None, + *args, + **kwargs, + ) -> list[Union[GateModelTaskResult, AnalogHamiltonianSimulationTaskResult]]: + """ + Run the tasks specified by the given IR programs. + + Extra arguments will contain any additional information necessary to run the tasks, + such as the extra parameters for AHS simulations. + + Args: + programs (Sequence[Union[OQ3Program, AHSProgram, JaqcdProgram]]): The IR representations + of the programs + max_parallel (Optional[int]): The maximum number of programs to run in parallel. + Default is the number of logical CPUs. + + Returns: + list[Union[GateModelTaskResult, AnalogHamiltonianSimulationTaskResult]]: A list of + result objects, with the ith object being the result of the ith program. + """ + max_parallel = max_parallel or cpu_count() + with Pool(min(max_parallel, len(programs))) as pool: + param_list = [(program, args, kwargs) for program in programs] + results = pool.starmap(self._run_wrapped, param_list) + return results + + def _run_wrapped( + self, ir: Union[OQ3Program, AHSProgram, JaqcdProgram], args, kwargs + ): # pragma: no cover + return self.run(ir, *args, **kwargs) + @property @abstractmethod def properties(self) -> DeviceCapabilities: diff --git a/test/unit_tests/braket/default_simulator/test_density_matrix_simulator.py b/test/unit_tests/braket/default_simulator/test_density_matrix_simulator.py index cebfeda0..ded5ef66 100644 --- a/test/unit_tests/braket/default_simulator/test_density_matrix_simulator.py +++ b/test/unit_tests/braket/default_simulator/test_density_matrix_simulator.py @@ -133,7 +133,8 @@ def test_simulator_run_bell_pair(bell_ir, caplog): simulator = DensityMatrixSimulator() shots_count = 10000 if isinstance(bell_ir, JaqcdProgram): - result = simulator.run(bell_ir, qubit_count=2, shots=shots_count) + # Ignore qubit_count + result = simulator.run(bell_ir, shots=shots_count) else: result = simulator.run(bell_ir, shots=shots_count) @@ -392,7 +393,6 @@ def test_properties(): "deviceParameters": GateModelSimulatorDeviceParameters.schema(), } ) - print(expected_properties) assert simulator.properties == expected_properties @@ -864,3 +864,23 @@ def test_noncontiguous_qubits_openqasm(qasm_file_name): (np.allclose(measurement, [0, 0]) or np.allclose(measurement, [1, 1])) for measurement in measurements ) + + +def test_run_multiple(): + payloads = [ + OpenQASMProgram( + source=f""" + OPENQASM 3.0; + bit[1] b; + qubit[1] q; + {gate} q[0]; + #pragma braket result density_matrix + """ + ) + for gate in ["h", "z", "x"] + ] + simulator = DensityMatrixSimulator() + results = simulator.run_multiple(payloads, shots=0) + assert np.allclose(results[0].resultTypes[0].value, np.array([[0.5, 0.5], [0.5, 0.5]])) + assert np.allclose(results[1].resultTypes[0].value, np.array([[1, 0], [0, 0]])) + assert np.allclose(results[2].resultTypes[0].value, np.array([[0, 0], [0, 1]])) diff --git a/test/unit_tests/braket/default_simulator/test_state_vector_simulator.py b/test/unit_tests/braket/default_simulator/test_state_vector_simulator.py index b4207a04..dbea83fb 100644 --- a/test/unit_tests/braket/default_simulator/test_state_vector_simulator.py +++ b/test/unit_tests/braket/default_simulator/test_state_vector_simulator.py @@ -102,7 +102,7 @@ def test_simulator_run_bell_pair(bell_ir, batch_size, caplog): shots_count = 10000 if isinstance(bell_ir, JaqcdProgram): # Ignore qubit_count - result = simulator.run(bell_ir, qubit_count=10, shots=shots_count, batch_size=batch_size) + result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size) else: result = simulator.run(bell_ir, shots=shots_count, batch_size=batch_size) @@ -1425,3 +1425,23 @@ def test_noncontiguous_qubits_jaqcd_multiple_targets(): assert result.measuredQubits == [0, 1] assert result.resultTypes[0].value == -1 + + +def test_run_multiple(): + payloads = [ + OpenQASMProgram( + source=f""" + OPENQASM 3.0; + bit[1] b; + qubit[1] q; + {gate} q[0]; + #pragma braket result state_vector + """ + ) + for gate in ["h", "z", "x"] + ] + simulator = StateVectorSimulator() + results = simulator.run_multiple(payloads, shots=0) + assert np.allclose(results[0].resultTypes[0].value, np.array([1, 1]) / np.sqrt(2)) + assert np.allclose(results[1].resultTypes[0].value, np.array([1, 0])) + assert np.allclose(results[2].resultTypes[0].value, np.array([0, 1]))