Skip to content

Commit

Permalink
Add an info cast to prims.squeeze decomposition (#3844)
Browse files Browse the repository at this point in the history
The onnx ingest sometimes has poorly propagated shape information. E.g.:

```mlir
...
    %9020 = torch.prims.squeeze %9010#1, %9019 : !torch.vtensor<[?,384,1],f32>, !torch.list<int> -> !torch.vtensor<[1,384],f32>
    return %9015, %9020 : !torch.vtensor<[1,384],f32>, !torch.vtensor<[1,384],f32>
  }
}
```

This occurred at the boundary of the onnx model
`migraphx_bert__bert-large-uncased`. Evidently, the output value tensor
info had more information than could be propagated forward. The
`PrimsSqueeze` lowering was returning a `!torch.vtensor<[?,384],f32>`
which was causing a type mismatch with the `func.return`.
  • Loading branch information
zjgarvey authored Nov 1, 2024
1 parent a82ba1c commit 3cfb7c8
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8958,7 +8958,8 @@ class DecomposePrimsSqueezeOp : public OpRewritePattern<PrimsSqueezeOp> {
}
result = *squeezeTensorInfo;
}
rewriter.replaceOp(op, result);
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(op, op.getType(),
result);
return success();
}
};
Expand Down

0 comments on commit 3cfb7c8

Please sign in to comment.