Skip to content

Commit

Permalink
[Linalg] Promote type for compare tensor op (iree-org#3416)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored Jun 4, 2024
1 parent 661be2d commit d59d0b6
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 75 deletions.
103 changes: 28 additions & 75 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,59 +149,18 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
}

template <typename OpTy>
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
std::is_same<OpTy, AtenLeTensorOp>() ||
std::is_same<OpTy, AtenGtTensorOp>() ||
std::is_same<OpTy, AtenGeTensorOp>() ||
std::is_same<OpTy, AtenEqTensorOp>() ||
std::is_same<OpTy, AtenNeTensorOp>(),
"unimplemented: op type not supported");

Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();

// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
op.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}

Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenNeTensorOp>()) {
return createNotEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
}
template <class T, class... Ts>
struct is_any_same : std::disjunction<std::is_same<T, Ts>...> {};

template <typename OpTy>
static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtScalarOp>() ||
std::is_same<OpTy, AtenLeScalarOp>() ||
std::is_same<OpTy, AtenEqScalarOp>() ||
std::is_same<OpTy, AtenNeScalarOp>() ||
std::is_same<OpTy, AtenGtScalarOp>() ||
std::is_same<OpTy, AtenGeScalarOp>(),
"unimplemented: op type not supported");
static Value createCompareOp(OpBuilder &b, Location loc, OpTy op, Value lhs,
Value rhs) {
static_assert(
is_any_same<OpTy, AtenLtScalarOp, AtenLeScalarOp, AtenEqScalarOp,
AtenNeScalarOp, AtenGtScalarOp, AtenGeScalarOp,
AtenLtTensorOp, AtenLeTensorOp, AtenGtTensorOp,
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp>(),
"unimplemented: op type not supported");

Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
Expand Down Expand Up @@ -229,22 +188,22 @@ static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
return nullptr;
}

if constexpr (std::is_same<OpTy, AtenLtScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenLtScalarOp, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenLeScalarOp, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenGtScalarOp, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenGeScalarOp, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenEqScalarOp, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenNeScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenNeScalarOp, AtenNeTensorOp>()) {
return createNotEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
Expand Down Expand Up @@ -892,28 +851,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::Atan2Op>(loc, lhs, rhs);
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto neTensor = dyn_cast<AtenNeTensorOp>(op)) {
return createCompareTensorOp(b, loc, neTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, neTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
Expand Down Expand Up @@ -996,27 +949,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}

if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
}

if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, geScalar, payloadArgs[0], operands[1]);
}

if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
}

if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, neScalar, payloadArgs[0], operands[1]);
}

if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
}

if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, leScalar, payloadArgs[0], operands[1]);
}

if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"InterpolateDynamicModule_scales_recompute_bilinear",
"ElementwiseFloatTensorGtIntTensorModule_basic",
}

LINALG_CRASHING_SET = {
Expand Down Expand Up @@ -2707,6 +2708,7 @@
"ElementwiseTanIntModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
Expand Down Expand Up @@ -3786,6 +3788,7 @@
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,51 @@ def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))


class ElementwiseIntTensorLtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.int64, True),
([-1], torch.float64, True),
]
)
def forward(self, x, y):
return torch.lt(x, y)


@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
def ElementwiseIntTensorLtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.rand(5, high=10).to(torch.float64))


class ElementwiseFloatTensorGtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1], torch.int32, True),
]
)
def forward(self, x, y):
return torch.gt(x, y)


@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
def ElementwiseFloatTensorGtIntTensorModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(3, 5, high=10).to(torch.float32),
tu.randint(5, high=10, dtype=torch.int32),
)


# ==============================================================================


Expand Down

0 comments on commit d59d0b6

Please sign in to comment.