Skip to content

Commit

Permalink
Cast static/dynamic shape for onnx.If branches to match result type (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 authored Nov 1, 2024
1 parent 3cfb7c8 commit 39d69db
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
34 changes: 29 additions & 5 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,15 +211,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
inlineIfCase(*elseRegion, primIfOp.getElseRegion());

auto replaceTerminator = [&](Region &region) {
auto replaceTerminator = [&](Region &region) -> LogicalResult {
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = region.front().getTerminator();
rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
terminator, terminator->getOperands());

// cast result shape if there is static/dynamic difference
llvm::SmallVector<Value> terOperands = terminator->getOperands();
if (terOperands.size() != resultTypes.size())
return failure();
for (size_t i = 0; i < terOperands.size(); i++) {
mlir::Type terType = terOperands[i].getType();
int64_t terOpRank =
dyn_cast<Torch::ValueTensorType>(terType).getSizes().size();
int64_t resRank = dyn_cast<Torch::ValueTensorType>(resultTypes[i])
.getSizes()
.size();
if (terOpRank != resRank)
return failure();
if (terType != resultTypes[i]) {
Value cast = rewriter.create<Torch::TensorStaticInfoCastOp>(
binder.getLoc(), resultTypes[i], terOperands[i]);
terOperands[i] = cast;
}
}

rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(terminator,
terOperands);
return success();
};
replaceTerminator(primIfOp.getThenRegion());
replaceTerminator(primIfOp.getElseRegion());
if (failed(replaceTerminator(primIfOp.getThenRegion())) ||
failed(replaceTerminator(primIfOp.getElseRegion())))
return rewriter.notifyMatchFailure(binder.op,
"terminator replace failure");

rewriter.replaceOp(binder.op, primIfOp.getResults());
return success();
Expand Down
21 changes: 21 additions & 0 deletions test/Conversion/TorchOnnxToTorch/ops/if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,24 @@ func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<
}
return %0 : !torch.vtensor<[1],f32>
}

// -----

// CHECK-LABEL: func.func @test_ifop_cast_shape
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[?],si64>)
// CHECK-DAG: %[[CAST:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64>
// CHECK-DAG: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[?],si64>
// CHECK-DAG: } else {
// CHECK-DAG: %[[SQUEEZE:.*]] = torch.prims.squeeze %arg1, %{{.*}} : !torch.vtensor<[?,1],si64>, !torch.list<int> -> !torch.vtensor<[?],si64>
// CHECK-DAG: torch.prim.If.yield %[[SQUEEZE]] : !torch.vtensor<[?],si64>
func.func @test_ifop_cast_shape(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[?,1],si64>) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[?],si64> {
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%2 = torch.operator "onnx.Squeeze"(%arg1, %1) : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64>
torch.operator_terminator %2 : !torch.vtensor<[?],si64>
}, {
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<0xsi64>} : () -> !torch.vtensor<[0],si64>
torch.operator_terminator %1 : !torch.vtensor<[0],si64>
}
return %0 : !torch.vtensor<[?],si64>
}

0 comments on commit 39d69db

Please sign in to comment.