Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Include measured in noncontiguous qubit map #267

Merged
merged 6 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/braket/default_simulator/openqasm/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def add_measure(self, target: tuple[int], classical_targets: Iterable[int] = Non
if qubit in self.measured_qubits:
raise ValueError(f"Qubit {qubit} is already measured or captured.")
self.measured_qubits.append(qubit)
self.qubit_set.add(qubit)
self.target_classical_indices.append(
classical_targets[index]
if classical_targets
Expand Down
71 changes: 31 additions & 40 deletions src/braket/default_simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def _create_results_obj(
openqasm_ir: OpenQASMProgram,
simulation: Simulation,
measured_qubits: list[int] = None,
mapped_measured_qubits: list[int] = None,
) -> GateModelTaskResult:
return GateModelTaskResult.construct(
taskMetadata=TaskMetadata(
Expand All @@ -267,10 +268,8 @@ def _create_results_obj(
action=openqasm_ir,
),
resultTypes=results,
measurements=self._formatted_measurements(simulation, measured_qubits),
measuredQubits=(
measured_qubits if measured_qubits else self._get_all_qubits(simulation.qubit_count)
),
measurements=self._formatted_measurements(simulation, mapped_measured_qubits),
measuredQubits=(measured_qubits or list(range(simulation.qubit_count))),
)

@staticmethod
Expand Down Expand Up @@ -348,10 +347,6 @@ def _validate_input_provided(self, circuit: Circuit) -> None:
missing_input = param.free_symbols.pop()
raise NameError(f"Missing input variable '{missing_input}'.")

@staticmethod
def _get_all_qubits(qubit_count: int) -> list[int]:
return list(range(qubit_count))

@staticmethod
def _tensor_product_index_dict(
observable: TensorProduct, func: Callable[[Observable], Any]
Expand Down Expand Up @@ -383,7 +378,7 @@ def _observable_hash(observable: Observable) -> Union[str, dict[int, str]]:
return str(observable.__class__.__name__)

@staticmethod
def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) -> Circuit:
def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) -> dict[int, int]:
"""
Maps the qubits in operations and result types to contiguous qubits.

Expand All @@ -392,24 +387,23 @@ def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) ->
result types.

Returns:
Circuit: The circuit with qubits in operations and result types mapped
to contiguous qubits.
dict[int, int]: Map of qubit index to corresponding contiguous index
"""
circuit_qubit_set = BaseLocalSimulator._get_circuit_qubit_set(circuit)
qubit_map = BaseLocalSimulator._contiguous_qubit_mapping(circuit_qubit_set)
BaseLocalSimulator._map_instructions_to_qubits(circuit, qubit_map)
return circuit
BaseLocalSimulator._map_circuit_qubits(circuit, qubit_map)
return qubit_map

@staticmethod
def _get_circuit_qubit_set(circuit: Union[Circuit, JaqcdProgram]) -> set:
def _get_circuit_qubit_set(circuit: Union[Circuit, JaqcdProgram]) -> set[int]:
"""
Returns the set of qubits used in the given circuit.

Args:
circuit (Union[Circuit, JaqcdProgram]): The circuit from which to extract the qubit set.

Returns:
set: The set of qubits used in the circuit.
set[int]: The set of qubits used in the circuit.
"""
if isinstance(circuit, Circuit):
return circuit.qubit_set
Expand All @@ -425,12 +419,13 @@ def _get_circuit_qubit_set(circuit: Union[Circuit, JaqcdProgram]) -> set:
return BaseLocalSimulator._get_qubits_referenced(operations)

@staticmethod
def _map_instructions_to_qubits(circuit: Union[Circuit, JaqcdProgram], qubit_map: dict):
def _map_circuit_qubits(circuit: Union[Circuit, JaqcdProgram], qubit_map: dict[int, int]):
"""
Maps the qubits in operations and result types to contiguous qubits.

Args:
circuit (Circuit): The circuit containing the operations and result types.
qubit_map (dict[int, int]): The mapping from qubits to their contiguous indices.

Returns:
Circuit: The circuit with qubits in operations and result types mapped
Expand All @@ -441,7 +436,6 @@ def _map_instructions_to_qubits(circuit: Union[Circuit, JaqcdProgram], qubit_map
BaseLocalSimulator._map_circuit_results(circuit, qubit_map)
else:
BaseLocalSimulator._map_jaqcd_instructions(circuit, qubit_map)

return circuit

@staticmethod
Expand Down Expand Up @@ -514,13 +508,13 @@ def _map_instruction_attributes(instruction, qubit_map: dict):
instruction.targets = [qubit_map.get(q, q) for q in instruction.targets]

@staticmethod
def _contiguous_qubit_mapping(qubit_set: list[int]) -> dict[int, int]:
def _contiguous_qubit_mapping(qubit_set: set[int]) -> dict[int, int]:
"""
Maping of qubits to contiguous integers. The qubit mapping may be discontiguous or
contiguous.

Args:
qubit_set (list[int]): List of qubits to be mapped.
qubit_set (set[int]): List of qubits to be mapped.

Returns:
dict[int, int]: Dictionary where keys are qubits and values are contiguous integers.
Expand Down Expand Up @@ -548,22 +542,16 @@ def _formatted_measurements(
]
# Gets the subset of measurements from the full measurements
if measured_qubits is not None and measured_qubits != []:
if any(qubit in range(simulation.qubit_count) for qubit in measured_qubits):
measured_qubits = np.array(measured_qubits)
in_circuit_mask = measured_qubits < simulation.qubit_count
measured_qubits_in_circuit = measured_qubits[in_circuit_mask]
measured_qubits_not_in_circuit = measured_qubits[~in_circuit_mask]

measurements_array = np.array(measurements)
selected_measurements = measurements_array[:, measured_qubits_in_circuit]
measurements = np.pad(
selected_measurements, ((0, 0), (0, len(measured_qubits_not_in_circuit)))
).tolist()

else:
measurements = np.zeros(
(simulation.shots, len(measured_qubits)), dtype=int
).tolist()
measured_qubits = np.array(measured_qubits)
in_circuit_mask = measured_qubits < simulation.qubit_count
measured_qubits_in_circuit = measured_qubits[in_circuit_mask]
measured_qubits_not_in_circuit = measured_qubits[~in_circuit_mask]

measurements_array = np.array(measurements)
selected_measurements = measurements_array[:, measured_qubits_in_circuit]
measurements = np.pad(
selected_measurements, ((0, 0), (0, len(measured_qubits_not_in_circuit)))
).tolist()
return measurements

def run_openqasm(
Expand Down Expand Up @@ -593,8 +581,12 @@ def run_openqasm(
are requested when shots>0.
"""
circuit = self.parse_program(openqasm_ir).circuit
qubit_map = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit)
qubit_count = circuit.num_qubits
measured_qubits = circuit.measured_qubits
mapped_measured_qubits = (
[qubit_map[q] for q in measured_qubits] if measured_qubits else None
)

self._validate_ir_results_compatibility(
circuit.results,
Expand All @@ -607,8 +599,6 @@ def run_openqasm(
self._validate_input_provided(circuit)
BaseLocalSimulator._validate_shots_and_ir_results(shots, circuit.results, qubit_count)

circuit = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit)

results = circuit.results

simulation = self.initialize_simulation(
Expand All @@ -635,7 +625,9 @@ def run_openqasm(
else:
simulation.evolve(circuit.basis_rotation_instructions)

return self._create_results_obj(results, openqasm_ir, simulation, measured_qubits)
return self._create_results_obj(
results, openqasm_ir, simulation, measured_qubits, mapped_measured_qubits
)

def run_jaqcd(
self,
Expand Down Expand Up @@ -674,8 +666,7 @@ def run_jaqcd(
device_action_type=DeviceActionType.JAQCD,
)
BaseLocalSimulator._validate_shots_and_ir_results(shots, circuit_ir.results, qubit_count)

circuit_ir = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit_ir)
BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit_ir)

operations = [
from_braket_instruction(instruction) for instruction in circuit_ir.instructions
Expand Down
6 changes: 0 additions & 6 deletions test/resources/discontiguous.qasm

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"braketSchemaHeader": {"name": "braket.ir.jaqcd.program", "version": "1"},
"instructions": [
{"target": 2, "type": "x"},
{"target": 2, "type": "h"},
{"control": 2, "target": 9, "type": "cnot"}
],
"results": [],
Expand Down
6 changes: 6 additions & 0 deletions test/resources/noncontiguous_physical.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
OPENQASM 3.0;
bit[2] b;
h $2;
cnot $2, $8;
b[0] = measure $2;
b[1] = measure $8;
7 changes: 7 additions & 0 deletions test/resources/noncontiguous_virtual.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
OPENQASM 3.0;
bit[2] b;
qubit[10] q;
h q[2];
cnot q[2], q[8];
b[0] = measure q[2];
b[1] = measure q[8];
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,12 @@ def grcs_8_qubit(ir_type):


@pytest.fixture
def discontiguous_jaqcd():
with open("test/resources/discontiguous_jaqcd.json") as jaqcd_definition:
def noncontiguous_jaqcd():
with open("test/resources/noncontiguous_jaqcd.json") as jaqcd_definition:
data = json.load(jaqcd_definition)
return json.dumps(data)


@pytest.fixture
def discontiguous_qasm():
return OpenQASMProgram(source="test/resources/discontiguous.qasm")


@pytest.fixture
def bell_ir(ir_type):
return (
Expand Down Expand Up @@ -828,35 +823,44 @@ def test_measure_no_gates():

def test_measure_with_qubits_not_used():
qasm = """
bit[4] b;
qubit[4] q;
h q[0];
cnot q[0], q[1];
bit[5] b;
qubit[5] q;
h q[1];
cnot q[1], q[3];
b = measure q;
"""
simulator = DensityMatrixSimulator()
result = simulator.run(OpenQASMProgram(source=qasm), shots=1000)
measurements = np.array(result.measurements, dtype=int)
assert 400 < np.sum(measurements, axis=0)[0] < 600
assert 400 < np.sum(measurements, axis=0)[1] < 600
assert 400 < np.sum(measurements, axis=0)[3] < 600
assert np.sum(measurements, axis=0)[0] == 0
assert np.sum(measurements, axis=0)[2] == 0
assert np.sum(measurements, axis=0)[3] == 0
assert len(measurements[0]) == 4
assert result.measuredQubits == [0, 1, 2, 3]
assert np.sum(measurements, axis=0)[4] == 0
assert len(measurements[0]) == 5
assert result.measuredQubits == [0, 1, 2, 3, 4]


def test_discontiguous_qubits_jaqcd(discontiguous_jaqcd):
prg = JaqcdProgram.parse_raw(discontiguous_jaqcd)
def test_noncontiguous_qubits_jaqcd(noncontiguous_jaqcd):
prg = JaqcdProgram.parse_raw(noncontiguous_jaqcd)
result = DensityMatrixSimulator().run(prg, qubit_count=2, shots=1)

assert result.measuredQubits == [0, 1]
assert result.measurements == [["1", "1"]]
assert result.measurements in ([["0", "0"]], [["1", "1"]])


def test_discontiguous_qubits_openqasm(discontiguous_qasm):
@pytest.mark.parametrize("qasm_file_name", ["noncontiguous_virtual", "noncontiguous_physical"])
def test_noncontiguous_qubits_openqasm(qasm_file_name):
simulator = DensityMatrixSimulator()
result = simulator.run(discontiguous_qasm, shots=1000)
shots = 1000
result = simulator.run(
OpenQASMProgram(source=f"test/resources/{qasm_file_name}.qasm"), shots=shots
)

assert result.measuredQubits == [2, 8]
measurements = np.array(result.measurements, dtype=int)
assert len(measurements[0]) == 5
assert result.measuredQubits == [0, 1, 2, 3, 4]
assert measurements.shape == (shots, 2)
assert all(
(np.allclose(measurement, [0, 0]) or np.allclose(measurement, [1, 1]))
for measurement in measurements
)
Loading