Skip to content

🐛 [Bug] Encountered bug when using Torch-TensorRT #1453

@stupiding

Description

@stupiding

Bug Description

I'm running a smallest test code. When I use torch_tensorrt to compile the model, it throw s shape mismatch error in TorchScript, but the torch.jit.trace itself didn't show the same error.
Can you please help me with this problem?

To Reproduce

Steps to reproduce the behavior:

  1. python trt_sample.py

code sample:

`

    import time
    import random

    import torch
    import torch_tensorrt
    from torch import nn

    torch.manual_seed(0)


    class Model(torch.nn.Module):
        def __init__(self, n_in=128, n_out=64, norm='LN', act=True):
            super().__init__()

            self.linear = nn.Linear(n_in, n_out, bias=False)
            self.norm = nn.LayerNorm(n_out)
            self.relu = nn.ReLU(inplace=True)
            self.act = act

        def forward(self, cout, rand, emb):
            idcs = torch.where(rand > 0.2)[0]

            out = self.linear(emb.view(1, -1, 128))

            out = self.norm(out)

            if self.act:
                out = self.relu(out)

            cout[:, idcs.long()] = out
            # cout.index_add_(1, idcs, out)
            return cout


    # Init data
    rand = torch.rand(5740).cuda()
    s = torch.where(rand>0.2)[0]

    emb = torch.rand((1, len(s) * 128)).cuda()
    cout = torch.rand(1, 5740, 64).cuda()

    # Init model
    model = Model().eval().cuda()

    # torch jit trace
    traced = torch.jit.trace(model, (cout.clone(), rand.clone(), emb.clone()))
    traced_output = traced(cout.clone(), rand.clone(), emb.clone())
    print('traced output: ', traced_output.shape)

    # trt compile
    enabled_precisions = {torch.float}
    trt_ts_module = torch_tensorrt.compile(
        model,  # traced,
        inputs=(cout.clone(), rand.clone(), emb.clone()),
        enabled_precisions=enabled_precisions,
        torch_executed_ops=['aten::index', 'aten::index_put_', 'aten::where', 'aten::to'],
        truncate_long_and_double=True,
        debug=True
    )
    print('trt compiled')

    # infer with trt
    result = trt_ts_module(cout.clone(), rand.clone(), emb.clone())
    print('trt: ', time.time() - start_time)

`

error message:
`

Traceback (most recent call last):
  File "test_trace.py", line 52, in <module>
    trt_ts_module = torch_tensorrt.compile(
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 113, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 134, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "test_trace.py", line 30, in forward
            out = self.relu(out)

        cout[:, idcs.long()] = out
        ~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        # cout.index_add_(1, idcs, out)
        return cout
RuntimeError: shape mismatch: value tensor of shape [1, 4582, 64] cannot be broadcast to indexing result of shape [1, 4631, 64]

`

Expected behavior

run without shape missmatch

Environment

docker image: nvcr.io/nvidia/pytorch:22.09-py3

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions