-
Notifications
You must be signed in to change notification settings - Fork 542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add low-precision to TFNO #172
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, thanks!
if self.fno_block_precision == 'half': | ||
x = x.half() | ||
else: | ||
x = x.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove this one or do you think we need to always explicitly cast here @crwhite14 @rtu715 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove the "else: x = x.float()"
I just did that in this commit: bcc52e7#diff-5ae1e49af12ed16c75135c0043a08575110fd03d4c722a837b60aa0950b31e32L342-L343
Thanks, great PR @crwhite14 @rtu715, merging! |
Add low-precision to TFNO
This pull request adds low precision options for TFNO. Here is a brief summary.
opt.amp_autocast
was previously in the yaml file but was a no-op. This PR getsopt.amp_autocast
workingfno_block_precision
stabilizer
fno_block_precision='half'
.Running
Improves runtime and memory by up to 30%, depending on the GPU used, the resolution of the data (greater speedups for 64x64 resolution or higher), and other hyperparameters such as factorization and rank.