Skip to content

Commit

Permalink
Fix flaky test for sensitivity analysis
Browse files Browse the repository at this point in the history
Summary:
This fixes a flaky test for sensitivity analysis. Currently the QMC
integration done for computing Sobol indices is seeded for reproducibility.
That done for computing derivative-based sensitivity measures is not. This diff
changes that so that the derivative-based measures will also have seeded QMC
with the same seed, to match Sobol indices.

A side effect of this is that it fixes the flaky test, since the results are no
longer random based on the QMC points.

Reviewed By: bernardbeckerman

Differential Revision: D54855136
  • Loading branch information
bletham authored and facebook-github-bot committed Mar 13, 2024
1 parent aa9025e commit a885b93
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ax/utils/sensitivity/derivative_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
if input_qmc:
# pyre-fixme[4]: Attribute must be annotated.
self.input_mc_samples = (
draw_sobol_samples(bounds=bounds, n=num_mc_samples, q=1)
draw_sobol_samples(bounds=bounds, n=num_mc_samples, q=1, seed=1234)
.squeeze(1)
.to(dtype)
)
Expand Down
8 changes: 4 additions & 4 deletions ax/utils/sensitivity/tests/test_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,23 +271,23 @@ def test_SobolGPMean(self) -> None:
# Test with signed
model_bridge = get_modelbridge(modular=True)
# Unsigned
sobol_kwargs = {"input_qmc": True, "num_mc_samples": 10}
ind_dict = ax_parameter_sens(
model_bridge, # pyre-ignore
input_qmc=True,
num_mc_samples=10,
order="total",
signed=False,
**sobol_kwargs, # pyre-ignore
)
ind_deriv = compute_derivatives_from_model_list(
model_list=[model_bridge.model.surrogate.model],
bounds=torch.tensor(model_bridge.model.search_space_digest.bounds).T,
**sobol_kwargs,
)
ind_dict_signed = ax_parameter_sens(
model_bridge, # pyre-ignore
input_qmc=True,
num_mc_samples=10,
order="total",
# signed=True
**sobol_kwargs, # pyre-ignore
)
for i, pname in enumerate(["x1", "x2"]):
self.assertEqual(
Expand Down

0 comments on commit a885b93

Please sign in to comment.