Fix erroneous warnings about different devices #1225
Merged
+3
−3
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this implement/fix? Explain your changes
Fix for erroneous warnings saying tensors are on different devices when they are on the same device.
This bug occurs when PyTorch
device
objects are used rather than their string representations. In the comparisons that check the tensors are on the right device one side is cast to a string but not the other. Whendevice
objects are used this leads to comparing an object to a string and hence them not being equal and a warning being raised, e.g.,This bug is fixed by casting both sides of the comparison to strings.
Does this close any currently open issues?
None
Any relevant code examples, logs, error output, etc?
Minimum working example to illustrate bug and replicate the warning shown above,
Any other comments?
None
Checklist
Put an
x
in the boxes that apply. You can also fill these out after creatingthe PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.
guidelines
with
pytest.mark.slow
.guidelines
main
(or there are no conflicts withmain
)