Skip to content

Commit

Permalink
Fix data type bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Sep 15, 2022
1 parent 4a21213 commit adc842c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion neuralprocesses/model/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def elbo(

if normalise:
# Normalise by the number of targets.
elbos = elbos / num_data(xt, yt)
elbos = elbos / B.cast(float64, num_data(xt, yt))

return state, elbos

Expand Down
2 changes: 1 addition & 1 deletion neuralprocesses/model/loglik.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def loglik(

if normalise:
# Normalise by the number of targets.
logpdfs = logpdfs / num_data(xt, yt)
logpdfs = logpdfs / B.cast(float64, num_data(xt, yt))

return state, logpdfs

Expand Down
10 changes: 6 additions & 4 deletions tests/test_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,23 +249,25 @@ def test_forward(nps, model_sample):
check_prediction(nps, pred, yt)


@pytest.mark.parametrize("normalise", [False, True])
@pytest.mark.flaky(reruns=3)
def test_elbo(nps, model_sample):
def test_elbo(nps, model_sample, normalise):
model, sample = model_sample
model = model()
xc, yc, xt, yt = sample()
elbos = nps.elbo(model, xc, yc, xt, yt, num_samples=2)
elbos = nps.elbo(model, xc, yc, xt, yt, num_samples=2, normalise=normalise)
assert B.rank(elbos) == 1
assert np.isfinite(B.to_numpy(B.sum(elbos)))
assert B.dtype(elbos) == nps.dtype64


@pytest.mark.parametrize("normalise", [False, True])
@pytest.mark.flaky(reruns=3)
def test_loglik(nps, model_sample):
def test_loglik(nps, model_sample, normalise):
model, sample = model_sample
model = model()
xc, yc, xt, yt = sample()
logpdfs = nps.loglik(model, xc, yc, xt, yt, num_samples=2)
logpdfs = nps.loglik(model, xc, yc, xt, yt, num_samples=2, normalise=normalise)
assert B.rank(logpdfs) == 1
assert np.isfinite(B.to_numpy(B.sum(logpdfs)))
assert B.dtype(logpdfs) == nps.dtype64
Expand Down

0 comments on commit adc842c

Please sign in to comment.