Skip to content

Commit

Permalink
chore: Deprecate truncate_long_and_double for the dynamo frontend
Browse files Browse the repository at this point in the history
`truncate_long_and_double` has been deprecated in favor of
`truncate_double` as int64 is natively supported

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 29, 2024
1 parent a6770cd commit da25720
Show file tree
Hide file tree
Showing 21 changed files with 102 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- --fix=lf
exclude: ^docs
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.1
rev: v14.0.6
hooks:
- id: clang-format
types_or: [c++, c, cuda]
Expand Down
4 changes: 2 additions & 2 deletions core/util/trt_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use
// Acceptable range for pos is [-d.nbDims - 1, d.nbDims]
TORCHTRT_ASSERT(
pos >= (-d.nbDims - 1) && pos <= d.nbDims,
"ERROR: Index to unsqueeze is out of bounds. " << "Expected value in range [" << (-d.nbDims - 1) << ", "
<< d.nbDims << "], but got " << pos);
"ERROR: Index to unsqueeze is out of bounds. "
<< "Expected value in range [" << (-d.nbDims - 1) << ", " << d.nbDims << "], but got " << pos);

// Unsqueeze with negative dimensions creates a new dimension at that index
pos = (pos < 0) ? (pos + d.nbDims + 1) : pos;
Expand Down
51 changes: 40 additions & 11 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections.abc
import logging
import warnings
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union

import torch
Expand All @@ -22,7 +23,7 @@
UnsupportedOperatorException,
convert_module,
interpret_module_to_result,
repair_long_or_double_inputs,
repair_double_inputs,
)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
Expand Down Expand Up @@ -58,7 +59,7 @@ def compile(
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
torch_executed_ops: Optional[Collection[Target]] = None,
Expand All @@ -74,7 +75,7 @@ def compile(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Takes a existing TorchScript module and a set of settings to configure the compiler
and will convert methods to JIT Graphs which call equivalent TensorRT engines
Expand Down Expand Up @@ -115,7 +116,7 @@ def compile(
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
truncate_double (bool): Truncate weights provided in double (float64) to float32
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
Expand All @@ -138,6 +139,19 @@ def compile(
if debug:
set_log_level(logger.parent, logging.DEBUG)

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
)
else:
truncate_double = kwargs["truncate_long_and_double"]
warnings.warn(
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
DeprecationWarning,
stacklevel=2,
)

engine_capability = EngineCapability._from(engine_capability)

if torch_executed_modules is not None and torch_executed_modules:
Expand Down Expand Up @@ -185,7 +199,7 @@ def compile(
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"truncate_double": truncate_double,
"use_fast_partitioner": use_fast_partitioner,
"num_avg_timing_iters": num_avg_timing_iters,
"enable_experimental_decompositions": enable_experimental_decompositions,
Expand Down Expand Up @@ -349,8 +363,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
if settings.truncate_double:
submodule_inputs = repair_double_inputs(
partitioned_module,
submodule,
submodule_inputs,
Expand Down Expand Up @@ -423,7 +437,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

def convert_module_to_trt_engine(
exported_program: ExportedProgram,
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
inputs: Tuple[Any, ...],
*,
enabled_precisions: (
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
) = _defaults.ENABLED_PRECISIONS,
Expand All @@ -436,7 +451,7 @@ def convert_module_to_trt_engine(
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME,
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
device: Device = Device._current_device(),
Expand All @@ -451,6 +466,7 @@ def convert_module_to_trt_engine(
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
calibrator: object = None,
allow_shape_tensors: bool = False,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -488,7 +504,7 @@ def convert_module_to_trt_engine(
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
Expand All @@ -512,6 +528,19 @@ def convert_module_to_trt_engine(
if debug:
set_log_level(logger.parent, logging.DEBUG)

if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"'
)
else:
truncate_double = kwargs["truncate_long_and_double"]
warnings.warn(
'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version',
DeprecationWarning,
stacklevel=2,
)

input_list = list(inputs) if inputs is not None else []
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
# Prepare torch_trt inputs
Expand All @@ -531,7 +560,7 @@ def convert_module_to_trt_engine(
"version_compatible": version_compatible,
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"truncate_double": truncate_double,
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"device": device,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
SPARSE_WEIGHTS = False
TRUNCATE_LONG_AND_DOUBLE = False
TRUNCATE_DOUBLE = False
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
REFIT,
REQUIRE_FULL_COMPILATION,
SPARSE_WEIGHTS,
TRUNCATE_LONG_AND_DOUBLE,
TRUNCATE_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
Expand All @@ -50,7 +50,7 @@ class CompilationSettings:
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
Expand Down Expand Up @@ -81,7 +81,7 @@ class CompilationSettings:
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
truncate_double: bool = TRUNCATE_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
from ._TRTInterpreter import * # noqa: F403
from .truncate_long_and_double import repair_long_or_double_inputs
from .truncate_double import repair_double_inputs
11 changes: 5 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
from typing import List, Sequence

import tensorrt as trt
import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
Expand All @@ -18,14 +17,16 @@
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_torch_inputs

import tensorrt as trt

logger = logging.getLogger(__name__)


def infer_module_output_dtypes(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
device: Device,
truncate_long_and_double: bool = False,
truncate_double: bool = False,
) -> List[dtype]:
torch_inputs = get_torch_inputs(inputs, device)
module = module.to(device.to(torch.device))
Expand All @@ -48,10 +49,8 @@ def infer_module_output_dtypes(
else:
output_ = torch.tensor(output)

if truncate_long_and_double and output_.dtype == dtype.float64:
if truncate_double and output_.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
elif truncate_long_and_double and output_.dtype == dtype.int64:
output_dtypes.append(dtype.int32)
else:
output_dtypes.append(dtype._from(output_.dtype))

Expand All @@ -75,7 +74,7 @@ def interpret_module_to_result(
module,
inputs,
settings.device,
truncate_long_and_double=settings.truncate_long_and_double,
truncate_double=settings.truncate_double,
)

interpreter = TRTInterpreter(
Expand Down
18 changes: 4 additions & 14 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,21 +297,11 @@ def get_trt_tensor(
A TensorRT ITensor that represents the given value.
"""
# If the input is 64-bit, cast it to 32-bit for TRT freezing
if (
isinstance(input_val, torch.Tensor)
and ctx.compilation_settings.truncate_long_and_double
):
if input_val.dtype == torch.int64:
input_val = input_val.to(torch.int32)
elif input_val.dtype == torch.float64:
if isinstance(input_val, torch.Tensor) and ctx.compilation_settings.truncate_double:
if input_val.dtype == torch.float64:
input_val = input_val.to(torch.float32)
elif (
isinstance(input_val, np.ndarray)
and ctx.compilation_settings.truncate_long_and_double
):
if input_val.dtype == np.int64:
input_val = input_val.astype(np.int32)
elif input_val.dtype == np.float64:
elif isinstance(input_val, np.ndarray) and ctx.compilation_settings.truncate_double:
if input_val.dtype == np.float64:
input_val = input_val.astype(np.float32)

if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)):
Expand Down
23 changes: 10 additions & 13 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,22 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor

import tensorrt as trt


def embedding(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: TRTTensor,
input: trt.ITensor,
weight: trt.ITensor,
scale_grad_by_freq: bool,
sparse: bool,
) -> TRTTensor:
) -> trt.ITensor:
indices_tensor = input
embedding_tensor = weight
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
raise RuntimeError(
"The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT."
)
indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor")
embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor")
# unsupported parameters
Expand All @@ -52,15 +49,15 @@ def embedding_bag(
target: Target,
source_ir: Optional[SourceIR],
name: str,
weight: TRTTensor,
indices: TRTTensor,
weight: trt.ITensor,
indices: trt.ITensor,
offsets: Union[torch.Tensor, np.ndarray, Sequence[int]],
scale_grad_by_freq: bool,
mode: int,
sparse: bool,
per_sample_weights: Optional[TRTTensor],
per_sample_weights: Optional[trt.ITensor],
include_last_offset: bool,
) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
) -> Tuple[trt.ITensor, trt.ITensor, trt.ITensor, trt.ITensor]:
"""
This function is for calculating embedding bags.
Expand Down Expand Up @@ -143,7 +140,7 @@ def embedding_bag(
# however, pytorch doc says if `include_last_offset` is True, the size of offsets
# is equal to the number of bags + 1. The last element is the size of the input,
# or the ending index position of the last bag (sequence).
offsets[-1] = indices.shape[0]
offsets[-1] = indices.shape[0] # type: ignore[index]

# separately reduce embeddings for different bags
reduced_embed = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _repair_64bit_input(
gm.recompile()


def repair_long_or_double_inputs(
def repair_double_inputs(
parent_graph: torch.fx.GraphModule,
submodule: torch.fx.GraphModule,
submodule_inputs: Sequence[Input],
Expand Down
Loading

0 comments on commit da25720

Please sign in to comment.