Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding support for native int64 #2789

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- --fix=lf
exclude: ^docs
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.1
rev: v14.0.6
hooks:
- id: clang-format
types_or: [c++, c, cuda]
Expand Down
5 changes: 5 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
m.def("set_logging_level", [](int64_t level) -> void {
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
});
m.def(
"get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); });
}

} // namespace
Expand Down
3 changes: 2 additions & 1 deletion core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kFloat, nvinfer1::DataType::kFLOAT},
{at::kHalf, nvinfer1::DataType::kHALF},
{at::kInt, nvinfer1::DataType::kINT32},
{at::kLong, nvinfer1::DataType::kINT32},
{at::kLong, nvinfer1::DataType::kINT64},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kByte, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL}};
Expand All @@ -304,6 +304,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
{nvinfer1::DataType::kFLOAT, at::kFloat},
{nvinfer1::DataType::kHALF, at::kHalf},
{nvinfer1::DataType::kINT32, at::kInt},
{nvinfer1::DataType::kINT64, at::kLong},
{nvinfer1::DataType::kINT8, at::kChar},
{nvinfer1::DataType::kBOOL, at::kBool},
};
Expand Down
2 changes: 2 additions & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
return stream << "Int8";
case nvinfer1::DataType::kINT32:
return stream << "Int32";
case nvinfer1::DataType::kINT64:
return stream << "Int64";
case nvinfer1::DataType::kBOOL:
return stream << "Bool";
default:
Expand Down
7 changes: 6 additions & 1 deletion py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import Any, Optional, Type, Union

import numpy as np
import tensorrt as trt
import torch
from torch_tensorrt._features import ENABLED_FEATURES

import tensorrt as trt


class dtype(Enum):
"""Enum to set supported dtypes in the compiler"""
Expand Down Expand Up @@ -103,6 +104,8 @@ def _from(
return dtype.i8
elif t == trt.int32:
return dtype.i32
elif t == trt.int64:
return dtype.i64
elif t == trt.float16:
return dtype.f16
elif t == trt.float32:
Expand Down Expand Up @@ -227,6 +230,8 @@ def to(
return trt.DataType.INT8
elif self == dtype.i32:
return trt.DataType.INT32
elif self == dtype.i64:
return trt.DataType.INT64
elif self == dtype.f16:
return trt.DataType.HALF
elif self == dtype.f32:
Expand Down
51 changes: 40 additions & 11 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections.abc
import logging
import warnings
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union

import torch
Expand All @@ -22,7 +23,7 @@
UnsupportedOperatorException,
convert_module,
interpret_module_to_result,
repair_long_or_double_inputs,
repair_double_inputs,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
Expand Down Expand Up @@ -58,7 +59,7 @@ def compile(
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
torch_executed_ops: Optional[Collection[Target]] = None,
Expand All @@ -74,7 +75,7 @@ def compile(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT

Takes a existing TorchScript module and a set of settings to configure the compiler
and will convert methods to JIT Graphs which call equivalent TensorRT engines
Expand Down Expand Up @@ -115,7 +116,7 @@ def compile(
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
truncate_double (bool): Truncate weights provided in double (float64) to float32
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
Expand All @@ -138,6 +139,19 @@ def compile(
if debug:
set_log_level(logger.parent, logging.DEBUG)

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
)
else:
truncate_double = kwargs["truncate_long_and_double"]
warnings.warn(
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
DeprecationWarning,
stacklevel=2,
)
Comment on lines +149 to +153
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for not using logger here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is apparently the recommended way to handle deprecation warnings, iirc I configured the logger to pull these messages in in an earlier PR


engine_capability = EngineCapability._from(engine_capability)

if torch_executed_modules is not None and torch_executed_modules:
Expand Down Expand Up @@ -185,7 +199,7 @@ def compile(
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"truncate_double": truncate_double,
"use_fast_partitioner": use_fast_partitioner,
"num_avg_timing_iters": num_avg_timing_iters,
"enable_experimental_decompositions": enable_experimental_decompositions,
Expand Down Expand Up @@ -349,8 +363,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
if settings.truncate_double:
submodule_inputs = repair_double_inputs(
partitioned_module,
submodule,
submodule_inputs,
Expand Down Expand Up @@ -423,7 +437,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

def convert_module_to_trt_engine(
exported_program: ExportedProgram,
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
inputs: Tuple[Any, ...],
*,
enabled_precisions: (
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
) = _defaults.ENABLED_PRECISIONS,
Expand All @@ -436,7 +451,7 @@ def convert_module_to_trt_engine(
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME,
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
device: Device = Device._current_device(),
Expand All @@ -451,6 +466,7 @@ def convert_module_to_trt_engine(
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
calibrator: object = None,
allow_shape_tensors: bool = False,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine

Expand Down Expand Up @@ -488,7 +504,7 @@ def convert_module_to_trt_engine(
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
Expand All @@ -512,6 +528,19 @@ def convert_module_to_trt_engine(
if debug:
set_log_level(logger.parent, logging.DEBUG)

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
)
else:
truncate_double = kwargs["truncate_long_and_double"]
warnings.warn(
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
DeprecationWarning,
stacklevel=2,
)

input_list = list(inputs) if inputs is not None else []
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
# Prepare torch_trt inputs
Expand All @@ -531,7 +560,7 @@ def convert_module_to_trt_engine(
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"truncate_double": truncate_double,
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"device": device,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
SPARSE_WEIGHTS = False
TRUNCATE_LONG_AND_DOUBLE = False
TRUNCATE_DOUBLE = False
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
REFIT,
REQUIRE_FULL_COMPILATION,
SPARSE_WEIGHTS,
TRUNCATE_LONG_AND_DOUBLE,
TRUNCATE_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
Expand All @@ -50,7 +50,7 @@ class CompilationSettings:
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
Expand Down Expand Up @@ -81,7 +81,7 @@ class CompilationSettings:
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
truncate_double: bool = TRUNCATE_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
Expand Down
9 changes: 5 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@
from torch.fx.node import Argument, Node, Target, _get_qualified_name
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

import tensorrt as trt

logger = logging.getLogger(__name__)

LegacyConverterImplSignature = Callable[
[
TRTNetwork,
trt.INetworkDefinition,
Target,
Tuple[Argument, ...],
Dict[str, Argument],
str,
],
Union[TRTTensor, Sequence[TRTTensor]],
Union[trt.ITensor, Sequence[trt.ITensor]],
]

DynamoConverterImplSignature = Callable[
Expand All @@ -44,7 +45,7 @@
Dict[str, Argument],
str,
],
Union[TRTTensor, Sequence[TRTTensor]],
Union[trt.ITensor, Sequence[trt.ITensor]],
]

ConverterImplSignature = Union[
Expand Down
27 changes: 17 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -26,6 +25,7 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
)

for i, output in enumerate(outputs):
name = f"output{i}"

output_dtype = dtype.unknown
if any(
op_name in output.name.split("_")
for op_name in (
Expand All @@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
"any",
)
):
output_bool = True
else:
output_bool = False
name = f"output{i}"
output.name = name
self.ctx.net.mark_output(output)
if output_bool:
output.dtype = trt.DataType.BOOL
output_dtype = dtype.b
elif self.output_dtypes is not None:
output.dtype = self.output_dtypes[i].to(trt.DataType)
if self.output_dtypes[i] == dtype.i64:
output = self.ctx.net.add_cast(
output, dtype.i64.to(trt.DataType)
).get_output(0)
output_dtype = dtype.i64
else:
output_dtype = self.output_dtypes[i]

self.ctx.net.mark_output(output)
if output_dtype is not dtype.unknown:
output.dtype = output_dtype.to(trt.DataType, use_default=True)
output.name = name

self._output_names.append(name)
_LOGGER.debug(
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
from ._TRTInterpreter import * # noqa: F403
from .truncate_long_and_double import repair_long_or_double_inputs
from .truncate_double import repair_double_inputs
Loading
Loading