Skip to content

Commit

Permalink
fix: add explicit cast for i64 outputs as they may not be supported in
Browse files Browse the repository at this point in the history
all layers

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 30, 2024
1 parent 77df69d commit 2d0fa75
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
27 changes: 17 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -26,6 +25,7 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
)

for i, output in enumerate(outputs):
name = f"output{i}"

output_dtype = dtype.unknown
if any(
op_name in output.name.split("_")
for op_name in (
Expand All @@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
"any",
)
):
output_bool = True
else:
output_bool = False
name = f"output{i}"
output.name = name
self.ctx.net.mark_output(output)
if output_bool:
output.dtype = trt.DataType.BOOL
output_dtype = dtype.b
elif self.output_dtypes is not None:
output.dtype = self.output_dtypes[i].to(trt.DataType)
if self.output_dtypes[i] == dtype.i64:
output = self.ctx.net.add_cast(
output, dtype.i64.to(trt.DataType)
).get_output(0)
output_dtype = dtype.i64
else:
output_dtype = self.output_dtypes[i]

self.ctx.net.mark_output(output)
if output_dtype is not dtype.unknown:
output.dtype = output_dtype.to(trt.DataType, use_default=True)
output.name = name

self._output_names.append(name)
_LOGGER.debug(
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# mypy: disallow-untyped-decorators=False

import logging
import operator
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -858,6 +860,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
allowed_casts = {
torch.float,
torch.int32,
torch.int64,
torch.bool,
torch.int8,
torch.float16,
Expand Down
4 changes: 4 additions & 0 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def run_test(
truncate_double=compilation_settings.truncate_double,
)

_LOGGER.debug(f"Compilation settings: {compilation_settings}")
_LOGGER.debug(f"Inputs: {input_specs}")
_LOGGER.debug(f"Output types: {output_dtypes}")

interp = TRTInterpreter(
mod,
input_specs,
Expand Down
4 changes: 0 additions & 4 deletions tests/py/dynamo/models/test_dtype_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def forward(self, x):
inputs=[in_tensor],
pass_through_build_failures=True,
truncate_double=True,
output_format="fx",
min_block_size=1,
use_python_runtime=False,
)
Expand Down Expand Up @@ -78,7 +77,6 @@ def forward(self, x):
inputs=[in_tensor],
pass_through_build_failures=True,
truncate_double=True,
output_format="fx",
min_block_size=1,
use_python_runtime=True,
)
Expand Down Expand Up @@ -123,7 +121,6 @@ def forward(self, x):
inputs=[in_tensor],
pass_through_build_failures=True,
truncate_double=False,
output_format="fx",
min_block_size=1,
use_python_runtime=False,
)
Expand Down Expand Up @@ -163,7 +160,6 @@ def forward(self, x):
inputs=[in_tensor],
pass_through_build_failures=True,
truncate_double=False,
output_format="fx",
min_block_size=1,
use_python_runtime=True,
)
Expand Down

0 comments on commit 2d0fa75

Please sign in to comment.