diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 794b755998fe..479bc1c0d620 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -424,8 +424,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b.create(loc, b.getFloatAttr(floatDtype, 0)); return createEqual(b, loc, floatDtype, self, zero); } - if (isa(op)) + if (isa(op)) { + if (payloadArgs[0].getType().isa()) + return b.create(loc, payloadArgs[0]); return b.create(loc, payloadArgs[0]); + } if (isa(op)) { Value abs = b.create(loc, payloadArgs[0]); Value infinity = b.create( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 973f75a2637a..088fbe36f0eb 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -574,7 +574,8 @@ "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseWhereScalarModule_basic", - "ElementwiseAbsModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", "EmbeddingModuleI32_basic", @@ -1055,7 +1056,8 @@ "EinsumStaticContractRhsModule_basic", "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", - "ElementwiseAbsModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", "ElementwiseAddModule_basic", "ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarInt64Module_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index f711af6d4639..c1a827ffe108 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2113,7 +2113,7 @@ def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseAbsModule(torch.nn.Module): +class ElementwiseAbsFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -2127,9 +2127,31 @@ def forward(self, a): return torch.abs(a) -@register_test_case(module_factory=lambda: ElementwiseAbsModule()) -def ElementwiseAbsModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0)) +@register_test_case(module_factory=lambda: ElementwiseAbsFloatModule()) +def ElementwiseAbsFloatModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[[-1.0, 0.0, 1.0]]])) + + +# ============================================================================== + + +class ElementwiseAbsIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.abs(a) + + +@register_test_case(module_factory=lambda: ElementwiseAbsIntModule()) +def ElementwiseAbsIntModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[[-1, 0, 1]]])) # ==============================================================================