diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index b7a939cbdc3f40..10427b4675d990 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -51,6 +51,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value) ), value_(value) { + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); + TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); set_constructor_metadata(); } @@ -130,6 +132,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const ), value_(view_value) { + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); + TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); set_constructor_metadata(); // Copy the original tensor's ViewMeta vector and push the current one. if (!base->view_metas_.empty()) { @@ -168,7 +172,9 @@ void FunctionalTensorWrapper::mutate_view_meta(at::functionalization::ViewMeta m // So, these ops are special - they're mutation AND view ops. They get special codegen. // An example is transpose_, e.g. `a.transpose_()` // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. + at::AutoDispatchSkipFunctionalize guard; value_ = meta.forward_fn(value_, meta.out_index); + TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); } // Note [Functionalization: Mutation Removal] @@ -200,15 +206,20 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) { // TODO: going to need to change this if we want nested functionalize() transforms. TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(other)); value_ = other; + TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); // out= ops are allowed to resize the output tensors, mutating both the data and metadata of the tensor. // We need to propagate that metadata mutation to the wrapper (new size). - set_sizes_and_strides(value_.sym_sizes(), value_.sym_strides(), value_.sym_storage_offset()); + auto sizes_ = value_.sym_sizes(); + auto strides_ = value_.sym_strides(); + auto storage_offset_ = value_.sym_storage_offset(); + set_sizes_and_strides(sizes_, strides_, storage_offset_); if (dtype() != value_.unsafeGetTensorImpl()->dtype() || layout() != value_.unsafeGetTensorImpl()->layout()) { // .to() should not re-entrantly go through functionalization. at::AutoDispatchSkipFunctionalize guard; // and we want _to_copy() to show up in the graph, not the composite .to() operator // (this can happen if autograd has already run by the time we enter this code) value_ = at::_to_copy(value_, c10::TensorOptions().dtype(dtype()).layout(layout())); + TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); } } @@ -243,6 +254,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { // Then it's safe to throw out the old storage and replace it with the new, larger one. storage_ = c10::Storage(c10::make_intrusive(other)); value_ = other; + TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); generation_ = 0; // And update the metadata on the wrapper to reflect the new sizes and strides set_sizes_and_strides(value_.sizes(), value_.strides()); diff --git a/aten/src/ATen/native/CPUFallback.cpp b/aten/src/ATen/native/CPUFallback.cpp index e1c6b6fcda865c..83d9516766c4e7 100644 --- a/aten/src/ATen/native/CPUFallback.cpp +++ b/aten/src/ATen/native/CPUFallback.cpp @@ -65,7 +65,7 @@ c10::optional compute_target_device(std::vector& t_args } -void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { +void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) { auto& schema_args = op.schema().arguments(); const auto num_arguments = schema_args.size(); auto arguments = torch::jit::last(stack, num_arguments); @@ -176,9 +176,15 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { } else { dev_str << ""; } - TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ", - "but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ", - "falling back to run on the CPU, since the tensor's storage cannot be shared across devices."); + if (error_on_views) { + TORCH_CHECK(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ", + "but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ", + "falling back to run on the CPU, since the tensor's storage cannot be shared across devices."); + } else { + TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ", + "but it has no implementation for the backend \"", dev_str.str(), "\". View operators don't support ", + "falling back to run on the CPU, since the tensor's storage cannot be shared across devices."); + } } // Case (2): copy case. Copy the cpu output tensor to the original device. diff --git a/aten/src/ATen/native/CPUFallback.h b/aten/src/ATen/native/CPUFallback.h index 2d4dfc98aa06ed..acc603b7a2b017 100644 --- a/aten/src/ATen/native/CPUFallback.h +++ b/aten/src/ATen/native/CPUFallback.h @@ -11,7 +11,7 @@ namespace at { namespace native { // This function implements a boxed fallback to CPU. // External backends can add their own custom logging on top if it to customize their own CPU fallbacks. -TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); +TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false); // This is a helper function that backends can use to directly call their boxed CPU fallback // TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands. diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index bec46a06eec874..4e3d86fe15ccbb 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5062,7 +5062,7 @@ device_check: NoCheck device_guard: False dispatch: - CompositeExplicitAutograd: slice_scatter + CompositeExplicitAutogradNonFunctional: slice_scatter autogen: slice_scatter.out tags: core @@ -5071,7 +5071,7 @@ device_check: NoCheck device_guard: False dispatch: - CompositeExplicitAutograd: select_scatter_symint + CompositeExplicitAutogradNonFunctional: select_scatter_symint autogen: select_scatter.out - func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor @@ -5079,7 +5079,7 @@ device_check: NoCheck device_guard: False dispatch: - CompositeExplicitAutograd: diagonal_scatter + CompositeExplicitAutogradNonFunctional: diagonal_scatter autogen: diagonal_scatter.out - func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor @@ -5087,7 +5087,7 @@ device_check: NoCheck device_guard: False dispatch: - CompositeExplicitAutograd: as_strided_scatter_symint + CompositeExplicitAutogradNonFunctional: as_strided_scatter_symint autogen: as_strided_scatter.out - func: smm(Tensor self, Tensor mat2) -> Tensor diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 3ad1866250e17e..9fbb2f31162199 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1799,7 +1799,6 @@ def meta_select(self, dim, index): check( not (-index > size or index >= size), - lambda: f"select(): index {index} out of range for tensor of size " f"{self.size()} at dimension {dim}", IndexError, ) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 2c298d84eaecb0..9240abda0276a2 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -463,7 +463,7 @@ def __call__( or (ignore_subclass and isinstance(t, torch.Tensor)) or isinstance(t, FakeTensor) ): - if any( + if t.device.type != "xla" and any( [ t.is_sparse_csr, t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc], diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index a75142cae28018..62c4dc462413d1 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -51,6 +51,7 @@ #include #include +#include #include #include #include @@ -1304,7 +1305,7 @@ std::vector compute_shape_select_scatter( /*layout=*/c10::make_optional(src.layout()), /*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt); - auto out_meta = at::compositeexplicitautograd::select_scatter( + auto out_meta = at::compositeexplicitautogradnonfunctional::select_scatter( self_meta, src_meta, dim, index); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } @@ -1329,7 +1330,7 @@ std::vector compute_shape_diagonal_scatter( /*layout=*/c10::make_optional(src.layout()), /*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt); - auto out_meta = at::compositeexplicitautograd::diagonal_scatter( + auto out_meta = at::compositeexplicitautogradnonfunctional::diagonal_scatter( self_meta, src_meta, offset, dim1, dim2); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } @@ -1355,8 +1356,9 @@ std::vector compute_shape_slice_scatter_symint( /*layout=*/c10::make_optional(src.layout()), /*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt); - auto out_meta = at::compositeexplicitautograd::slice_scatter_symint( - self_meta, src_meta, dim, start, end, step); + auto out_meta = + at::compositeexplicitautogradnonfunctional::slice_scatter_symint( + self_meta, src_meta, dim, start, end, step); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } @@ -1380,8 +1382,9 @@ std::vector compute_shape_as_strided_scatter_symint( /*layout=*/c10::make_optional(src.layout()), /*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt); - auto out_meta = at::compositeexplicitautograd::as_strided_scatter_symint( - self_meta, src_meta, size, stride, storage_offset); + auto out_meta = + at::compositeexplicitautogradnonfunctional::as_strided_scatter_symint( + self_meta, src_meta, size, stride, storage_offset); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index b193bb7922b350..6d2181de081df4 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -432,6 +432,7 @@ Tensor internal_new_from_data( // to dispatch to it. // TODO: arguably it should have an autograd implementation that noops at::AutoDispatchBelowADInplaceOrView guard; + return at::lift_fresh(tensor); } diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index 871d227eba8f18..a08bdc6f963eb8 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -323,10 +323,21 @@ def gen_out_inplace_wrapper( for i, ret_name in enumerate(return_names) ) returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})' - else: + elif len(return_names) == 1: ret_name = return_names[0] updates = f"{copy_op}({func_res}, {ret_name});" returns = ret_name + else: + assert len(f.func.arguments.out) == 1 + returns = "" + out_arg = f.func.arguments.out[0] + if out_arg.type.is_list_like(): + updates = f"""\ + for (int64_t i = 0; i < {func_res}.size(); ++i) {{ + {copy_op}({func_res}[i], {out_arg.name}[i]); + }}""" + else: + updates = f"{copy_op}({func_res}, {out_arg.name});" functional_sig = self.wrapper_kernel_sig(g.functional) wrapper_name = sig.name() diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index bb858a36766c42..bb1f5603aa5d0f 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -541,6 +541,11 @@ def emit_inplace_functionalization_body( for a in f.func.arguments.flat_all if a.type.is_tensor_like() and a.annotation is None ] + non_mutated_tensor_names = [ + a.name + for a in f.func.arguments.flat_all + if a.type == BaseType(BaseTy.Tensor) and a.annotation is None + ] # all mutable inputs must be functional tensors in order to participate in functionalization check_all_mutated_args_are_functional = " && ".join( ["true"] @@ -556,6 +561,14 @@ def emit_inplace_functionalization_body( for a in non_mutated_names ] ) + + check_any_non_mutated_tensors_are_xla = " || ".join( + ["false"] + + [ + f"{a}.device().type() == c10::DeviceType::XLA" + for a in non_mutated_tensor_names + ] + ) # These are used in the cases where we don't functionalize and redispatch to the inplace op # case 1: we hit an inplace op that doesn't have an out-of-place equivalent # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops) @@ -619,7 +632,9 @@ def emit_inplace_functionalization_body( }} {unwrap_tensor_args_str} if (!({check_all_mutated_args_are_functional})) {{ - if (({check_any_non_mutated_args_are_functional})) {{ + // We want to disable this check if there are any XLA tensors. + // cpu_tensor.copy_(xla_tensor) is valid code. + if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{ // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.",