Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix erroneous warnings about different devices #1225

Merged
merged 1 commit into from
Aug 20, 2024

Conversation

ThomasGesseyJonesPX
Copy link
Contributor

@ThomasGesseyJonesPX ThomasGesseyJonesPX commented Aug 19, 2024

What does this implement/fix? Explain your changes

Fix for erroneous warnings saying tensors are on different devices when they are on the same device.

This bug occurs when PyTorch device objects are used rather than their string representations. In the comparisons that check the tensors are on the right device one side is cast to a string but not the other. When device objects are used this leads to comparing an object to a string and hence them not being equal and a warning being raised, e.g.,

UserWarning: Data x has device 'cuda:0'. Moving x to the data_device 'cuda:0'. Training will proceed on device 'cuda:0'.
  theta, x = validate_theta_and_x(

This bug is fixed by casting both sides of the comparison to strings.

Does this close any currently open issues?

None

Any relevant code examples, logs, error output, etc?

Minimum working example to illustrate bug and replicate the warning shown above,

"""MWE Example of erroneous warning in device comparison."""

import torch
from sbi.inference import SNPE
from sbi.utils.user_input_checks import process_prior, process_simulator

# Set device to GPU
device = torch.device("cuda:0")

# Define prior and simulator to be on the GPU
prior = torch.distributions.MultivariateNormal(
    torch.zeros(2, dtype=float, device=device), torch.eye(2, dtype=float, device=device)
)
simulator = lambda theta: theta

prior.sample()
prior, num_parameters, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)

# Also set the NPE to be on the same GPU
inference = SNPE(prior=prior, device=device)
theta = prior.sample((100,))
x = simulator(theta)
inference = inference.append_simulations(theta, x)

# This will raise a warning saying that the device of the density estimator is different from the 
# device of the prior even though they are both on the same GPU
density_estimator = inference.train(max_num_epochs=1)

Any other comments?

None

Checklist

Put an x in the boxes that apply. You can also fill these out after creating
the PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.

  • I have read and understood the contribution
    guidelines
  • I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
  • I have commented my code, particularly in hard-to-understand areas
  • I have added tests that prove my fix is effective or that my feature works
  • I have reported how long the new tests run and potentially marked them
    with pytest.mark.slow.
  • New and existing unit tests pass locally with my changes
  • I performed linting and formatting as described in the contribution
    guidelines
  • I rebased on main (or there are no conflicts with main)
  • For reviewer: The continuous deployment (CD) workflow are passing.

Sorry, something went wrong.

@janfb janfb self-requested a review August 19, 2024 16:52
@janfb janfb self-assigned this Aug 19, 2024
@janfb janfb added the bug Something isn't working label Aug 19, 2024
@janfb janfb added this to the Hackathon and release 0.23 milestone Aug 19, 2024
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - thanks a lot for fixing this 🙏

device tests are passing locally.

Copy link

codecov bot commented Aug 19, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 76.90%. Comparing base (5de7784) to head (6632774).
Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1225      +/-   ##
==========================================
- Coverage   85.44%   76.90%   -8.54%     
==========================================
  Files         101      101              
  Lines        7941     7941              
==========================================
- Hits         6785     6107     -678     
- Misses       1156     1834     +678     
Flag Coverage Δ
unittests 76.90% <100.00%> (-8.54%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/utils/nn_utils.py 88.23% <100.00%> (ø)
sbi/utils/user_input_checks.py 76.31% <100.00%> (ø)

... and 25 files with indirect coverage changes

@janfb janfb merged commit b3254ed into sbi-dev:main Aug 20, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants