Skip to content

Commit

Permalink
[ONNX][TorchToLinalg] Add support for dynamic dims in Interpolate low…
Browse files Browse the repository at this point in the history
…ering (#3351)

Addresses [Shark-Turbine
#196](nod-ai/SHARK-TestSuite#196)

Related tracker [Shark-Turbine
#566](nod-ai/SHARK-ModelDev#566)

Related onnx.Resize issues [Shark-Turbine
#616](nod-ai/SHARK-ModelDev#616)
  • Loading branch information
zjgarvey authored May 17, 2024
1 parent 513d89c commit 6cba93b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 28 deletions.
26 changes: 9 additions & 17 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2912,11 +2912,13 @@ class ConvertInterpolateOp
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();

if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) {
return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op");
}

SmallVector<Value, 2> outputSizeIntValues;
Value inputSizeH = getDimOp(rewriter, loc, input, 2);
inputSizeH = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inputSizeH);
Value inputSizeW = getDimOp(rewriter, loc, input, 3);
inputSizeW = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inputSizeW);

if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
SmallVector<Value, 2> ScaleFactorTorchFloat;
Expand All @@ -2927,8 +2929,6 @@ class ConvertInterpolateOp
SmallVector<Value, 2> ScaleFactorFloatValues;
ScaleFactorFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
Value inputSizeH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(inputType.getShape()[2]));
Value inputHFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizeH);
Value scale = rewriter.create<arith::TruncFOp>(loc, inputHFP.getType(),
Expand All @@ -2938,8 +2938,6 @@ class ConvertInterpolateOp
outputH =
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputH);

Value inputSizeW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(inputType.getShape()[3]));
Value inputWFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizeW);
scale = rewriter.create<arith::TruncFOp>(loc, inputWFP.getType(),
Expand All @@ -2960,11 +2958,9 @@ class ConvertInterpolateOp
outputSizeIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
}
int hDimOffset = 2;
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]);
dims[hDimOffset + 1] =
castIntToIndex(rewriter, loc, outputSizeIntValues[1]);
SmallVector<Value> dims = getTensorSizesUntilDim(rewriter, loc, input, 1);
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[0]));
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[1]));

Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(dims), inputType.getElementType());
Expand All @@ -2983,10 +2979,6 @@ class ConvertInterpolateOp
[&](OpBuilder &b, Location loc, ValueRange args) {
Value outputSizeH = outputSizeIntValues[0];
Value outputSizeW = outputSizeIntValues[1];
Value inputSizeH = b.create<arith::ConstantOp>(
loc, b.getI64IntegerAttr(inputType.getShape()[2]));
Value inputSizeW = b.create<arith::ConstantOp>(
loc, b.getI64IntegerAttr(inputType.getShape()[3]));
Value retVal;
if (mode == "nearest") {
retVal =
Expand Down
3 changes: 0 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2607,9 +2607,6 @@
"BernoulliTensorModule_basic",
# Failure - onnx_lowering: onnx.ReduceProd
"ReduceProdDimIntFloatModule_basic",
# Failure - onnx_lowering: onnx.Resize
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticSize_basic",
# Failure - onnx_lowering: onnx.ScatterElements
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMinModuleIncludeSelf",
Expand Down
12 changes: 4 additions & 8 deletions test/Conversion/TorchToLinalg/resize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[generic:.*]] = linalg.generic
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
// CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32
// CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x14:.*]] = linalg.index 3 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32
// CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64
Expand All @@ -23,7 +21,7 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
// CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32
// CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32
// CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
// CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32
// CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64
Expand Down Expand Up @@ -96,12 +94,10 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:

func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[GENERIC:.*]] = linalg.generic
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x14:.*]] = linalg.index 3 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
Expand Down

0 comments on commit 6cba93b

Please sign in to comment.