Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Onnx Zoo Models] expected offsets to be non-negative, but got -1 in TensorPadToTensorInsertSlicePass #19935

Open
pravg-amd opened this issue Feb 7, 2025 · 1 comment
Labels
bug 🐞 Something isn't working

Comments

@pravg-amd
Copy link

What happened?

For the below IR,

module {
  func.func @nasnetalarge_Opset16_timm(%arg0: !torch.vtensor<[?,168,?,?],f32>) ->  !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
    %1689 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__362> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
    %1690 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__363> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64>
    %1691 = torch.operator "onnx.ConstantOfShape"(%1689) {torch.onnx.value = dense_resource<__364> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64>
    %1692 = torch.operator "onnx.Concat"(%1690, %1691) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64>
    %1693 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__365> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64>
    %1694 = torch.operator "onnx.Reshape"(%1692, %1693) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64>
    %1695 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__366> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
    %1696 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__367> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
    %1697 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__368> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
    %1698 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__369> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
    %1699 = torch.operator "onnx.Slice"(%1694, %1696, %1697, %1695, %1698) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64>
    %1700 = torch.operator "onnx.Transpose"(%1699) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64>
    %1701 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__370> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
    %1702 = torch.operator "onnx.Reshape"(%1700, %1701) {torch.onnx.allowzero = 0 : si64} : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64>
    %1703 = torch.operator "onnx.Cast"(%1702) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64>
    %cst2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__371> : tensor<f32>} : () -> !torch.vtensor<[],f32>
    %0 = torch.operator "onnx.Pad"(%arg0, %1703, %cst2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,168,?,?],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32>
    return %0 : !torch.vtensor<[?,?,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      __362: "0x080000000400000000000000",
      __363: "0x08000000FFFFFFFFFFFFFFFF0100000000000000FFFFFFFFFFFFFFFF0100000000000000",
      __364: "0x080000000000000000000000",
      __365: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
      __366: "0x080000000000000000000000",
      __367: "0x08000000FFFFFFFFFFFFFFFF",
      __368: "0x080000000100000000000080",
      __369: "0x08000000FFFFFFFFFFFFFFFF",
      __370: "0x08000000FFFFFFFFFFFFFFFF",
      __371: "0x0800000000000000",
      __372: "0x080000000200000000000000"
    }
  }
#-}

Getting the error in TensorPadToTensorInsertSlicePass as

test.mlir:19:10: error: expected offsets to be non-negative, but got -1
    %0 = torch.operator "onnx.Pad"(%arg0, %1703, %cst2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,168,?,?],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32>
         ^
test.mlir:19:10: note: see current operation: %16 = "tensor.insert_slice"(%7, %12, %13, %14, %15) <{operandSegmentSizes = array<i32: 1, 1, 0, 3, 0>, static_offsets = array<i64: 0, 0, -1, -1>, static_sizes = array<i64: -9223372036854775808, 168, -9223372036854775808, -9223372036854775808>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<?x168x?x?xf32>, tensor<?x168x?x?xf32>, index, index, index) -> tensor<?x168x?x?xf32>

Steps to reproduce your issue

Command:

iree-compile test.mlir --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host -o outs.vmfb

What component(s) does this issue relate to?

Compiler

Version information

IREE commit : 5f7b471

Additional context

Model : Onnx Zoo Models (nasnetalarge_Opset16_timm / pnasnet5large_Opset16_timm)

Dump IR using : '--mlir-print-ir-after-all --mlir-print-ir-before-all --mlir-disable-threading --mlir-elide-elementsattrs-if-larger=4'

IR Dump: https://gist.github.com/pravg-amd/167c7ecdac7aa2fdc3b8beaeb7152268

@hanhanW
Copy link
Contributor

hanhanW commented Feb 7, 2025

I think it is because you are padding negative sizes.

// -----// IR Dump Before ConvertTorchToLinalg (convert-torch-to-linalg) //----- //
func.func @nasnetalarge_Opset16_timm(%arg0: !torch.vtensor<[?,168,?,?],f32>) -> !torch.vtensor<[?,168,?,?],f32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
  %int0 = torch.constant.int 0
  %int-1 = torch.constant.int -1
  %int1 = torch.constant.int 1
  %float0.000000e00 = torch.constant.float 0.000000e+00
  %0 = torch.prim.ListConstruct %int-1, %int1, %int-1, %int1, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %1 = torch.aten.constant_pad_nd %arg0, %0, %float0.000000e00 : !torch.vtensor<[?,168,?,?],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[?,168,?,?],f32>
  return %1 : !torch.vtensor<[?,168,?,?],f32>
}

// -----// IR Dump After ConvertTorchToLinalg (convert-torch-to-linalg) //----- //
func.func @nasnetalarge_Opset16_timm(%arg0: !torch.vtensor<[?,168,?,?],f32>) -> !torch.vtensor<[?,168,?,?],f32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
  %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,168,?,?],f32> -> tensor<?x168x?x?xf32>
  %int0 = torch.constant.int 0
  %int-1 = torch.constant.int -1
  %int1 = torch.constant.int 1
  %float0.000000e00 = torch.constant.float 0.000000e+00
  %1 = torch_c.to_f64 %float0.000000e00
  %2 = torch.prim.ListConstruct %int-1, %int1, %int-1, %int1, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %3 = arith.truncf %1 : f64 to f32
  %padded = tensor.pad %0 low[0, 0, -1, -1] high[0, 0, 1, 1] {
  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
    tensor.yield %3 : f32
  } : tensor<?x168x?x?xf32> to tensor<?x168x?x?xf32>
  %cast = tensor.cast %padded : tensor<?x168x?x?xf32> to tensor<?x168x?x?xf32>
  %4 = torch_c.from_builtin_tensor %cast : tensor<?x168x?x?xf32> -> !torch.vtensor<[?,168,?,?],f32>
  return %4 : !torch.vtensor<[?,168,?,?],f32>
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants