Skip to content

Commit

Permalink
fix: add explicit cast for i64 outputs as they may not be supported in
Browse files Browse the repository at this point in the history
all layers

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 30, 2024
1 parent da25720 commit 7b6bdcd
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
26 changes: 16 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -26,6 +25,7 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
)

for i, output in enumerate(outputs):
name = f"output{i}"

output_dtype = dtype.unknown
if any(
op_name in output.name.split("_")
for op_name in (
Expand All @@ -514,16 +517,19 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
"any",
)
):
output_bool = True
else:
output_bool = False
name = f"output{i}"
output.name = name
self.ctx.net.mark_output(output)
if output_bool:
output.dtype = trt.DataType.BOOL
output_dtype = dtype.b
elif self.output_dtypes is not None:
output.dtype = self.output_dtypes[i].to(trt.DataType)
if self.output_dtypes[i] == dtype.i64:
output = self.ctx.net.add_cast(
output, dtype.i64.to(trt.DataType)
).get_output(0)
output_dtype = dtype.i64
else:
output_dtype = self.output_dtypes[i]

self.ctx.net.mark_output(output)
output.dtype = output_dtype.to(trt.DataType)
output.name = name

self._output_names.append(name)
_LOGGER.debug(
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# mypy: disallow-untyped-decorators=False

import logging
import operator
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -858,6 +860,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
allowed_casts = {
torch.float,
torch.int32,
torch.int64,
torch.bool,
torch.int8,
torch.float16,
Expand Down
4 changes: 4 additions & 0 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def run_test(
truncate_double=compilation_settings.truncate_double,
)

_LOGGER.debug(f"Compilation settings: {compilation_settings}")
_LOGGER.debug(f"Inputs: {input_specs}")
_LOGGER.debug(f"Output types: {output_dtypes}")

interp = TRTInterpreter(
mod,
input_specs,
Expand Down

0 comments on commit 7b6bdcd

Please sign in to comment.