Skip to content

Commit

Permalink
[Functionalization] Fix test_simple_expand_on_2d_tensor (#4452)
Browse files Browse the repository at this point in the history
Summary:
This pull request fixes TestDynamicShapes. test_simple_expand_on_2d_tensor by:
1. enabling python dispatcher in the test such that dynamic ops are decomposed correctly,
2. implementing missing sym size ops, ge and clone.

Test Plan:
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" python test/test_dynamic_shapes.py -v TestDynamicShapes.test_simple_expand_on_2d_tensor

Fixes #4448
  • Loading branch information
alanwaketan authored and vanbasten23 committed Feb 2, 2023
1 parent 3a7546c commit 1ed625e
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 9 deletions.
42 changes: 39 additions & 3 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def test_simple_expand(self):
t6 = t5.expand(t2.size(0))
self.assertIn('<=10', torch_xla._XLAC._get_xla_tensors_text([t6]))
t6_cpu = t6.cpu()
self.assertEqual(t6_cpu.shape[0], 2)
self.assertEqual(t6_cpu.shape[0], 2) # 10 instead of 2

@unittest.skip("fails with functionalization")
def test_simple_expand_on_2d_tensor(self):
size1 = 5
size2 = 2
Expand Down Expand Up @@ -57,6 +56,10 @@ def test_simple_expand_on_2d_tensor(self):
self.assertEqual(t4.shape[0], 2)
self.assertEqual(t4.shape[1], size2)

# size_clone should be called as part of decomposition from
# the python dispatcher.
self.assertGreater(met.counter_value("xla::size_clone"), 0)

def test_simple_expand_add_dimension(self):
size1 = 5
size2 = 2
Expand All @@ -82,7 +85,6 @@ def test_wrap(self):
a3 = a2.shape[0] + 3 # tests wrap
self.assertIsInstance(a3, torch.SymInt)

@unittest.skip("fails with functionalization")
def test_sizeAdd(self):
size1 = 5
size2 = 2
Expand Down Expand Up @@ -142,6 +144,40 @@ def test_empty_symint(self):
self.assertIsInstance(t2.shape[1], int)
self.assertEqual(t2.shape[1], 2)

def test_sizeGe(self):
met.clear_all()

size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
# Create a SizeAdd IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] >= t2.shape[1]
self.assertGreater(met.counter_value("xla::size_ge"), 0)
# Exercises SizeGe::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 0)

def test_sizeLt(self):
met.clear_all()

size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
# Create a SizeAdd IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] < t2.shape[1]
self.assertGreater(met.counter_value("xla::size_lt"), 0)
# Exercises SizeLt::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 1)


if __name__ == '__main__':
assert os.environ['XLA_EXPERIMENTAL'] != ''
Expand Down
44 changes: 44 additions & 0 deletions torch_xla/csrc/ops/dynamic_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,50 @@ int64_t SizeEq::getDynamicValue() const {

std::string SizeEq::ToString() const { return "aten::size_eq"; }

SizeGe::SizeGe(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::ge")},
{a, b},
xla::ShapeUtil::MakeShape(
GetShapeDimensionType(/*device=*/nullptr), {}),
1) {
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
};

int64_t SizeGe::getDynamicValue() const {
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
return dim_node_0->getDynamicValue() >= dim_node_1->getDynamicValue() ? 1 : 0;
}

std::string SizeGe::ToString() const { return "aten::ge_size"; }

SizeLt::SizeLt(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::lt")},
{a, b},
xla::ShapeUtil::MakeShape(
GetShapeDimensionType(/*device=*/nullptr), {}),
1) {
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
};

int64_t SizeLt::getDynamicValue() const {
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
XLA_CHECK(dim_node_0);
XLA_CHECK(dim_node_1);
return dim_node_0->getDynamicValue() < dim_node_1->getDynamicValue() ? 1 : 0;
}

std::string SizeLt::ToString() const { return "aten::lt_size"; }

SizeConstant::SizeConstant(int64_t val)
: Scalar(c10::Scalar{val},
xla::ShapeUtil::MakeShape(
Expand Down
30 changes: 30 additions & 0 deletions torch_xla/csrc/ops/dynamic_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ class SizeEq : public XlaNode, public torch::lazy::DimensionNode {
}
};

class SizeGe : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeGe(torch::lazy::Value a, torch::lazy::Value b);
int64_t getDynamicValue() const override;
int64_t getStaticValue() const override {
TORCH_CHECK(false, "Comparison operators should be using getDynamicValue");
}
bool isSymbolic() const override { return true; }
std::string ToString() const override;
virtual XlaOpVector Lower(LoweringContext* loctx) const override {
// TODO: not sure we will ever need it?
TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!");
}
};

class SizeLt : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeLt(torch::lazy::Value a, torch::lazy::Value b);
int64_t getDynamicValue() const override;
int64_t getStaticValue() const override {
TORCH_CHECK(false, "Comparison operators should be using getDynamicValue");
}
bool isSymbolic() const override { return true; }
std::string ToString() const override;
virtual XlaOpVector Lower(LoweringContext* loctx) const override {
// TODO: not sure we will ever need it?
TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!");
}
};

class SizeAdd : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeAdd(torch::lazy::Value a, torch::lazy::Value b);
Expand Down
16 changes: 10 additions & 6 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,10 @@ c10::SymNode XLASymNodeImpl::gt(const c10::SymNode& other) {
}

c10::SymNode XLASymNodeImpl::lt(const c10::SymNode& other) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
TORCH_LAZY_FN_COUNTER("xla::size_");
auto p_other = dynamic_cast<XLASymNodeImpl*>(other.get());
auto n_lt = torch::lazy::MakeNode<SizeLt>(node(), p_other->node());
return c10::make_intrusive<XLASymNodeImpl>(n_lt);
}

c10::SymNode XLASymNodeImpl::le(const c10::SymNode& other) {
Expand All @@ -703,8 +705,10 @@ c10::SymNode XLASymNodeImpl::le(const c10::SymNode& other) {
}

c10::SymNode XLASymNodeImpl::ge(const c10::SymNode& other) {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
TORCH_LAZY_FN_COUNTER("xla::size_");
auto p_other = dynamic_cast<XLASymNodeImpl*>(other.get());
auto n_ge = torch::lazy::MakeNode<SizeGe>(node(), p_other->node());
return c10::make_intrusive<XLASymNodeImpl>(n_ge);
}

c10::SymNode XLASymNodeImpl::ceil() {
Expand Down Expand Up @@ -733,8 +737,8 @@ c10::SymNode XLASymNodeImpl::sym_max(const c10::SymNode& other) {
}

c10::SymNode XLASymNodeImpl::clone() {
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
<< " has not been implemented.";
TORCH_LAZY_FN_COUNTER("xla::size_");
return c10::make_intrusive<XLASymNodeImpl>(node());
}

c10::SymNode XLASymNodeImpl::sym_float() {
Expand Down

0 comments on commit 1ed625e

Please sign in to comment.