Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for PyTorch/XLA functionalization integration #88787

Closed
wants to merge 11 commits into from
14 changes: 13 additions & 1 deletion aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
),
value_(value)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata();
}

Expand Down Expand Up @@ -130,6 +132,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
),
value_(view_value)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata();
// Copy the original tensor's ViewMeta vector and push the current one.
if (!base->view_metas_.empty()) {
Expand Down Expand Up @@ -168,7 +172,9 @@ void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta m
// So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()`
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
at::AutoDispatchSkipFunctionalize guard;
value_ = meta.forward_fn(value_, meta.out_index);
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
}

// Note [Functionalization: Mutation Removal]
Expand Down Expand Up @@ -200,15 +206,20 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
// TODO: going to need to change this if we want nested functionalize() transforms.
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other));
value_ = other;
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
// out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor.
// We need to propagate that metadata mutation to the wrapper (new size).
set_sizes_and_strides(value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset());
auto sizes_ = value_.sym_sizes();
auto strides_ = value_.sym_strides();
auto storage_offset_ = value_.sym_storage_offset();
set_sizes_and_strides(sizes_, strides_, storage_offset_);
if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) {
// .to() should not re-entrantly go through functionalization.
at::AutoDispatchSkipFunctionalize guard;
// and we want _to_copy() to show up in the graph, not the composite .to() operator
// (this can happen if autograd has already run by the time we enter this code)
value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout()));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
}
}

Expand Down Expand Up @@ -243,6 +254,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
// Then it's safe to throw out the old storage and replace it with the new, larger one.
storage_ = c10::Storage(c10::make_intrusive<functionalization::FunctionalStorageImpl>(other));
value_ = other;
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
generation_ = 0;
// And update the metadata on the wrapper to reflect the new sizes and strides
set_sizes_and_strides(value_.sizes(), value_.strides());
Expand Down
14 changes: 10 additions & 4 deletions aten/src/ATen/native/CPUFallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ c10::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args
}


void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) {
auto& schema_args = op.schema().arguments();
const auto num_arguments = schema_args.size();
auto arguments = torch::jit::last(stack, num_arguments);
Expand Down Expand Up @@ -176,9 +176,15 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
} else {
dev_str << "<none>";
}
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
if (error_on_views) {
TORCH_CHECK(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
} else {
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
}
}
// Case (2): copy case. Copy the cpu output tensor to the original device.

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/CPUFallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace at { namespace native {

// This function implements a boxed fallback to CPU.
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false);

// This is a helper function that backends can use to directly call their boxed CPU fallback
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5062,7 +5062,7 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: slice_scatter
CompositeExplicitAutogradNonFunctional: slice_scatter
autogen: slice_scatter.out
tags: core

Expand All @@ -5071,23 +5071,23 @@
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: select_scatter_symint
CompositeExplicitAutogradNonFunctional: select_scatter_symint
autogen: select_scatter.out

- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
variants: function, method
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: diagonal_scatter
CompositeExplicitAutogradNonFunctional: diagonal_scatter
autogen: diagonal_scatter.out

- func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
variants: function, method
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: as_strided_scatter_symint
CompositeExplicitAutogradNonFunctional: as_strided_scatter_symint
autogen: as_strided_scatter.out

- func: smm(Tensor self, Tensor mat2) -> Tensor
Expand Down
1 change: 0 additions & 1 deletion torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,7 +1799,6 @@ def meta_select(self, dim, index):

check(
not (-index > size or index >= size),
lambda: f"select(): index {index} out of range for tensor of size "
f"{self.size()} at dimension {dim}",
IndexError,
)
Expand Down
2 changes: 1 addition & 1 deletion torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def __call__(
or (ignore_subclass and isinstance(t, torch.Tensor))
or isinstance(t, FakeTensor)
):
if any(
if t.device.type != "xla" and any(
[
t.is_sparse_csr,
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
Expand Down
15 changes: 9 additions & 6 deletions torch/csrc/lazy/core/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

#include <ATen/AccumulateType.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Functions.h>
Expand Down Expand Up @@ -1304,7 +1305,7 @@ std::vector<Shape> compute_shape_select_scatter(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::select_scatter(
auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter(
self_meta, src_meta, dim, index);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
Expand All @@ -1329,7 +1330,7 @@ std::vector<Shape> compute_shape_diagonal_scatter(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::diagonal_scatter(
auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter(
self_meta, src_meta, offset, dim1, dim2);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
Expand All @@ -1355,8 +1356,9 @@ std::vector<Shape> compute_shape_slice_scatter_symint(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::slice_scatter_symint(
self_meta, src_meta, dim, start, end, step);
auto out_meta =
at::compositeexplicitautogradnonfunctional::slice_scatter_symint(
self_meta, src_meta, dim, start, end, step);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}

Expand All @@ -1380,8 +1382,9 @@ std::vector<Shape> compute_shape_as_strided_scatter_symint(
/*layout=*/c10::make_optional(src.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::compositeexplicitautograd::as_strided_scatter_symint(
self_meta, src_meta, size, stride, storage_offset);
auto out_meta =
at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint(
self_meta, src_meta, size, stride, storage_offset);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/utils/tensor_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ Tensor internal_new_from_data(
// to dispatch to it.
// TODO: arguably it should have an autograd implementation that noops
at::AutoDispatchBelowADInplaceOrView guard;

return at::lift_fresh(tensor);
}

Expand Down
13 changes: 12 additions & 1 deletion torchgen/dest/register_dispatch_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,21 @@ def gen_out_inplace_wrapper(
for i, ret_name in enumerate(return_names)
)
returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
else:
elif len(return_names) == 1:
ret_name = return_names[0]
updates = f"{copy_op}({func_res}, {ret_name});"
returns = ret_name
else:
assert len(f.func.arguments.out) == 1
returns = ""
out_arg = f.func.arguments.out[0]
if out_arg.type.is_list_like():
updates = f"""\
for (int64_t i = 0; i < {func_res}.size(); ++i) {{
{copy_op}({func_res}[i], {out_arg.name}[i]);
}}"""
else:
updates = f"{copy_op}({func_res}, {out_arg.name});"

functional_sig = self.wrapper_kernel_sig(g.functional)
wrapper_name = sig.name()
Expand Down
17 changes: 16 additions & 1 deletion torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,11 @@ def emit_inplace_functionalization_body(
for a in f.func.arguments.flat_all
if a.type.is_tensor_like() and a.annotation is None
]
non_mutated_tensor_names = [
a.name
for a in f.func.arguments.flat_all
if a.type == BaseType(BaseTy.Tensor) and a.annotation is None
]
# all mutable inputs must be functional tensors in order to participate in functionalization
check_all_mutated_args_are_functional = " && ".join(
["true"]
Expand All @@ -556,6 +561,14 @@ def emit_inplace_functionalization_body(
for a in non_mutated_names
]
)

check_any_non_mutated_tensors_are_xla = " || ".join(
["false"]
+ [
f"{a}.device().type() == c10::DeviceType::XLA"
for a in non_mutated_tensor_names
]
)
# These are used in the cases where we don't functionalize and redispatch to the inplace op
# case 1: we hit an inplace op that doesn't have an out-of-place equivalent
# case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
Expand Down Expand Up @@ -619,7 +632,9 @@ def emit_inplace_functionalization_body(
}}
{unwrap_tensor_args_str}
if (!({check_all_mutated_args_are_functional})) {{
if (({check_any_non_mutated_args_are_functional})) {{
// We want to disable this check if there are any XLA tensors.
// cpu_tensor.copy_(xla_tensor) is valid code.
if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{
// case 1: trying to mutate a non functional tensor with a functional tensor is an error
TORCH_INTERNAL_ASSERT(false,
"mutating a non-functional tensor with a functional tensor is not allowed.",
Expand Down