Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make empty_symint support dynamism. #4550

Merged
merged 1 commit into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5173,7 +5173,12 @@ TEST_F(AtenXlaTensorTest, TestOneIndexTransfer) {
}
}

// Temporarily disable test. See https://github.com/pytorch/xla/issues/4501
// Temporarily disable test. Original issue
// https://github.com/pytorch/xla/issues/4501 has been resolved. The next error
// is https://gist.github.com/vanbasten23/b3a79e0cc7f17edc0018eb83cdd5d738 (see
// https://github.com/pytorch/xla/issues/4432 for more info). The next error
// happens on TPU but not on CPU.
//
/*
TEST_F(AtenXlaTensorTest, TestNonzero) {
torch::Tensor a = torch::zeros({4, 2}, torch::TensorOptions(torch::kFloat));
Expand Down
38 changes: 38 additions & 0 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch, torch_xla
import torch_xla.core.xla_model as xm

pd = torch._C._EnablePythonDispatcher()
dev = xm.xla_device()


Expand Down Expand Up @@ -53,6 +54,24 @@ def test_simple_expand_on_2d_tensor(self):
self.assertEqual(t4.shape[0], 2)
self.assertEqual(t4.shape[1], size2)

def test_simple_expand_add_dimension(self):
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
t1[3][1] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
t3 = torch.ones(1, device=dev)

t4 = t3.expand(t2.shape[0], t2.shape[0])
self.assertIsInstance(t4.shape[0], torch.SymInt)
self.assertEqual(str(t4.shape[0]), '<=10')
self.assertEqual(t4.shape[0], 2)
self.assertIsInstance(t4.shape[1], torch.SymInt)
self.assertEqual(str(t4.shape[1]), '<=10')
self.assertEqual(t4.shape[1], 2)

def test_wrap(self):
a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=dev)
a2 = torch.nonzero(a1)
Expand Down Expand Up @@ -80,8 +99,27 @@ def test_sizeAdd(self):
t4 = t3.expand(dyn_size)
self.assertEqual(t4.size(0), 3)

def get_dynamic_tensor(self):
a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=dev)
a2 = torch.nonzero(a1)
return a2

def test_empty_symint(self):
# t1.shape= torch.Size([<=6, 2]) with real size [3, 2]
t1 = self.get_dynamic_tensor()
# Don't print t1 otherwise it would cause the test to crash.
self.assertIsInstance(t1.shape[0], torch.SymInt)
t2 = torch.empty(t1.shape, dtype=torch.int32, device=dev)
self.assertIsInstance(t2.shape[0], torch.SymInt)
self.assertEqual(str(t2.shape[0]), '<=6')
self.assertEqual(t2.shape[0], 3)
self.assertIsInstance(t2.shape[1], int)
self.assertEqual(t2.shape[1], 2)


if __name__ == '__main__':
assert os.environ['XLA_EXPERIMENTAL'] != ''
test = unittest.main()
# DISABLE PYTHON DISPATCHER FLAG
del pd
sys.exit(0 if test.result.wasSuccessful() else 1)
14 changes: 10 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1129,15 +1129,21 @@ at::Tensor XLANativeFunctions::empty_symint(
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> /* memory_format */) {
TORCH_LAZY_FN_COUNTER("xla::");
auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
c10::optional<at::IntArrayRef> int_sizes =
c10::asIntArrayRefSlowOpt(sym_size);
bool all_dims_static = int_sizes.has_value();
// PT empty*() are optimizations to avoid initializing the data when it is
// known it will be completely rewritten. But since for us doing a zero*()
// does not actually end up doing any memory initialization, we use that and
// avoid going to CPU for it. A common PT pattern is indeed doing empty() plus
// s_copy_().
return bridge::AtenFromXlaTensor(tensor_methods::full(
XlaHelpers::I64List(size), 0, GetXlaDeviceOrCurrent(device),
at::dtype_or_default(dtype)));
if (all_dims_static) {
return bridge::AtenFromXlaTensor(tensor_methods::full(
XlaHelpers::I64List(int_sizes.value()), 0,
GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype)));
}
return bridge::AtenFromXlaTensor(tensor_methods::full_symint(
sym_size, 0, GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype)));
}

at::Tensor XLANativeFunctions::empty_strided_symint(
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,24 @@ XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value,
device, *scalar_type);
}

XLATensorPtr full_symint(at::SymIntArrayRef sym_size,
const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device,
at::ScalarType scalar_type) {
XLA_CHECK(std::all_of(sym_size.begin(), sym_size.end(), [](at::SymInt dim) {
if (!dim.is_symbolic()) {
return dim >= 0;
}
return true;
})) << "Dimensions cannot be negative numbers";

return XLATensor::Create(
XLAGraphExecutor::Get()->GetIrValueForScalar(
fill_value, MakeXlaPrimitiveType(scalar_type, &device), sym_size,
device),
device, scalar_type);
}

XLATensorPtr gather(const XLATensorPtr& input, int64_t dim,
const XLATensorPtr& index) {
xla::Shape input_shape = input->shape();
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ XLATensorPtr full(absl::Span<const int64_t> size, const at::Scalar& fill_value,
XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device,
c10::optional<at::ScalarType> scalar_type);
XLATensorPtr full_symint(at::SymIntArrayRef sym_size,
const at::Scalar& fill_value,
const torch::lazy::BackendDevice& device,
at::ScalarType scalar_type);

XLATensorPtr gather(const XLATensorPtr& input, int64_t dim,
const XLATensorPtr& index);
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/dynamic_ir.h"
#include "torch_xla/csrc/ops/expand.h"
#include "torch_xla/csrc/ops/expand_symint.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/ops/view.h"
#include "torch_xla/csrc/ops/xla_ops.h"
Expand Down Expand Up @@ -255,6 +256,14 @@ torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
return ir_value;
}

torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, xla::PrimitiveType type,
c10::SymIntArrayRef sym_size, const torch::lazy::BackendDevice& device) {
torch::lazy::Value ir_value = GetIrValueForScalar(value, type, device);
SymIntElements size_elements = SymIntElements(sym_size);
return torch::lazy::MakeNode<ExpandSymInt>(ir_value, size_elements);
}

torch::lazy::Value XLAGraphExecutor::GetIrValueForScalar(
const at::Scalar& value, const xla::Shape& shape,
const torch::lazy::BackendDevice& device) {
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
const at::Scalar& value, xla::PrimitiveType type,
absl::Span<const int64_t> dimensions,
const torch::lazy::BackendDevice& device);
torch::lazy::Value GetIrValueForScalar(
const at::Scalar& value, xla::PrimitiveType type,
c10::SymIntArrayRef sym_size, const torch::lazy::BackendDevice& device);
torch::lazy::Value GetIrValueForScalar(
const at::Scalar& value, const xla::Shape& shape,
const torch::lazy::BackendDevice& device);
Expand Down