diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 51b5d899eb..8119f3bac7 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -691,10 +691,13 @@ def acc_ops_layer_norm(network, target, args, kwargs, name): eps_field = trt.PluginField( "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32 ) + normalized_shape = kwargs["normalized_shape"] try: - normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32) + normalized_shape = np.array(normalized_shape, dtype=np.int32) except TypeError: - _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") + _LOGGER.error( + f"Unable to convert normalized_shape with value {normalized_shape} to a field, fall back to []" + ) normalized_shape = np.array([], dtype=np.int32) normalized_shape_filed = trt.PluginField( diff --git a/py/torch_tensorrt/fx/diagnostics.py b/py/torch_tensorrt/fx/diagnostics.py index 0ba2a30652..0d78513a81 100644 --- a/py/torch_tensorrt/fx/diagnostics.py +++ b/py/torch_tensorrt/fx/diagnostics.py @@ -87,12 +87,14 @@ class DiagnosticsWriter: def __init__(self): self._root_dir = tempfile.mkdtemp(prefix="fx2trt.") + self._data = "" _LOGGER.info(f"Initializing DiagnosticsWriter with root_dir: {self._root_dir}") def write(self, file_name: str, data: WriteObj): """ TODO: Can be disabled by regex on file_name """ + self._data = data # Only write if we are inside a collect_when() context. if not _IS_IN_COLLECT_CONTEXT.get(False): return @@ -117,6 +119,9 @@ def write(self, file_name: str, data: WriteObj): def root_dir(self) -> str: return self._root_dir + def data(self) -> WriteObj: + return self._data + def _write(self, file_name: str, to_write: bytes): # ms granularity - no naming collash, otherwise file will be # overwritten. @@ -271,6 +276,9 @@ def collect(self) -> str: finally: os.remove(fp) + def data(self) -> WriteObj: + return self._write.data() + def _res_or_err(data: WriteObj) -> t.Tuple[TWrite, str]: if isinstance(data, (str, bytes)): diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d0a6bdf0a1..846c90bdd5 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -1,4 +1,5 @@ import logging +import os import warnings from datetime import datetime from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence @@ -211,6 +212,11 @@ def run( builder_config = self.builder.create_builder_config() builder_config.max_workspace_size = max_workspace_size + # Speed up TRT build time in the test environment + if trt.__version__ >= "8.6" and os.environ.get("TRT_TEST_ENV", "0") == "1": + _LOGGER.info("Set TRT optimization level to 0") + builder_config.builder_optimization_level = 0 + cache = None if timing_cache: cache_file = numpy.array(timing_cache) diff --git a/py/torch_tensorrt/fx/input_tensor_spec.py b/py/torch_tensorrt/fx/input_tensor_spec.py index 781c11f32c..8128fc1760 100644 --- a/py/torch_tensorrt/fx/input_tensor_spec.py +++ b/py/torch_tensorrt/fx/input_tensor_spec.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Iterable, List, NamedTuple, Optional, Sequence, Tuple import torch @@ -18,6 +18,12 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None): # is the dynamic batch dimension. Otherwise, we use the additional # inputs to determine the batch dimension. if additional_inputs is None: + batch_dims = None + if not isinstance(inputs, torch.Tensor) and len(inputs) > 1: + bs = inputs[0].size(0) + batch_dims = None + if not all(x.size(0) == bs for x in inputs): + batch_dims = InputTensorSpec.find_batch_size_dim(inputs) return InputTensorSpec.from_tensors_with_dynamic_batch_size( inputs, ( @@ -26,6 +32,7 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None): lower_setting.max_batch_size, ), lower_setting.opt_profile_replica, + batch_dims, ) else: batch_dims = [] @@ -147,25 +154,69 @@ def from_tensors_with_dynamic_batch_size( A list of InputTensorSpec named tuples with dynamic ranges. """ if batch_dims is None: - batch_dims = [0] * len(tensors) + batch_dims = cls.find_batch_size_dim(tensors) input_specs = [] batch_size = tensors[0].size(batch_dims[0]) for i, tensor in enumerate(tensors): batch_dim = batch_dims[i] - assert batch_size == tensor.size( - batch_dim - ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." - shape = list(tensor.shape) - shape[batch_dim] = -1 - shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] - input_specs.append( - cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) - ) + if batch_dim == -1: + input_specs.append(cls.from_tensor(tensor)) + else: + shape = list(tensor.shape) + assert batch_size == tensor.size( + batch_dim + ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." + shape[batch_dim] = -1 + shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] + input_specs.append( + cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) + ) return input_specs + @classmethod + # pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any` + def find_batch_size_dim(cls, inputs: Any) -> []: + if isinstance(inputs, torch.Tensor) or len(inputs) <= 1: + return [0] + shapes = [i.shape for i in inputs] + frequency_map = {} + first_dims = set() + for shape in shapes: + if len(shape) < 2: + # By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info + continue + # Dedup shape value for single tensor + first_dims.add(shape[0]) + shape = set(shape) + for i in shape: + frequency_map[i] = frequency_map.get(i, 0) + 1 + + if len(first_dims) == 1: + # first dim is the same in every input: we use it as batch_size + batch_size = first_dims.pop() + elif frequency_map: + # first dims are different: we use the most frequent dim as batch_size + sorted_frequency = sorted(frequency_map.items(), key=lambda x: -x[1]) + batch_size = sorted_frequency[0][0] + else: + # no dims to sort: no batch_size + batch_size = -1 + + bs_dim = [] + for i in inputs: + # Default batch size dim = -1, indicate no batch_size + dim = -1 + for index, val in enumerate(i.shape): + if val == batch_size: + dim = index + break + bs_dim.append(dim) + + return bs_dim + def to_random_tensor(self, id=1): shape = tuple(self.shape) if len(get_dynamic_dims(shape)): diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index f96f1db6b9..6572fe9588 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -41,6 +41,8 @@ def compile( dynamic_batch=True, is_aten=False, use_experimental_fx_rt=False, + correctness_atol=1e-1, + correctness_rtol=1e-1, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module @@ -81,6 +83,8 @@ def compile( dynamic_batch=dynamic_batch, is_aten=is_aten, use_experimental_rt=use_experimental_fx_rt, + correctness_atol=correctness_atol, + correctness_rtol=correctness_rtol, ) lowerer = Lowerer.create(lower_setting=lower_setting) return lowerer(module, input) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index e753d6e227..e98a9371c5 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -54,10 +54,14 @@ def fill_with_mul_zero_and_add(*args): def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule: - # Now we do constant folding on traced module. We want to skip pattern like - # weights -> quant -> dequant -> op during constant folding when the model is - # a quantized int8 model. - def skip_folding_quant_dequant(node: torch.fx.Node): + def skip_folding_ops(node: torch.fx.Node): + # dtype op + if node.target == acc_ops.dtype: + return True + # Now we do constant folding on traced module. We want to skip pattern like + # weights -> quant -> dequant -> op during constant folding when the model is + # a quantized int8 model. + # quant_dequant if node.target != acc_ops.quantize_per_tensor: return False # If quantize_per_node -> dequantize, then skip folding. @@ -66,7 +70,7 @@ def skip_folding_quant_dequant(node: torch.fx.Node): return True return False - const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant) + const_split_mod = split_const_subgraphs(traced_mod, skip_folding_ops) const_split_mod.run_folding() return const_split_mod @@ -630,3 +634,35 @@ def fix_clamp_numerical_limits_to_fp16( mod.recompile() return mod + + +@log_before_after +@validate_inference(atol=1e-3, rtol=1e-2) +def remove_dtype_and_to_pattern( + mod: torch.fx.GraphModule, input: Input +) -> torch.fx.GraphModule: + """ + Remove this pattern since it is unnecessary to cast to dtype + %dtype : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.dtype](args = (), kwargs = {input: %_attention_layers_0__uva}) + %to_18 : [#users=2] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.to_dtype](args = (), kwargs = {input: %x}) + """ + for node in mod.graph.nodes: + if node.op == "call_function" and node.target == acc_ops.dtype: + # find its first user + next_node = next(iter(node.users)) + # acc_op or pt op is treated differently + input = ( + next_node.kwargs["input"] + if "input" in next_node.kwargs + else next_node.args[0] + ) + if len(node.users) == 1 and ( + next_node.target == acc_ops.to_dtype or next_node.target == "to" + ): + next_node.replace_all_uses_with(input) + mod.graph.erase_node(next_node) + mod.graph.erase_node(node) + + mod.graph.eliminate_dead_code() + mod.recompile() + return mod diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 61052b21af..6e6b40d42f 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -8,6 +8,7 @@ from torch.fx.passes.pass_manager import inplace_wrapper, PassManager from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult +from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion from torch_tensorrt.fx.utils import LowerPrecision from ..input_tensor_spec import generate_input_specs @@ -229,10 +230,9 @@ def lower_func(split_result: SplitResult) -> nn.Module: submod = getattr(split_result.split_module, submod_name) LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) - # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - _LOGGER.info(f"Now lowering submodule {submod_name}") + _LOGGER.info(f"ACC submodule graph: {submod.graph}") lowering_start_time = datetime.datetime.now() self.lower_setting.additional_inputs = ( @@ -251,6 +251,9 @@ def lower_func(split_result: SplitResult) -> nn.Module: _LOGGER.info( f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" ) + else: + _LOGGER.info(f"GPU submodule graph: {submod.graph}") + apply_bfloat_float_conversion(submod, submod_inputs, submod_name) return split_result.split_module diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index fabc92881d..0b8578ffba 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -1,12 +1,17 @@ +import contextlib import io +import json import logging import tempfile from datetime import datetime from functools import wraps +from traceback import TracebackException from typing import Any, Callable, List, Optional import torch +import torch_tensorrt.fx.diagnostics as diagnostics from torch import fx +from torch.fx.node import Node from torch.fx.passes.shape_prop import ShapeProp # Create an alias for module input type to avoid littering pyre-ignore for Any @@ -20,6 +25,11 @@ FINAL_CHECK_ATOL_MULTIPLIER: float = 10 FINAL_CHECK_RTOL_MULTIPLIER: float = 10 +# A global override of the alternative batch size used in validate_variable_batch_sizes +ALTERNATIVE_BATCH_SIZE_OVERRIDE: Optional[int] = None +# If exception during validate_variable_batch_sizes should be thrown +ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW: bool = False + class RelaxAccuracyCheckMode: """ @@ -83,6 +93,46 @@ def __exit__(self, type, value, traceback): ) +@contextlib.contextmanager +def override_alternative_batch_size(alternative_batch_size: int = -1): + """ + A context manager to override alternative_batch_size + + Example: + + >>> # disables run_alternative_batch_size verification + >>> with override_alternative_batch_size(-1): + >>> fx2ait() + """ + + global ALTERNATIVE_BATCH_SIZE_OVERRIDE + old_value = ALTERNATIVE_BATCH_SIZE_OVERRIDE + ALTERNATIVE_BATCH_SIZE_OVERRIDE = alternative_batch_size + _LOGGER.info(f"Override {ALTERNATIVE_BATCH_SIZE_OVERRIDE=} ({old_value=})") + try: + yield + finally: + ALTERNATIVE_BATCH_SIZE_OVERRIDE = old_value + _LOGGER.info(f"Restored old value: {ALTERNATIVE_BATCH_SIZE_OVERRIDE=})") + + +@contextlib.contextmanager +def override_alternative_batch_size_exception_should_throw( + exception_should_throw: bool, +): + """ + A context manager to set if exception during alternative batch size verification + should be thrown. + """ + global ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW + old_value = ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW + ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW = exception_should_throw + try: + yield + finally: + ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW = old_value + + def chain_passes(*passes: PassFunc) -> PassFunc: """ Chains a sequence of pass functions to form a single pass function @@ -100,11 +150,28 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. -def validate_inference(rtol=None, atol=None): +def validate_inference( + rtol=None, atol=None, run_alternative_batch_size: int = -1 +) -> "Decorator": + """ + Returns a decorator on a PassFunc to sanity check the model outputs + difference before/after the transformation is within tolerance. + + Args: + rtol: reletive tolerance + atol: absoluate tolerance + run_alternative_batch_size (int): + In addition to running inference at original batch size in the + input, also run at an alternative batch size. If set to -1, do not + run at alternative batch size. It must be smaller than the original + batch size. This is useful to check the model can run at different + batch sizes. Usually we can set this to 1. + """ + def _validate_inference(pass_: PassFunc) -> PassFunc: """ - Wraps a pass function to validate that its inference results before and - after the pass run should be `close`. + A decorator to wrap a pass function to validate that its inference + results before and after the pass run should be `close`. """ @wraps(pass_) @@ -162,6 +229,120 @@ def pass_with_validation( return _validate_inference +def validate_variable_batch_sizes(run_alternative_batch_size: int = -1) -> "Decorator": + """ + Returns a decorator on a PassFunc to verify the model can run with + different batch sizes before/after the transformation is within tolerance. + + Args: + run_alternative_batch_size (int): + In addition to running inference at original batch size in the + input, also run at an alternative batch size. If set to -1, do not + run at alternative batch size. It must be smaller than the original + batch size. This is useful to check the model can run at different + batch sizes. Usually we can set this to 1. + + If the global variable `ALTERNATIVE_BATCH_SIZE_OVERRIDE` is set, it + overrides `run_alternative_batch_size`. + `ALTERNATIVE_BATCH_SIZE_OVERRIDE` can be set via: + + with override_alternative_batch_size(...): ... + """ + + def _run_alternative_batch_size(pass_: PassFunc) -> PassFunc: + """ + A decorator for PassFunc to check that the model (both before and after + the transformation by pass func) can run at alternative batch size. + """ + + @wraps(pass_) + def pass_with_validation( + module: fx.GraphModule, + input: Input, + *args, + **kwargs, + ) -> fx.GraphModule: + _run_alternative_batch_size = ( + ALTERNATIVE_BATCH_SIZE_OVERRIDE + if ALTERNATIVE_BATCH_SIZE_OVERRIDE is not None + else run_alternative_batch_size + ) + + if _run_alternative_batch_size < 0: + return pass_(module, input, *args, **kwargs) + + if not isinstance(input, (list, tuple)): + _LOGGER.info( + f"Skip run_alternative_batch_size: input must be list, tuple. Actual: {type(input)}" + ) + return pass_(module, input, *args, **kwargs) + + if not all(isinstance(x, torch.Tensor) for x in input): + _LOGGER.info( + "Skip run_alternative_batch_size: input elements must all be tensors" + ) + return pass_(module, input, *args, **kwargs) + + if not all(len(x.shape) > 0 for x in input): + _LOGGER.info( + "Skip run_alternative_batch_size: some input tensor(s) are scalar" + ) + return pass_(module, input, *args, **kwargs) + + batch_size_candidates = {x.shape[0] for x in input} + if len(batch_size_candidates) > 1: + _LOGGER.info( + f"Skip run_alternative_batch_size: input tensors' first dim must be the same, actual: {batch_size_candidates}" + ) + return pass_(module, input, *args, **kwargs) + + batch_size = next(iter(batch_size_candidates)) + assert ( + _run_alternative_batch_size <= batch_size + ), f"{_run_alternative_batch_size=} must be smaller or equal to {batch_size=}" + + input_alt_bs = [x[:_run_alternative_batch_size, ...] for x in input] + + def run_module(mod, stage: str): + """Run module with full bs and alternative bs""" + _LOGGER.info( + f"Running {stage} model at alternative batch size: {_run_alternative_batch_size}" + ) + try: + mod(*input) + mod(*input_alt_bs) + except Exception as e: + _LOGGER.warning( + f"Failed running {stage} module at full or alternative batch size: {e}" + ) + diagnostics.write( + "lowering_diagnostics", + json.dumps( + { + "validate_variable_batch_sizes_exception": repr(e), + "validate_variable_batch_sizes_exception_type": type( + e + ).__name__, + "validate_variable_batch_sizes_exception_traceback": "".join( + TracebackException.from_exception(e).format() + ), + } + ), + ) + if ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW: + raise + + run_module(module, "original") + module_after = pass_(module, input, *args, **kwargs) + run_module(module_after, "transformed") + + return module_after + + return pass_with_validation + + return _run_alternative_batch_size + + Decorator = Callable[[Callable], Callable] @@ -269,3 +450,67 @@ def collect(x: fx.node.Argument) -> fx.node.Argument: fx.node.map_aggregate(arg, collect) return res + + +class InputOutputDtypeInferInterpreter(torch.fx.Interpreter): + """ + Interprete a graph to propagate the output tensor dtype from its inputs, extracing + input and output graph node that need dtype cast to float32/bfloat16. + """ + + def __init__(self, module: torch.fx.GraphModule): + super().__init__(module) + self.need_cast_to_float32 = [] + self.need_cast_to_bfloat = [] + + def _need_cast(self, node: Node, run_result) -> None: + if node.op == "placeholder" and ( + run_result.dtype not in (torch.int32, torch.int64) + ): + _LOGGER.info( + f"Encountered node: {node.format_node()} need dtype cast to float32." + ) + self.need_cast_to_float32.append(node) + # Process node that will be used as final output + elif "output" in set(i.name for i in node.users.keys()): + if run_result.dtype not in (torch.int32, torch.int64): + _LOGGER.info( + f"Encountered node: {node.format_node()} need dtype cast to bfloat16." + ) + self.need_cast_to_bfloat.append(node) + + def run_node(self, n: Node) -> Any: + run_result = super().run_node(n) + + if torch.is_tensor(run_result): + n.meta["tensor_dtype"] = run_result.dtype + self._need_cast(n, run_result) + return run_result + + +def apply_bfloat_float_conversion( + gm: torch.fx.GraphModule, inputs: Any, name: str +) -> None: + _LOGGER.info("Apply bfloat-float32 conversion on {name}") + interpreter = InputOutputDtypeInferInterpreter(gm) + interpreter.run(*inputs) + + def to_bfloat(x): + return x.to(torch.bfloat16) + + def to_float(x): + return x.to(torch.float32) + + for node in interpreter.need_cast_to_float32: + with gm.graph.inserting_after(node): + cast = gm.graph.call_function( + to_float, + (node,), + {}, + ) + node.replace_all_uses_with(cast) + + for node in interpreter.need_cast_to_bfloat: + with gm.graph.inserting_after(node): + cast = gm.graph.call_function(to_bfloat, (node,), {}) + node.replace_all_uses_with(cast) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py index cfeb235af3..55bd7b1e8b 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py @@ -9,7 +9,7 @@ class TestCatConverter(DispatchTestCase): @parameterized.expand( [ ("pos", 1), - # ("neg", -2), #Dynamo tracer issue + # ("neg", -2), #dim can not have dynamic input ] ) def test_cat(self, _, dim): @@ -27,7 +27,7 @@ def forward(self, x, y, z): @parameterized.expand( [ ("pos", 1), - # ("neg", -2), #Dynamo tracer issue + # ("neg", -2), #dim can not have dynamic input ] ) def test_cat_dynamic_shape(self, _, dim): diff --git a/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py b/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py index a604f4b75a..384d55d44e 100644 --- a/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py +++ b/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py @@ -81,8 +81,7 @@ def forward(self, x): ("tuple_parameters", 1, (1, 1, 1), (0, 0, 0)), param("non_zero_padding", 1, padding=1), param("dilation", 1, dilation=2), - # TODO: Enable this when TRT fixes https://github.com/pytorch/TensorRT/issues/1445 - # param("groups", 1, groups=3), + param("groups", 1, groups=3), ] ) def test_conv3d( diff --git a/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py index db848eaf1c..0443278460 100644 --- a/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py @@ -47,6 +47,23 @@ def test_from_tensors_with_dynamic_batch_size(self): self.assertEqual(batch_size, shape[0]) self.assertSequenceEqual(tensor.shape[1:], shape[1:]) + def test_from_tensors_with_dynamic_batch_size_no_bs_input(self): + tensors = [torch.randn(1, 2, 3), torch.randn(1, 4), torch.randn(72, 16)] + batch_size_range = [2, 3, 4] + specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( + tensors, batch_size_range + ) + for index, (spec, tensor) in enumerate(zip(specs, tensors)): + if index == 2: + for a, b in zip(spec.shape, tensor.shape): + self.assertEqual(a, b) + else: + self._validate_spec(spec, tensor, dynamic_dims=[0]) + + for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]): + self.assertEqual(batch_size, shape[0]) + self.assertSequenceEqual(tensor.shape[1:], shape[1:]) + def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self): tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)] batch_size_range = [2, 3, 4] @@ -88,6 +105,22 @@ def test_generate_input_specs(self): self._validate_spec(spec, tensor, dynamic_dims=[1]) self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) + # Explicit batch dim with inputs w/ different batch dims. + bs = 10 + inputs = [ + torch.randn(bs, 1, 2), + torch.randn(bs, 10, 3), + torch.randn(4, bs, 5), + torch.randn(bs, 2, 5), + ] + specs = generate_input_specs(inputs, lower_setting) + for idx, (spec, tensor) in enumerate(zip(specs, inputs)): + if idx == 2: + self._validate_spec(spec, tensor, dynamic_dims=[1]) + else: + self._validate_spec(spec, tensor, dynamic_dims=[0]) + self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py index 36420375f8..4edc2ef706 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py @@ -76,6 +76,8 @@ def forward(self, x): inputs, {trt_transposed_linear}, apply_passes=[fuse_permute_linear], + rtol=5e-3, + atol=2e-3, ) diff --git a/py/torch_tensorrt/fx/test/passes/test_pass_utils.py b/py/torch_tensorrt/fx/test/passes/test_pass_utils.py new file mode 100644 index 0000000000..6f5edde004 --- /dev/null +++ b/py/torch_tensorrt/fx/test/passes/test_pass_utils.py @@ -0,0 +1,97 @@ +import logging +import unittest +from typing import Optional + +import torch +import torch_tensorrt.fx.diagnostics as diagnostics +from torch_tensorrt.fx.passes.pass_utils import ( + override_alternative_batch_size, + override_alternative_batch_size_exception_should_throw, + validate_variable_batch_sizes, +) + +diagnostics.set_current_collector( + diagnostics.ZipDiagnosticsCollector(writer=diagnostics.get_current_writer()) +) + + +_LOGGER: logging.Logger = logging.getLogger(__name__) +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) # configure root logger + + +class BatchSizeError(Exception): + pass + + +class PassUtilsTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + + def test_run_alternative_batch_size(self): + class TestModule(torch.nn.Module): + should_fail_at_bs: Optional[int] = None + + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + if x.shape[0] == self.should_fail_at_bs: + raise BatchSizeError(self.should_fail_at_bs) + + return x + y + z + + def gen_input(bs: int): + return [ + torch.rand(bs, 64), + torch.rand(bs, 64), + torch.rand(bs, 64), + ] + + @validate_variable_batch_sizes(1) + def model_transform_pass_good(model, input): + """ + This is a good transformation. Meaning that the model it + produces will not fail at any batch sizes + """ + model.should_fail_at_bs = None + return model + + @validate_variable_batch_sizes(1) + def model_transform_pass_bad(model, input): + """ + This is a bad transformation. Meaning that the model it produces + will fail when the given input batch size is 1 + """ + model.should_fail_at_bs = 1 + return model + + model = TestModule() + input = gen_input(bs=10) + + with diagnostics.collect_when(diagnostics.CollectionConditions.always()): + + with override_alternative_batch_size_exception_should_throw(True): + # This should succeed: the validate_inference decorator will + # run both bs=10 and bs=1 successfully + model_transform_pass_good(model, input) + + # This should fail: the validate_inference decorator will run the + # model (post transform) at bs=1. + model.should_fail_at_bs = None # reset + self.assertRaises( + BatchSizeError, lambda: model_transform_pass_bad(model, input) + ) + + # Test override_alternative_batch_size can disable run alt bs: + # This should success: the validate_inference decorator will + # NOT run alternative batch size, because it is disabled via + # override_alternative_batch_size. + model.should_fail_at_bs = None # reset + with override_alternative_batch_size(alternative_batch_size=-1): + model_transform_pass_bad(model, input) + + # Test that by default alt bs failures won't cause exception + # thrown, because of no + # `override_alternative_batch_size_exception_should_throw` + model_transform_pass_bad(model, input) diff --git a/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py b/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py index 8f9c1a887f..cb7ff8f906 100644 --- a/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py @@ -1,7 +1,6 @@ import torch import torch._dynamo as torchdynamo from parameterized import parameterized -from torch._dynamo.optimizations import backends from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase @@ -507,93 +506,93 @@ def transform_fx(gm, example_inputs): optimize_mod(*inputs) # test with torchdynamo - def test_setitem1d_trt(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[1] = x - return y - - inputs = [torch.randn(1), torch.randn(3)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - ref_output = m(*inputs) - - optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) - - output = optimize_mod(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) - - @parameterized.expand( - [ - ("c1", (4, 2), (4, 5), 0, 2), - ("c2", (4, 2), (4, 5), 1, 3), - ] - ) - def test_setitem2d_1v_trt(self, name, x_shape, y_shape, y_start, y_end): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, y_start:y_end] = x - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - ref_output = m(*inputs) - optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) - output = optimize_mod(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) - - @parameterized.expand( - [ - ("c1", (2, 3, 4, 5), (4, 5, 6, 7), 0, 2, 0, 3, 0, 4, 0, 5), - ("c2", (2, 3, 4, 5), (4, 5, 6, 7), 1, 3, 1, 4, 1, 5, 1, 6), - ] - ) - def test_setitem4d_4v_trt( - self, - name, - x_shape, - y_shape, - start_0, - end_0, - start_1, - end_1, - start_2, - end_2, - start_3, - end_3, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] = x - y = y + 3 - x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] - return x - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - ref_output = m(*inputs) - optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) - output = optimize_mod(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) + # def test_setitem1d_trt(self): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # def forward(self, x, y): + # y[1] = x + # return y + + # inputs = [torch.randn(1), torch.randn(3)] + # m = TestModule() + + # inputs = [i.cuda() for i in inputs] + # m.cuda() + # ref_output = m(*inputs) + + # optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + + # output = optimize_mod(*inputs) + # self.assertTrue(torch.allclose(ref_output, output)) + + # @parameterized.expand( + # [ + # ("c1", (4, 2), (4, 5), 0, 2), + # ("c2", (4, 2), (4, 5), 1, 3), + # ] + # ) + # def test_setitem2d_1v_trt(self, name, x_shape, y_shape, y_start, y_end): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # def forward(self, x, y): + # y[:, y_start:y_end] = x + # return y + + # inputs = [torch.randn(x_shape), torch.randn(y_shape)] + # m = TestModule() + + # inputs = [i.cuda() for i in inputs] + # m.cuda() + + # ref_output = m(*inputs) + # optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + # output = optimize_mod(*inputs) + # self.assertTrue(torch.allclose(ref_output, output)) + + # @parameterized.expand( + # [ + # ("c1", (2, 3, 4, 5), (4, 5, 6, 7), 0, 2, 0, 3, 0, 4, 0, 5), + # ("c2", (2, 3, 4, 5), (4, 5, 6, 7), 1, 3, 1, 4, 1, 5, 1, 6), + # ] + # ) + # def test_setitem4d_4v_trt( + # self, + # name, + # x_shape, + # y_shape, + # start_0, + # end_0, + # start_1, + # end_1, + # start_2, + # end_2, + # start_3, + # end_3, + # ): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # def forward(self, x, y): + # y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] = x + # y = y + 3 + # x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] + # return x + + # inputs = [torch.randn(x_shape), torch.randn(y_shape)] + # m = TestModule() + + # inputs = [i.cuda() for i in inputs] + # m.cuda() + + # ref_output = m(*inputs) + # optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + # output = optimize_mod(*inputs) + # self.assertTrue(torch.allclose(ref_output, output)) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index 633359127f..74715d6030 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -1494,6 +1494,21 @@ def test_dropout(self): lambda x: nn.functional.dropout(x, training=False), input_shape=(1, 2, 3), ) + self._make_acc_op_function_test( + None, + lambda x: nn.functional.dropout1d(x, training=False), + input_shape=(4, 2, 3), + ) + self._make_acc_op_function_test( + None, + lambda x: nn.functional.dropout2d(x, training=False), + input_shape=(4, 2, 3), + ) + self._make_acc_op_function_test( + None, + lambda x: nn.functional.dropout3d(x, training=False), + input_shape=(4, 2, 3), + ) def test_stochastic_depth(self): self._make_acc_op_function_test( @@ -1727,6 +1742,11 @@ def test_ceil(self): def test_softmax(self): self._make_acc_op_function_test(acc_ops.softmax, torch.nn.functional.softmax) + def test_normalize(self): + self._make_acc_op_function_test( + acc_ops.normalize, torch.nn.functional.normalize + ) + def test_tensor_squeeze(self): self._make_acc_op_function_test(acc_ops.squeeze, lambda x: x.squeeze()) @@ -2628,6 +2648,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertIsNotNone(getitem) self.assertTrue(torch.equal(m(x), traced(x))) + def test_skip_normalization_if_none_repeat_interleave(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + repeats = y[0] + return torch.repeat_interleave(x, repeats, 1) + + # TODO: finish test later + m = TestModule() + inputs = (torch.randn(3, 4), torch.tensor([1])) + traced = acc_tracer.trace(m, inputs) + # Make sure repeat_interleave wasn't mapped into tiles + self.assertTrue("torch.repeat_interleave" in str(traced.graph)) + self.assertFalse("tile" in str(traced.graph)) + + def test_skip_normalization_if_none_repeat(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + repeats = [y[0], y[2], 3] + return x.repeat(repeats) + + # TODO: finish test later + m = TestModule() + inputs = (torch.randn(3, 4, 5), torch.tensor([1, 2, 3])) + traced = acc_tracer.trace(m, inputs) + # Make sure repeat wasn't mapped into tiles + self.assertTrue("repeat" in str(traced.graph)) + self.assertFalse("tile" in str(traced.graph)) + def test_acc_normalization_block_list(self): class TestModule(nn.Module): def forward(self, x: List[torch.Tensor]) -> torch.Tensor: @@ -2668,6 +2722,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertTrue(torch.equal(m(*sample_inputs), traced(*sample_inputs))) + def test_threshold_bwd(self): + class TestModule(nn.Module): + def __init__(self, threshold): + super().__init__() + self._threshold = threshold + + def forward(self, grad: torch.Tensor, input: torch.Tensor) -> torch.Tensor: + return torch.ops.aten.threshold_backward.default( + grad, input, self._threshold + ) + + m = TestModule(0.0) + grad = torch.randn(4096) + sample_inputs = torch.randn(4096) + traced = acc_tracer.trace(m, [grad, sample_inputs]) + + output = None + for node in traced.graph.nodes: + if node.op == "output": + assert output is None + output = node + + ref = m(grad, sample_inputs) + res = traced(grad, sample_inputs) + self.assertTrue(torch.equal(ref, res)) + def test_all_acc_ops_registered(self): self.assertEqual( acc_normalizer._acc_ops, @@ -2689,6 +2769,7 @@ def test_all_acc_ops_registered(self): acc_ops.minimum, acc_ops.cat, acc_ops.softmax, + acc_ops.normalize, acc_ops.sign, acc_ops.permute, acc_ops.matmul, @@ -2713,6 +2794,8 @@ def test_all_acc_ops_registered(self): acc_ops.tuple_construct, acc_ops.unsqueeze, acc_ops.sigmoid, + acc_ops.sigmoid_backward, + acc_ops.threshold_backward, acc_ops.sum, acc_ops.prod, acc_ops.max_full_reduce, @@ -2726,6 +2809,7 @@ def test_all_acc_ops_registered(self): acc_ops.atan, acc_ops.exp, acc_ops.log, + acc_ops.log_softmax, acc_ops.sqrt, acc_ops.reciprocal, acc_ops.abs, @@ -2797,5 +2881,16 @@ def test_all_acc_ops_registered(self): acc_ops.var, acc_ops.grid_sample, acc_ops.xl_weight, + acc_ops.clone, + acc_ops.unbind, + acc_ops.group_norm, + acc_ops.long, + acc_ops.full_like, + acc_ops.new_full, + acc_ops.ones_like, + acc_ops.zeros_like, + acc_ops.new_zeros, + acc_ops.index_add, + acc_ops.masked_select, }, ) diff --git a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py index e160626cf2..b5db157663 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py @@ -7,17 +7,12 @@ import torch._dynamo.config import torchvision from functorch.experimental import functionalize -from torch._dynamo.optimizations import backends -from torch._dynamo.optimizations.normalize import normalize_ir from torch.library import Library from torch_tensorrt.fx.lower import compile from torch_tensorrt.fx.tracer.dispatch_tracer.tracer import make_fx from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace -# TODO(ezyang): remove this after we properly support fake example inputs -torch._dynamo.config.DO_NOT_USE_legacy_non_fake_example_inputs = True - torch.manual_seed(0) wrap_lib = Library("wrap", "DEF") @@ -65,19 +60,93 @@ def forward(self, x, y): ref_output = mod(*inputs_new) torch.testing.assert_close(output, ref_output) - def test_resnet18_dynamo(self): + def test_simple(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x, y): + y = y + x + y = y.mul(x) + y = y + x + y = y + x + y = y / x + y = y + x + y = y + x + y = y / x + y = y + x + y = self.relu(y) + return y + + mod = TestModule() + mod = mod.cuda().half().eval() + + def f(x, y): + return mod(x, y) + + inputs = [torch.randn(2, 5), torch.ones(2, 5)] + inputs = [i.cuda().half() for i in inputs] + ref_output = f(*inputs) + + mod = compile( + mod, + inputs, + max_batch_size=100, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=True, + is_aten=True, + ) + output = mod(*inputs) + torch.testing.assert_close(output, ref_output) + + def test_resnet18_aten(self): mod = torchvision.models.resnet18() mod = mod.cuda().half().eval() inputs = [torch.ones(32, 3, 224, 224)] inputs = [i.cuda().half() for i in inputs] - ref_output = mod(*inputs) - torchdynamo.reset() - dynamo_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod) - dynamo_output = dynamo_mod(*inputs) + aten_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=False, + is_aten=True, + ) + aten_output = aten_mod(*inputs) + fx_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=False, + is_aten=False, + ) + fx_output = fx_mod(*inputs) + # Kernel selection is tricky in TRT with big variance as shown below: + # Mismatched elements: 30816 / 32000 (96.3%) + # Greatest absolute difference: 0.05859375 at index (0, 499) (up to 1e-05 allowed) + # Greatest relative difference: 3.293713681986265 at index (0, 142) (up to 0.001 allowed) + # so we choose to use cosine similarity cos_val = torch.nn.functional.cosine_similarity( - dynamo_output.flatten(), ref_output.flatten(), dim=0, eps=1e-4 + aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 ) self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) @@ -224,8 +293,6 @@ def f(x, y): ref_output = f(*inputs) def compile_dispatch(gm, example_inputs): - # after normalization, relu in-place is removed - gm = normalize_ir(gm, example_inputs) # dispatch tracer nargs = len(example_inputs) diff --git a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py index 916394e944..96584c59bd 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py +++ b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py @@ -10,7 +10,11 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.fx.passes import splitter_base from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting +from torch_tensorrt.fx.tools.trt_splitter import ( + create_trt_operator_support, + TRTSplitter, + TRTSplitterSetting, +) from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer ERROR_MSG_NO_ACC_MODULE = "FX split failed: Did not find any ACC submodule!" @@ -625,6 +629,50 @@ def test_splitter(splitter): test_splitter(splitter) + def test_decline_if_input_dtype(self): + operator_support = create_trt_operator_support() + + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.relu(a) + return b + + test_mod = TestModule().cuda().eval() + x = torch.randn(2, 3) + mod = acc_tracer.trace(test_mod, [x]) + settings = TRTSplitterSetting() + settings.min_acc_module_size = 0 + # nodes w/ float16 input should be lowered + splitter = TRTSplitter( + mod, + (x.half().cuda(),), + operator_support, + settings, + ) + split_results_half = splitter.generate_split_results() + self.assertTrue(len(split_results_half), 1) + self.assertEqual( + dict(split_results_half.split_module.named_children()).keys(), + {"_run_on_acc_0"}, + ) + + # nodes w/ float64 input should not be lowered + mod = acc_tracer.trace(test_mod, [x]) + splitter = TRTSplitter( + mod, + (x.double().cuda(),), + operator_support, + settings, + ) + + split_results_double = splitter.generate_split_results() + + self.assertTrue(len(split_results_double), 1) + self.assertEqual( + dict(split_results_double.split_module.named_children()).keys(), + {"_run_on_gpu_0"}, + ) + class TestSplitComplexGraph(TestCase): """ diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index 30d6dc96c9..6d883a4f62 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -3,6 +3,8 @@ import unittest from typing import Callable, List, Optional, Set, Tuple +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt import torch import torch.fx @@ -257,6 +259,8 @@ def run_test( pass_tracer = chain_passes(*apply_passes) mod = pass_tracer(mod, inputs) + if trt.__version__ >= "8.6": + test_implicit_batch_dim = False if test_implicit_batch_dim: interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test( diff --git a/py/torch_tensorrt/fx/tools/model_packager.py b/py/torch_tensorrt/fx/tools/model_packager.py index 0ef0ff05a4..b86c21e809 100644 --- a/py/torch_tensorrt/fx/tools/model_packager.py +++ b/py/torch_tensorrt/fx/tools/model_packager.py @@ -51,12 +51,34 @@ def generate_standalone_repro( "", "import torch", "from torch import nn", + ] + code = str(model.code) + + import_modules = set() + import_map = { + "torch_tensorrt_fx_tracer_acc_tracer_acc_ops": "torch_tensorrt.fx.tracer.acc_tracer.acc_ops", + "torch_tensorrt_fx_passes_lower_basic_pass": "torch_tensorrt.fx.passes.lower_basic_pass", + } + for line in code.split("\n"): + for k, v in import_map.items(): + if k in line: + sub_string = line.split("(")[0].split()[-1] + if sub_string.startswith(k): + mod = sub_string.replace(k + "_", "") + import_modules.add( + "from " + v + " import " + mod + " as " + sub_string + ) + for mod in sorted(import_modules): + lines.append(mod) + + lines += [ "", "", "class ExportedModule(nn.Module):", f"{INDENT}def __init__(self):", f"{INDENT * 2}super().__init__()", ] + for k, v in model._holder.named_parameters(): shape = ", ".join([str(i) for i in v.shape]) rand_func = "randn" if torch.is_floating_point(v) else "randint" @@ -64,7 +86,6 @@ def generate_standalone_repro( lines.append( f"{INDENT * 2}self.{k} = nn.Parameter(torch.{rand_func}({int_range}{shape}, dtype={v.dtype}))" ) - code = str(model.code) def dump(f): f.write(prelude) diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index bea925453f..aa3d930bfb 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -34,8 +34,9 @@ def create_trt_operator_support( return ops.chain( ops.OpSupports.decline_if_node_in_names(exclude_support_node_name), - # 1. Node is not supported if it has args with int64 dtype: + # 1. Node is not supported if it has args with int64 or float64 dtype: ops.OpSupports.decline_if_input_dtype(torch.int64), + ops.OpSupports.decline_if_input_dtype(torch.float64), # 2. Node is supported if it has TRT converter: supported_if_converter_registered, ) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py index 55cb39d4a5..1271b6f30c 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py @@ -69,6 +69,7 @@ class NormalizationInfo(NamedTuple): List[Union[Tuple[str, str, bool], Tuple[str, str]]] ] needs_shapes_for_normalization: bool + skip_normalization_if_none: bool # Dict from (op, target) to NormalizationInfo for that op. @@ -88,6 +89,7 @@ def _insert_fun( ] = None, needs_shapes_for_normalization=False, allow_normalize_from_torch_package=False, + skip_normalization_if_none=False, ): if op_and_target[0] == "call_function": assert callable(op_and_target[1]) @@ -129,6 +131,7 @@ def _insert_fun( custom_mapping_fn=custom_mapping_fn, kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty, needs_shapes_for_normalization=needs_shapes_for_normalization, + skip_normalization_if_none=skip_normalization_if_none, ) _normalization_dict[op_and_target] = norm_info @@ -217,6 +220,7 @@ def register_custom_acc_mapper_fn( ], needs_shapes_for_normalization=False, allow_normalize_from_torch_package=False, + skip_normalization_if_none=False, ): def insert(custom_mapping_fn: Callable): _insert_fun( @@ -225,6 +229,7 @@ def insert(custom_mapping_fn: Callable): arg_replacement_tuples=arg_replacement_tuples, # type: ignore[arg-type] needs_shapes_for_normalization=needs_shapes_for_normalization, allow_normalize_from_torch_package=allow_normalize_from_torch_package, + skip_normalization_if_none=skip_normalization_if_none, ) return custom_mapping_fn @@ -363,12 +368,18 @@ def normalize_to_acc_op( if normalization_info.custom_mapping_fn is not None: # For custom mapping, the normalized_kwargs are used for the original op, # i.e. *before* custom acc_ops normalization. Do that now. + if normalization_info.skip_normalization_if_none: + original_args = node.args + original_kwargs = node.kwargs node.args = normalized_args node.kwargs = normalized_kwargs new_node = normalization_info.custom_mapping_fn(node, mod) # If a new node is returned then use it to replace the old node. Otherwise # the custom mapping function did its own replacement, so return early. if new_node is None: + if normalization_info.skip_normalization_if_none: + node.args = original_args + node.kwargs = original_kwargs return else: # If there's kwargs_to_move_to_acc_out_ty then use it to setup acc_out_ty in diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 8abad9c509..1ed25d66f1 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -1,9 +1,10 @@ # encoding: utf-8 +import logging import operator import warnings import torch # isort:skip -from typing import cast, Iterable, List, Sequence +from typing import cast, Iterable, List, Optional, Sequence import torch.nn as nn from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata @@ -16,6 +17,8 @@ ) from .acc_op_properties import AccOpProperty, register_acc_op_properties +logger: logging.Logger = logging.getLogger(__name__) + this_arg_is_optional = True move_to_qparams = True dont_move_to_qparams = False @@ -161,6 +164,18 @@ def max_pool3d( ) +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.normalize)) +@register_acc_op +def normalize(*, input, p, dim, eps, out): + return nn.functional.normalize( + input=input, + p=p, + dim=dim, + eps=eps, + out=out, + ) + + @register_acc_op_mapping( op_and_target=("call_function", nn.functional.adaptive_avg_pool2d) ) @@ -364,9 +379,10 @@ def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: getitem_node.meta = node.meta.copy() return getitem_node - assert ( - input_obj_type == torch.Tensor - ), f"Expected torch.Tensor type for {input_obj_type}" + assert input_obj_type in [ + torch.Tensor, + torch.nn.parameter.Parameter, + ], f"Expected torch.Tensor type for {input_obj_type}" assert ( attr_name == "shape" or attr_name == "device" or attr_name == "dtype" ), f"Only supporting shape, device and dtype getattr for now, not {attr_name}" @@ -417,7 +433,10 @@ def tensor_size_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: @register_acc_op_mapping(op_and_target=("call_method", "add")) @register_acc_op def add(*, input, other): - return input + other + if not (isinstance(input, torch.Tensor) or isinstance(other, torch.Tensor)): + return operator.add(input, other) + else: + return input + other @register_acc_op_properties(AccOpProperty.unary) @@ -442,14 +461,27 @@ def tile(*, input, dims): ("input", "input"), ("*", "sizes"), ], + skip_normalization_if_none=True, ) -def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: +def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> Optional[torch.fx.Node]: """ Map repeat to tile. """ with node.graph.inserting_before(node): inputs = node.kwargs["input"] dims = node.kwargs["sizes"] + # Skip repeat mapping when the list of dims is not all ints (ie. contains + # some calculated value). torch.tile cannot support cases where dims + # are Proxy nodes + if ( + isinstance(dims, (list, tuple)) + and len(dims) > 0 + and not all(isinstance(x, int) for x in dims) + ): + logger.info( + "Not mapping repeat to an acc op. We can't handle variable dims." + ) + return new_node = node.graph.create_node( "call_function", tile, @@ -468,6 +500,7 @@ def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: ("dim", "dim", this_arg_is_optional), ("output_size", "output_size", this_arg_is_optional), ], + skip_normalization_if_none=True, ) @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.repeat_interleave), @@ -477,14 +510,17 @@ def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: ("dim", "dim", this_arg_is_optional), ("output_size", "output_size", this_arg_is_optional), ], + skip_normalization_if_none=True, ) def repeat_interleave_mapper(node: torch.fx.Node, _: nn.Module): input_node = node.kwargs["input"] repeats = cast(int, node.kwargs["repeats"]) dim = node.kwargs["dim"] - assert ( - type(repeats) is int - ), "We currently only support `repeat_interleave` with int repeats" + if not (type(repeats) is int): + logger.info( + "Not mapping repeat_interleave to an acc op. We currently only support `repeat_interleave` with int repeats" + ) + return rank = node.meta["tensor_rank"] if dim is None: repeat_dim = rank - 1 @@ -825,6 +861,18 @@ def matmul(*, input, other): op_and_target=("call_function", nn.functional.dropout), arg_replacement_tuples=[("input", "input")], ) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", nn.functional.dropout1d), + arg_replacement_tuples=[("input", "input")], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", nn.functional.dropout2d), + arg_replacement_tuples=[("input", "input")], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", nn.functional.dropout3d), + arg_replacement_tuples=[("input", "input")], +) @register_custom_acc_mapper_fn( op_and_target=("call_method", "detach"), arg_replacement_tuples=[("input", "input")] ) @@ -1055,7 +1103,10 @@ def rescale_quantize_per_channel(*, input, acc_out_ty=None): @register_acc_op_mapping(op_and_target=("call_method", "sub")) @register_acc_op def sub(*, input, other): - return input - other + if not (isinstance(input, torch.Tensor) or isinstance(other, torch.Tensor)): + return operator.sub(input, other) + else: + return input - other @register_acc_op_properties(AccOpProperty.pointwise) @@ -1067,6 +1118,19 @@ def mul(*, input, other): return input * other +@register_acc_op_mapping( + op_and_target=("call_function", torch.ops.aten.threshold_backward.default), + arg_replacement_tuples=[ + ("grad", "grad"), + ("self", "input"), + ("threshold", "threshold"), + ], +) +@register_acc_op +def threshold_backward(*, grad, input, threshold): + return torch.ops.aten.threshold_backward.default(grad, input, threshold) + + @register_custom_acc_mapper_fn( op_and_target=("call_method", "div"), arg_replacement_tuples=[ @@ -1367,7 +1431,7 @@ def std_mapper(node, mod): mean_kwargs = { "input": input_node, "dim": dim, - "keepdim": keepdim, + "keepdim": True, } mean_node = node.graph.call_function(mean, kwargs=mean_kwargs) mean_node.meta["type"] = torch.Tensor @@ -1385,7 +1449,7 @@ def std_mapper(node, mod): } pow_node = node.graph.call_function(pow, kwargs=pow_kwargs) pow_node.meta["type"] = torch.Tensor - # sum(pow(X-mean(X))))/N + # mean(pow(X-mean(X))) post_mean_kwargs = { "input": pow_node, "dim": dim, @@ -1393,7 +1457,7 @@ def std_mapper(node, mod): } post_mean_node = node.graph.call_function(mean, kwargs=post_mean_kwargs) post_mean_node.meta["type"] = torch.Tensor - # sqrt(sum(pow(X-mean(X))))/N) + # sqrt( mean(pow(X-mean(X))) ) sqrt_kwargs = { "input": post_mean_node, } @@ -1653,12 +1717,26 @@ def fmod(*, input, other): @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sigmoid)) +@register_acc_op_mapping( + op_and_target=("call_function", torch.ops.aten.sigmoid.default) +) @register_acc_op_mapping(op_and_target=("call_method", "sigmoid")) @register_acc_op def sigmoid(*, input): return torch.sigmoid(input=input) +@register_acc_op_properties(AccOpProperty.pointwise) +@register_acc_op_mapping( + op_and_target=("call_function", torch.ops.aten.sigmoid_backward.default) +) +@register_acc_op_mapping(op_and_target=("call_method", "sigmoid_backward")) +@register_acc_op +# first argument's name needs to be input to use same_shape_and_dtype_as_input +def sigmoid_backward(*, input, dest): + return torch.ops.aten.sigmoid_backward(grad_output=input, output=dest) + + @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sinh)) @register_acc_op @@ -1716,6 +1794,23 @@ def log(*, input): return torch.log(input=input) +@register_acc_op_properties(AccOpProperty.unary) +@register_acc_op_mapping( + op_and_target=("call_function", torch.nn.functional.log_softmax), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim"), + ("dtype", "dtype", this_arg_is_optional), + ], +) +@register_acc_op +def log_softmax(*, input, dim, dtype=None): + """ + _stacklevel are ignored here. + """ + return torch.nn.functional.log_softmax(input=input, dim=dim, dtype=dtype) + + @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sqrt)) @register_acc_op_mapping(op_and_target=("call_method", "sqrt")) @@ -1773,7 +1868,10 @@ def abs(*, input): @register_acc_op_mapping(op_and_target=("call_function", torch.neg)) @register_acc_op def neg(*, input): - return torch.neg(input=input) + if not isinstance(input, torch.Tensor): + return operator.neg(input) + else: + return torch.neg(input=input) @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @@ -2282,6 +2380,7 @@ def embedding_bag_4bit_rowwise_offsets( @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sin)) +@register_acc_op_mapping(op_and_target=("call_method", "sin")) @register_acc_op def sin(*, input): return torch.sin(input=input) @@ -2289,6 +2388,7 @@ def sin(*, input): @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.cos)) +@register_acc_op_mapping(op_and_target=("call_method", "cos")) @register_acc_op def cos(*, input): return torch.cos(input=input) @@ -2314,13 +2414,53 @@ def getitem(*, input, idx): return input[idx] -@register_acc_op_mapping(op_and_target=("call_function", torch.nan_to_num)) -@register_acc_op_mapping(op_and_target=("call_method", "nan_to_num")) @register_acc_op -def nan_to_num(*, input, nan=0.0, posinf=None, neginf=None): +def nan_to_num(*, input, nan=None, posinf=None, neginf=None): return torch.nan_to_num(input, nan=nan, posinf=posinf, neginf=neginf) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.nan_to_num), + arg_replacement_tuples=[ + ("input", "input"), + ("nan", "nan"), + ("posinf", "posinf"), + ("neginf", "neginf"), + ], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "nan_to_num"), + arg_replacement_tuples=[ + ("input", "input"), + ("nan", "nan"), + ("posinf", "posinf"), + ("neginf", "neginf"), + ], +) +def custom_nan_to_num_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: + nan_val, posinf, neginf = ( + node.kwargs["nan"], + node.kwargs["posinf"], + node.kwargs["neginf"], + ) + if nan_val is None: + nan_val = 0 + if posinf is None: + posinf = torch.finfo(torch.float16).max + if neginf is None: + neginf = torch.finfo(torch.float16).min + kwargs = { + "input": node.kwargs["input"], + "nan": nan_val, + "posinf": posinf, + "neginf": neginf, + } + with node.graph.inserting_before(node): + new_node = node.graph.call_function(nan_to_num, kwargs=kwargs) + new_node.meta = node.meta.copy() + return new_node + + @register_acc_op_properties(AccOpProperty.unary) @register_acc_op_mapping( op_and_target=("call_method", "expand"), @@ -2422,7 +2562,10 @@ def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: @register_acc_op def reshape(*, input, acc_out_ty=None): assert acc_out_ty is not None - return input.reshape(acc_out_ty.shape) + shape = acc_out_ty.shape + if len(shape) == 1 and not isinstance(shape[0], int): + return input.reshape(shape[0]) + return input.reshape(shape) @register_custom_acc_mapper_fn( @@ -2977,22 +3120,6 @@ def tensor_split(*, input, indices_or_sections, dim=0): ) -@register_acc_op_mapping( - op_and_target=("call_method", "new_ones"), - arg_replacement_tuples=[ - ("input", "input"), - ("size", "size"), - ("dtype", "dtype", this_arg_is_optional), - ("device", "device", this_arg_is_optional), - ("requires_grad", "requires_grad", this_arg_is_optional), - ], -) -@register_acc_op -def new_ones(*, input, size, dtype=None, device=None, requires_grad=False): - assert requires_grad is False, f"requires_grad != False, it is {requires_grad}" - return input.new_ones(size, dtype=dtype, device=device) - - @register_acc_op_mapping( op_and_target=("call_method", "new_empty"), arg_replacement_tuples=[ @@ -3080,33 +3207,6 @@ def xl_weight(weight_id: str, metadata: TensorMetadata, proxy_shape, dtype): return torch.zeros(proxy_shape, dtype=dtype) -@register_custom_acc_mapper_fn( - op_and_target=("call_function", torch.nn.functional.log_softmax), - arg_replacement_tuples=[ - ("input", "input"), - ("dim", "dim"), - ("dtype", "dtype"), - ], -) -def log_softmax_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: - with node.graph.inserting_after(node): - - softmax_kwargs = { - "input": node.kwargs["input"], - "dim": node.kwargs["dim"], - "dtype": node.kwargs["dtype"], - } - softmax_node = node.graph.call_function(softmax, kwargs=softmax_kwargs) - softmax_node.meta = node.meta.copy() - - with softmax_node.graph.inserting_after(softmax_node): - log_kwargs = {"input": softmax_node} - log_node = node.graph.call_function(log, kwargs=log_kwargs) - log_node.meta = node.meta.copy() - - return log_node - - @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.nn.functional.softplus), arg_replacement_tuples=[ @@ -3256,6 +3356,124 @@ def baddbmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: return add_node +@register_acc_op_mapping(op_and_target=("call_function", torch.clone)) +@register_acc_op_mapping(op_and_target=("call_method", "clone")) +@register_acc_op +def clone(*, input): + return torch.clone(input) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.unbind)) +@register_acc_op +def unbind(*, input, dim=0): + return torch.unbind(input, dim=dim) + + +@register_acc_op_mapping( + op_and_target=("call_function", torch.nn.functional.group_norm), + arg_replacement_tuples=[ + ("input", "input"), + ("num_groups", "num_groups"), + ("weight", "weight"), + ("bias", "bias"), + ("eps", "eps"), + ], +) +@register_acc_op +def group_norm(*, input, num_groups, weight=None, bias=None, eps=1e-05): + return torch.nn.functional.group_norm( + input, num_groups, weight=weight, bias=bias, eps=eps + ) + + +@register_acc_op_mapping(op_and_target=("call_method", "long")) +@register_acc_op +def long(*, input): + return input.long() + + +@register_acc_op_mapping( + op_and_target=("call_method", "new_full"), + arg_replacement_tuples=[ + ("input", "input"), + ("size", "size"), + ("fill_value", "fill_value"), + ("dtype", "dtype", this_arg_is_optional), + ("device", "device", this_arg_is_optional), + ("requires_grad", "requires_grad", this_arg_is_optional), + ], +) +@register_acc_op +def new_full(*, input, size, fill_value, dtype=None, device=None, requires_grad=False): + return input.new_full(size, fill_value=fill_value, dtype=dtype, device=device) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.full_like)) +@register_acc_op +def full_like(*, input, fill_value, dtype=None, device=None): + return torch.full_like( + input=input, fill_value=fill_value, dtype=dtype, device=device + ) + + +@register_acc_op_mapping( + op_and_target=("call_method", "new_ones"), + arg_replacement_tuples=[ + ("input", "input"), + ("size", "size"), + ("dtype", "dtype", this_arg_is_optional), + ("device", "device", this_arg_is_optional), + ("requires_grad", "requires_grad", this_arg_is_optional), + ], +) +@register_acc_op +def new_ones(*, input, size, dtype=None, device=None, requires_grad=False): + assert requires_grad is False, f"requires_grad != False, it is {requires_grad}" + return input.new_ones(size, dtype=dtype, device=device) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.ones_like)) +@register_acc_op +def ones_like(*, input, dtype=None, device=None): + return torch.ones_like(input=input, dtype=dtype, device=device) + + +@register_acc_op_mapping( + op_and_target=("call_method", "new_zeros"), + arg_replacement_tuples=[ + ("input", "input"), + ("size", "size"), + ("dtype", "dtype", this_arg_is_optional), + ("device", "device", this_arg_is_optional), + ("requires_grad", "requires_grad", this_arg_is_optional), + ], +) +@register_acc_op +def new_zeros(*, input, size, dtype=None, device=None, requires_grad=False): + return input.new_zeros(size, dtype=dtype, device=device) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.zeros_like)) +@register_acc_op +def zeros_like(*, input, dtype=None, device=None): + return torch.zeros_like(input=input, dtype=dtype, device=device) + + +@register_acc_op_mapping( + op_and_target=("call_method", "index_add_"), +) +@register_acc_op_mapping(op_and_target=("call_function", torch.index_add)) +@register_acc_op +def index_add(*, input, dim, index, source, alpha=1): + return torch.index_add(input, dim, index, source, alpha=alpha) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.masked_select)) +@register_acc_op +def masked_select(*, input, mask): + return torch.masked_select(input=input, mask=mask) + + ############################################################################### # Set ops as side-effectul, this prevents them from being optimized away or diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index c3a5ad850e..bc8c613fee 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -581,9 +581,15 @@ def _replace_transpose_last_dims(gm: torch.fx.GraphModule): gm.recompile() -def rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list): +def rewriter_base_trace( + mod, + ast_rewriter_allow_list, + leaf_module_list, + concrete_args: Optional[Dict[str, Any]] = None, +): rewritten_graph, rewritten_mod = AccRewritingTracer().trace( mod, + concrete_args, ast_rewriter_allow_list=ast_rewriter_allow_list, leaf_module_list=leaf_module_list, ) @@ -605,6 +611,8 @@ def trace( acc_normalization_block_list: Optional[ Set[Tuple[str, Union[str, Callable]]] ] = None, + dont_retrace_gm: bool = False, + concrete_args: Optional[Dict[str, Any]] = None, ) -> torch.fx.GraphModule: """ Performs tracing and arg normalization specialized for accelerator lowering. @@ -653,6 +661,10 @@ def trace( normalization to. Just like the register_acc_op decarators, the target can either be a string (e.g. for op == "call_method") or a callable (e.g. for op == "call_function"). + + dont_retrace_gm (bool): Optional bool for whether to re-trace the provided + module if it's a graph module already. + """ if mod.training: warnings.warn( @@ -664,7 +676,12 @@ def trace( assert isinstance(sample_inputs, (list, tuple)) # Rewrite the module to make it symbolic traceable, and then trace it. - traced = rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list) + if dont_retrace_gm and isinstance(mod, torch.fx.GraphModule): + traced = mod + else: + traced = rewriter_base_trace( + mod, ast_rewriter_allow_list, leaf_module_list, concrete_args + ) # Now remove all assertions and exceptions if requested. if remove_assertions: diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py index ab3207925f..75418034cb 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py @@ -110,7 +110,8 @@ def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None): If `header` is provided then it's included in the printed string. """ ops_and_counts: Dict[Callable, int] = {} - placeholder_count = get_attr_count = call_method_count = call_module_count = 0 + placeholder_count = get_attr_count = 0 + call_method_count = call_module_count = output_count = 0 for node in gm.graph.nodes: if node.op == "call_function": ops_and_counts[node.target] = ops_and_counts.get(node.target, 0) + 1 @@ -141,7 +142,8 @@ def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None): # easier to parse. pretty_ops_and_counts: List[Tuple[str, int]] = [] for op, count in ops_and_counts.items(): - pretty_ops_and_counts.append((_get_qualified_name(op), count)) + name = strip_module_prefixes(_get_qualified_name(op)) + pretty_ops_and_counts.append((name, count)) pretty_ops_and_counts.sort() for op_str, count in pretty_ops_and_counts: model_info_str += f"> {op_str}: {count}\n" @@ -149,6 +151,14 @@ def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None): return model_info_str +def strip_module_prefixes(op_name): + return ( + op_name.replace("torch_tensorrt.fx.tracer.acc_tracer.", "") + .replace("glow.fb.fx.acc_tracer.", "") + .replace("glow.fb.fx.", "") + ) + + def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: """ Make sure the name is unique (in a module) and can represents an attr. diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index e60c8f8d13..edcce20d65 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -42,24 +42,28 @@ def __init__( capture_scalar_outputs: bool = True, guard_nn_modules: bool = True, dynamic_shapes: bool = True, + specialize_int: bool = True, verbose: bool = True, ) -> None: self.capture_scalar_outputs = capture_scalar_outputs self.guard_nn_modules = guard_nn_modules self.dynamic_shapes = dynamic_shapes + self.specialize_int = specialize_int self.verbose = verbose def activate(self) -> None: torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs torchdynamo.config.guard_nn_modules = self.guard_nn_modules torchdynamo.config.dynamic_shapes = self.dynamic_shapes + torchdynamo.config.specialize_int = self.specialize_int torchdynamo.config.verbose = self.verbose def deactivate(self) -> None: torchdynamo.config.capture_scalar_outputs = True torchdynamo.config.guard_nn_modules = True torchdynamo.config.dynamic_shapes = True + torchdynamo.config.specialize_int = True torchdynamo.config.verbose = True diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 79779f604e..257dba6de4 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -1,6 +1,5 @@ from enum import Enum -from typing import List, Callable -from packaging import version +from typing import List, Optional, Callable # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt @@ -19,6 +18,20 @@ class LowerPrecision(Enum): FP32 = "fp32" FP16 = "fp16" INT8 = "int8" + BF16 = "bf16" + + @staticmethod + def from_str(label: str) -> Optional["LowerPrecision"]: + if label in ("fp32", "float32", "float", "torch.float32"): + return LowerPrecision.FP32 + elif label in ("fp16", "float16", "half", "torch.half", "torch.float16"): + return LowerPrecision.FP16 + elif label in ("int8"): + return LowerPrecision.INT8 + elif label in ("bf16", "bfloat16", "torch.bfloat16"): + return LowerPrecision.BF16 + else: + return None def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType: