diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 68a2355c3a..32b0ca65d7 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -2,6 +2,7 @@ import collections.abc import logging +import warnings from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch @@ -22,7 +23,7 @@ UnsupportedOperatorException, convert_module, interpret_module_to_result, - repair_long_or_double_inputs, + repair_double_inputs, ) from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, @@ -58,7 +59,7 @@ def compile( dla_sram_size: int = _defaults.DLA_SRAM_SIZE, dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, - truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, min_block_size: int = _defaults.MIN_BLOCK_SIZE, torch_executed_ops: Optional[Collection[Target]] = None, @@ -74,7 +75,7 @@ def compile( hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, **kwargs: Any, ) -> torch.fx.GraphModule: - """Compile a TorchScript module for NVIDIA GPUs using TensorRT + """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT Takes a existing TorchScript module and a set of settings to configure the compiler and will convert methods to JIT Graphs which call equivalent TensorRT engines @@ -115,7 +116,7 @@ def compile( dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution - truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32 + truncate_double (bool): Truncate weights provided in double (float64) to float32 calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT @@ -138,6 +139,19 @@ def compile( if debug: set_log_level(logger.parent, logging.DEBUG) + if "truncate_long_and_double" in kwargs.keys(): + if truncate_double is not _defaults.TRUNCATE_DOUBLE: + raise ValueError( + 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' + ) + else: + truncate_double = kwargs["truncate_long_and_double"] + warnings.warn( + 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', + DeprecationWarning, + stacklevel=2, + ) + engine_capability = EngineCapability._from(engine_capability) if torch_executed_modules is not None and torch_executed_modules: @@ -185,7 +199,7 @@ def compile( "version_compatible": version_compatible, "optimization_level": optimization_level, "use_python_runtime": use_python_runtime, - "truncate_long_and_double": truncate_long_and_double, + "truncate_double": truncate_double, "use_fast_partitioner": use_fast_partitioner, "num_avg_timing_iters": num_avg_timing_iters, "enable_experimental_decompositions": enable_experimental_decompositions, @@ -349,8 +363,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: assert submodule_inputs is not None # Handle long/double inputs if requested by the user - if settings.truncate_long_and_double: - submodule_inputs = repair_long_or_double_inputs( + if settings.truncate_double: + submodule_inputs = repair_double_inputs( partitioned_module, submodule, submodule_inputs, @@ -423,7 +437,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: def convert_module_to_trt_engine( exported_program: ExportedProgram, - inputs: Optional[Sequence[Input | torch.Tensor]] = None, + inputs: Tuple[Any, ...], + *, enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, @@ -436,7 +451,7 @@ def convert_module_to_trt_engine( version_compatible: bool = _defaults.VERSION_COMPATIBLE, optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME, - truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE, + truncate_double: bool = _defaults.TRUNCATE_DOUBLE, use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, device: Device = Device._current_device(), @@ -451,6 +466,7 @@ def convert_module_to_trt_engine( dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, calibrator: object = None, allow_shape_tensors: bool = False, + **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -488,7 +504,7 @@ def convert_module_to_trt_engine( use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the argument as None - truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32 + truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32 use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system enable_experimental_decompositions (bool): Whether to enable all core aten decompositions or only a selected subset of them @@ -512,6 +528,19 @@ def convert_module_to_trt_engine( if debug: set_log_level(logger.parent, logging.DEBUG) + if "truncate_long_and_double" in kwargs.keys(): + if truncate_double is not _defaults.TRUNCATE_DOUBLE: + raise ValueError( + 'Provided configuration for "truncate_double" and deprecated API "truncate_long_and_double", please only use "truncate_double"' + ) + else: + truncate_double = kwargs["truncate_long_and_double"] + warnings.warn( + 'Compiler option "truncate_long_and_double" is deprecated in favor of "truncate_double" as int64 is now natively supported, this option will be removed in the next version', + DeprecationWarning, + stacklevel=2, + ) + input_list = list(inputs) if inputs is not None else [] torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() # Prepare torch_trt inputs @@ -531,7 +560,7 @@ def convert_module_to_trt_engine( "version_compatible": version_compatible, "optimization_level": optimization_level, "use_python_runtime": use_python_runtime, - "truncate_long_and_double": truncate_long_and_double, + "truncate_double": truncate_double, "use_fast_partitioner": use_fast_partitioner, "enable_experimental_decompositions": enable_experimental_decompositions, "device": device, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 27db215466..97430137c0 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -18,7 +18,7 @@ VERSION_COMPATIBLE = False OPTIMIZATION_LEVEL = None SPARSE_WEIGHTS = False -TRUNCATE_LONG_AND_DOUBLE = False +TRUNCATE_DOUBLE = False USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 3aee629812..d3cf6709ef 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -23,7 +23,7 @@ REFIT, REQUIRE_FULL_COMPILATION, SPARSE_WEIGHTS, - TRUNCATE_LONG_AND_DOUBLE, + TRUNCATE_DOUBLE, USE_FAST_PARTITIONER, USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, @@ -50,7 +50,7 @@ class CompilationSettings: use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the argument as None - truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32 + truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32 use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system enable_experimental_decompositions (bool): Whether to enable all core aten decompositions or only a selected subset of them @@ -81,7 +81,7 @@ class CompilationSettings: version_compatible: bool = VERSION_COMPATIBLE optimization_level: Optional[int] = OPTIMIZATION_LEVEL use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME - truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE + truncate_double: bool = TRUNCATE_DOUBLE use_fast_partitioner: bool = USE_FAST_PARTITIONER enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS device: Device = field(default_factory=default_device) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 03c7bd0ca0..5351f02bb6 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -3,4 +3,4 @@ from ._ConversionContext import ConversionContext from ._ConverterRegistry import * # noqa: F403 from ._TRTInterpreter import * # noqa: F403 -from .truncate_long_and_double import repair_long_or_double_inputs +from .truncate_double import repair_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 93fc73b4e2..ea078c7d64 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -4,7 +4,6 @@ import logging from typing import List, Sequence -import tensorrt as trt import torch from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype @@ -18,6 +17,8 @@ from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule from torch_tensorrt.dynamo.utils import get_torch_inputs +import tensorrt as trt + logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ def infer_module_output_dtypes( module: torch.fx.GraphModule, inputs: Sequence[Input], device: Device, - truncate_long_and_double: bool = False, + truncate_double: bool = False, ) -> List[dtype]: torch_inputs = get_torch_inputs(inputs, device) module = module.to(device.to(torch.device)) @@ -48,10 +49,8 @@ def infer_module_output_dtypes( else: output_ = torch.tensor(output) - if truncate_long_and_double and output_.dtype == dtype.float64: + if truncate_double and output_.dtype == dtype.float64: output_dtypes.append(dtype.float32) - elif truncate_long_and_double and output_.dtype == dtype.int64: - output_dtypes.append(dtype.int32) else: output_dtypes.append(dtype._from(output_.dtype)) @@ -75,7 +74,7 @@ def interpret_module_to_result( module, inputs, settings.device, - truncate_long_and_double=settings.truncate_long_and_double, + truncate_double=settings.truncate_double, ) interpreter = TRTInterpreter( diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 04f048c5f3..8f11e7fb91 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -297,21 +297,11 @@ def get_trt_tensor( A TensorRT ITensor that represents the given value. """ # If the input is 64-bit, cast it to 32-bit for TRT freezing - if ( - isinstance(input_val, torch.Tensor) - and ctx.compilation_settings.truncate_long_and_double - ): - if input_val.dtype == torch.int64: - input_val = input_val.to(torch.int32) - elif input_val.dtype == torch.float64: + if isinstance(input_val, torch.Tensor) and ctx.compilation_settings.truncate_double: + if input_val.dtype == torch.float64: input_val = input_val.to(torch.float32) - elif ( - isinstance(input_val, np.ndarray) - and ctx.compilation_settings.truncate_long_and_double - ): - if input_val.dtype == np.int64: - input_val = input_val.astype(np.int32) - elif input_val.dtype == np.float64: + elif isinstance(input_val, np.ndarray) and ctx.compilation_settings.truncate_double: + if input_val.dtype == np.float64: input_val = input_val.astype(np.float32) if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)): diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index ac9faf9f4d..ee3354ae08 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -9,7 +9,8 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_numpy from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTTensor + +import tensorrt as trt def embedding( @@ -17,17 +18,13 @@ def embedding( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, - weight: TRTTensor, + input: trt.ITensor, + weight: trt.ITensor, scale_grad_by_freq: bool, sparse: bool, -) -> TRTTensor: +) -> trt.ITensor: indices_tensor = input embedding_tensor = weight - if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64: - raise RuntimeError( - "The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT." - ) indices_tensor = get_trt_tensor(ctx, indices_tensor, f"{name}_indices_tensor") embedding_tensor = get_trt_tensor(ctx, embedding_tensor, f"{name}_embedding_tensor") # unsupported parameters @@ -52,15 +49,15 @@ def embedding_bag( target: Target, source_ir: Optional[SourceIR], name: str, - weight: TRTTensor, - indices: TRTTensor, + weight: trt.ITensor, + indices: trt.ITensor, offsets: Union[torch.Tensor, np.ndarray, Sequence[int]], scale_grad_by_freq: bool, mode: int, sparse: bool, - per_sample_weights: Optional[TRTTensor], + per_sample_weights: Optional[trt.ITensor], include_last_offset: bool, -) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: +) -> Tuple[trt.ITensor, trt.ITensor, trt.ITensor, trt.ITensor]: """ This function is for calculating embedding bags. @@ -143,7 +140,7 @@ def embedding_bag( # however, pytorch doc says if `include_last_offset` is True, the size of offsets # is equal to the number of bags + 1. The last element is the size of the input, # or the ending index position of the last bag (sequence). - offsets[-1] = indices.shape[0] + offsets[-1] = indices.shape[0] # type: ignore[index] # separately reduce embeddings for different bags reduced_embed = [] diff --git a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py deleted file mode 100644 index 6a14c9f443..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -from typing import Optional, Sequence, Set - -import torch -from torch.fx.node import _get_qualified_name -from torch_tensorrt._enums import dtype -from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo.utils import get_torch_inputs - - -def _extract_downstream_get_nodes( - module_node: torch.fx.Node, output_indices: Set[int] -) -> Sequence[torch.fx.Node]: - """Extracts downstream users of a node which get the item at a particular index - - Certain module-type nodes have multiple outputs (tuple of outputs). This function - returns downstream nodes which call the _operator.getitem function, which extracts - the element at a particular index in the tuple - - Args: - module_node: FX module-type node to analyze - output_index: Indices in the module node output to search for - Returns: - List of nodes which get the item at the specified index in the module node output - """ - get_nodes = [] - - # Iterate over all downstream users of the node object - for user in module_node.users: - # If the user is a "get" node accessing the specified index, store it - if _get_qualified_name(user.target) == "_operator.getitem" and ( - user.args[1] in output_indices - ): - get_nodes.append(user) - - return get_nodes - - -def _repair_64bit_input( - gm: torch.fx.GraphModule, - position: int, - submodule_name: str, - submodule_outputs: Optional[torch.Tensor | Sequence[torch.Tensor]], - dtype: torch.dtype, -) -> None: - """Fixes a single Long/Double input to a TRT-accelerated subgraph - - In-Place modifies the provided graph - - Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, - inserts an upcast back to the 64-bit type for subsequent Torch operations - - Args: - gm: FX GraphModule enclosing the TRT subgraph - position: Index in the submodule inputs at which the long or double input is found - submodule_name: Name of TRT-accelerated subgraph module in FX graph - submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) - dtype: Data type of tensor at position in submodule (double/long) - """ - assert dtype in ( - torch.float64, - ), f"dtype argument must be torch.float64, got {dtype}" - - # Determine target data type in 32 and 64 bit forms - dtype_64bit = dtype - dtype_32bit = torch.float32 - - # Find the node representing the submodule in the graph - module_node = None - - # Iterate over all nodes in the graph, seeking target module name match - for n in gm.graph.nodes: - if n.op == "call_module" and str(n.target) == submodule_name: - module_node = n - break - - if module_node is None: - raise AssertionError( - f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}" - ) - - # Extract the 64-bit node of the input - node_64bit = module_node.all_input_nodes[position] - - # Prior to the module, insert a cast to the 32-bit equivalent node - with gm.graph.inserting_before(module_node): - node_32bit = gm.graph.call_function( - torch.ops.aten._to_copy.default, - args=(node_64bit,), - kwargs={"dtype": dtype_32bit}, - ) - - # Replace 64-bit input to TRT module with new 32-bit cast node - module_node.replace_input_with(node_64bit, node_32bit) - - output_positions_64bit = set() - - # Determine if any outputs of the model are 64-bit type and store their indices - if submodule_outputs is not None: - outputs_list = ( - [submodule_outputs] - if isinstance(submodule_outputs, torch.Tensor) - else submodule_outputs - ) - - for output_position, output in enumerate(outputs_list): - if output.dtype == dtype_64bit: - output_positions_64bit.add(output_position) - - # Only enter this code block if there exists a 64-bit output - # This implies a cast is needed, since TRT cannot output 64-bit tensors - if output_positions_64bit: - # Determine whther the outputs of the module are tuple-type or not - is_collection_output = False - if isinstance(submodule_outputs, tuple): - is_collection_output = True - - if not is_collection_output: - # If the output is a single tensor, insert a cast back to int64 - with gm.graph.inserting_after(module_node): - cast_node_64bit = gm.graph.call_function( - torch.ops.aten._to_copy.default, - args=(module_node,), - kwargs={"dtype": dtype_64bit}, - ) - - # Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent - module_node.replace_all_uses_with( - cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit) - ) - - else: - # If the output is a tuple of tensors, extract downstream users for each 64-bit output - get_nodes = _extract_downstream_get_nodes( - module_node, output_positions_64bit - ) - - # For each downstream user, append a cast node back to the 64-bit precision - for get_node in get_nodes: - with gm.graph.inserting_after(get_node): - cast_node_64bit = gm.graph.call_function( - torch.ops.aten._to_copy.default, - args=(get_node,), - kwargs={"dtype": torch.float64}, - ) - - get_node.replace_all_uses_with( - cast_node_64bit, - delete_user_cb=lambda user: (user != cast_node_64bit), - ) - - # Clean up graph and ensure invariants are preserved - gm.graph.eliminate_dead_code() - gm.graph.lint() - gm.recompile() - - -def repair_long_or_double_inputs( - parent_graph: torch.fx.GraphModule, - submodule: torch.fx.GraphModule, - submodule_inputs: Sequence[Input], - device: torch.device, - submodule_name: Optional[str] = None, -) -> Sequence[Input]: - """Fixes all Long/Double type inputs to a TRT-accelerated subgraph - - In-Place modifies the provided graph - - Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, - inserts an upcast back to the 64-bit type for subsequent Torch operations - - Args: - parent_graph: FX GraphModule enclosing the TRT subgraph - submodule: Child submodule to repair inputs on - submodule_inputs: Input tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) - submodule_name: Optionally specify the name of the submodule target in the parent graph - Returns: - New submodule inputs, updated accordingly with long/double truncation - """ - submodule_torch_inputs = get_torch_inputs(submodule_inputs, device) - num_submodule_inputs = len(submodule_inputs) - repaired_outputs_once = False - - # For each input to the TRT subgraph, check if its type is long/double - for position in range(num_submodule_inputs): - param = submodule_torch_inputs[position] - - # If the data type of the input is long/double, insert necessary - # casts to replace the operation - if param.dtype == torch.float64: - # Ensure outputs are only repaired once per submodule to avoid - # unnecessary ops showing up in the graph - if not repaired_outputs_once: - submodule_outputs = submodule(*submodule_torch_inputs) - - _repair_64bit_input( - parent_graph, - position, - submodule_name if submodule_name is not None else submodule._get_name(), - None if repaired_outputs_once else submodule_outputs, - param.dtype, - ) - - repaired_outputs_once = True - - # Repair submodule inputs in accordance with inserted casts - dtype_32bit = torch.float32 - submodule_torch_inputs = ( - list(submodule_torch_inputs[:position]) - + [ - param.to(dtype_32bit), - ] - + list(submodule_torch_inputs[position + 1 :]) - ) - - # Set the 32bit inputs and their types to the submodule Inputs - for idx in range(len(submodule_inputs)): - submodule_inputs[idx].torch_tensor = submodule_torch_inputs[idx] - submodule_inputs[idx].dtype = dtype._from( - submodule_torch_inputs[idx].dtype - ) - - return submodule_inputs diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 506c9a1959..affd3d2286 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -221,7 +221,7 @@ def forward(self, x, y): inputs, min_block_size=1, pass_through_build_failures=True, - truncate_long_and_double=True, + truncate_double=True, debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() @@ -240,16 +240,14 @@ def forward(self, x, y): def test_int64_input_partial_support(self): class PartiallySupportedMultiOp(torch.nn.Module): def forward(self, x, y): - return torch.ops.aten.div.Tensor_mode( - x, torch.ops.aten.add.Tensor(y, y), rounding_mode=None - ) + return torch.ops.aten.abs(torch.ops.aten.add.Tensor(x, y)) fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) unexpected_ops = {torch.ops.aten.add.Tensor} inputs = [ - torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(), - torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(), + torch.randint(-40, 40, (1, 16, 7, 5), dtype=torch.long).cuda(), + torch.randint(1, 40, (1, 16, 7, 5), dtype=torch.long).cuda(), ] ( @@ -265,17 +263,17 @@ def forward(self, x, y): testing_partitioning=True, ) - self.assertEquals( + self.assertEqual( len(unexpected_ops_seen), 0, f"The following unexpected ops were encountered: {unexpected_ops_seen}", ) - self.assertEquals( + self.assertEqual( len(partitioned_graphs), 1, "Without control flow breaks, there should only be a single graph", ) - self.assertEquals( + self.assertEqual( len( [ 1 @@ -296,8 +294,9 @@ def forward(self, x, y): inputs, min_block_size=1, pass_through_build_failures=True, - truncate_long_and_double=True, + truncate_double=False, debug=True, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index 613fc167bb..20c32a3cd2 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -223,7 +223,7 @@ def forward(self, x): inputs, min_block_size=1, pass_through_build_failures=True, - truncate_long_and_double=True, + truncate_double=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index ef034c914f..72c701014b 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -99,8 +99,6 @@ def run_test( if not isinstance(ref, torch.Tensor): ref = torch.tensor([ref]) ref = ref.cpu() # to_dtype test has cases with gpu output - if ref.dtype == torch.int64: - ref = ref.int() # convert torch.max's index output tensor to int32 torch.testing.assert_close( out.cpu(), ref, @@ -238,7 +236,7 @@ def run_test( # We replicate this behavior here compilation_settings = CompilationSettings( enabled_precisions={dtype._from(precision)}, - truncate_long_and_double=True, + truncate_double=True, debug=True, ) @@ -250,7 +248,7 @@ def run_test( mod, input_specs, compilation_settings.device, - truncate_long_and_double=compilation_settings.truncate_long_and_double, + truncate_double=compilation_settings.truncate_double, ) interp = TRTInterpreter( @@ -289,7 +287,7 @@ def run_test_with_dynamic_shape( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here - compilation_settings = CompilationSettings(truncate_long_and_double=True) + compilation_settings = CompilationSettings(truncate_double=True) interp = TRTInterpreter( mod, diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index 84234db857..88260ba771 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -64,7 +64,7 @@ def forward(self, x): precision=torch.float, ) - def test_to_copy_unsupported(self): + def test_to_copy_i64b(self): class ToCopy64Bit(nn.Module): def forward(self, x): y = torch.ops.aten._to_copy.default(x, dtype=torch.int64) @@ -72,11 +72,10 @@ def forward(self, x): inputs = [torch.randn((1, 3, 10)).int()] - with self.assertRaises(UnsupportedOperatorException): - self.run_test( - ToCopy64Bit(), - inputs, - ) + self.run_test( + ToCopy64Bit(), + inputs, + ) def test_to_copy_multiple_returns(self): class ToCopyReturns(nn.Module): diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 533e9d84d3..a0155f1383 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -469,7 +469,7 @@ def forward(self, x): "torch_compile", inputs, min_block_size=1, - truncate_long_and_double=True, + truncate_double=True, pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() diff --git a/tests/py/dynamo/models/conftest.py b/tests/py/dynamo/models/conftest.py index a779a0c5df..0dedfa3d2f 100644 --- a/tests/py/dynamo/models/conftest.py +++ b/tests/py/dynamo/models/conftest.py @@ -1,3 +1,5 @@ +# type: ignore + import pytest @@ -10,10 +12,10 @@ def pytest_addoption(parser): required=False, help="IR to compile with", choices=["dynamo", "torch_compile"], - default="dynamo", ) @pytest.fixture def ir(request): - return request.config.getoption("--ir")[0] + ir_opt = request.config.getoption("--ir") + return ir_opt[0] if ir_opt else "dynamo" diff --git a/tests/py/dynamo/models/test_dtype_support.py b/tests/py/dynamo/models/test_dtype_support.py index 713145cd93..4af933da60 100644 --- a/tests/py/dynamo/models/test_dtype_support.py +++ b/tests/py/dynamo/models/test_dtype_support.py @@ -38,7 +38,7 @@ def forward(self, x): exp_mod, inputs=[in_tensor], pass_through_build_failures=True, - truncate_long_and_double=True, + truncate_double=True, output_format="fx", min_block_size=1, use_python_runtime=False, @@ -77,7 +77,7 @@ def forward(self, x): exp_mod, inputs=[in_tensor], pass_through_build_failures=True, - truncate_long_and_double=True, + truncate_double=True, output_format="fx", min_block_size=1, use_python_runtime=True, @@ -122,7 +122,7 @@ def forward(self, x): exp_mod, inputs=[in_tensor], pass_through_build_failures=True, - truncate_long_and_double=False, + truncate_double=False, output_format="fx", min_block_size=1, use_python_runtime=False, @@ -162,7 +162,7 @@ def forward(self, x): exp_mod, inputs=[in_tensor], pass_through_build_failures=True, - truncate_long_and_double=False, + truncate_double=False, output_format="fx", min_block_size=1, use_python_runtime=True, diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index eef05725be..f2da38d746 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -1,3 +1,5 @@ +# type: ignore + import unittest import pytest diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index b986a4d158..2d45af2b49 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -1,3 +1,5 @@ +# type: ignore + import unittest import pytest @@ -124,7 +126,7 @@ def test_bert_base_uncased(ir): ], "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, - "truncate_long_and_double": True, + "truncate_double": True, "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 32ec1315ff..f16d60c9be 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -127,7 +127,7 @@ def test_bert_base_uncased(ir): ], "device": torchtrt.Device("cuda:0"), "enabled_precisions": {torch.float}, - "truncate_long_and_double": True, + "truncate_double": True, "ir": ir, "min_block_size": 10, } diff --git a/tests/py/dynamo/runtime/test_compilation_settings.py b/tests/py/dynamo/runtime/test_compilation_settings.py index daa67ad032..47f700038a 100644 --- a/tests/py/dynamo/runtime/test_compilation_settings.py +++ b/tests/py/dynamo/runtime/test_compilation_settings.py @@ -36,7 +36,7 @@ def forward(self, x): refit=True, num_avg_timing_iters=5, workspace_size=1 << 10, - truncate_long_and_double=True, + truncate_double=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu()