Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Runtime Error on BART #1532

Closed
gs-olive opened this issue Dec 6, 2022 · 3 comments
Closed

🐛 [Bug] Runtime Error on BART #1532

gs-olive opened this issue Dec 6, 2022 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented Dec 6, 2022

Bug Description

When performing inference with a Torch-TRT converted BART network (https://huggingface.co/facebook/bart-base), the following error is encountered:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(243): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(325): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(850): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/models/bart/modeling_bart.py(1233): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1190): _call_impl
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(976): trace_module
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(759): trace
RuntimeError: The size of tensor a (1536) must match the size of tensor b (128) at non-singleton dimension 3

Note that compilation of the model succeeds.

To Reproduce

Steps to reproduce the behavior:

  1. Run torch_tensorrt.compile with BART model as input, using fp32 precision.
  2. Choose two fixed-size inputs of shape [1, 128] and [1, 128] and enable truncate_long_and_double with 12 GB workspace.
  3. Pass in model keyword args to disable attention and hidden state outputs
  4. Run inference using the compiled model on two sample inputs.

Expected behavior

Model should successfully perform inference with Torch-TRT. Specifically, internal shape issues should either be caught at compile time, or should otherwise not cause errors.

Environment

  • Torch-TensorRT Version: 1.4.0.dev0+81f2dabb
  • PyTorch Version: 1.14.0.dev20221114+cu116
  • CPU Architecture: Intel Xeon CPU
  • OS: Ubuntu 20.04
  • How you installed PyTorch: pip
  • Build command you used: python setup.py develop
  • Are you using local sources or building from archives: local
  • Python version: 3.8.13
  • CUDA version: 11.6

Additional context

The problem currently seems to be related to Torch-TensorRT flattening input tensors in a way which is inconsistent with the analogous PyTorch behavior. Two potential operations which could be relevant are aten::mul and aten::add which are used often in the BART code as replacements for the linear layer, inserted in the LinearToAddMM lowering pass:

void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
std::string flatten_linear_pattern = R"IR(
graph(%input, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
std::string fused_linear = R"IR(
graph(%input, %weight_t, %bias):
%1: int = prim::Constant[value=1]()
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%input, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add(%b_f, %mm, %1)
return (%out))IR";

Temporary Solution

A temporary fix to this problem is to add the following to the compilation arguments in torch_tensorrt.compile:

torch_tensorrt.compile( ..., torch_executed_ops=["aten::mul"], ...)

This solution works as it happens to exclude the problematic code, which could potentially be related to the aten::mul operator itself.

Related Issues

Potentially related to Issue #1455, as a similar error appears under certain compilation configurations for that model as well.

Additional Note

The bug appears to be nondeterministic, as, after recompiling and running inference using the model many times, inference ultimately completes successfully.

@gs-olive gs-olive added the bug Something isn't working label Dec 6, 2022
@Christina-Young-NVIDIA
Copy link
Collaborator

Bo has reproduced this issue.

@gs-olive
Copy link
Collaborator Author

Related Findings

The Torch-executed code block from which the error stems is:

attn_weights1 = torch.bmm(_73, torch.transpose(_72, 1, 2))
_75 = torch.reshape(attn_weights1, [_54, 12, _56, _74])
attn_weights02 = torch.add(_75, _21)
                  ~~~~~~~~~ <--- HERE
_76 = torch.reshape(attn_weights02, [_67, _56, _74])
input0 = torch.softmax(_76, -1)

It appears that the reshaped tensor _75 has a different final dimension from the other tensor _21, causing the add to fail. The operator causing the shape mismatch is likely within/arising from Torch-TensorRT or TensorRT, as the mismatch does not occur during the Torch-only dry-run in partitioning (the error is only thrown at inference time).

@gs-olive gs-olive self-assigned this Feb 2, 2023
@gs-olive
Copy link
Collaborator Author

gs-olive commented Feb 2, 2023

Fixed by #1619

@gs-olive gs-olive closed this as completed Feb 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants