-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in CVAE example #3325
Fix bug in CVAE example #3325
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix, just a couple of suggestions.
pytest.param( | ||
"cvae/main.py --num-quadrant-inputs=1 --num-epochs=1", | ||
marks=pytest.mark.skip(reason="https://github.com/pyro-ppl/pyro/issues/3273"), | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you replace these four lines with just the string
"cvae/main.py --num-quadrant-inputs=1 --num-epochs=1",
The pytest.mark.skip is just to skip the test, and we now want it to run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the line for the test.
examples/cvae/baseline.py
Outdated
loss = F.binary_cross_entropy(input, target, reduction="none") | ||
loss[target == self.masked_with] = 0 | ||
# only calculate loss on target pixels (value = -1) | ||
loss = F.binary_cross_entropy(input[target != self.masked_with], target[target != self.masked_with], reduction='none') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you run make lint
to fix the lint error? I think it's just a line too long.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had checked the format using make lint
and fix the lint error.
The bug is caused by input validation (only accept value between 0 or 1 otherwise raise error) of F.binary_cross_entropy_loss after Pytorch 2.1 update. The fix is to manually mask the input of the loss function to avoid invalid value rather than removing the loss of invalid input from the final loss result. Please refer to:pyro-ppl#3273 for the initial report of the bug
Fixes #3273
The bug is caused by input validation (only accept value between 0 or 1 otherwise raise error) of F.binary_cross_entropy_loss after Pytorch 2.1 update. The fix is to manually mask the input of the loss function to avoid invalid value rather than removing the loss of invalid input from the final loss result. Please refer to:#3273 for the initial report of the bug