Skip to content

Commit

Permalink
change: Move conversion into its own function
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws committed May 6, 2024
1 parent 3803744 commit 7f837e0
Showing 1 changed file with 22 additions and 31 deletions.
53 changes: 22 additions & 31 deletions src/braket/simulator_v2/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@
from .julia_import import jl, jlBraketSimulator


def _analytic_result_value_to_ndarray(
task_result: GateModelTaskResult,
) -> GateModelTaskResult:
"""Convert any StateVector or DensityMatrix result values from raw Python lists to the expected
np.ndarray. This must be done because the wrapper Julia simulator results Python lists to comply
with the pydantic specification for ResultTypeValues.
"""
for result_ind, result_type in enumerate(task_result.resultTypes):
if isinstance(result_type.type, StateVector) or isinstance(
result_type.type, DensityMatrix
):
task_result.resultTypes[result_ind].value = np.asarray(
task_result.resultTypes[result_ind].value
)
return task_result


class StateVectorSimulatorV2(BaseLocalSimulator):
"""A state vector simulator meant to run directly on the user's machine using a Julia backend.
Expand Down Expand Up @@ -92,15 +109,7 @@ def run_jaqcd(
r = jl.simulate(self._device, [circuit_ir], qubit_count, shots)
r.additionalMetadata.action = circuit_ir
if not shots:
# need to convert `list` value for `statevector`
# and `densitymatrix` result types to `np.ndarray`
for result_ind, result_type in enumerate(r.resultTypes):
if isinstance(result_type.type, StateVector) or isinstance(
result_type.type, DensityMatrix
):
r.resultTypes[result_ind].value = np.asarray(
r.resultTypes[result_ind].value
)
r = _analytic_result_value_to_ndarray(r)
return r

def run_openqasm(
Expand Down Expand Up @@ -166,19 +175,11 @@ def run_openqasm(
self._device, [circuit], qubit_count, shots, measured_qubits=measured_qubits
)
r.additionalMetadata.action = openqasm_ir
# attach the result types
if shots:
# attach the result types
r.resultTypes = results
else:
# need to convert `list` value for `statevector`
# and `densitymatrix` result types to `np.ndarray`
for result_ind, result_type in enumerate(r.resultTypes):
if isinstance(result_type.type, StateVector) or isinstance(
result_type.type, DensityMatrix
):
r.resultTypes[result_ind].value = np.asarray(
r.resultTypes[result_ind].value
)
r = _analytic_result_value_to_ndarray(r)
return r

@property
Expand Down Expand Up @@ -497,12 +498,7 @@ def run_jaqcd(
r = jl.simulate(self._device, [circuit_ir], qubit_count, shots)
r.additionalMetadata.action = circuit_ir
if not shots:
# need to convert `list` value for `densitymatrix` result type to `np.ndarray`
for result_ind, result_type in enumerate(r.resultTypes):
if isinstance(result_type.type, DensityMatrix):
r.resultTypes[result_ind].value = np.asarray(
r.resultTypes[result_ind].value
)
r = _analytic_result_value_to_ndarray(r)
return r

def run_openqasm(
Expand Down Expand Up @@ -571,12 +567,7 @@ def run_openqasm(
if shots:
r.resultTypes = results
else:
# need to convert `list` value for `densitymatrix` result type to `np.ndarray`
for result_ind, result_type in enumerate(r.resultTypes):
if isinstance(result_type.type, DensityMatrix):
r.resultTypes[result_ind].value = np.asarray(
r.resultTypes[result_ind].value
)
r = _analytic_result_value_to_ndarray(r)
return r

@property
Expand Down

0 comments on commit 7f837e0

Please sign in to comment.