From 09f502667b400865843aea90f6f6b6c104969be4 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 24 Jun 2024 15:22:50 -0700 Subject: [PATCH] `AtenTensorOp::fold` should not fold when result type is not fully specified (#3494) In one of our downstreams, we encountered an internal assertion failure in an intermediate pass from `AtenTensorOp::fold` invocation: ``` external/llvm-project/llvm/include/llvm/Support/Casting.h:650: decltype(auto) llvm::dyn_cast(const From &) [To = mlir::torch::Torch::NonValueTensorType, From = mlir::Type]: Assertion `detail::isPresent(Val) && "dyn_cast on a non-existent value"' failed. ``` for this snippet in the IR: ``` %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,1,15360],f32>} ... %218 = torch.aten.size %arg1 : !torch.tensor -> !torch.list %219 = torch.aten.tensor %218, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.tensor ``` Turns out this was [fixed](https://github.com/llvm/torch-mlir/pull/3189/files#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4fR3719) eventually (and we were on an old hash of torch-mlir). This PR submits just the lit test for test coverage on that specific change: ```c++ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { auto resultTy = dyn_cast(getType()); // lit test this if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; ... ``` --- test/Dialect/Torch/canonicalize.mlir | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 250f11cf67a1..aa943a5a1e5a 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1534,6 +1534,16 @@ func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) { return %67 : !torch.vtensor<[1],si64> } +// CHECK-LABEL: func.func @torch.aten.tensor$no_fold( +// CHECK: torch.aten.tensor %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.tensor +func.func @torch.aten.tensor$no_fold(%arg0: !torch.tensor) -> (!torch.tensor) { + %none = torch.constant.none + %false = torch.constant.bool false + %1 = torch.aten.size %arg0 : !torch.tensor -> !torch.list + %2 = torch.aten.tensor %1, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.tensor + return %2 : !torch.tensor +} + // CHECK-LABEL: func.func @torch.aten.tensor.float( // CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor) : !torch.vtensor<[],f32> func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {