Skip to content

Commit

Permalink
fix: Force StateVector and DensityMatrix values to be ndarrays and test
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws committed May 6, 2024
1 parent aec3305 commit 3803744
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
37 changes: 37 additions & 0 deletions src/braket/simulator_v2/simulator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys

import numpy as np
from braket.default_simulator.operation_helpers import from_braket_instruction
from braket.default_simulator.result_types import TargetedResultType
from braket.default_simulator.simulator import BaseLocalSimulator
Expand All @@ -8,7 +9,9 @@
GateModelSimulatorDeviceCapabilities,
GateModelSimulatorDeviceParameters,
)
from braket.ir.jaqcd import DensityMatrix
from braket.ir.jaqcd import Program as JaqcdProgram
from braket.ir.jaqcd import StateVector
from braket.ir.openqasm import Program as OpenQASMProgram
from braket.task_result import GateModelTaskResult

Expand Down Expand Up @@ -88,6 +91,16 @@ def run_jaqcd(
)
r = jl.simulate(self._device, [circuit_ir], qubit_count, shots)
r.additionalMetadata.action = circuit_ir
if not shots:
# need to convert `list` value for `statevector`
# and `densitymatrix` result types to `np.ndarray`
for result_ind, result_type in enumerate(r.resultTypes):
if isinstance(result_type.type, StateVector) or isinstance(
result_type.type, DensityMatrix
):
r.resultTypes[result_ind].value = np.asarray(
r.resultTypes[result_ind].value
)
return r

def run_openqasm(
Expand Down Expand Up @@ -156,6 +169,16 @@ def run_openqasm(
# attach the result types
if shots:
r.resultTypes = results
else:
# need to convert `list` value for `statevector`
# and `densitymatrix` result types to `np.ndarray`
for result_ind, result_type in enumerate(r.resultTypes):
if isinstance(result_type.type, StateVector) or isinstance(
result_type.type, DensityMatrix
):
r.resultTypes[result_ind].value = np.asarray(
r.resultTypes[result_ind].value
)
return r

@property
Expand Down Expand Up @@ -473,6 +496,13 @@ def run_jaqcd(
)
r = jl.simulate(self._device, [circuit_ir], qubit_count, shots)
r.additionalMetadata.action = circuit_ir
if not shots:
# need to convert `list` value for `densitymatrix` result type to `np.ndarray`
for result_ind, result_type in enumerate(r.resultTypes):
if isinstance(result_type.type, DensityMatrix):
r.resultTypes[result_ind].value = np.asarray(
r.resultTypes[result_ind].value
)
return r

def run_openqasm(
Expand Down Expand Up @@ -540,6 +570,13 @@ def run_openqasm(
# attach the result types
if shots:
r.resultTypes = results
else:
# need to convert `list` value for `densitymatrix` result type to `np.ndarray`
for result_ind, result_type in enumerate(r.resultTypes):
if isinstance(result_type.type, DensityMatrix):
r.resultTypes[result_ind].value = np.asarray(
r.resultTypes[result_ind].value
)
return r

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
GateModelSimulatorDeviceCapabilities,
GateModelSimulatorDeviceParameters,
)
from braket.ir.jaqcd import Expectation
from braket.ir.jaqcd import DensityMatrix, Expectation
from braket.ir.jaqcd import Program as JaqcdProgram
from braket.ir.openqasm import Program as OpenQASMProgram
from braket.task_result import AdditionalMetadata, TaskMetadata
Expand Down Expand Up @@ -846,3 +846,34 @@ def test_measure_targets():
assert 400 < np.sum(measurements, axis=0)[0] < 600
assert len(measurements[0]) == 1
assert result.measuredQubits == [0]


@pytest.mark.parametrize(
"jaqcd_string, oq3_pragma, jaqcd_type",
[
["densitymatrix", "density_matrix", DensityMatrix()],
],
)
def test_simulator_analytic_value_type(jaqcd_string, oq3_pragma, jaqcd_type):
simulator = DensityMatrixSimulator()
jaqcd = JaqcdProgram.parse_raw(
json.dumps(
{
"instructions": [{"type": "h", "target": 0}],
"results": [{"type": jaqcd_string}],
}
)
)
qasm = OpenQASMProgram(
source=f"""
qubit q;
h q;
#pragma braket result {oq3_pragma}
"""
)
result = simulator.run(jaqcd, qubit_count=2, shots=0)
assert result.resultTypes[0].type == jaqcd_type
assert isinstance(result.resultTypes[0].value, np.ndarray)
result = simulator.run(qasm, shots=0)
assert result.resultTypes[0].type == jaqcd_type
assert isinstance(result.resultTypes[0].value, np.ndarray)
32 changes: 32 additions & 0 deletions test/unit_tests/braket/simulator_v2/test_state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,3 +1410,35 @@ def test_rotation_parameter_expressions(operation, state_vector):
result = simulator.run(OpenQASMProgram(source=qasm), shots=0)
assert result.resultTypes[0].type == StateVector()
assert np.allclose(result.resultTypes[0].value, np.array(state_vector))


@pytest.mark.parametrize(
"jaqcd_string, oq3_pragma, jaqcd_type",
[
["statevector", "state_vector", StateVector()],
["densitymatrix", "density_matrix", DensityMatrix()],
],
)
def test_simulator_analytic_value_type(jaqcd_string, oq3_pragma, jaqcd_type):
simulator = StateVectorSimulator()
jaqcd = JaqcdProgram.parse_raw(
json.dumps(
{
"instructions": [{"type": "h", "target": 0}],
"results": [{"type": jaqcd_string}],
}
)
)
qasm = OpenQASMProgram(
source=f"""
qubit q;
h q;
#pragma braket result {oq3_pragma}
"""
)
result = simulator.run(jaqcd, qubit_count=2, shots=0)
assert result.resultTypes[0].type == jaqcd_type
assert isinstance(result.resultTypes[0].value, np.ndarray)
result = simulator.run(qasm, shots=0)
assert result.resultTypes[0].type == jaqcd_type
assert isinstance(result.resultTypes[0].value, np.ndarray)

0 comments on commit 3803744

Please sign in to comment.