Skip to content

Commit

Permalink
AtenTensorOp::fold should not fold when result type is not fully sp…
Browse files Browse the repository at this point in the history
…ecified (llvm#3494)

In one of our downstreams, we encountered an internal assertion failure
in an intermediate pass from `AtenTensorOp::fold` invocation:
```
external/llvm-project/llvm/include/llvm/Support/Casting.h:650: decltype(auto) llvm::dyn_cast(const From &) [To = mlir::torch::Torch::NonValueTensorType, From = mlir::Type]: Assertion `detail::isPresent(Val) && "dyn_cast on a non-existent value"' failed.
```

for this snippet in the IR:
```
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,1,15360],f32>}
...
    %218 = torch.aten.size %arg1 : !torch.tensor -> !torch.list<int>
    %219 = torch.aten.tensor %218, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
```

Turns out this was
[fixed](https://github.com/llvm/torch-mlir/pull/3189/files#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4fR3719)
eventually (and we were on an old hash of torch-mlir). This PR submits
just the lit test for test coverage on that specific change:
```c++
OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
  auto resultTy = dyn_cast<ValueTensorType>(getType());
  // lit test this
  if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
    return nullptr;
  ...
```
  • Loading branch information
sjain-stanford authored Jun 24, 2024
1 parent 61f37ae commit 09f5026
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,16 @@ func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) {
return %67 : !torch.vtensor<[1],si64>
}

// CHECK-LABEL: func.func @torch.aten.tensor$no_fold(
// CHECK: torch.aten.tensor %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
func.func @torch.aten.tensor$no_fold(%arg0: !torch.tensor) -> (!torch.tensor) {
%none = torch.constant.none
%false = torch.constant.bool false
%1 = torch.aten.size %arg0 : !torch.tensor -> !torch.list<int>
%2 = torch.aten.tensor %1, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
return %2 : !torch.tensor
}

// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
Expand Down

0 comments on commit 09f5026

Please sign in to comment.