Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make torch.amp.autocast more generic (#125103)
Summary: # Motivation As discussed in [#124479](pytorch/pytorch#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. X-link: pytorch/pytorch#125103 Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui Reviewed By: izaitsevfb Differential Revision: D57138276 fbshipit-source-id: 17f883924e43f68dd6836d99b06fe8a47cfccbf6
- Loading branch information