Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Functionalization integration (#4158)
Summary: This is the consolidated branch of all the functionalization changes. For individual changes, please visit the corresponding PRs for details. Here is a brief summary of the highlights: Funtionalization is a dispatcher pass introduced in upstream to remove views and mutations from a PyTorch program to produce functional graphs which is better for backend compilers to do their optimizations. We have a in-house infrastructure to do similar tricks as the HLOs we generated are always functional. The benefits of adopting the upstream Funtionalization pass is that we can get rid of the in-house view infrastructure that we have struggled to maintain for years, and let upstream do all those heavy lifting for us. Implementation details: 1. To enable Funtionalization, we just need to wrap our newly created at::Tensor that holds a XLATensor in a FunctionalWrapper and then return the wrapper to Python. Then for any consecutive ops will then first go to the Funtionalization pass before reaching us. Correspondingly, we then have to unwrap the FunctionalWrapper before getting the XLATensor. Basically a thin layer called FunctionalWrapper is added to the whole aten-xla bridge process. FunctionalWrapper <=> at::Tensor <=> XLATensor. 2. To support the new view ops, for each view op, we have to implement at most two variants: view_copy: the view op but returns a copy instead; view_scatter: sometimes we need extra logic to reapply the updated value to the base view. 3. For in-place ops, we have a new op called _propagate_xla_data to keep any in-place ops optimization we had before active. Test Plan: CI
- Loading branch information