Skip to content

Commit

Permalink
feat: Adding support for native int64
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 30, 2024
1 parent 2eb4c19 commit d4e59b1
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 15 deletions.
5 changes: 5 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
m.def("set_logging_level", [](int64_t level) -> void {
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
});
m.def(
"get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); });
}

} // namespace
Expand Down
7 changes: 4 additions & 3 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use
// Acceptable range for pos is [-d.nbDims - 1, d.nbDims]
TORCHTRT_ASSERT(
pos >= (-d.nbDims - 1) && pos <= d.nbDims,
"ERROR: Index to unsqueeze is out of bounds. "
<< "Expected value in range [" << (-d.nbDims - 1) << ", " << d.nbDims << "], but got " << pos);
"ERROR: Index to unsqueeze is out of bounds. " << "Expected value in range [" << (-d.nbDims - 1) << ", "
<< d.nbDims << "], but got " << pos);

// Unsqueeze with negative dimensions creates a new dimension at that index
pos = (pos < 0) ? (pos + d.nbDims + 1) : pos;
Expand Down Expand Up @@ -292,7 +292,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
{at::kFloat, nvinfer1::DataType::kFLOAT},
{at::kHalf, nvinfer1::DataType::kHALF},
{at::kInt, nvinfer1::DataType::kINT32},
{at::kLong, nvinfer1::DataType::kINT32},
{at::kLong, nvinfer1::DataType::kINT64},
{at::kChar, nvinfer1::DataType::kINT8},
{at::kByte, nvinfer1::DataType::kINT8},
{at::kBool, nvinfer1::DataType::kBOOL}};
Expand All @@ -304,6 +304,7 @@ const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_ma
{nvinfer1::DataType::kFLOAT, at::kFloat},
{nvinfer1::DataType::kHALF, at::kHalf},
{nvinfer1::DataType::kINT32, at::kInt},
{nvinfer1::DataType::kINT64, at::kLong},
{nvinfer1::DataType::kINT8, at::kChar},
{nvinfer1::DataType::kBOOL, at::kBool},
};
Expand Down
2 changes: 2 additions & 0 deletions core/util/trt_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
return stream << "Int8";
case nvinfer1::DataType::kINT32:
return stream << "Int32";
case nvinfer1::DataType::kINT64:
return stream << "Int64";
case nvinfer1::DataType::kBOOL:
return stream << "Bool";
default:
Expand Down
7 changes: 6 additions & 1 deletion py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import Any, Optional, Type, Union

import numpy as np
import tensorrt as trt
import torch
from torch_tensorrt._features import ENABLED_FEATURES

import tensorrt as trt


class dtype(Enum):
"""Enum to set supported dtypes in the compiler"""
Expand Down Expand Up @@ -103,6 +104,8 @@ def _from(
return dtype.i8
elif t == trt.int32:
return dtype.i32
elif t == trt.int64:
return dtype.i64
elif t == trt.float16:
return dtype.f16
elif t == trt.float32:
Expand Down Expand Up @@ -227,6 +230,8 @@ def to(
return trt.DataType.INT8
elif self == dtype.i32:
return trt.DataType.INT32
elif self == dtype.i64:
return trt.DataType.INT64
elif self == dtype.f16:
return trt.DataType.HALF
elif self == dtype.f32:
Expand Down
9 changes: 5 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@
from torch.fx.node import Argument, Node, Target, _get_qualified_name
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

import tensorrt as trt

logger = logging.getLogger(__name__)

LegacyConverterImplSignature = Callable[
[
TRTNetwork,
trt.INetworkDefinition,
Target,
Tuple[Argument, ...],
Dict[str, Argument],
str,
],
Union[TRTTensor, Sequence[TRTTensor]],
Union[trt.ITensor, Sequence[trt.ITensor]],
]

DynamoConverterImplSignature = Callable[
Expand All @@ -44,7 +45,7 @@
Dict[str, Argument],
str,
],
Union[TRTTensor, Sequence[TRTTensor]],
Union[trt.ITensor, Sequence[trt.ITensor]],
]

ConverterImplSignature = Union[
Expand Down
11 changes: 5 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,12 @@ def _repair_64bit_input(
dtype: Data type of tensor at position in submodule (double/long)
"""
assert dtype in (
torch.int64,
torch.float64,
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"
), f"dtype argument must be torch.float64, got {dtype}"

# Determine target data type in 32 and 64 bit forms
dtype_64bit = dtype
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32
dtype_32bit = torch.float32

# Find the node representing the submodule in the graph
module_node = None
Expand Down Expand Up @@ -143,7 +142,7 @@ def _repair_64bit_input(
cast_node_64bit = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(get_node,),
kwargs={"dtype": torch.int64},
kwargs={"dtype": torch.float64},
)

get_node.replace_all_uses_with(
Expand Down Expand Up @@ -189,7 +188,7 @@ def repair_long_or_double_inputs(

# If the data type of the input is long/double, insert necessary
# casts to replace the operation
if param.dtype in (torch.int64, torch.float64):
if param.dtype == torch.float64:
# Ensure outputs are only repaired once per submodule to avoid
# unnecessary ops showing up in the graph
if not repaired_outputs_once:
Expand All @@ -206,7 +205,7 @@ def repair_long_or_double_inputs(
repaired_outputs_once = True

# Repair submodule inputs in accordance with inserted casts
dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32
dtype_32bit = torch.float32
submodule_torch_inputs = (
list(submodule_torch_inputs[:position])
+ [
Expand Down
45 changes: 45 additions & 0 deletions py/torch_tensorrt/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Any

import torch
from torch_tensorrt._features import ENABLED_FEATURES

import tensorrt as trt
Expand Down Expand Up @@ -51,6 +52,12 @@ def __enter__(self) -> None:
self.ts_level = ts_logging.get_reportable_log_level()
ts_logging.set_reportable_log_level(ts_logging.Level.InternalError)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
self.rt_level = torch.ops.tensorrt.get_logging_level()
torch.ops.tensorrt.set_logging_level(
int(trt.ILogger.Severity.INTERNAL_ERROR)
)

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
_LOGGER.setLevel(self.external_lvl)

Expand All @@ -59,6 +66,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:

ts_logging.set_reportable_log_level(self.ts_level)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
torch.ops.tensorrt.set_logging_level(self.rt_level)


class errors:
"""Context-manager to limit displayed log messages to just errors and above
Expand All @@ -79,6 +89,10 @@ def __enter__(self) -> None:
self.ts_level = ts_logging.get_reportable_log_level()
ts_logging.set_reportable_log_level(ts_logging.Level.Error)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
self.rt_level = torch.ops.tensorrt.get_logging_level()
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR))

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
_LOGGER.setLevel(self.external_lvl)

Expand All @@ -87,6 +101,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:

ts_logging.set_reportable_log_level(self.ts_level)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
torch.ops.tensorrt.set_logging_level(self.rt_level)


class warnings:
"""Context-manager to limit displayed log messages to just warnings and above
Expand All @@ -107,6 +124,10 @@ def __enter__(self) -> None:
self.ts_level = ts_logging.get_reportable_log_level()
ts_logging.set_reportable_log_level(ts_logging.Level.Warning)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
self.rt_level = torch.ops.tensorrt.get_logging_level()
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING))

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
_LOGGER.setLevel(self.external_lvl)

Expand All @@ -115,6 +136,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:

ts_logging.set_reportable_log_level(self.ts_level)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
torch.ops.tensorrt.set_logging_level(self.rt_level)


class info:
"""Context-manager to display all info and greater severity messages
Expand All @@ -135,6 +159,10 @@ def __enter__(self) -> None:
self.ts_level = ts_logging.get_reportable_log_level()
ts_logging.set_reportable_log_level(ts_logging.Level.Info)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
self.rt_level = torch.ops.tensorrt.get_logging_level()
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO))

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
_LOGGER.setLevel(self.external_lvl)

Expand All @@ -143,6 +171,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:

ts_logging.set_reportable_log_level(self.ts_level)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
torch.ops.tensorrt.set_logging_level(self.rt_level)


class debug:
"""Context-manager to display full debug information through the logger
Expand All @@ -163,6 +194,10 @@ def __enter__(self) -> None:
self.ts_level = ts_logging.get_reportable_log_level()
ts_logging.set_reportable_log_level(ts_logging.Level.Debug)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
self.rt_level = torch.ops.tensorrt.get_logging_level()
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE))

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
_LOGGER.setLevel(self.external_lvl)

Expand All @@ -171,6 +206,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:

ts_logging.set_reportable_log_level(self.ts_level)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
torch.ops.tensorrt.set_logging_level(self.rt_level)


class graphs:
"""Context-manager to display the results of intermediate lowering passes
Expand All @@ -192,10 +230,17 @@ def __enter__(self) -> None:
self.ts_level = ts_logging.get_reportable_log_level()
ts_logging.set_reportable_log_level(ts_logging.Level.Graph)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
self.rt_level = torch.ops.tensorrt.get_logging_level()
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE) + 1)

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
_LOGGER.setLevel(self.external_lvl)

if ENABLED_FEATURES.torchscript_frontend:
from torch_tensorrt.ts import logging as ts_logging

ts_logging.set_reportable_log_level(self.ts_level)

elif ENABLED_FEATURES.torch_tensorrt_runtime:
torch.ops.tensorrt.set_logging_level(self.rt_level)
3 changes: 2 additions & 1 deletion tests/py/dynamo/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ def pytest_addoption(parser):
metavar="Internal Representation",
nargs=1,
type=str,
required=True,
required=False,
help="IR to compile with",
choices=["dynamo", "torch_compile"],
default="dynamo",
)


Expand Down
Loading

0 comments on commit d4e59b1

Please sign in to comment.