Skip to content

Commit

Permalink
Functionalization integration (#4158)
Browse files Browse the repository at this point in the history
Summary:
This is the consolidated branch of all the functionalization changes. For individual changes, please visit the corresponding PRs for details.

Here is a brief summary of the highlights: Funtionalization is a dispatcher pass introduced in upstream to remove views and mutations from a PyTorch program to produce functional graphs which is better for backend compilers to do their optimizations. We have a in-house infrastructure to do similar tricks as the HLOs we generated are always functional. The benefits of adopting the upstream Funtionalization pass is that we can get rid of the in-house view infrastructure that we have struggled to maintain for years, and let upstream do all those heavy lifting for us.

Implementation details:
1. To enable Funtionalization, we just need to wrap our newly created at::Tensor that holds a XLATensor in a FunctionalWrapper and then return the wrapper to Python. Then for any consecutive ops will then first go to the Funtionalization pass before reaching us. Correspondingly, we then have to unwrap the FunctionalWrapper before getting the XLATensor. Basically a thin layer called FunctionalWrapper is added to the whole aten-xla bridge process. FunctionalWrapper <=> at::Tensor <=> XLATensor.
2. To support the new view ops, for each view op, we have to implement at most two variants: view_copy: the view op but returns a copy instead; view_scatter: sometimes we need extra logic to reapply the updated value to the base view.
3. For in-place ops, we have a new op called _propagate_xla_data to keep any in-place ops optimization we had before active.

Test Plan:
CI
  • Loading branch information
bdhirsh authored Mar 2, 2023
1 parent 5ca5686 commit e49df83
Show file tree
Hide file tree
Showing 25 changed files with 1,057 additions and 387 deletions.
6 changes: 6 additions & 0 deletions test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ RMBUILD=1
LOGFILE=/tmp/pytorch_cpp_test.log
XLA_EXPERIMENTAL="nonzero:masked_select"

# See Note [Keep Going]
CONTINUE_ON_ERROR=false
if [[ "$CONTINUE_ON_ERROR" == "1" ]]; then
set +e
fi

if [ "$DEBUG" == "1" ]; then
BUILDTYPE="Debug"
fi
Expand Down
Loading

0 comments on commit e49df83

Please sign in to comment.