diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 5d378d9fa1ef..fcb4285c76af 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -53,6 +53,8 @@ namespace { at::Tensor to_meta(const at::Tensor& tensor) { // undefined tensors can't be converted to the meta device, since they don't // have sizes/strides + std::cout << "WONJOO: at aten_xla_type.cpp, to_meta1" << std::endl; + std::cout << "WONJOO: at aten_xla_type.cpp, to_meta2, tensor_is_functional=" << at::functionalization::impl::isFunctionalTensor(tensor) << std::endl; if (!tensor.defined()) return tensor; auto out = at::native::empty_strided_meta_symint( tensor.sym_sizes(), tensor.sym_strides(), @@ -458,6 +460,7 @@ at::Tensor& XLANativeFunctions::_amp_update_scale_(at::Tensor& current_scale, at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { + std::cout << "WONJOO: at aten_xla_type.cpp, _copy_from" << std::endl; TORCH_LAZY_FN_COUNTER("xla::"); auto dst_tensor = bridge::TryGetXlaTensor(dst); auto self_tensor = bridge::TryGetXlaTensor(self); @@ -652,6 +655,7 @@ at::Tensor XLANativeFunctions::argmin(const at::Tensor& self, at::Tensor XLANativeFunctions::as_strided_copy( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { + std::cout << "WONJOO: at aten_xla_type.cpp, as_strided_copy1" << std::endl; TORCH_LAZY_FN_COUNTER("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); auto xsize = XlaHelpers::I64List(size); @@ -671,6 +675,7 @@ at::Tensor XLANativeFunctions::as_strided_scatter( const at::Tensor& base, const at::Tensor& mutated_view, at::IntArrayRef size, at::IntArrayRef stride, c10::optional storage_offset) { + std::cout << "WONJOO: at aten_xla_type.cpp, as_strided_scatter1" << std::endl; TORCH_LAZY_FN_COUNTER("xla::"); auto base_ = bridge::GetXlaTensor(base); auto xsize = XlaHelpers::I64List(size); @@ -1148,6 +1153,7 @@ at::Tensor XLANativeFunctions::empty_symint( c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional /* memory_format */) { + std::cout << "WONJOO: at XLANativeFunctions::empty_symint" << std::endl; TORCH_LAZY_FN_COUNTER("xla::"); auto size = c10::asIntArrayRefSlow(sym_size); // PT empty*() are optimizations to avoid initializing the data when it is @@ -1381,6 +1387,8 @@ at::Tensor& XLANativeFunctions::index_fill_(at::Tensor& self, int64_t dim, at::Tensor& XLANativeFunctions::index_put_( at::Tensor& self, const c10::List>& indices, const at::Tensor& values, bool accumulate) { + std::cout << "WONJOO: at aten_xla_type.cpp, input_put_1" << std::endl; + std::cout << "WONJOO: at aten_xla_type.cpp, input_put_2, self.is_functional=" << at::functionalization::impl::isFunctionalTensor(self) << std::endl; TORCH_LAZY_FN_COUNTER("xla::"); XLA_CHECK(self.scalar_type() == values.scalar_type()); CanonicalIndexInfo canonical_index_info = @@ -1399,6 +1407,7 @@ at::Tensor& XLANativeFunctions::index_put_( canonical_index_info.start_dim, bridge::GetOrCreateXlaTensor(values, *device), accumulate, canonical_index_info.result_permutation); +std::cout << "WONJOO: at aten_xla_type.cpp, input_put_3, self.is_functional=" << at::functionalization::impl::isFunctionalTensor(self) << std::endl; return self; } @@ -1468,17 +1477,18 @@ at::Tensor XLANativeFunctions::lerp(const at::Tensor& self, bridge::GetXlaTensor(self), bridge::GetXlaTensor(end), weight)); } -at::Tensor XLANativeFunctions::lift_fresh(const at::Tensor& self) { - std::cout << "WONJOO: at aten_xla_type.cpp, lift_fresh" << std::endl; - return at::functionalization::impl::to_functional_tensor(self); - // return at::functionalization::functionalize_aten_op::call(self); +at::Tensor XLANativeFunctions::lift(const at::Tensor& tensor) { + std::cout << "WONJOO: at XLANativeFunctions::lift" << std::endl; + TORCH_INTERNAL_ASSERT( + !at::functionalization::impl::isFunctionalTensor(tensor)); + return at::functionalization::impl::to_functional_tensor(tensor); } -at::Tensor XLANativeFunctions::lift_fresh_copy(const at::Tensor& self) { - std::cout << "WONJOO: at aten_xla_type.cpp, lift_fresh_copy" << std::endl; - return at::functionalization::functionalize_aten_op::call(self); +at::Tensor XLANativeFunctions::lift_fresh(const at::Tensor& tensor) { + std::cout << "WONJOO: at XLANativeFunctions::lift_fresh" << std::endl; + TORCH_INTERNAL_ASSERT( + !at::functionalization::impl::isFunctionalTensor(tensor)); + return at::functionalization::impl::to_functional_tensor(tensor); } at::Tensor XLANativeFunctions::linspace(const at::Scalar& start, @@ -3114,6 +3124,7 @@ at::Tensor XLANativeFunctions::block_diag(at::TensorList tensors) { return at::functionalization::functionalize_aten_op::call(tensors); } + at::Tensor XLANativeFunctions::new_empty_strided_symint( const at::Tensor& self, at::SymIntArrayRef size, at::SymIntArrayRef stride, c10::optional dtype, c10::optional layout, diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 9bb0d70e2d9f..d93cba4eb9c9 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -197,8 +197,8 @@ supported: - leaky_relu_backward - lerp.Scalar - lerp.Tensor + - lift - lift_fresh - - lift_fresh_copy - linspace - log - log1p