Skip to content

Commit

Permalink
run opted out decompositions
Browse files Browse the repository at this point in the history
  • Loading branch information
bdhirsh committed Jun 16, 2022
1 parent de45c7c commit 552fab9
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 0 deletions.
1 change: 1 addition & 0 deletions torch_patches/.torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#79420
59 changes: 59 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/ExpandUtils.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Operators.h>
#include <ATen/native/BinaryOps.h>
Expand Down Expand Up @@ -3520,4 +3521,62 @@ XLANativeFunctions::native_group_norm(const at::Tensor& input,
eps);
}

// All of the below ops correspond to CompositeExplicitAutograd kernels from
// core that call into view operators internally. These are all composite ops
// that LTC can technically re-use / get for free, but we need to
// "functionalize" them to remove the view ops before we can use them.
at::Tensor XLANativeFunctions::block_diag(at::TensorList tensors) {
return at::native::block_diag(tensors);
}
at::Tensor XLANativeFunctions::new_empty_strided(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
return at::native::new_empty_strided(self, size, stride, dtype, layout,
device, pin_memory);
}

at::Tensor XLANativeFunctions::narrow_copy(const at::Tensor& self, int64_t dim,
int64_t start, int64_t length) {
return at::native::narrow_copy(self, dim, start, length);
}
at::Tensor XLANativeFunctions::pixel_shuffle(const at::Tensor& self,
int64_t upscale_factor) {
return at::native::pixel_shuffle(self, upscale_factor);
}
at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self,
int64_t downscale_factor) {
return at::native::pixel_unshuffle(self, downscale_factor);
}
at::Tensor XLANativeFunctions::select_backward(const at::Tensor& grad_output,
at::IntArrayRef input_sizes,
int64_t dim, int64_t index) {
return at::native::select_backward(grad_output, input_sizes, dim, index);
}
::std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::linalg_inv_ex(
const at::Tensor& self, bool check_errors) {
return at::native::linalg_inv_ex(self, check_errors);
}
at::Tensor XLANativeFunctions::linalg_pinv(
const at::Tensor& self, const c10::optional<at::Tensor>& atol,
const c10::optional<at::Tensor>& rtol, bool hermitian) {
return at::native::linalg_pinv(self, atol, rtol, hermitian);
}

at::Tensor XLANativeFunctions::diagonal_backward(const at::Tensor& grad_output,
at::IntArrayRef input_sizes,
int64_t offset, int64_t dim1,
int64_t dim2) {
return at::native::diagonal_backward(grad_output, input_sizes, offset, dim1,
dim2);
}

at::Tensor XLANativeFunctions::slice_backward(const at::Tensor& grad_output,
at::IntArrayRef input_sizes,
int64_t dim, int64_t start,
int64_t end, int64_t step) {
return at::native::slice_backward(grad_output, input_sizes, dim, start, end,
step);
}

} // namespace torch_xla
18 changes: 18 additions & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,24 @@ supported:
- where.self
- xlogy.Tensor
- zero_
# Below are all operators that are "composite" in core,
# but require us to explicitly re-enable functionalization in order to use them.
# Why? These operators are all CompositeExplicitAutograd, which mean that they run
# after functionalization,
# but their implementations call view operators (which we need to functionalize away).
- block_diag
- slice_backward
- diagonal_backward
- new_empty_strided
- narrow_copy
- pixel_shuffle
- pixel_unshuffle
- select_backward
- linalg_inv_ex
- linalg_pinv.atol_rtol_tensor
# The same applies to these ops, but we already have direct lowerings for them
# - _trilinear
# - logsumexp.out
autograd:
- max_pool2d
- max_pool3d
Expand Down

0 comments on commit 552fab9

Please sign in to comment.