From b12c2d9238090ace78db034bc4d44d9f75eabb9e Mon Sep 17 00:00:00 2001 From: Chunhui Gu Date: Fri, 9 Feb 2024 18:23:17 -0600 Subject: [PATCH] Fix bug in CVAE example The bug is caused by input validation (only accept value between 0 or 1 otherwise raise error) of F.binary_cross_entropy_loss after Pytorch 2.1 update. The fix is to manually mask the input of the loss function to avoid invalid value rather than removing the loss of invalid input from the final loss result. Please refer to:https://github.com/pyro-ppl/pyro/issues/3273 for the initial report of the bug --- examples/cvae/baseline.py | 8 ++++++-- tests/test_examples.py | 5 +---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/cvae/baseline.py b/examples/cvae/baseline.py index 68695c6d5e..f901631806 100644 --- a/examples/cvae/baseline.py +++ b/examples/cvae/baseline.py @@ -34,8 +34,12 @@ def __init__(self, masked_with=-1): def forward(self, input, target): target = target.view(input.shape) - loss = F.binary_cross_entropy(input, target, reduction="none") - loss[target == self.masked_with] = 0 + # only calculate loss on target pixels (value = -1) + loss = F.binary_cross_entropy( + input[target != self.masked_with], + target[target != self.masked_with], + reduction="none", + ) return loss.sum() diff --git a/tests/test_examples.py b/tests/test_examples.py index d8abf2c450..8e62a7f770 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -123,10 +123,7 @@ "vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential", "vae/vae.py --num-epochs=1", "vae/vae_comparison.py --num-epochs=1", - pytest.param( - "cvae/main.py --num-quadrant-inputs=1 --num-epochs=1", - marks=pytest.mark.skip(reason="https://github.com/pyro-ppl/pyro/issues/3273"), - ), + "cvae/main.py --num-quadrant-inputs=1 --num-epochs=1", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=0 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=1 ", "contrib/funsor/hmm.py --num-steps=1 --truncate=10 --model=2 ",