diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index e696ccf4f1df..5e2ec31863c4 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -5173,7 +5173,12 @@ TEST_F(AtenXlaTensorTest, TestOneIndexTransfer) { } } -// Temporarily disable test. See https://github.com/pytorch/xla/issues/4501 +// Temporarily disable test. Original issue +// https://github.com/pytorch/xla/issues/4501 has been resolved. The next error +// is https://gist.github.com/vanbasten23/b3a79e0cc7f17edc0018eb83cdd5d738 (see +// https://github.com/pytorch/xla/issues/4432 for more info). The next error +// happens on TPU but not on CPU. +// /* TEST_F(AtenXlaTensorTest, TestNonzero) { torch::Tensor a = torch::zeros({4, 2}, torch::TensorOptions(torch::kFloat)); diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 5c504e4d456d..353cd8ae555b 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4,6 +4,7 @@ import torch, torch_xla import torch_xla.core.xla_model as xm +pd = torch._C._EnablePythonDispatcher() dev = xm.xla_device() @@ -53,6 +54,24 @@ def test_simple_expand_on_2d_tensor(self): self.assertEqual(t4.shape[0], 2) self.assertEqual(t4.shape[1], size2) + def test_simple_expand_add_dimension(self): + size1 = 5 + size2 = 2 + t1 = torch.zeros([size1, size2], device=dev) + t1[3][0] = 1 + t1[3][1] = 1 + # t2 has size [<=10, 2] + t2 = torch.nonzero(t1) + t3 = torch.ones(1, device=dev) + + t4 = t3.expand(t2.shape[0], t2.shape[0]) + self.assertIsInstance(t4.shape[0], torch.SymInt) + self.assertEqual(str(t4.shape[0]), '<=10') + self.assertEqual(t4.shape[0], 2) + self.assertIsInstance(t4.shape[1], torch.SymInt) + self.assertEqual(str(t4.shape[1]), '<=10') + self.assertEqual(t4.shape[1], 2) + def test_wrap(self): a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=dev) a2 = torch.nonzero(a1) @@ -80,8 +99,27 @@ def test_sizeAdd(self): t4 = t3.expand(dyn_size) self.assertEqual(t4.size(0), 3) + def get_dynamic_tensor(self): + a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=dev) + a2 = torch.nonzero(a1) + return a2 + + def test_empty_symint(self): + # t1.shape= torch.Size([<=6, 2]) with real size [3, 2] + t1 = self.get_dynamic_tensor() + # Don't print t1 otherwise it would cause the test to crash. + self.assertIsInstance(t1.shape[0], torch.SymInt) + t2 = torch.empty(t1.shape, dtype=torch.int32, device=dev) + self.assertIsInstance(t2.shape[0], torch.SymInt) + self.assertEqual(str(t2.shape[0]), '<=6') + self.assertEqual(t2.shape[0], 3) + self.assertIsInstance(t2.shape[1], int) + self.assertEqual(t2.shape[1], 2) + if __name__ == '__main__': assert os.environ['XLA_EXPERIMENTAL'] != '' test = unittest.main() + # DISABLE PYTHON DISPATCHER FLAG + del pd sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 4ac2dd6425b9..e3a1d8cb0222 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1129,15 +1129,21 @@ at::Tensor XLANativeFunctions::empty_symint( c10::optional pin_memory, c10::optional /* memory_format */) { TORCH_LAZY_FN_COUNTER("xla::"); - auto size = C10_AS_INTARRAYREF_SLOW(sym_size); + c10::optional int_sizes = + c10::asIntArrayRefSlowOpt(sym_size); + bool all_dims_static = int_sizes.has_value(); // PT empty*() are optimizations to avoid initializing the data when it is // known it will be completely rewritten. But since for us doing a zero*() // does not actually end up doing any memory initialization, we use that and // avoid going to CPU for it. A common PT pattern is indeed doing empty() plus // s_copy_(). - return bridge::AtenFromXlaTensor(tensor_methods::full( - XlaHelpers::I64List(size), 0, GetXlaDeviceOrCurrent(device), - at::dtype_or_default(dtype))); + if (all_dims_static) { + return bridge::AtenFromXlaTensor(tensor_methods::full( + XlaHelpers::I64List(int_sizes.value()), 0, + GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype))); + } + return bridge::AtenFromXlaTensor(tensor_methods::full_symint( + sym_size, 0, GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype))); } at::Tensor XLANativeFunctions::empty_strided_symint( diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 97d6a4cdb7f6..c2c53ccf9b1e 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1236,6 +1236,24 @@ XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value, device, *scalar_type); } +XLATensorPtr full_symint(at::SymIntArrayRef sym_size, + const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, + at::ScalarType scalar_type) { + XLA_CHECK(std::all_of(sym_size.begin(), sym_size.end(), [](at::SymInt dim) { + if (!dim.is_symbolic()) { + return dim >= 0; + } + return true; + })) << "Dimensions cannot be negative numbers"; + + return XLATensor::Create( + XLAGraphExecutor::Get()->GetIrValueForScalar( + fill_value, MakeXlaPrimitiveType(scalar_type, &device), sym_size, + device), + device, scalar_type); +} + XLATensorPtr gather(const XLATensorPtr& input, int64_t dim, const XLATensorPtr& index) { xla::Shape input_shape = input->shape(); diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 36ec4e959f12..947453f7efd0 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -359,6 +359,10 @@ XLATensorPtr full(absl::Span size, const at::Scalar& fill_value, XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value, const torch::lazy::BackendDevice& device, c10::optional scalar_type); +XLATensorPtr full_symint(at::SymIntArrayRef sym_size, + const at::Scalar& fill_value, + const torch::lazy::BackendDevice& device, + at::ScalarType scalar_type); XLATensorPtr gather(const XLATensorPtr& input, int64_t dim, const XLATensorPtr& index); diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index be6eb63c21bb..4c2482cf1b1b 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -46,6 +46,7 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/dynamic_ir.h" #include "torch_xla/csrc/ops/expand.h" +#include "torch_xla/csrc/ops/expand_symint.h" #include "torch_xla/csrc/ops/ops.h" #include "torch_xla/csrc/ops/view.h" #include "torch_xla/csrc/ops/xla_ops.h" @@ -255,6 +256,14 @@ torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar( return ir_value; } +torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar( + const at::Scalar& value, xla::PrimitiveType type, + c10::SymIntArrayRef sym_size, const torch::lazy::BackendDevice& device) { + torch::lazy::Value ir_value = GetIrValueForScalar(value, type, device); + SymIntElements size_elements = SymIntElements(sym_size); + return torch::lazy::MakeNode(ir_value, size_elements); +} + torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar( const at::Scalar& value, const xla::Shape& shape, const torch::lazy::BackendDevice& device) { diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 1eb9d763a5f7..7a0e35ac09d3 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -76,6 +76,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { const at::Scalar& value, xla::PrimitiveType type, absl::Span dimensions, const torch::lazy::BackendDevice& device); + torch::lazy::Value GetIrValueForScalar( + const at::Scalar& value, xla::PrimitiveType type, + c10::SymIntArrayRef sym_size, const torch::lazy::BackendDevice& device); torch::lazy::Value GetIrValueForScalar( const at::Scalar& value, const xla::Shape& shape, const torch::lazy::BackendDevice& device);