Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] MatMul operator in TVM seems fragile #16891

Closed
shaoyuyoung opened this issue Apr 16, 2024 · 1 comment
Closed

[Bug] MatMul operator in TVM seems fragile #16891

shaoyuyoung opened this issue Apr 16, 2024 · 1 comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@shaoyuyoung
Copy link

shaoyuyoung commented Apr 16, 2024

TVM seems to have strict restrictions on MatMul operator which means that it cannot use tensors with different shapes.

Look at this simple graph. In Pytorch and onnx, the model is correctly defined and the input and output shapes are exactly as shown below.
The evidence is here: https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
image

When I try to covert ONNX to TVM, I get an error indicating that the tensor shape is inconsistent. However, When converting Pytorch to TVM, everything is OK!

I guess one possible reason is that TorchScript plays a role in this but ONNX does not.

Moreover, look at the last line of the error message. I wonder why T.int64(1) is used here. It seems that TVM has a pretty fragile system of int64.

image

Expected behavior

Pass compilation as it can produce results in ONNX and PyTorch.

Actual behavior

Compilation failure

Traceback (most recent call last):
  18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  17: tvm::transform::Pass::operator()(tvm::IRModule) const
  16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  14: _ZN3tvm7runtime13PackedFun
  13: tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::DynamicToStatic()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  12: tvm::relay::DynamicToStatic(tvm::relay::Function, tvm::IRModule)
  11: tvm::relay::DynamicToStaticMutator::PrepareInput(tvm::RelayExpr const&)
  10: tvm::transform::Pass::operator()(tvm::IRModule) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  2: tvm::relay::TypeSolver::Solve()
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: bool tvm::relay::BatchMatmulRel<tvm::relay::BatchMatmulAttrs>(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  File "/root/anaconda3/conda-bld/tvm-package_1701590675822/work/src/relay/op/nn/nn.h", line 212
InternalError: Check failed: (reporter->AssertEQ(xk, yk)) is false: BatchDot: shapes of x and y is inconsistent,  x shape=[T.int64(1), 5, 5], y shape=[5, 5, 4]

Environment

Operating System: Ubuntu 18
TVM:0.15
Torch: 2.1.1
ONNX: 1.15.0

Steps to reproduce

Here is the script:

import torch
import torch.nn as nn
import tvm
from tvm import relay
import onnx

class DirectMatMulModel(nn.Module):
    def __init__(self):
        super(DirectMatMulModel, self).__init__()

    def forward(self, x1, x2, y1, y2):
        result1 = torch.matmul(x1, x2)
        result2 = torch.matmul(y1, y2)
        final_result = torch.matmul(result1, result2)
        return final_result


torch_model = DirectMatMulModel().eval()

x1 = torch.randn(5, 1)
x2 = torch.randn(1)
y1 = torch.randn(5, 4, 5)
y2 = torch.randn(5)

scripted_model = torch.jit.trace(torch_model, (x1, x2, y1, y2))

torch.onnx.export(torch_model,
                      (x1, x2, y1, y2),
                      "direct_matmul_model.onnx",
                      export_params=True,
                      opset_version=12,
                      do_constant_folding=True,
                      input_names=['x1', 'x2', 'y1', 'y2'],
                      output_names=['output'])

onnx_model = onnx.load("direct_matmul_model.onnx")
onnx.checker.check_model(onnx_model)

def compile_onnx():
    mod_from_onnx, params_onnx = relay.frontend.from_onnx(onnx_model, shape={'x1': [5, 1], 'x2': [1], 'y1': [5, 4, 5], 'y2': [5]})
    with tvm.transform.PassContext(opt_level=4):
        executor = relay.build_module.create_executor(
            'graph', mod_from_onnx, tvm.cpu(), 'llvm', params_onnx
        ).evaluate()

def compile_torch():
    mod_from_torch, params_torch = relay.frontend.from_pytorch(scripted_model, input_infos=[('x1', [5, 1]), ('x2', [1]), ('y1', [5, 4, 5]), ('y2', [5])])
    with tvm.transform.PassContext(opt_level=4):
        executor = relay.build_module.create_executor(
            'graph', mod_from_torch, tvm.cpu(), 'llvm', params_torch
        ).evaluate()

try:
    compile_torch()
except Exception as e:
    print(f"torch fail\n {e}")

try:
    compile_onnx()
except Exception as e:
    print(f"onnx fail\n {e}")

Triage

  • needs-triage
@shaoyuyoung shaoyuyoung added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Apr 16, 2024
xhmelon added a commit to xhmelon/tvm that referenced this issue May 1, 2024
This commit provides batch_matmul conversions between a 3D or above
matrix and a 1D matrix with proper broadcasting, which improves
the robustness of the ONNX frontend. This issue was captured in apache#16891.
Hzfengsy pushed a commit that referenced this issue May 5, 2024
* [Bugfix][VTA] Fix FSIM compile error on macOS.

VTA FSIM could not be built on macOS, for it leverages malloc.h and
memalign, yet both have been deprecated and are not provided by
macOS. This issue was captured in #13173.

This commit stops including malloc.h in VTA Runtime as stdlib.h has
provided functions we need.

This commit uses posix_memalign instead of memalign. It is a portable standard function.

* Fix format.

* [Bugfix][ONNX] Improve broadcast and batch_matmul conversion

This commit provides batch_matmul conversions between a 3D or above
matrix and a 1D matrix with proper broadcasting, which improves
the robustness of the ONNX frontend. This issue was captured in #16891.

* Fix format.
@shaoyuyoung
Copy link
Author

shaoyuyoung commented May 6, 2024

Thx for your fixes. 😊@xhmelon

It seems everything is OK!

This issue can be closed although commit has not been successfully built at this time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

1 participant