Skip to content

Commit

Permalink
[Functionalization] Fix test_simple_expand_on_2d_tensor (#4452)
Browse files Browse the repository at this point in the history
Summary:
This pull request fixes TestDynamicShapes. test_simple_expand_on_2d_tensor by:
1. enabling python dispatcher in the test such that dynamic ops are decomposed correctly,
2. implementing missing sym size ops, ge and clone.

Test Plan:
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" python test/test_dynamic_shapes.py -v TestDynamicShapes.test_simple_expand_on_2d_tensor

Fixes #4448
  • Loading branch information
alanwaketan committed Jan 25, 2023
1 parent 7b6667f commit 9da8d2a
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 9 deletions.
44 changes: 41 additions & 3 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import unittest
import torch, torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

pd = torch._C._EnablePythonDispatcher()
dev = xm.xla_device()


Expand All @@ -22,9 +24,8 @@ def test_simple_expand(self):
t6 = t5.expand(t2.size(0))
self.assertIn('<=10', torch_xla._XLAC._get_xla_tensors_text([t6]))
t6_cpu = t6.cpu()
self.assertEqual(t6_cpu.shape[0], 2)
self.assertEqual(t6_cpu.shape[0], 2) # 10 instead of 2

@unittest.skip("fails with functionalization")
def test_simple_expand_on_2d_tensor(self):
size1 = 5
size2 = 2
Expand Down Expand Up @@ -55,14 +56,17 @@ def test_simple_expand_on_2d_tensor(self):
self.assertEqual(t4.shape[0], 2)
self.assertEqual(t4.shape[1], size2)

# size_clone should be called as part of decomposition from
# the python dispatcher.
self.assertGreater(met.counter_value("xla::size_clone"), 0)

def test_wrap(self):
a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=dev)
a2 = torch.nonzero(a1)
self.assertTrue(a2.shape[0] == 3)
a3 = a2.shape[0] + 3 # tests wrap
self.assertIsInstance(a3, torch.SymInt)

@unittest.skip("fails with functionalization")
def test_sizeAdd(self):
size1 = 5
size2 = 2
Expand All @@ -83,6 +87,40 @@ def test_sizeAdd(self):
t4 = t3.expand(dyn_size)
self.assertEqual(t4.size(0), 3)

def test_sizeGe(self):
met.clear_all()

size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
# Create a SizeAdd IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] >= t2.shape[1]
self.assertGreater(met.counter_value("xla::size_ge"), 0)
# Exercises SizeGe::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 0)

def test_sizeLt(self):
met.clear_all()

size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
# Create a SizeAdd IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] < t2.shape[1]
self.assertGreater(met.counter_value("xla::size_lt"), 0)
# Exercises SizeLt::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 1)


if __name__ == '__main__':
assert os.environ['XLA_EXPERIMENTAL'] != ''
Expand Down
44 changes: 44 additions & 0 deletions torch_xla/csrc/ops/dynamic_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,50 @@ int64_t SizeEq::getDynamicValue() const {

std::string SizeEq::ToString() const { return "aten::eq_size"; }

SizeGe::SizeGe(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::ge")},
{a, b},
xla::ShapeUtil::MakeShape(
GetShapeDimensionType(/*device=*/nullptr), {}),
1) {
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
};

int64_t SizeGe::getDynamicValue() const {
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
return dim_node_0->getDynamicValue() >= dim_node_1->getDynamicValue() ? 1 : 0;
}

std::string SizeGe::ToString() const { return "aten::ge_size"; }

SizeLt::SizeLt(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::lt")},
{a, b},
xla::ShapeUtil::MakeShape(
GetShapeDimensionType(/*device=*/nullptr), {}),
1) {
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
};

int64_t SizeLt::getDynamicValue() const {
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
return dim_node_0->getDynamicValue() < dim_node_1->getDynamicValue() ? 1 : 0;
}

std::string SizeLt::ToString() const { return "aten::lt_size"; }

SizeConstant::SizeConstant(int64_t val)
: Scalar(c10::Scalar{val},
xla::ShapeUtil::MakeShape(
Expand Down
30 changes: 30 additions & 0 deletions torch_xla/csrc/ops/dynamic_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ class SizeEq : public XlaNode, public torch::lazy::DimensionNode {
}
};

class SizeGe : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeGe(torch::lazy::Value a, torch::lazy::Value b);
int64_t getDynamicValue() const override;
int64_t getStaticValue() const override {
TORCH_CHECK(false, "Comparison operators should be using getDynamicValue");
}
bool isSymbolic() const override { return true; }
std::string ToString() const override;
virtual XlaOpVector Lower(LoweringContext* loctx) const override {
// TODO: not sure we will ever need it?
TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!");
}
};

class SizeLt : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeLt(torch::lazy::Value a, torch::lazy::Value b);
int64_t getDynamicValue() const override;
int64_t getStaticValue() const override {
TORCH_CHECK(false, "Comparison operators should be using getDynamicValue");
}
bool isSymbolic() const override { return true; }
std::string ToString() const override;
virtual XlaOpVector Lower(LoweringContext* loctx) const override {
// TODO: not sure we will ever need it?
TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!");
}
};

class SizeAdd : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeAdd(torch::lazy::Value a, torch::lazy::Value b);
Expand Down
16 changes: 10 additions & 6 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,10 @@ c10::SymNode XLASymNodeImpl::gt(const c10::SymNode& other) {
}

c10::SymNode XLASymNodeImpl::lt(const c10::SymNode& other) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
TORCH_LAZY_FN_COUNTER("xla::size_");
auto p_other = dynamic_cast<XLASymNodeImpl*>(other.get());
auto n_lt = torch::lazy::MakeNode<SizeLt>(node(), p_other->node());
return c10::make_intrusive<XLASymNodeImpl>(n_lt);
}

c10::SymNode XLASymNodeImpl::le(const c10::SymNode& other) {
Expand All @@ -683,8 +685,10 @@ c10::SymNode XLASymNodeImpl::le(const c10::SymNode& other) {
}

c10::SymNode XLASymNodeImpl::ge(const c10::SymNode& other) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
TORCH_LAZY_FN_COUNTER("xla::size_");
auto p_other = dynamic_cast<XLASymNodeImpl*>(other.get());
auto n_ge = torch::lazy::MakeNode<SizeGe>(node(), p_other->node());
return c10::make_intrusive<XLASymNodeImpl>(n_ge);
}

c10::SymNode XLASymNodeImpl::ceil() {
Expand Down Expand Up @@ -713,8 +717,8 @@ c10::SymNode XLASymNodeImpl::sym_max(const c10::SymNode& other) {
}

c10::SymNode XLASymNodeImpl::clone() {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
TORCH_LAZY_FN_COUNTER("xla::size_");
return c10::make_intrusive<XLASymNodeImpl>(node());
}

c10::SymNode XLASymNodeImpl::sym_float() {
Expand Down

0 comments on commit 9da8d2a

Please sign in to comment.