From fd597d6d1fe2f38dc14e7e2fbf6a72657cf0f3b5 Mon Sep 17 00:00:00 2001 From: Ryan Shaffer <3620100+rmshaffer@users.noreply.github.com> Date: Mon, 23 Sep 2024 14:43:07 -0400 Subject: [PATCH] fix: Pass through inputs for SerializableProgram simulation (#1033) --- src/braket/devices/local_simulator.py | 5 +++-- .../braket/devices/test_local_simulator.py | 21 +++++++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/braket/devices/local_simulator.py b/src/braket/devices/local_simulator.py index 02b379596..f9e70a1ea 100644 --- a/src/braket/devices/local_simulator.py +++ b/src/braket/devices/local_simulator.py @@ -290,8 +290,9 @@ def _(self, program: OpenQASMProgram, inputs: Optional[dict[str, float]], _shots return program @_construct_payload.register - def _(self, program: SerializableProgram, _inputs, _shots): - return OpenQASMProgram(source=program.to_ir(ir_type=IRType.OPENQASM)) + def _(self, program: SerializableProgram, inputs: Optional[dict[str, float]], _shots): + inputs_copy = inputs.copy() if inputs is not None else {} + return OpenQASMProgram(source=program.to_ir(ir_type=IRType.OPENQASM), inputs=inputs_copy) @_construct_payload.register def _(self, program: AnalogHamiltonianSimulation, _inputs, _shots): diff --git a/test/unit_tests/braket/devices/test_local_simulator.py b/test/unit_tests/braket/devices/test_local_simulator.py index 451553f02..6c2976812 100644 --- a/test/unit_tests/braket/devices/test_local_simulator.py +++ b/test/unit_tests/braket/devices/test_local_simulator.py @@ -588,10 +588,8 @@ def test_run_serializable_program_model(): source=""" qubit[2] q; bit[2] c; - h q[0]; cnot q[0], q[1]; - c = measure q; """ ) @@ -599,6 +597,25 @@ def test_run_serializable_program_model(): assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT) +def test_run_serializable_program_model_with_inputs(): + dummy = DummySerializableProgramSimulator() + sim = LocalSimulator(dummy) + task = sim.run( + DummySerializableProgram( + source=""" +input float a; +qubit[2] q; +bit[2] c; +h q[0]; +cnot q[0], q[1]; +c = measure q; +""" + ), + inputs={"a": 0.1}, + ) + assert task.result() == GateModelQuantumTaskResult.from_object(GATE_MODEL_RESULT) + + @pytest.mark.xfail(raises=ValueError) def test_run_gate_model_value_error(): dummy = DummyCircuitSimulator()