diff --git a/examples/cvae/baseline.py b/examples/cvae/baseline.py index 68695c6d5e..f901631806 100644 --- a/examples/cvae/baseline.py +++ b/examples/cvae/baseline.py @@ -34,8 +34,12 @@ def __init__(self, masked_with=-1): def forward(self, input, target): target = target.view(input.shape) - loss = F.binary_cross_entropy(input, target, reduction="none") - loss[target == self.masked_with] = 0 + # only calculate loss on target pixels (value = -1) + loss = F.binary_cross_entropy( + input[target != self.masked_with], + target[target != self.masked_with], + reduction="none", + ) return loss.sum() diff --git a/tests/test_examples.py b/tests/test_examples.py index d8abf2c450..8e62a7f770 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -123,10 +123,7 @@ "vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential", "vae/vae.py --num-epochs=1", "vae/vae_comparison.py --num-epochs=1", - pytest.param( - "cvae/main.py --num-quadrant-inputs=1 --num-epochs=1", - marks=pytest.mark.skip(reason="https://github.com/pyro-ppl/pyro/issues/3273"), - ), + "cvae/main.py --num-quadrant-inputs=1 --num-epochs=1", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 ",