Skip to content

Commit

Permalink
preserve measuredQubits
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 committed Jun 26, 2024
1 parent ff309b8 commit b4500cd
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 21 deletions.
34 changes: 15 additions & 19 deletions src/braket/default_simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def _create_results_obj(
openqasm_ir: OpenQASMProgram,
simulation: Simulation,
measured_qubits: list[int] = None,
mapped_measured_qubits: list[int] = None,
) -> GateModelTaskResult:
return GateModelTaskResult.construct(
taskMetadata=TaskMetadata(
Expand All @@ -267,10 +268,8 @@ def _create_results_obj(
action=openqasm_ir,
),
resultTypes=results,
measurements=self._formatted_measurements(simulation, measured_qubits),
measuredQubits=(
measured_qubits if measured_qubits else self._get_all_qubits(simulation.qubit_count)
),
measurements=self._formatted_measurements(simulation, mapped_measured_qubits),
measuredQubits=(measured_qubits or list(range(simulation.qubit_count))),
)

@staticmethod
Expand Down Expand Up @@ -348,10 +347,6 @@ def _validate_input_provided(self, circuit: Circuit) -> None:
missing_input = param.free_symbols.pop()
raise NameError(f"Missing input variable '{missing_input}'.")

@staticmethod
def _get_all_qubits(qubit_count: int) -> list[int]:
return list(range(qubit_count))

@staticmethod
def _tensor_product_index_dict(
observable: TensorProduct, func: Callable[[Observable], Any]
Expand Down Expand Up @@ -383,7 +378,7 @@ def _observable_hash(observable: Observable) -> Union[str, dict[int, str]]:
return str(observable.__class__.__name__)

@staticmethod
def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) -> Circuit:
def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) -> dict[int, int]:
"""
Maps the qubits in operations and result types to contiguous qubits.
Expand All @@ -392,13 +387,12 @@ def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) ->
result types.
Returns:
Circuit: The circuit with qubits in operations and result types mapped
to contiguous qubits.
dict[int, int]: Map of qubit index to corresponding contiguous index
"""
circuit_qubit_set = BaseLocalSimulator._get_circuit_qubit_set(circuit)
qubit_map = BaseLocalSimulator._contiguous_qubit_mapping(circuit_qubit_set)
BaseLocalSimulator._map_circuit_qubits(circuit, qubit_map)
return circuit
return qubit_map

@staticmethod
def _get_circuit_qubit_set(circuit: Union[Circuit, JaqcdProgram]) -> set[int]:
Expand Down Expand Up @@ -440,7 +434,6 @@ def _map_circuit_qubits(circuit: Union[Circuit, JaqcdProgram], qubit_map: dict[i
if isinstance(circuit, Circuit):
BaseLocalSimulator._map_circuit_instructions(circuit, qubit_map)
BaseLocalSimulator._map_circuit_results(circuit, qubit_map)
circuit.measured_qubits = [qubit_map[q] for q in circuit.measured_qubits]
else:
BaseLocalSimulator._map_jaqcd_instructions(circuit, qubit_map)
return circuit
Expand Down Expand Up @@ -587,11 +580,13 @@ def run_openqasm(
as a result type when shots=0. Or, if StateVector and Amplitude result types
are requested when shots>0.
"""
circuit = BaseLocalSimulator._map_circuit_to_contiguous_qubits(
self.parse_program(openqasm_ir).circuit
)
circuit = self.parse_program(openqasm_ir).circuit
qubit_map = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit)
qubit_count = circuit.num_qubits
measured_qubits = circuit.measured_qubits
mapped_measured_qubits = (
[qubit_map[q] for q in measured_qubits] if measured_qubits else None
)

self._validate_ir_results_compatibility(
circuit.results,
Expand Down Expand Up @@ -630,7 +625,9 @@ def run_openqasm(
else:
simulation.evolve(circuit.basis_rotation_instructions)

return self._create_results_obj(results, openqasm_ir, simulation, measured_qubits)
return self._create_results_obj(
results, openqasm_ir, simulation, measured_qubits, mapped_measured_qubits
)

def run_jaqcd(
self,
Expand Down Expand Up @@ -669,8 +666,7 @@ def run_jaqcd(
device_action_type=DeviceActionType.JAQCD,
)
BaseLocalSimulator._validate_shots_and_ir_results(shots, circuit_ir.results, qubit_count)

circuit_ir = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit_ir)
BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit_ir)

operations = [
from_braket_instruction(instruction) for instruction in circuit_ir.instructions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def test_noncontiguous_qubits_openqasm(qasm_file_name):
OpenQASMProgram(source=f"test/resources/{qasm_file_name}.qasm"), shots=shots
)

assert result.measuredQubits == [0, 1]
assert result.measuredQubits == [2, 8]
measurements = np.array(result.measurements, dtype=int)
assert measurements.shape == (shots, 2)
assert all(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ def test_noncontiguous_qubits_openqasm(qasm_file_name):
OpenQASMProgram(source=f"test/resources/{qasm_file_name}.qasm"), shots=shots
)

assert result.measuredQubits == [0, 1]
assert result.measuredQubits == [2, 8]
measurements = np.array(result.measurements, dtype=int)
assert measurements.shape == (shots, 2)
assert all(
Expand Down

0 comments on commit b4500cd

Please sign in to comment.