Skip to content

Commit

Permalink
Use pytest fixture for multiprocessing Pool in XEB tests (#6766)
Browse files Browse the repository at this point in the history
Direct Pool executions cause flaky test outcomes in our internal test
framework.  This change makes it easier to work around the problem.

No change in the effective test code.
  • Loading branch information
pavoljuhas authored Oct 14, 2024
1 parent 54f9e8c commit 9ff9b6b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 33 deletions.
44 changes: 24 additions & 20 deletions cirq-core/cirq/experiments/xeb_fitting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import itertools
import multiprocessing
from typing import Iterable
from typing import Iterable, Iterator

import networkx as nx
import numpy as np
Expand All @@ -40,6 +42,13 @@
_POOL_NUM_PROCESSES = min(4, multiprocessing.cpu_count())


@pytest.fixture
def pool() -> Iterator[multiprocessing.pool.Pool]:
ctx = multiprocessing.get_context()
with ctx.Pool(_POOL_NUM_PROCESSES) as pool:
yield pool


@pytest.fixture(scope='module')
def circuits_cycle_depths_sampled_df():
q0, q1 = cirq.LineQubit.range(2)
Expand Down Expand Up @@ -207,7 +216,7 @@ def test_get_initial_simplex():
assert simplex.shape[1] == len(names)


def test_characterize_phased_fsim_parameters_with_xeb():
def test_characterize_phased_fsim_parameters_with_xeb(pool):
q0, q1 = cirq.LineQubit.range(2)
rs = np.random.RandomState(52)
circuits = [
Expand All @@ -232,17 +241,16 @@ def test_characterize_phased_fsim_parameters_with_xeb():
characterize_phi=False,
)
p_circuits = [parameterize_circuit(circuit, options) for circuit in circuits]
with multiprocessing.Pool(_POOL_NUM_PROCESSES) as pool:
result = characterize_phased_fsim_parameters_with_xeb(
sampled_df=sampled_df,
parameterized_circuits=p_circuits,
cycle_depths=cycle_depths,
options=options,
# speed up with looser tolerances:
fatol=1e-2,
xatol=1e-2,
pool=pool,
)
result = characterize_phased_fsim_parameters_with_xeb(
sampled_df=sampled_df,
parameterized_circuits=p_circuits,
cycle_depths=cycle_depths,
options=options,
# speed up with looser tolerances:
fatol=1e-2,
xatol=1e-2,
pool=pool,
)
opt_res = result.optimization_results[(q0, q1)]
assert np.abs(opt_res.x[0] + np.pi / 4) < 0.1
assert np.abs(opt_res.fun) < 0.1 # noiseless simulator
Expand All @@ -252,7 +260,7 @@ def test_characterize_phased_fsim_parameters_with_xeb():


@pytest.mark.parametrize('use_pool', (True, False))
def test_parallel_full_workflow(use_pool):
def test_parallel_full_workflow(request, use_pool):
circuits = rqcg.generate_library_of_2q_circuits(
n_library_circuits=5,
two_qubit_gate=cirq.ISWAP**0.5,
Expand All @@ -272,10 +280,8 @@ def test_parallel_full_workflow(use_pool):
combinations_by_layer=combs,
)

if use_pool:
pool = multiprocessing.Pool(_POOL_NUM_PROCESSES)
else:
pool = None
# avoid starting worker pool if it is not needed
pool = request.getfixturevalue("pool") if use_pool else None

fids_df_0 = benchmark_2q_xeb_fidelities(
sampled_df=sampled_df, circuits=circuits, cycle_depths=cycle_depths, pool=pool
Expand All @@ -296,8 +302,6 @@ def test_parallel_full_workflow(use_pool):
xatol=5e-2,
pool=pool,
)
if pool is not None:
pool.terminate()

assert len(result.optimization_results) == graph.number_of_edges()
for opt_res in result.optimization_results.values():
Expand Down
29 changes: 16 additions & 13 deletions cirq-core/cirq/experiments/xeb_simulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import multiprocessing
from typing import Dict, Any, Optional
from typing import Sequence
from typing import Any, Dict, Iterator, Optional, Sequence

import numpy as np
import pandas as pd
Expand All @@ -27,7 +28,14 @@
_POOL_NUM_PROCESSES = min(4, multiprocessing.cpu_count())


def test_simulate_2q_xeb_circuits():
@pytest.fixture
def pool() -> Iterator[multiprocessing.pool.Pool]:
ctx = multiprocessing.get_context()
with ctx.Pool(_POOL_NUM_PROCESSES) as pool:
yield pool


def test_simulate_2q_xeb_circuits(pool):
q0, q1 = cirq.LineQubit.range(2)
circuits = [
rqcg.random_rotations_between_two_qubit_circuit(
Expand All @@ -45,8 +53,7 @@ def test_simulate_2q_xeb_circuits():
assert len(row['pure_probs']) == 4
assert np.isclose(np.sum(row['pure_probs']), 1)

with multiprocessing.Pool(_POOL_NUM_PROCESSES) as pool:
df2 = simulate_2q_xeb_circuits(circuits, cycle_depths, pool=pool)
df2 = simulate_2q_xeb_circuits(circuits, cycle_depths, pool=pool)

pd.testing.assert_frame_equal(df, df2)

Expand Down Expand Up @@ -121,8 +128,8 @@ def _ref_simulate_2q_xeb_circuits(
return pd.DataFrame(records).set_index(['circuit_i', 'cycle_depth']).sort_index()


@pytest.mark.parametrize('multiprocess', (True, False))
def test_incremental_simulate(multiprocess):
@pytest.mark.parametrize('use_pool', (True, False))
def test_incremental_simulate(request, use_pool):
q0, q1 = cirq.LineQubit.range(2)
circuits = [
rqcg.random_rotations_between_two_qubit_circuit(
Expand All @@ -132,16 +139,12 @@ def test_incremental_simulate(multiprocess):
]
cycle_depths = np.arange(3, 100, 9, dtype=np.int64)

if multiprocess:
pool = multiprocessing.Pool(_POOL_NUM_PROCESSES)
else:
pool = None
# avoid starting worker pool if it is not needed
pool = request.getfixturevalue("pool") if use_pool else None

df_ref = _ref_simulate_2q_xeb_circuits(circuits=circuits, cycle_depths=cycle_depths, pool=pool)

df = simulate_2q_xeb_circuits(circuits=circuits, cycle_depths=cycle_depths, pool=pool)
if pool is not None:
pool.terminate()
pd.testing.assert_frame_equal(df_ref, df)

# Use below for approximate equality, if e.g. you're using qsim:
Expand Down

0 comments on commit 9ff9b6b

Please sign in to comment.