Skip to content

Commit

Permalink
Fix off by one bug (thanks @tom-andersson!)
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jul 30, 2023
1 parent 33a6a49 commit 46bd53e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
11 changes: 10 additions & 1 deletion neuralprocesses/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ def _pad_zeros(x: B.Numeric, up_to: int, axis: int):
return B.concat(x, B.zeros(B.dtype(x), *shape), axis=axis)


def _ceil_to_closest_multiple(n, m):
d, r = divmod(n, m)
# If `n` is zero, then we must also round up.
if n == 0 or r > 0:
return (d + 1) * m
else:
return d * m


@_dispatch
def _determine_ns(xc: tuple, multiple: Union[int, tuple]):
ns = [B.shape(xci, 2) for xci in xc]
Expand All @@ -48,7 +57,7 @@ def _determine_ns(xc: tuple, multiple: Union[int, tuple]):
multiple = (multiple,) * len(ns)

# Ceil to the closest multiple of `multiple`.
ns = [((n - 1) // m + 1) * m for n, m in zip(ns, multiple)]
ns = [_ceil_to_closest_multiple(n, m) for n, m in zip(ns, multiple)]

return ns

Expand Down
22 changes: 21 additions & 1 deletion tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from .test_architectures import generate_data
from .util import nps, approx # noqa
from .util import approx, nps # noqa


@pytest.mark.flaky(reruns=3)
Expand Down Expand Up @@ -31,3 +31,23 @@ def test_convgnp_mask(nps):
# Check that the two ways of doing it coincide.
approx(pred.mean, pred_masked.mean)
approx(pred.var, pred_masked.var)


@pytest.mark.parametrize("ns", [(10,), (0,), (10, 5), (10, 0), (0, 10), (15, 5, 10)])
@pytest.mark.parametrize("multiple", [1, 2, 3, 5])
def test_mask_contexts(nps, ns, multiple):
x, y = nps.merge_contexts(
*((B.randn(nps.dtype, 2, 3, n), B.randn(nps.dtype, 2, 4, n)) for n in ns),
multiple=multiple
)

# Test that the output is of the right shape.
if max(ns) == 0:
assert B.shape(y.y, 2) == multiple
else:
assert B.shape(y.y, 2) == ((max(ns) - 1) // multiple + 1) * multiple

# Test that the mask is right.
mask = y.mask == 1 # Convert mask to booleans.
assert B.all(B.take(B.flatten(y.y), B.flatten(mask)) != 0)
assert B.all(B.take(B.flatten(y.y), B.flatten(~mask)) == 0)

0 comments on commit 46bd53e

Please sign in to comment.