Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Introduce run_multiple method #264

Merged
merged 19 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion src/braket/simulator/braket_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# language governing permissions and limitations under the License.

from abc import ABC, abstractmethod
from typing import Union
from collections.abc import Mapping, Sequence
from multiprocessing import Pool
from os import cpu_count
from typing import Any, Optional, Union

from braket.device_schema import DeviceCapabilities
from braket.ir.ahs import Program as AHSProgram
Expand Down Expand Up @@ -59,6 +62,56 @@ def run(
representing the results of the simulation.
"""

def run_multiple(
self,
payloads: Sequence[Union[OQ3Program, AHSProgram, JaqcdProgram]],
speller26 marked this conversation as resolved.
Show resolved Hide resolved
args: Optional[Sequence[Sequence[Any]]] = None,
kwargs: Optional[Sequence[Mapping[str, Any]]] = None,
max_parallel: Optional[int] = None,
) -> list[Union[GateModelTaskResult, AnalogHamiltonianSimulationTaskResult]]:
"""
Run the tasks specified by the given IR payloads.

Extra arguments will contain any additional information necessary to run the tasks,
such as number of shots.

Args:
payloads (Sequence[Union[OQ3Program, AHSProgram, JaqcdProgram]]): The IR representations
of the programs
args (Optional[Sequence[Sequence[Any]]]): The positional args to include with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I only submit one of args and kwargs? Is the intent that specifying shots in kwargs should override any value in args?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can specify both, just as you can technically

device.run(circuit, 1000, shots=1000)

it'll just error out with something like

LocalSimulator.run() got multiple values for argument 'shots'

each payload; the nth entry of this sequence corresponds to the nth payload.
If specified, the length of args must be equal to the length of payloads.
Default: None.
kwargs (Optional[Sequence[Mapping[str, Any]]]): The keyword args to include with
each payload; the nth entry of this sequence corresponds to the nth payload.
If specified, the length of kwargs must be equal to the length of payloads.
Default: None.
max_parallel (Optional[int]): The maximum number of payloads to run in parallel.
Default is the number of CPUs.
speller26 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
list[Union[GateModelTaskResult, AnalogHamiltonianSimulationTaskResult]]: A list of
result objects, with the ith object being the result of the ith program.
"""
max_parallel = max_parallel or cpu_count()
math411 marked this conversation as resolved.
Show resolved Hide resolved
if args and len(args) != len(payloads):
raise ValueError("The number of arguments must equal the number of payloads.")
if kwargs and len(kwargs) != len(payloads):
raise ValueError("The number of keyword arguments must equal the number of payloads.")
get_nth_args = (lambda n: args[n]) if args else lambda _: []
get_nth_kwargs = (lambda n: kwargs[n]) if kwargs else lambda _: {}
with Pool(min(max_parallel, len(payloads))) as pool:
results = pool.starmap(
self._run_wrapped,
[(payloads[i], get_nth_args(i), get_nth_kwargs(i)) for i in range(len(payloads))],
)
return results

def _run_wrapped(
self, ir: Union[OQ3Program, AHSProgram, JaqcdProgram], args, kwargs
): # pragma: no cover
return self.run(ir, *args, **kwargs)

@property
@abstractmethod
def properties(self) -> DeviceCapabilities:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,69 @@ def test_discontiguous_qubits_openqasm(discontiguous_qasm):
measurements = np.array(result.measurements, dtype=int)
assert len(measurements[0]) == 5
assert result.measuredQubits == [0, 1, 2, 3, 4]


def test_run_multiple():
payloads = [
OpenQASMProgram(
source=f"""
OPENQASM 3.0;
bit[2] b;
qubit[2] q;
{gates[0]} q[0];
{gates[1]} q[1];
b = measure q;
"""
)
for gates in [("x", "z"), ("z", "x"), ("x", "x")]
]
args = [[2], [5], [10]]
kwargs = [{"shots": 3}, {"shots": 6}, {"shots": 9}]
expected_measurements = [[1, 0], [0, 1], [1, 1]]
simulator = DensityMatrixSimulator()
for result, payload_args, expected in zip(
simulator.run_multiple(payloads, args=args), args, expected_measurements
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't you missing a kwargs=kwargs here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see what's going on. IMO this is pretty confusing, to allow shots to be in either

):
measurements = np.array(result.measurements, dtype=int)
print(measurements)
speller26 marked this conversation as resolved.
Show resolved Hide resolved
print(expected)
assert len(measurements) == payload_args[0]
assert all(np.all(expected == actual) for actual in measurements)
for result, payload_kwargs, expected in zip(
simulator.run_multiple(payloads, kwargs=kwargs), kwargs, expected_measurements
):
measurements = np.array(result.measurements, dtype=int)
assert len(measurements) == payload_kwargs["shots"]
assert all(np.all(expected == actual) for actual in measurements)


def test_run_multiple_wrong_num_args():
payload = OpenQASMProgram(
source="""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
h q[0];
b = measure q;
"""
)
args = [[2], [5], [10], [15]]
simulator = DensityMatrixSimulator()
with pytest.raises(ValueError):
simulator.run_multiple([payload] * (len(args) - 1), args=args)


def test_run_multiple_wrong_num_kwargs():
payload = OpenQASMProgram(
source="""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
h q[0];
b = measure q;
"""
)
kwargs = [{"shots": 3}, {"shots": 6}]
simulator = DensityMatrixSimulator()
with pytest.raises(ValueError):
simulator.run_multiple([payload] * (len(kwargs) + 1), kwargs=kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -1420,3 +1420,69 @@ def test_discontiguous_qubits_jaqcd_multiple_targets():

assert result.measuredQubits == [0, 1]
assert result.resultTypes[0].value == -1


def test_run_multiple():
payloads = [
OpenQASMProgram(
source=f"""
OPENQASM 3.0;
bit[2] b;
qubit[2] q;
{gates[0]} q[0];
{gates[1]} q[1];
b = measure q;
"""
)
for gates in [("x", "z"), ("z", "x"), ("x", "x")]
]
args = [[2], [5], [10]]
kwargs = [{"shots": 3}, {"shots": 6}, {"shots": 9}]
expected_measurements = [[1, 0], [0, 1], [1, 1]]
simulator = StateVectorSimulator()
for result, payload_args, expected in zip(
simulator.run_multiple(payloads, args=args), args, expected_measurements
):
measurements = np.array(result.measurements, dtype=int)
print(measurements)
speller26 marked this conversation as resolved.
Show resolved Hide resolved
print(expected)
assert len(measurements) == payload_args[0]
assert all(np.all(expected == actual) for actual in measurements)
for result, payload_kwargs, expected in zip(
simulator.run_multiple(payloads, kwargs=kwargs), kwargs, expected_measurements
):
measurements = np.array(result.measurements, dtype=int)
assert len(measurements) == payload_kwargs["shots"]
assert all(np.all(expected == actual) for actual in measurements)


def test_run_multiple_wrong_num_args():
payload = OpenQASMProgram(
source="""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
h q[0];
b = measure q;
"""
)
args = [[2], [5], [10], [15]]
simulator = StateVectorSimulator()
with pytest.raises(ValueError):
simulator.run_multiple([payload] * (len(args) - 1), args=args)


def test_run_multiple_wrong_num_kwargs():
payload = OpenQASMProgram(
source="""
OPENQASM 3.0;
bit[1] b;
qubit[1] q;
h q[0];
b = measure q;
"""
)
kwargs = [{"shots": 3}, {"shots": 6}]
simulator = StateVectorSimulator()
with pytest.raises(ValueError):
simulator.run_multiple([payload] * (len(kwargs) + 1), kwargs=kwargs)
Loading