From c918cdb4dcc7b4e047571b8d000f109650c1ea21 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Mon, 29 Apr 2024 17:27:42 -0700 Subject: [PATCH] fix: add explicit cast for i64 outputs as they may not be supported in all layers Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .../dynamo/conversion/_TRTInterpreter.py | 27 ++++++++++++------- .../dynamo/conversion/aten_ops_converters.py | 3 +++ tests/py/dynamo/conversion/harness.py | 4 +++ tests/py/dynamo/models/test_dtype_support.py | 4 --- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 9a75add755..59d2c5d6c0 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np -import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -26,6 +25,7 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER +import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: ) for i, output in enumerate(outputs): + name = f"output{i}" + + output_dtype = dtype.unknown if any( op_name in output.name.split("_") for op_name in ( @@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: "any", ) ): - output_bool = True - else: - output_bool = False - name = f"output{i}" - output.name = name - self.ctx.net.mark_output(output) - if output_bool: - output.dtype = trt.DataType.BOOL + output_dtype = dtype.b elif self.output_dtypes is not None: - output.dtype = self.output_dtypes[i].to(trt.DataType) + if self.output_dtypes[i] == dtype.i64: + output = self.ctx.net.add_cast( + output, dtype.i64.to(trt.DataType) + ).get_output(0) + output_dtype = dtype.i64 + else: + output_dtype = self.output_dtypes[i] + + self.ctx.net.mark_output(output) + if output_dtype is not dtype.unknown: + output.dtype = output_dtype.to(trt.DataType, use_default=True) + output.name = name self._output_names.append(name) _LOGGER.debug( diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c566d9de0a..fe6bd11579 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,3 +1,5 @@ +# mypy: disallow-untyped-decorators=False + import logging import operator from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union @@ -858,6 +860,7 @@ def validate_dtype(to_copy_node: Node) -> bool: allowed_casts = { torch.float, torch.int32, + torch.int64, torch.bool, torch.int8, torch.float16, diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 72c701014b..7ce3939371 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -251,6 +251,10 @@ def run_test( truncate_double=compilation_settings.truncate_double, ) + _LOGGER.debug(f"Compilation settings: {compilation_settings}") + _LOGGER.debug(f"Inputs: {input_specs}") + _LOGGER.debug(f"Output types: {output_dtypes}") + interp = TRTInterpreter( mod, input_specs, diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 4af933da60..e88c85de75 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -39,7 +39,6 @@ def forward(self, x): inputs=[in_tensor], pass_through_build_failures=True, truncate_double=True, - output_format="fx", min_block_size=1, use_python_runtime=False, ) @@ -78,7 +77,6 @@ def forward(self, x): inputs=[in_tensor], pass_through_build_failures=True, truncate_double=True, - output_format="fx", min_block_size=1, use_python_runtime=True, ) @@ -123,7 +121,6 @@ def forward(self, x): inputs=[in_tensor], pass_through_build_failures=True, truncate_double=False, - output_format="fx", min_block_size=1, use_python_runtime=False, ) @@ -163,7 +160,6 @@ def forward(self, x): inputs=[in_tensor], pass_through_build_failures=True, truncate_double=False, - output_format="fx", min_block_size=1, use_python_runtime=True, )