Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix casting step outputs to device #253

Merged
merged 1 commit into from
Feb 20, 2024
Merged

Fix casting step outputs to device #253

merged 1 commit into from
Feb 20, 2024

Conversation

gsarti
Copy link
Member

@gsarti gsarti commented Feb 20, 2024

Description

Ensures the outcome of the attribute_step method is always on cpu to minimize GPU memory allocation, while also preventing device issues raised in #251

@gsarti
Copy link
Member Author

gsarti commented Feb 20, 2024

@xuan25 could you try out whether this works for you on CUDA? The current version we merged from your PR causes some issues when out.show() is called because tensors are never cast back to CPU and the call tensor.numpy() in the viz method fails. This should fix it but preserve CUDA compatibility. Thanks in advance!

@xuan25
Copy link
Contributor

xuan25 commented Feb 20, 2024

Hi @gsarti, thanks for pointing it out. I can confirm all 50 tests have been passed on my side with this branch.

I also think whether it is more reasonable to only move the tensor to the CPU when show or other operations that have to be done on the CPU have been invoked. Otherwise, the tensor can be left on its original device (same as the model) to avoid potentially unnecessary overhead.

@gsarti
Copy link
Member Author

gsarti commented Feb 20, 2024

Thanks @xuan25, wil merge then! I think that if the goal is to make the methods as accessible as possible, saving a bit of GPU RAM by moving tensors back to CPU should be the preferred approach as it could enable people to use bigger batch sizes or models. I think it's worth keeping as-is in the current state, but for the future an option would be to have a flag to control this behavior, while still keeping the memory-saving one as default!

@gsarti gsarti merged commit b038877 into main Feb 20, 2024
3 checks passed
@gsarti gsarti deleted the fix-device branch February 20, 2024 13:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants