From 9659a436d1374612d7d2c7518a74dfd9ae821bc0 Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Thu, 8 Feb 2024 14:53:40 -0800 Subject: [PATCH] Add lowering support for math::AbsIOp (#2875) There is no lowering support for math::AbsIOp, so if the operand is an integer type, it will fail to lower to math::AbsFOp since the op operand #0 must be floating-point-like. --- .../TorchToLinalg/Uncategorized.cpp | 5 +++- projects/pt1/e2e_testing/xfail_sets.py | 6 ++-- .../test_suite/elementwise.py | 30 ++++++++++++++++--- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 794b755998fe..479bc1c0d620 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -424,8 +424,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b.create(loc, b.getFloatAttr(floatDtype, 0)); return createEqual(b, loc, floatDtype, self, zero); } - if (isa(op)) + if (isa(op)) { + if (payloadArgs[0].getType().isa()) + return b.create(loc, payloadArgs[0]); return b.create(loc, payloadArgs[0]); + } if (isa(op)) { Value abs = b.create(loc, payloadArgs[0]); Value infinity = b.create( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0d789a22db0e..26f3e843954f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -579,7 +579,8 @@ "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseWhereScalarModule_basic", - "ElementwiseAbsModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", "EmbeddingModuleI32_basic", @@ -1060,7 +1061,8 @@ "EinsumStaticContractRhsModule_basic", "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", - "ElementwiseAbsModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", "ElementwiseAddModule_basic", "ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarInt64Module_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index f711af6d4639..c1a827ffe108 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2113,7 +2113,7 @@ def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseAbsModule(torch.nn.Module): +class ElementwiseAbsFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -2127,9 +2127,31 @@ def forward(self, a): return torch.abs(a) -@register_test_case(module_factory=lambda: ElementwiseAbsModule()) -def ElementwiseAbsModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0)) +@register_test_case(module_factory=lambda: ElementwiseAbsFloatModule()) +def ElementwiseAbsFloatModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[[-1.0, 0.0, 1.0]]])) + + +# ============================================================================== + + +class ElementwiseAbsIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.abs(a) + + +@register_test_case(module_factory=lambda: ElementwiseAbsIntModule()) +def ElementwiseAbsIntModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[[-1, 0, 1]]])) # ==============================================================================