Skip to content

Commit

Permalink
Merge pull request #300 from randomir/fix-flaky-test
Browse files Browse the repository at this point in the history
Fix flaky tests
  • Loading branch information
randomir authored Dec 2, 2024
2 parents c9741e8 + 3f32640 commit ec4b3e6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
10 changes: 6 additions & 4 deletions hybrid/composers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ def next(self, states, **runopts):
synthesis_en = thesis_en

# input sanity check
# TODO: convert to hard input validation
assert len(thesis) == len(antithesis)
assert state_thesis.problem == state_antithesis.problem
if len(thesis) != len(antithesis):
raise ValueError("thesis-antithesis length mismatch")
if state_thesis.problem != state_antithesis.problem:
raise ValueError("thesis and antithesis refer to different problems")

diff = {v for v in thesis if thesis[v] != antithesis[v]}

Expand All @@ -116,7 +117,8 @@ def next(self, states, **runopts):
synthesis_samples = SampleSet.from_samples_bqm(synthesis, bqm)

# calculation sanity check
assert synthesis_samples.first.energy == synthesis_en
if synthesis_samples.first.energy != synthesis_en:
logger.error("Synthesis error: lowest energy sample is not on synthesis path.")

return state_thesis.updated(samples=synthesis_samples)

Expand Down
5 changes: 3 additions & 2 deletions hybrid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,10 @@ def __init__(self, sampler, fields, **sample_kwargs):
if not isinstance(sampler, dimod.Sampler):
raise TypeError("'sampler' should be 'dimod.Sampler'")
try:
assert len(tuple(fields)) == 2
if len(tuple(fields)) != 2:
raise ValueError
except:
raise ValueError("'fields' should be two-tuple with input/output state fields")
raise ValueError("'fields' should be a two-tuple with input/output state fields")

self.sampler = sampler
self.input, self.output = fields
Expand Down
39 changes: 21 additions & 18 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,13 +435,14 @@ def test_validation(self):
class TestHybridRunnable(unittest.TestCase):
bqm = dimod.BinaryQuadraticModel({}, {'ab': 1, 'bc': 1, 'ca': -1}, 0, dimod.SPIN)
init_state = State.from_sample(min_sample(bqm), bqm)
ground_energy = dimod.ExactSolver().sample(bqm).first.energy

def test_generic(self):
runnable = HybridRunnable(TabuSampler(), fields=('problem', 'samples'))
response = runnable.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)

def test_validation(self):
with self.assertRaises(TypeError):
Expand All @@ -462,54 +463,56 @@ def test_problem_sampler_runnable(self):
response = runnable.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)

def test_subproblem_sampler_runnable(self):
runnable = HybridSubproblemRunnable(TabuSampler())
state = self.init_state.updated(subproblem=self.bqm)
response = runnable.run(state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().subsamples.record[0].energy, -3.0)
self.assertEqual(response.result().subsamples.record[0].energy, self.ground_energy)

def test_runnable_composition(self):
runnable = IdentityDecomposer() | HybridSubproblemRunnable(TabuSampler()) | IdentityComposer()
response = runnable.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)

def test_racing_workflow_with_oracle_subsolver(self):
class ExactSolver(dimod.ExactSolver):
"""Exact solver that returns only the ground state."""
def sample(self, bqm):
return super().sample(bqm).truncate(1)

workflow = hybrid.LoopUntilNoImprovement(hybrid.RacingBranches(
hybrid.InterruptableTabuSampler(),
hybrid.EnergyImpactDecomposer(size=1)
| HybridSubproblemRunnable(dimod.ExactSolver())
| HybridSubproblemRunnable(ExactSolver())
| hybrid.SplatComposer()
) | hybrid.ArgMin(), convergence=3)
state = State.from_sample(min_sample(self.bqm), self.bqm)
response = workflow.run(state)
response = workflow.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)

def test_sampling_parameters_filtering(self):
class Sampler(dimod.ExactSolver):
"""Exact solver that fails if a sampling parameter is provided."""
parameters = {}
def sample(self, bqm):
return super().sample(bqm)
return super().sample(bqm).truncate(1)

workflow = hybrid.LoopUntilNoImprovement(hybrid.RacingBranches(
hybrid.InterruptableTabuSampler(),
hybrid.EnergyImpactDecomposer(size=1)
| HybridSubproblemRunnable(Sampler())
| hybrid.SplatComposer()
) | hybrid.ArgMin(), convergence=3)
state = State.from_sample(min_sample(self.bqm), self.bqm)
response = workflow.run(state)
workflow = (
hybrid.IdentityDecomposer()
| HybridSubproblemRunnable(Sampler(), unknown_sampler_argument=1)
| hybrid.IdentityComposer()
)
response = workflow.run(self.init_state)

self.assertIsInstance(response, concurrent.futures.Future)
self.assertEqual(response.result().samples.record[0].energy, -3.0)
self.assertEqual(response.result().samples.record[0].energy, self.ground_energy)


class TestLogging(unittest.TestCase):
Expand Down

0 comments on commit ec4b3e6

Please sign in to comment.