-
Notifications
You must be signed in to change notification settings - Fork 371
Closed
Labels
component: partitioningfeature requestNew feature or requestNew feature or requestrelease: v1.3Tagged to be included in v1.3Tagged to be included in v1.3
Description
As we know, if there are some operators that torch-tensorrt doesn't support, the model will be partitioned into tensorrt and torch subgraphs. TensorRT doesn't support int64 value and will truncate int64 to int32.
In some cases, the operators in the torch subgraph consume int64 value(like aten::index), and this value is produced from tensorrt subgraph(truncated into int32), this will cause an error. We need to track the data type conversion and automatic convert the data type back to the origianl type between torch and tensorrt.
Here is a typical case
import torch
import torch.nn as nn
import torch_tensorrt
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self, data, index):
src = 1
index = index.to(torch.int64)
data = data * data
data = data.scatter_(1,index,src)
data = data + 1
return data
data = torch.randn([5,5])
index = torch.randint(0,4,[2,2], dtype = torch.int32)
compile_spec = {
"inputs": None,
"device": {
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
},
"truncate_long_and_double": True,
"require_full_compilation": False,
"torch_executed_ops": ["aten::scatter_", "aten::scatter"],
"min_block_size": 1
}
net = Net()
model = torch.jit.trace(net, (data, index))
torch_type = torch.float32
min_shape = [5,5]
data2 = torch_tensorrt.Input(shape=min_shape, dtype=torch_type)
torch_type = torch.int32
index2 = torch_tensorrt.Input(shape=min_shape, dtype=torch_type)
inputs = [data2, index2]
compile_spec["inputs"] = inputs
with torch_tensorrt.logging.debug():
trt_mod = torch_tensorrt.ts.compile(model, **compile_spec)
inputs = [data.cuda(), index.cuda()]
output = trt_mod(*inputs)
print(output)
subgraph log
INFO: [Torch-TensorRT - Debug Build] - Partitioned Graph: [Segment Block @0:
Target: TensorRT
Graph: graph(%index.1 : Tensor,
%data.1 : Tensor):
%2 : int = prim::Constant[value=4]() # test_int64.py:28:0
%3 : bool = prim::Constant[value=0]() # test_int64.py:28:0
%4 : NoneType = prim::Constant()
%index : Tensor = aten::to(%index.1, %2, %3, %3, %4) # test_int64.py:28:0
%data.3 : Tensor = aten::mul(%data.1, %data.1) # test_int64.py:29:0
return (%index, %data.3)
Segment Block @1:
Target: Torch
Graph: graph(%data.3 : Tensor,
%index : Tensor):
%2 : int = prim::Constant[value=1]() # test_int64.py:30:0
%0 : Tensor = aten::scatter(%data.3, %2, %index, %2) # test_int64.py:30:0
return (%0)
Segment Block @2:
Target: TensorRT
Graph: graph(%1 : Tensor):
%2 : Tensor = prim::Constant[value={1}]() # test_int64.py:31:0
%3 : int = prim::Constant[value=1]() # test_int64.py:30:0
%0 : Tensor = aten::add(%1, %2, %3) # test_int64.py:31:0
return (%0)
]
Metadata
Metadata
Assignees
Labels
component: partitioningfeature requestNew feature or requestNew feature or requestrelease: v1.3Tagged to be included in v1.3Tagged to be included in v1.3