Skip to content

Commit

Permalink
2.3 cherry pick feat: Adding support for native int64 (#2789) (#2802)
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan authored May 1, 2024
1 parent 9e0b547 commit 0499493
Show file tree
Hide file tree
Showing 28 changed files with 383 additions and 106 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- --fix=lf
exclude: ^docs
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.1
rev: v14.0.6
hooks:
- id: clang-format
types_or: [c++, c, cuda]
Expand Down
5 changes: 5 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
m.def("set_logging_level", [](int64_t level) -> void {
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
});
m.def(
"get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); });
}

} // namespace
Expand Down
3 changes: 2 additions & 1 deletion core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kFloat, nvinfer1::DataType::kFLOAT},
{at::kHalf, nvinfer1::DataType::kHALF},
{at::kInt, nvinfer1::DataType::kINT32},
{at::kLong, nvinfer1::DataType::kINT32},
{at::kLong, nvinfer1::DataType::kINT64},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kByte, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL}};
Expand All @@ -304,6 +304,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
{nvinfer1::DataType::kFLOAT, at::kFloat},
{nvinfer1::DataType::kHALF, at::kHalf},
{nvinfer1::DataType::kINT32, at::kInt},
{nvinfer1::DataType::kINT64, at::kLong},
{nvinfer1::DataType::kINT8, at::kChar},
{nvinfer1::DataType::kBOOL, at::kBool},
};
Expand Down
2 changes: 2 additions & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
return stream << "Int8";
case nvinfer1::DataType::kINT32:
return stream << "Int32";
case nvinfer1::DataType::kINT64:
return stream << "Int64";
case nvinfer1::DataType::kBOOL:
return stream << "Bool";
default:
Expand Down
7 changes: 6 additions & 1 deletion py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import Any, Optional, Type, Union

import numpy as np
import tensorrt as trt
import torch
from torch_tensorrt._features import ENABLED_FEATURES

import tensorrt as trt


class dtype(Enum):
"""Enum to set supported dtypes in the compiler"""
Expand Down Expand Up @@ -103,6 +104,8 @@ def _from(
return dtype.i8
elif t == trt.int32:
return dtype.i32
elif t == trt.int64:
return dtype.i64
elif t == trt.float16:
return dtype.f16
elif t == trt.float32:
Expand Down Expand Up @@ -227,6 +230,8 @@ def to(
return trt.DataType.INT8
elif self == dtype.i32:
return trt.DataType.INT32
elif self == dtype.i64:
return trt.DataType.INT64
elif self == dtype.f16:
return trt.DataType.HALF
elif self == dtype.f32:
Expand Down
51 changes: 40 additions & 11 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections.abc
import logging
import warnings
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union

import torch
Expand All @@ -22,7 +23,7 @@
UnsupportedOperatorException,
convert_module,
interpret_module_to_result,
repair_long_or_double_inputs,
repair_double_inputs,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
Expand Down Expand Up @@ -58,7 +59,7 @@ def compile(
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
torch_executed_ops: Optional[Collection[Target]] = None,
Expand All @@ -74,7 +75,7 @@ def compile(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Takes a existing TorchScript module and a set of settings to configure the compiler
and will convert methods to JIT Graphs which call equivalent TensorRT engines
Expand Down Expand Up @@ -115,7 +116,7 @@ def compile(
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
truncate_double (bool): Truncate weights provided in double (float64) to float32
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
Expand All @@ -138,6 +139,19 @@ def compile(
if debug:
set_log_level(logger.parent, logging.DEBUG)

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
)
else:
truncate_double = kwargs["truncate_long_and_double"]
warnings.warn(
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
DeprecationWarning,
stacklevel=2,
)

engine_capability = EngineCapability._from(engine_capability)

if torch_executed_modules is not None and torch_executed_modules:
Expand Down Expand Up @@ -185,7 +199,7 @@ def compile(
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"truncate_double": truncate_double,
"use_fast_partitioner": use_fast_partitioner,
"num_avg_timing_iters": num_avg_timing_iters,
"enable_experimental_decompositions": enable_experimental_decompositions,
Expand Down Expand Up @@ -349,8 +363,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
if settings.truncate_double:
submodule_inputs = repair_double_inputs(
partitioned_module,
submodule,
submodule_inputs,
Expand Down Expand Up @@ -423,7 +437,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

def convert_module_to_trt_engine(
exported_program: ExportedProgram,
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
inputs: Tuple[Any, ...],
*,
enabled_precisions: (
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
) = _defaults.ENABLED_PRECISIONS,
Expand All @@ -436,7 +451,7 @@ def convert_module_to_trt_engine(
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME,
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
device: Device = Device._current_device(),
Expand All @@ -451,6 +466,7 @@ def convert_module_to_trt_engine(
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
calibrator: object = None,
allow_shape_tensors: bool = False,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -488,7 +504,7 @@ def convert_module_to_trt_engine(
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
Expand All @@ -512,6 +528,19 @@ def convert_module_to_trt_engine(
if debug:
set_log_level(logger.parent, logging.DEBUG)

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
)
else:
truncate_double = kwargs["truncate_long_and_double"]
warnings.warn(
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
DeprecationWarning,
stacklevel=2,
)

input_list = list(inputs) if inputs is not None else []
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
# Prepare torch_trt inputs
Expand All @@ -531,7 +560,7 @@ def convert_module_to_trt_engine(
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"truncate_double": truncate_double,
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"device": device,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
SPARSE_WEIGHTS = False
TRUNCATE_LONG_AND_DOUBLE = False
TRUNCATE_DOUBLE = False
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
Expand Down
10 changes: 5 additions & 5 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Collection, Optional, Union
from typing import Collection, Optional, Set, Union

from torch.fx.node import Target
from torch_tensorrt._Device import Device
Expand All @@ -23,7 +23,7 @@
REFIT,
REQUIRE_FULL_COMPILATION,
SPARSE_WEIGHTS,
TRUNCATE_LONG_AND_DOUBLE,
TRUNCATE_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
Expand All @@ -50,7 +50,7 @@ class CompilationSettings:
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
Expand All @@ -71,7 +71,7 @@ class CompilationSettings:
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
"""

enabled_precisions: dtype = field(default_factory=lambda: ENABLED_PRECISIONS)
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
debug: bool = DEBUG
workspace_size: int = WORKSPACE_SIZE
min_block_size: int = MIN_BLOCK_SIZE
Expand All @@ -81,7 +81,7 @@ class CompilationSettings:
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
truncate_double: bool = TRUNCATE_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
Expand Down
9 changes: 5 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@
from torch.fx.node import Argument, Node, Target, _get_qualified_name
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

import tensorrt as trt

logger = logging.getLogger(__name__)

LegacyConverterImplSignature = Callable[
[
TRTNetwork,
trt.INetworkDefinition,
Target,
Tuple[Argument, ...],
Dict[str, Argument],
str,
],
Union[TRTTensor, Sequence[TRTTensor]],
Union[trt.ITensor, Sequence[trt.ITensor]],
]

DynamoConverterImplSignature = Callable[
Expand All @@ -44,7 +45,7 @@
Dict[str, Argument],
str,
],
Union[TRTTensor, Sequence[TRTTensor]],
Union[trt.ITensor, Sequence[trt.ITensor]],
]

ConverterImplSignature = Union[
Expand Down
27 changes: 17 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -26,6 +25,7 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
)

for i, output in enumerate(outputs):
name = f"output{i}"

output_dtype = dtype.unknown
if any(
op_name in output.name.split("_")
for op_name in (
Expand All @@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
"any",
)
):
output_bool = True
else:
output_bool = False
name = f"output{i}"
output.name = name
self.ctx.net.mark_output(output)
if output_bool:
output.dtype = trt.DataType.BOOL
output_dtype = dtype.b
elif self.output_dtypes is not None:
output.dtype = self.output_dtypes[i].to(trt.DataType)
if self.output_dtypes[i] == dtype.i64:
output = self.ctx.net.add_cast(
output, dtype.i64.to(trt.DataType)
).get_output(0)
output_dtype = dtype.i64
else:
output_dtype = self.output_dtypes[i]

self.ctx.net.mark_output(output)
if output_dtype is not dtype.unknown:
output.dtype = output_dtype.to(trt.DataType, use_default=True)
output.name = name

self._output_names.append(name)
_LOGGER.debug(
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
from ._TRTInterpreter import * # noqa: F403
from .truncate_long_and_double import repair_long_or_double_inputs
from .truncate_double import repair_double_inputs
Loading

0 comments on commit 0499493

Please sign in to comment.