Skip to content

Commit

Permalink
Add lowering support for math::AbsIOp (#2875)
Browse files Browse the repository at this point in the history
There is no lowering support for math::AbsIOp, so if the operand is an
integer type, it will fail to lower to math::AbsFOp since the op operand
#0 must be floating-point-like.
  • Loading branch information
aviator19941 authored Feb 8, 2024
1 parent 44f8f89 commit 9659a43
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
5 changes: 4 additions & 1 deletion lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
return createEqual(b, loc, floatDtype, self, zero);
}
if (isa<AtenAbsOp>(op))
if (isa<AtenAbsOp>(op)) {
if (payloadArgs[0].getType().isa<IntegerType>())
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
}
if (isa<AtenIsinfOp>(op)) {
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
Value infinity = b.create<arith::ConstantOp>(
Expand Down
6 changes: 4 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseAbsModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic",
Expand Down Expand Up @@ -1060,7 +1061,8 @@
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"ElementwiseAbsModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"ElementwiseAddModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
Expand Down
30 changes: 26 additions & 4 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,7 +2113,7 @@ def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseAbsModule(torch.nn.Module):
class ElementwiseAbsFloatModule(torch.nn.Module):

def __init__(self):
super().__init__()
Expand All @@ -2127,9 +2127,31 @@ def forward(self, a):
return torch.abs(a)


@register_test_case(module_factory=lambda: ElementwiseAbsModule())
def ElementwiseAbsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0))
@register_test_case(module_factory=lambda: ElementwiseAbsFloatModule())
def ElementwiseAbsFloatModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[[-1.0, 0.0, 1.0]]]))


# ==============================================================================


class ElementwiseAbsIntModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
])
def forward(self, a):
return torch.abs(a)


@register_test_case(module_factory=lambda: ElementwiseAbsIntModule())
def ElementwiseAbsIntModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[[-1, 0, 1]]]))


# ==============================================================================
Expand Down

0 comments on commit 9659a43

Please sign in to comment.