Skip to content

Commit

Permalink
Stabilise flaky test
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jun 4, 2024
1 parent f32f081 commit 450477e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
8 changes: 4 additions & 4 deletions tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def test_convgnp_mask(nps):
conv_receptive_field=0.5,
conv_layers=1,
conv_channels=1,
# Dividing by the density channel makes the forward very sensitive to the
# numerics.
divide_by_density=False,
# A large margin and `float64`s help with numerical stability.
margin=2,
dtype=nps.dtype64,
)
xc, yc, xt, yt = generate_data(nps)
xc, yc, xt, yt = generate_data(nps, dtype=nps.dtype64)

# Predict without the final three points.
pred = model(xc[:, :, :-3], yc[:, :, :-3], xt)
Expand Down
17 changes: 10 additions & 7 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def approx(
nps_tf.dtype64 = tf.float64


@pytest.fixture(params=[nps_torch, nps_tf], scope="module")
@pytest.fixture(params=[nps_tf, nps_torch], scope="module")
def nps(request):
return request.param

Expand All @@ -64,14 +64,17 @@ def generate_data(
n_context=5,
n_target=7,
binary=False,
dtype=None,
):
xc = B.randn(nps.dtype, batch_size, dim_x, n_context)
yc = B.randn(nps.dtype, batch_size, dim_y, n_context)
xt = B.randn(nps.dtype, batch_size, dim_x, n_target)
yt = B.randn(nps.dtype, batch_size, dim_y, n_target)
if dtype is None:
dtype = nps.dtype
xc = B.randn(dtype, batch_size, dim_x, n_context)
yc = B.randn(dtype, batch_size, dim_y, n_context)
xt = B.randn(dtype, batch_size, dim_x, n_target)
yt = B.randn(dtype, batch_size, dim_y, n_target)
if binary:
yc = B.cast(nps.dtype, yc >= 0)
yt = B.cast(nps.dtype, yt >= 0)
yc = B.cast(dtype, yc >= 0)
yt = B.cast(dtype, yt >= 0)
return xc, yc, xt, yt


Expand Down

0 comments on commit 450477e

Please sign in to comment.