From 4d51d555a7e1847b666b594122ca39ceda97cba4 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 18 Jan 2023 23:11:20 -0800 Subject: [PATCH] [Functionalization] Fix test_simple_expand_on_2d_tensor (#4452) Summary: This pull request fixes TestDynamicShapes. test_simple_expand_on_2d_tensor by: 1. enabling python dispatcher in the test such that dynamic ops are decomposed correctly, 2. implementing missing sym size ops, ge and clone. Test Plan: XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" python test/test_dynamic_shapes.py -v TestDynamicShapes.test_simple_expand_on_2d_tensor Fixes #4448 --- test/test_dynamic_shapes.py | 42 ++++++++++++++++++++++++++--- torch_xla/csrc/ops/dynamic_ir.cpp | 44 +++++++++++++++++++++++++++++++ torch_xla/csrc/ops/dynamic_ir.h | 30 +++++++++++++++++++++ torch_xla/csrc/tensor.cpp | 16 ++++++----- 4 files changed, 123 insertions(+), 9 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index a9d0b3f9030b..d8eb44cb360d 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -24,9 +24,8 @@ def test_simple_expand(self): t6 = t5.expand(t2.size(0)) self.assertIn('<=10', torch_xla._XLAC._get_xla_tensors_text([t6])) t6_cpu = t6.cpu() - self.assertEqual(t6_cpu.shape[0], 2) + self.assertEqual(t6_cpu.shape[0], 2) # 10 instead of 2 - @unittest.skip("fails with functionalization") def test_simple_expand_on_2d_tensor(self): size1 = 5 size2 = 2 @@ -57,6 +56,10 @@ def test_simple_expand_on_2d_tensor(self): self.assertEqual(t4.shape[0], 2) self.assertEqual(t4.shape[1], size2) + # size_clone should be called as part of decomposition from + # the python dispatcher. + self.assertGreater(met.counter_value("xla::size_clone"), 0) + def test_simple_expand_add_dimension(self): size1 = 5 size2 = 2 @@ -82,7 +85,6 @@ def test_wrap(self): a3 = a2.shape[0] + 3 # tests wrap self.assertIsInstance(a3, torch.SymInt) - @unittest.skip("fails with functionalization") def test_sizeAdd(self): size1 = 5 size2 = 2 @@ -142,6 +144,40 @@ def test_empty_symint(self): self.assertIsInstance(t2.shape[1], int) self.assertEqual(t2.shape[1], 2) + def test_sizeGe(self): + met.clear_all() + + size1 = 5 + size2 = 2 + t1 = torch.zeros([size1, size2], device=dev) + t1[3][0] = 1 + # t2 has size [<=10, 2] + t2 = torch.nonzero(t1) + # Create a SizeAdd IR node. + # t2.shape[1] generates a SizeConstant node. + dyn_size = t2.shape[0] >= t2.shape[1] + self.assertGreater(met.counter_value("xla::size_ge"), 0) + # Exercises SizeGe::getDynamicValue. + dynamic_size = int(dyn_size) + self.assertEqual(dynamic_size, 0) + + def test_sizeLt(self): + met.clear_all() + + size1 = 5 + size2 = 2 + t1 = torch.zeros([size1, size2], device=dev) + t1[3][0] = 1 + # t2 has size [<=10, 2] + t2 = torch.nonzero(t1) + # Create a SizeAdd IR node. + # t2.shape[1] generates a SizeConstant node. + dyn_size = t2.shape[0] < t2.shape[1] + self.assertGreater(met.counter_value("xla::size_lt"), 0) + # Exercises SizeLt::getDynamicValue. + dynamic_size = int(dyn_size) + self.assertEqual(dynamic_size, 1) + if __name__ == '__main__': assert os.environ['XLA_EXPERIMENTAL'] != '' diff --git a/torch_xla/csrc/ops/dynamic_ir.cpp b/torch_xla/csrc/ops/dynamic_ir.cpp index 32dd7ce6c62f..5875035718af 100644 --- a/torch_xla/csrc/ops/dynamic_ir.cpp +++ b/torch_xla/csrc/ops/dynamic_ir.cpp @@ -149,6 +149,50 @@ int64_t SizeEq::getDynamicValue() const { std::string SizeEq::ToString() const { return "aten::size_eq"; } +SizeGe::SizeGe(torch::lazy::Value a, torch::lazy::Value b) + : XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::ge")}, + {a, b}, + xla::ShapeUtil::MakeShape( + GetShapeDimensionType(/*device=*/nullptr), {}), + 1) { + const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); + const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); + XLA_CHECK(dim_node_0); + XLA_CHECK(dim_node_1); +}; + +int64_t SizeGe::getDynamicValue() const { + const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); + const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); + XLA_CHECK(dim_node_0); + XLA_CHECK(dim_node_1); + return dim_node_0->getDynamicValue() >= dim_node_1->getDynamicValue() ? 1 : 0; +} + +std::string SizeGe::ToString() const { return "aten::ge_size"; } + +SizeLt::SizeLt(torch::lazy::Value a, torch::lazy::Value b) + : XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::lt")}, + {a, b}, + xla::ShapeUtil::MakeShape( + GetShapeDimensionType(/*device=*/nullptr), {}), + 1) { + const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); + const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); + XLA_CHECK(dim_node_0); + XLA_CHECK(dim_node_1); +}; + +int64_t SizeLt::getDynamicValue() const { + const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0)); + const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1)); + XLA_CHECK(dim_node_0); + XLA_CHECK(dim_node_1); + return dim_node_0->getDynamicValue() < dim_node_1->getDynamicValue() ? 1 : 0; +} + +std::string SizeLt::ToString() const { return "aten::lt_size"; } + SizeConstant::SizeConstant(int64_t val) : Scalar(c10::Scalar{val}, xla::ShapeUtil::MakeShape( diff --git a/torch_xla/csrc/ops/dynamic_ir.h b/torch_xla/csrc/ops/dynamic_ir.h index 88fac1fec338..898e381c2ec8 100644 --- a/torch_xla/csrc/ops/dynamic_ir.h +++ b/torch_xla/csrc/ops/dynamic_ir.h @@ -68,6 +68,36 @@ class SizeEq : public XlaNode, public torch::lazy::DimensionNode { } }; +class SizeGe : public XlaNode, public torch::lazy::DimensionNode { + public: + SizeGe(torch::lazy::Value a, torch::lazy::Value b); + int64_t getDynamicValue() const override; + int64_t getStaticValue() const override { + TORCH_CHECK(false, "Comparison operators should be using getDynamicValue"); + } + bool isSymbolic() const override { return true; } + std::string ToString() const override; + virtual XlaOpVector Lower(LoweringContext* loctx) const override { + // TODO: not sure we will ever need it? + TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!"); + } +}; + +class SizeLt : public XlaNode, public torch::lazy::DimensionNode { + public: + SizeLt(torch::lazy::Value a, torch::lazy::Value b); + int64_t getDynamicValue() const override; + int64_t getStaticValue() const override { + TORCH_CHECK(false, "Comparison operators should be using getDynamicValue"); + } + bool isSymbolic() const override { return true; } + std::string ToString() const override; + virtual XlaOpVector Lower(LoweringContext* loctx) const override { + // TODO: not sure we will ever need it? + TORCH_CHECK(false, "Lowering comparison nodes isn't supported yet!"); + } +}; + class SizeAdd : public XlaNode, public torch::lazy::DimensionNode { public: SizeAdd(torch::lazy::Value a, torch::lazy::Value b); diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 964a3f7493bd..b057e0e3e676 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -693,8 +693,10 @@ c10::SymNode XLASymNodeImpl::gt(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::lt(const c10::SymNode& other) { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + TORCH_LAZY_FN_COUNTER("xla::size_"); + auto p_other = dynamic_cast(other.get()); + auto n_lt = torch::lazy::MakeNode(node(), p_other->node()); + return c10::make_intrusive(n_lt); } c10::SymNode XLASymNodeImpl::le(const c10::SymNode& other) { @@ -703,8 +705,10 @@ c10::SymNode XLASymNodeImpl::le(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::ge(const c10::SymNode& other) { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + TORCH_LAZY_FN_COUNTER("xla::size_"); + auto p_other = dynamic_cast(other.get()); + auto n_ge = torch::lazy::MakeNode(node(), p_other->node()); + return c10::make_intrusive(n_ge); } c10::SymNode XLASymNodeImpl::ceil() { @@ -733,8 +737,8 @@ c10::SymNode XLASymNodeImpl::sym_max(const c10::SymNode& other) { } c10::SymNode XLASymNodeImpl::clone() { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + TORCH_LAZY_FN_COUNTER("xla::size_"); + return c10::make_intrusive(node()); } c10::SymNode XLASymNodeImpl::sym_float() {