From 3cfb7c8df6d83e817815be8cec62e118dcceca9d Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 1 Nov 2024 12:10:47 -0500 Subject: [PATCH] Add an info cast to `prims.squeeze` decomposition (#3844) The onnx ingest sometimes has poorly propagated shape information. E.g.: ```mlir ... %9020 = torch.prims.squeeze %9010#1, %9019 : !torch.vtensor<[?,384,1],f32>, !torch.list -> !torch.vtensor<[1,384],f32> return %9015, %9020 : !torch.vtensor<[1,384],f32>, !torch.vtensor<[1,384],f32> } } ``` This occurred at the boundary of the onnx model `migraphx_bert__bert-large-uncased`. Evidently, the output value tensor info had more information than could be propagated forward. The `PrimsSqueeze` lowering was returning a `!torch.vtensor<[?,384],f32>` which was causing a type mismatch with the `func.return`. --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index bbd1f3bf855b..004aaa5a77e5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8958,7 +8958,8 @@ class DecomposePrimsSqueezeOp : public OpRewritePattern { } result = *squeezeTensorInfo; } - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, op.getType(), + result); return success(); } };