diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61d97503a2..4d6b9eacc7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - --fix=lf exclude: ^docs - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.1 + rev: v14.0.6 hooks: - id: clang-format types_or: [c++, c, cuda] diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index c2a344a307..901923ce20 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -122,6 +122,11 @@ TORCH_LIBRARY(tensorrt, m) { m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void { MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; }); + m.def("set_logging_level", [](int64_t level) -> void { + util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level)); + }); + m.def( + "get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); }); } } // namespace diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 50b58a0bdb..3ca5780603 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -292,7 +292,7 @@ const std::unordered_map& get_at_trt_type_ma {at::kFloat, nvinfer1::DataType::kFLOAT}, {at::kHalf, nvinfer1::DataType::kHALF}, {at::kInt, nvinfer1::DataType::kINT32}, - {at::kLong, nvinfer1::DataType::kINT32}, + {at::kLong, nvinfer1::DataType::kINT64}, {at::kChar, nvinfer1::DataType::kINT8}, {at::kByte, nvinfer1::DataType::kINT8}, {at::kBool, nvinfer1::DataType::kBOOL}}; @@ -304,6 +304,7 @@ const std::unordered_map& get_trt_at_type_ma {nvinfer1::DataType::kFLOAT, at::kFloat}, {nvinfer1::DataType::kHALF, at::kHalf}, {nvinfer1::DataType::kINT32, at::kInt}, + {nvinfer1::DataType::kINT64, at::kLong}, {nvinfer1::DataType::kINT8, at::kChar}, {nvinfer1::DataType::kBOOL, at::kBool}, }; diff --git a/core/util/trt_util.h b/core/util/trt_util.h index a09407a5cd..da6653bef3 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -53,6 +53,8 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType& return stream << "Int8"; case nvinfer1::DataType::kINT32: return stream << "Int32"; + case nvinfer1::DataType::kINT64: + return stream << "Int64"; case nvinfer1::DataType::kBOOL: return stream << "Bool"; default: diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index 062abb9a87..8eeb55cf36 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -5,10 +5,11 @@ from typing import Any, Optional, Type, Union import numpy as np -import tensorrt as trt import torch from torch_tensorrt._features import ENABLED_FEATURES +import tensorrt as trt + class dtype(Enum): """Enum to set supported dtypes in the compiler""" @@ -103,6 +104,8 @@ def _from( return dtype.i8 elif t == trt.int32: return dtype.i32 + elif t == trt.int64: + return dtype.i64 elif t == trt.float16: return dtype.f16 elif t == trt.float32: @@ -227,6 +230,8 @@ def to( return trt.DataType.INT8 elif self == dtype.i32: return trt.DataType.INT32 + elif self == dtype.i64: + return trt.DataType.INT64 elif self == dtype.f16: return trt.DataType.HALF elif self == dtype.f32: diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 68a2355c3a..32b0ca65d7 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -2,6 +2,7 @@ import collections.abc import logging +import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch @@ -22,7 +23,7 @@ UnsupportedOperatorException, convert_module, interpret_module_to_result, - repair_long_or_double_inputs, + repair_double_inputs, ) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, @@ -58,7 +59,7 @@ def compile( dla_sram_size: int = _defaults.DLA_SRAM_SIZE, dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, - truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, min_block_size: int = _defaults.MIN_BLOCK_SIZE, torch_executed_ops: Optional[Collection[Target]] = None, @@ -74,7 +75,7 @@ def compile( hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, **kwargs: Any, ) -> torch.fx.GraphModule: - """Compile a TorchScript module for NVIDIA GPUs using TensorRT + """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT Takes a existing TorchScript module and a set of settings to configure the compiler and will convert methods to JIT Graphs which call equivalent TensorRT engines @@ -115,7 +116,7 @@ def compile( dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution - truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32 + truncate_double (bool): Truncate weights provided in double (float64) to float32 calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT @@ -138,6 +139,19 @@ def compile( if debug: set_log_level(logger.parent, logging.DEBUG) + if "truncate_long_and_double" in kwargs.keys(): + if truncate_double is not _defaults.TRUNCATE_DOUBLE: + raise ValueError( + 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' + ) + else: + truncate_double = kwargs["truncate_long_and_double"] + warnings.warn( + 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', + DeprecationWarning, + stacklevel=2, + ) + engine_capability = EngineCapability._from(engine_capability) if torch_executed_modules is not None and torch_executed_modules: @@ -185,7 +199,7 @@ def compile( "version_compatible": version_compatible, "optimization_level": optimization_level, "use_python_runtime": use_python_runtime, - "truncate_long_and_double": truncate_long_and_double, + "truncate_double": truncate_double, "use_fast_partitioner": use_fast_partitioner, "num_avg_timing_iters": num_avg_timing_iters, "enable_experimental_decompositions": enable_experimental_decompositions, @@ -349,8 +363,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: assert submodule_inputs is not None # Handle long/double inputs if requested by the user - if settings.truncate_long_and_double: - submodule_inputs = repair_long_or_double_inputs( + if settings.truncate_double: + submodule_inputs = repair_double_inputs( partitioned_module, submodule, submodule_inputs, @@ -423,7 +437,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: def convert_module_to_trt_engine( exported_program: ExportedProgram, - inputs: Optional[Sequence[Input | torch.Tensor]] = None, + inputs: Tuple[Any, ...], + *, enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, @@ -436,7 +451,7 @@ def convert_module_to_trt_engine( version_compatible: bool = _defaults.VERSION_COMPATIBLE, optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, - truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, device: Device = Device._current_device(), @@ -451,6 +466,7 @@ def convert_module_to_trt_engine( dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, calibrator: object = None, allow_shape_tensors: bool = False, + **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -488,7 +504,7 @@ def convert_module_to_trt_engine( use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the argument as None - truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32 + truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32 use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system enable_experimental_decompositions (bool): Whether to enable all core aten decompositions or only a selected subset of them @@ -512,6 +528,19 @@ def convert_module_to_trt_engine( if debug: set_log_level(logger.parent, logging.DEBUG) + if "truncate_long_and_double" in kwargs.keys(): + if truncate_double is not _defaults.TRUNCATE_DOUBLE: + raise ValueError( + 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' + ) + else: + truncate_double = kwargs["truncate_long_and_double"] + warnings.warn( + 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', + DeprecationWarning, + stacklevel=2, + ) + input_list = list(inputs) if inputs is not None else [] torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() # Prepare torch_trt inputs @@ -531,7 +560,7 @@ def convert_module_to_trt_engine( "version_compatible": version_compatible, "optimization_level": optimization_level, "use_python_runtime": use_python_runtime, - "truncate_long_and_double": truncate_long_and_double, + "truncate_double": truncate_double, "use_fast_partitioner": use_fast_partitioner, "enable_experimental_decompositions": enable_experimental_decompositions, "device": device, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 27db215466..97430137c0 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -18,7 +18,7 @@ VERSION_COMPATIBLE = False OPTIMIZATION_LEVEL = None SPARSE_WEIGHTS = False -TRUNCATE_LONG_AND_DOUBLE = False +TRUNCATE_DOUBLE = False USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 3aee629812..d3cf6709ef 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -23,7 +23,7 @@ REFIT, REQUIRE_FULL_COMPILATION, SPARSE_WEIGHTS, - TRUNCATE_LONG_AND_DOUBLE, + TRUNCATE_DOUBLE, USE_FAST_PARTITIONER, USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, @@ -50,7 +50,7 @@ class CompilationSettings: use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the argument as None - truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32 + truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32 use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system enable_experimental_decompositions (bool): Whether to enable all core aten decompositions or only a selected subset of them @@ -81,7 +81,7 @@ class CompilationSettings: version_compatible: bool = VERSION_COMPATIBLE optimization_level: Optional[int] = OPTIMIZATION_LEVEL use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME - truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE + truncate_double: bool = TRUNCATE_DOUBLE use_fast_partitioner: bool = USE_FAST_PARTITIONER enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS device: Device = field(default_factory=default_device) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 050a62ef3e..6a8a7e3d41 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -21,19 +21,20 @@ from torch.fx.node import Argument, Node, Target, _get_qualified_name from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +import tensorrt as trt logger = logging.getLogger(__name__) LegacyConverterImplSignature = Callable[ [ - TRTNetwork, + trt.INetworkDefinition, Target, Tuple[Argument, ...], Dict[str, Argument], str, ], - Union[TRTTensor, Sequence[TRTTensor]], + Union[trt.ITensor, Sequence[trt.ITensor]], ] DynamoConverterImplSignature = Callable[ @@ -44,7 +45,7 @@ Dict[str, Argument], str, ], - Union[TRTTensor, Sequence[TRTTensor]], + Union[trt.ITensor, Sequence[trt.ITensor]], ] ConverterImplSignature = Union[ diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 9a75add755..59d2c5d6c0 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np -import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -26,6 +25,7 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER +import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -498,6 +498,9 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: ) for i, output in enumerate(outputs): + name = f"output{i}" + + output_dtype = dtype.unknown if any( op_name in output.name.split("_") for op_name in ( @@ -514,16 +517,20 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: "any", ) ): - output_bool = True - else: - output_bool = False - name = f"output{i}" - output.name = name - self.ctx.net.mark_output(output) - if output_bool: - output.dtype = trt.DataType.BOOL + output_dtype = dtype.b elif self.output_dtypes is not None: - output.dtype = self.output_dtypes[i].to(trt.DataType) + if self.output_dtypes[i] == dtype.i64: + output = self.ctx.net.add_cast( + output, dtype.i64.to(trt.DataType) + ).get_output(0) + output_dtype = dtype.i64 + else: + output_dtype = self.output_dtypes[i] + + self.ctx.net.mark_output(output) + if output_dtype is not dtype.unknown: + output.dtype = output_dtype.to(trt.DataType, use_default=True) + output.name = name self._output_names.append(name) _LOGGER.debug( diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 03c7bd0ca0..5351f02bb6 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -3,4 +3,4 @@ from ._ConversionContext import ConversionContext from ._ConverterRegistry import * # noqa: F403 from ._TRTInterpreter import * # noqa: F403 -from .truncate_long_and_double import repair_long_or_double_inputs +from .truncate_double import repair_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 93fc73b4e2..ea078c7d64 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -4,7 +4,6 @@ import logging from typing import List, Sequence -import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -18,6 +17,8 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs +import tensorrt as trt + logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ def infer_module_output_dtypes( module: torch.fx.GraphModule, inputs: Sequence[Input], device: Device, - truncate_long_and_double: bool = False, + truncate_double: bool = False, ) -> List[dtype]: torch_inputs = get_torch_inputs(inputs, device) module = module.to(device.to(torch.device)) @@ -48,10 +49,8 @@ def infer_module_output_dtypes( else: output_ = torch.tensor(output) - if truncate_long_and_double and output_.dtype == dtype.float64: + if truncate_double and output_.dtype == dtype.float64: output_dtypes.append(dtype.float32) - elif truncate_long_and_double and output_.dtype == dtype.int64: - output_dtypes.append(dtype.int32) else: output_dtypes.append(dtype._from(output_.dtype)) @@ -75,7 +74,7 @@ def interpret_module_to_result( module, inputs, settings.device, - truncate_long_and_double=settings.truncate_long_and_double, + truncate_double=settings.truncate_double, ) interpreter = TRTInterpreter( diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c566d9de0a..fe6bd11579 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,3 +1,5 @@ +# mypy: disallow-untyped-decorators=False + import logging import operator from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union @@ -858,6 +860,7 @@ def validate_dtype(to_copy_node: Node) -> bool: allowed_casts = { torch.float, torch.int32, + torch.int64, torch.bool, torch.int8, torch.float16, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 04f048c5f3..8f11e7fb91 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -297,21 +297,11 @@ def get_trt_tensor( A TensorRT ITensor that represents the given value. """ # If the input is 64-bit, cast it to 32-bit for TRT freezing - if ( - isinstance(input_val, torch.Tensor) - and ctx.compilation_settings.truncate_long_and_double - ): - if input_val.dtype == torch.int64: - input_val = input_val.to(torch.int32) - elif input_val.dtype == torch.float64: + if isinstance(input_val, torch.Tensor) and ctx.compilation_settings.truncate_double: + if input_val.dtype == torch.float64: input_val = input_val.to(torch.float32) - elif ( - isinstance(input_val, np.ndarray) - and ctx.compilation_settings.truncate_long_and_double - ): - if input_val.dtype == np.int64: - input_val = input_val.astype(np.int32) - elif input_val.dtype == np.float64: + elif isinstance(input_val, np.ndarray) and ctx.compilation_settings.truncate_double: + if input_val.dtype == np.float64: input_val = input_val.astype(np.float32) if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)): diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index ac9faf9f4d..ee3354ae08 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -9,7 +9,8 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTTensor + +import tensorrt as trt def embedding( @@ -17,17 +18,13 @@ def embedding( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, - weight: TRTTensor, + input: trt.ITensor, + weight: trt.ITensor, scale_grad_by_freq: bool, sparse: bool, -) -> TRTTensor: +) -> trt.ITensor: indices_tensor = input embedding_tensor = weight - if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64: - raise RuntimeError( - "The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT." - ) indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor") embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor") # unsupported parameters @@ -52,15 +49,15 @@ def embedding_bag( target: Target, source_ir: Optional[SourceIR], name: str, - weight: TRTTensor, - indices: TRTTensor, + weight: trt.ITensor, + indices: trt.ITensor, offsets: Union[torch.Tensor, np.ndarray, Sequence[int]], scale_grad_by_freq: bool, mode: int, sparse: bool, - per_sample_weights: Optional[TRTTensor], + per_sample_weights: Optional[trt.ITensor], include_last_offset: bool, -) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: +) -> Tuple[trt.ITensor, trt.ITensor, trt.ITensor, trt.ITensor]: """ This function is for calculating embedding bags. @@ -143,7 +140,7 @@ def embedding_bag( # however, pytorch doc says if `include_last_offset` is True, the size of offsets # is equal to the number of bags + 1. The last element is the size of the input, # or the ending index position of the last bag (sequence). - offsets[-1] = indices.shape[0] + offsets[-1] = indices.shape[0] # type: ignore[index] # separately reduce embeddings for different bags reduced_embed = [] diff --git a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py b/py/torch_tensorrt/dynamo/conversion/truncate_double.py similarity index 95% rename from py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py rename to py/torch_tensorrt/dynamo/conversion/truncate_double.py index d5670be1db..832a7ebbb6 100644 --- a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py +++ b/py/torch_tensorrt/dynamo/conversion/truncate_double.py @@ -59,13 +59,12 @@ def _repair_64bit_input( dtype: Data type of tensor at position in submodule (double/long) """ assert dtype in ( - torch.int64, torch.float64, - ), f"dtype argument must be torch.int64 or torch.float64, got {dtype}" + ), f"dtype argument must be torch.float64, got {dtype}" # Determine target data type in 32 and 64 bit forms dtype_64bit = dtype - dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32 + dtype_32bit = torch.float32 # Find the node representing the submodule in the graph module_node = None @@ -143,7 +142,7 @@ def _repair_64bit_input( cast_node_64bit = gm.graph.call_function( torch.ops.aten._to_copy.default, args=(get_node,), - kwargs={"dtype": torch.int64}, + kwargs={"dtype": torch.float64}, ) get_node.replace_all_uses_with( @@ -157,7 +156,7 @@ def _repair_64bit_input( gm.recompile() -def repair_long_or_double_inputs( +def repair_double_inputs( parent_graph: torch.fx.GraphModule, submodule: torch.fx.GraphModule, submodule_inputs: Sequence[Input], @@ -189,7 +188,7 @@ def repair_long_or_double_inputs( # If the data type of the input is long/double, insert necessary # casts to replace the operation - if param.dtype in (torch.int64, torch.float64): + if param.dtype == torch.float64: # Ensure outputs are only repaired once per submodule to avoid # unnecessary ops showing up in the graph if not repaired_outputs_once: @@ -206,7 +205,7 @@ def repair_long_or_double_inputs( repaired_outputs_once = True # Repair submodule inputs in accordance with inserted casts - dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32 + dtype_32bit = torch.float32 submodule_torch_inputs = ( list(submodule_torch_inputs[:position]) + [ diff --git a/py/torch_tensorrt/logging.py b/py/torch_tensorrt/logging.py index 4cbb686b0d..e75998b870 100644 --- a/py/torch_tensorrt/logging.py +++ b/py/torch_tensorrt/logging.py @@ -1,6 +1,7 @@ import logging from typing import Any +import torch from torch_tensorrt._features import ENABLED_FEATURES import tensorrt as trt @@ -51,6 +52,12 @@ def __enter__(self) -> None: self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.InternalError) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + self.rt_level = torch.ops.tensorrt.get_logging_level() + torch.ops.tensorrt.set_logging_level( + int(trt.ILogger.Severity.INTERNAL_ERROR) + ) + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) @@ -59,6 +66,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: ts_logging.set_reportable_log_level(self.ts_level) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_logging_level(self.rt_level) + class errors: """Context-manager to limit displayed log messages to just errors and above @@ -79,6 +89,10 @@ def __enter__(self) -> None: self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Error) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + self.rt_level = torch.ops.tensorrt.get_logging_level() + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR)) + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) @@ -87,6 +101,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: ts_logging.set_reportable_log_level(self.ts_level) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_logging_level(self.rt_level) + class warnings: """Context-manager to limit displayed log messages to just warnings and above @@ -107,6 +124,10 @@ def __enter__(self) -> None: self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Warning) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + self.rt_level = torch.ops.tensorrt.get_logging_level() + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING)) + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) @@ -115,6 +136,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: ts_logging.set_reportable_log_level(self.ts_level) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_logging_level(self.rt_level) + class info: """Context-manager to display all info and greater severity messages @@ -135,6 +159,10 @@ def __enter__(self) -> None: self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Info) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + self.rt_level = torch.ops.tensorrt.get_logging_level() + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO)) + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) @@ -143,6 +171,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: ts_logging.set_reportable_log_level(self.ts_level) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_logging_level(self.rt_level) + class debug: """Context-manager to display full debug information through the logger @@ -163,6 +194,10 @@ def __enter__(self) -> None: self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Debug) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + self.rt_level = torch.ops.tensorrt.get_logging_level() + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE)) + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) @@ -171,6 +206,9 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: ts_logging.set_reportable_log_level(self.ts_level) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_logging_level(self.rt_level) + class graphs: """Context-manager to display the results of intermediate lowering passes @@ -192,6 +230,10 @@ def __enter__(self) -> None: self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Graph) + elif ENABLED_FEATURES.torch_tensorrt_runtime: + self.rt_level = torch.ops.tensorrt.get_logging_level() + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE) + 1) + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) @@ -199,3 +241,6 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) + + elif ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_logging_level(self.rt_level) diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 08cefcd524..0f138c7100 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -221,7 +221,7 @@ def forward(self, x, y): inputs, min_block_size=1, pass_through_build_failures=True, - truncate_long_and_double=True, + truncate_double=True, debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() @@ -240,16 +240,14 @@ def forward(self, x, y): def test_int64_input_partial_support(self): class PartiallySupportedMultiOp(torch.nn.Module): def forward(self, x, y): - return torch.ops.aten.div.Tensor_mode( - x, torch.ops.aten.add.Tensor(y, y), rounding_mode=None - ) + return torch.ops.aten.abs(torch.ops.aten.add.Tensor(x, y)) fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) unexpected_ops = {torch.ops.aten.add.Tensor} inputs = [ - torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(), - torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(), + torch.randint(-40, 40, (1, 16, 7, 5), dtype=torch.long).cuda(), + torch.randint(1, 40, (1, 16, 7, 5), dtype=torch.long).cuda(), ] ( @@ -296,8 +294,9 @@ def forward(self, x, y): inputs, min_block_size=1, pass_through_build_failures=True, - truncate_long_and_double=True, + truncate_double=False, debug=True, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index dc2620197e..a11b414f64 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -223,7 +223,7 @@ def forward(self, x): inputs, min_block_size=1, pass_through_build_failures=True, - truncate_long_and_double=True, + truncate_double=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index ef034c914f..7ce3939371 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -99,8 +99,6 @@ def run_test( if not isinstance(ref, torch.Tensor): ref = torch.tensor([ref]) ref = ref.cpu() # to_dtype test has cases with gpu output - if ref.dtype == torch.int64: - ref = ref.int() # convert torch.max's index output tensor to int32 torch.testing.assert_close( out.cpu(), ref, @@ -238,7 +236,7 @@ def run_test( # We replicate this behavior here compilation_settings = CompilationSettings( enabled_precisions={dtype._from(precision)}, - truncate_long_and_double=True, + truncate_double=True, debug=True, ) @@ -250,9 +248,13 @@ def run_test( mod, input_specs, compilation_settings.device, - truncate_long_and_double=compilation_settings.truncate_long_and_double, + truncate_double=compilation_settings.truncate_double, ) + _LOGGER.debug(f"Compilation settings: {compilation_settings}") + _LOGGER.debug(f"Inputs: {input_specs}") + _LOGGER.debug(f"Output types: {output_dtypes}") + interp = TRTInterpreter( mod, input_specs, @@ -289,7 +291,7 @@ def run_test_with_dynamic_shape( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here - compilation_settings = CompilationSettings(truncate_long_and_double=True) + compilation_settings = CompilationSettings(truncate_double=True) interp = TRTInterpreter( mod, diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index 84234db857..88260ba771 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -64,7 +64,7 @@ def forward(self, x): precision=torch.float, ) - def test_to_copy_unsupported(self): + def test_to_copy_i64b(self): class ToCopy64Bit(nn.Module): def forward(self, x): y = torch.ops.aten._to_copy.default(x, dtype=torch.int64) @@ -72,11 +72,10 @@ def forward(self, x): inputs = [torch.randn((1, 3, 10)).int()] - with self.assertRaises(UnsupportedOperatorException): - self.run_test( - ToCopy64Bit(), - inputs, - ) + self.run_test( + ToCopy64Bit(), + inputs, + ) def test_to_copy_multiple_returns(self): class ToCopyReturns(nn.Module): diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 650ef70d3f..81e8661878 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -469,7 +469,7 @@ def forward(self, x): "torch_compile", inputs, min_block_size=1, - truncate_long_and_double=True, + truncate_double=True, pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() diff --git a/tests/py/dynamo/models/conftest.py b/tests/py/dynamo/models/conftest.py index 3fbabc3360..0dedfa3d2f 100644 --- a/tests/py/dynamo/models/conftest.py +++ b/tests/py/dynamo/models/conftest.py @@ -1,3 +1,5 @@ +# type: ignore + import pytest @@ -7,7 +9,7 @@ def pytest_addoption(parser): metavar="Internal Representation", nargs=1, type=str, - required=True, + required=False, help="IR to compile with", choices=["dynamo", "torch_compile"], ) @@ -15,4 +17,5 @@ def pytest_addoption(parser): @pytest.fixture def ir(request): - return request.config.getoption("--ir")[0] + ir_opt = request.config.getoption("--ir") + return ir_opt[0] if ir_opt else "dynamo" diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py new file mode 100644 index 0000000000..e88c85de75 --- /dev/null +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -0,0 +1,178 @@ +# type: ignore + +import math +import unittest + +import torch +import torch_tensorrt +from torch import nn +from torch.nn.parameter import Parameter, UninitializedParameter +from torch.testing._internal.common_utils import TestCase, run_tests + +from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing + + +class Test64BitSupport(TestCase): + + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT Runtime is not available", + ) + def test_truncate_f64_weights_cpp(self): + class f64_weight_module(nn.Module): + def __init__(self, h, w): + super().__init__() + factory_kwargs = {"dtype": torch.float64} + self.weight = Parameter(torch.empty((h, w), **factory_kwargs)) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, x): + return x + self.weight + + h, w = 4, 4 + in_tensor = torch.randn((h, w), dtype=torch.float64, device="cuda") + mod = f64_weight_module(h, w).to("cuda") + + exp_mod = torch.export.export(mod, (in_tensor,)) + trt_mod = torch_tensorrt.dynamo.compile( + exp_mod, + inputs=[in_tensor], + pass_through_build_failures=True, + truncate_double=True, + min_block_size=1, + use_python_runtime=False, + ) + + torch_model_results = mod(in_tensor) + optimized_model_results = trt_mod(in_tensor) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Torch outputs and TRT outputs don't match close enough.", + ) + + def test_truncate_f64_weights_py(self): + class f64_weight_module(nn.Module): + def __init__(self, h, w): + super().__init__() + factory_kwargs = {"dtype": torch.float64} + self.weight = Parameter(torch.empty((h, w), **factory_kwargs)) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, x): + return x + self.weight + + h, w = 4, 4 + in_tensor = torch.randn((h, w), dtype=torch.float64, device="cuda") + mod = f64_weight_module(h, w).to("cuda") + + exp_mod = torch.export.export(mod, (in_tensor,)) + trt_mod = torch_tensorrt.dynamo.compile( + exp_mod, + inputs=[in_tensor], + pass_through_build_failures=True, + truncate_double=True, + min_block_size=1, + use_python_runtime=True, + ) + + torch_model_results = mod(in_tensor) + with torch_tensorrt.logging.debug(): + optimized_model_results = trt_mod(in_tensor) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Torch outputs and TRT outputs don't match close enough.", + ) + + @unittest.skipIf( + not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime, + "Torch-TensorRT Runtime is not available", + ) + def test_native_i64_cpp(self): + class i64_module(nn.Module): + def __init__(self, h, w): + super().__init__() + self.const_tensor = Parameter( + torch.randint(0, 100, (h, w), dtype=torch.int64), + requires_grad=False, + ) + + def forward(self, x): + return (x + self.const_tensor) * 10 + + h, w = 4, 4 + in_tensor = torch.randint(0, 100, (h, w), dtype=torch.int64, device="cuda") + mod = i64_module(h, w).to("cuda") + + exp_mod = torch.export.export(mod, (in_tensor,)) + trt_mod = torch_tensorrt.dynamo.compile( + exp_mod, + inputs=[in_tensor], + pass_through_build_failures=True, + truncate_double=False, + min_block_size=1, + use_python_runtime=False, + ) + + torch_model_results = mod(in_tensor) + optimized_model_results = trt_mod(in_tensor) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Torch outputs and TRT outputs don't match close enough.", + ) + + def test_native_i64_py(self): + class i64_module(nn.Module): + def __init__(self, h, w): + super().__init__() + self.const_tensor = Parameter( + torch.randint(0, 100, (h, w), dtype=torch.int64), + requires_grad=False, + ) + + def forward(self, x): + return (x + self.const_tensor) * 10 + + h, w = 4, 4 + in_tensor = torch.randint(0, 100, (h, w), dtype=torch.int64, device="cuda") + mod = i64_module(h, w).to("cuda") + + exp_mod = torch.export.export(mod, (in_tensor,)) + trt_mod = torch_tensorrt.dynamo.compile( + exp_mod, + inputs=[in_tensor], + pass_through_build_failures=True, + truncate_double=False, + min_block_size=1, + use_python_runtime=True, + ) + + torch_model_results = mod(in_tensor) + optimized_model_results = trt_mod(in_tensor) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Torch outputs and TRT outputs don't match close enough.", + ) diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 47d58ee7ed..4c6b98e555 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -1,3 +1,5 @@ +# type: ignore + import unittest import pytest diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index b986a4d158..2d45af2b49 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -1,3 +1,5 @@ +# type: ignore + import unittest import pytest @@ -124,7 +126,7 @@ def test_bert_base_uncased(ir): ], "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, - "truncate_long_and_double": True, + "truncate_double": True, "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 32ec1315ff..f16d60c9be 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -127,7 +127,7 @@ def test_bert_base_uncased(ir): ], "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, - "truncate_long_and_double": True, + "truncate_double": True, "ir": ir, "min_block_size": 10, } diff --git a/tests/py/dynamo/runtime/test_compilation_settings.py b/tests/py/dynamo/runtime/test_compilation_settings.py index daa67ad032..47f700038a 100644 --- a/tests/py/dynamo/runtime/test_compilation_settings.py +++ b/tests/py/dynamo/runtime/test_compilation_settings.py @@ -36,7 +36,7 @@ def forward(self, x): refit=True, num_avg_timing_iters=5, workspace_size=1 << 10, - truncate_long_and_double=True, + truncate_double=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu()