Skip to content

Commit

Permalink
[TorchToLinalg] address a dtype mismatch in aten.multinomial loweri…
Browse files Browse the repository at this point in the history
…ng (llvm#3630)

Resolves <llvm#3628>
Unblocks a compile failure for one of the MiGraphx models
(`AgentModel`).
  • Loading branch information
zjgarvey authored Aug 20, 2024
1 parent f72770a commit f66908f
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 15 deletions.
15 changes: 14 additions & 1 deletion lib/Conversion/TorchToLinalg/Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,16 @@ class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {

Value initSum = rewriter.create<arith::ConstantOp>(
loc, f64Ty, rewriter.getF64FloatAttr(0.0));
int64_t srcWidth = cast<mlir::FloatType>(elemTy).getWidth();
if (srcWidth > 64)
op->emitWarning("Op bitwidth will be truncated from " +
std::to_string(srcWidth) + " bits to 64 bits.");
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value input = payloadArgs[0];
if (srcWidth < 64)
input = b.create<arith::ExtFOp>(loc, f64Ty, input);
if (srcWidth > 64)
input = b.create<arith::TruncFOp>(loc, f64Ty, input);
Value result = payloadArgs[1];
Value nextSum = b.create<arith::AddFOp>(loc, input, result);
b.create<linalg::YieldOp>(loc, nextSum);
Expand All @@ -310,7 +318,7 @@ class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {

// compute cdf in loop
Value initCdf = b.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy);
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty);
Value cdf =
b.create<scf::ForOp>(
loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
Expand All @@ -330,6 +338,11 @@ class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {
ind = ValueRange{jIndex, iIndex};
}
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
if (srcWidth < 64)
currWeight = b.create<arith::ExtFOp>(loc, f64Ty, currWeight);
if (srcWidth > 64)
currWeight =
b.create<arith::TruncFOp>(loc, f64Ty, currWeight);
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
Value currCum =
b.create<scf::IfOp>(
Expand Down
6 changes: 4 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,6 +2318,8 @@
"ElementwiseLog2IntModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseFmaxModule_basic",
"MultinomialModule2D_basic",
"MultinomialModule2D_F32",
"PixelShuffleModuleStaticRank4Float32_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
Expand Down Expand Up @@ -2346,6 +2348,8 @@
"MoveDimIntNegativeIndexModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ViewSizeFromOtherTensor_basic",
# incorrect shape generated by torch.onnx.export (needs an unsqueeze)
"MultinomialModule_basic",
# Failure - onnx_export
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
Expand Down Expand Up @@ -2849,8 +2853,6 @@
"ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"MultinomialModule_basic",
"MultinomialModule2D_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceAnyFloatModule_basic",
Expand Down
51 changes: 39 additions & 12 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,20 @@ def BernoulliPModule_basic(module, tu: TestUtils):
# ==============================================================================


class MultinomialModule(torch.nn.Module):
def __init__(self):
super().__init__()
def generate_sample_distr(sizes: list[int], torchdtype, tu: TestUtils):
assert len(sizes) == 1 or len(sizes) == 2
init = tu.rand(*sizes).to(dtype=torchdtype).abs()
normalized = init / (init.sum(-1, True, dtype=torchdtype))
return normalized


class MultinomialBase(torch.nn.Module):
def _forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a


class MultinomialModule(MultinomialBase):
@export
@annotate_args(
[
Expand All @@ -389,20 +399,36 @@ def __init__(self):
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
return self._forward(x).mean(dtype=torch.double)


@register_test_case(module_factory=lambda: MultinomialModule())
def MultinomialModule_basic(module, tu: TestUtils):
x = tu.rand(100).double()
x = generate_sample_distr([100], torch.float64, tu)
module.forward(x)


class MultinomialModule2D(torch.nn.Module):
def __init__(self):
super().__init__()
class MultinomialModule2DF32(MultinomialBase):
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
# note: this should really call mean(-1)
# for some reason, doing this causes a torchscript numerics error?
return self._forward(x).mean(dtype=torch.double)


@register_test_case(module_factory=lambda: MultinomialModule2DF32())
def MultinomialModule2D_F32(module, tu: TestUtils):
x = generate_sample_distr([10, 100], torch.float32, tu)
module.forward(x)


class MultinomialModule2D(MultinomialBase):
@export
@annotate_args(
[
Expand All @@ -411,13 +437,14 @@ def __init__(self):
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
# note: this should really call mean(-1)
# for some reason, doing this causes a torchscript numerics error?
return self._forward(x).mean(dtype=torch.double)


@register_test_case(module_factory=lambda: MultinomialModule2D())
def MultinomialModule2D_basic(module, tu: TestUtils):
x = tu.rand(10, 100).double()
x = generate_sample_distr([10, 100], torch.float64, tu)
module.forward(x)


Expand Down

0 comments on commit f66908f

Please sign in to comment.