-
Notifications
You must be signed in to change notification settings - Fork 371
Closed
Labels
bugSomething isn't workingSomething isn't workingcomponent: loweringIssues re: The lowering / preprocessing passesIssues re: The lowering / preprocessing passescomponent: partitioning
Description
Bug Description
I'm running a smallest test code. When I use torch_tensorrt to compile the model, it throw s shape mismatch error in TorchScript, but the torch.jit.trace itself didn't show the same error.
Can you please help me with this problem?
To Reproduce
Steps to reproduce the behavior:
- python trt_sample.py
code sample:
`
import time
import random
import torch
import torch_tensorrt
from torch import nn
torch.manual_seed(0)
class Model(torch.nn.Module):
def __init__(self, n_in=128, n_out=64, norm='LN', act=True):
super().__init__()
self.linear = nn.Linear(n_in, n_out, bias=False)
self.norm = nn.LayerNorm(n_out)
self.relu = nn.ReLU(inplace=True)
self.act = act
def forward(self, cout, rand, emb):
idcs = torch.where(rand > 0.2)[0]
out = self.linear(emb.view(1, -1, 128))
out = self.norm(out)
if self.act:
out = self.relu(out)
cout[:, idcs.long()] = out
# cout.index_add_(1, idcs, out)
return cout
# Init data
rand = torch.rand(5740).cuda()
s = torch.where(rand>0.2)[0]
emb = torch.rand((1, len(s) * 128)).cuda()
cout = torch.rand(1, 5740, 64).cuda()
# Init model
model = Model().eval().cuda()
# torch jit trace
traced = torch.jit.trace(model, (cout.clone(), rand.clone(), emb.clone()))
traced_output = traced(cout.clone(), rand.clone(), emb.clone())
print('traced output: ', traced_output.shape)
# trt compile
enabled_precisions = {torch.float}
trt_ts_module = torch_tensorrt.compile(
model, # traced,
inputs=(cout.clone(), rand.clone(), emb.clone()),
enabled_precisions=enabled_precisions,
torch_executed_ops=['aten::index', 'aten::index_put_', 'aten::where', 'aten::to'],
truncate_long_and_double=True,
debug=True
)
print('trt compiled')
# infer with trt
result = trt_ts_module(cout.clone(), rand.clone(), emb.clone())
print('trt: ', time.time() - start_time)
`
error message:
`
Traceback (most recent call last):
File "test_trace.py", line 52, in <module>
trt_ts_module = torch_tensorrt.compile(
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 113, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 134, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "test_trace.py", line 30, in forward
out = self.relu(out)
cout[:, idcs.long()] = out
~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
# cout.index_add_(1, idcs, out)
return cout
RuntimeError: shape mismatch: value tensor of shape [1, 4582, 64] cannot be broadcast to indexing result of shape [1, 4631, 64]
`
Expected behavior
run without shape missmatch
Environment
docker image: nvcr.io/nvidia/pytorch:22.09-py3
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcomponent: loweringIssues re: The lowering / preprocessing passesIssues re: The lowering / preprocessing passescomponent: partitioning