From f77a0029e7accfdeb0e72a64f43d10220fec64fc Mon Sep 17 00:00:00 2001 From: {{wushirong}} <@meta.com> Date: Wed, 17 May 2023 10:47:22 -0700 Subject: [PATCH] Changes done internally at Facebook 483102cd4151f02c2d3632e6b6df7a5e59c0d6f3 Wei Wei [fx2trt] move acc op `torch.ops._caffe2.RoIAlign` to fb only 8ce94a01caa090d56adb4708452b52890160ba69 Wei Wei [aten2trt] reshape support 422326213bad177019e92c95dbc61af7a427bebc Shirong Wu nan_to_num aten converter f729c8a7f1268f329d15e3cf05f1fb9232fab2d9 Huamin Li Record TRT/AIT lower context into Scuba gpu_lowering_diagnostics 2df64af6bcf102a0ce40f1c5ab8472370d012904 Wei Wei [aten2ait][fx2ait] sin,cos,sqrt,clone support 9fa6469ccb9d00320d78684d748fe1a7e5c3cf60 Janet Yang Split nodes w/ float64 inputs from lowering d2ea242f721156df9e075927ea7956db772d4107 Fei Kou Handle Ellipsis in dper passes d053b097a0d1c158cde29792a35c4ec4174d9417 Jason Ansel Fix tests broken by D42953629 e18c6c76b1678a95c35583dabb41666b33c3df63 Zhijing Li (Accelerator Enablement) Add dper test for push_down_split pass 5008c6d200f2a9ca035547204b47eb5e1704ce88 Zhijing Li (Accelerator Enablement) Add passes as option to AITTestCase.run_test f7bc0c543b553ca2f80149995b4c28599a6ea396 Ying Zhang Back out "Add passes as option to AITTestCase.run_test" 22d4044c66720e0e656f41538c81a3e90ef1a433 Zhijing Li (Accelerator Enablement) Relaunch add passes as option to AITTestCase.run_test ae0de22b6a97bca82c0ef6a14b0be2b570eb443a Eli Uriegas Remove fx2trt/torch2trt backends (#93822) b08e568951c911e4c3bbc72b55830fa1d4be4b2b Eli Uriegas Remove torch/_dynamo/optimizations (#93871) 725266c0b7eb0549060e79b65346d703cc5bc39e Benson Ma [T143761882] Migrate CUTLASS group gemm operators to OSS 44110f5df422e84cd9d9afbf5dfbe742057b9d98 Zhijing Li (Accelerator Enablement) Add noop pass for torch.ops.fb.scale_gradient 84befb25b778485c8694ba659248d4d570d92390 Chao Gu [FX] Add log_softmax b641713bd774cb7c7bf903f514bff5c87a6f3a33 Wei Wei [fx2ait] support torch.unbind, torch.group_norm d263b38b53b93a18a78cd34b2a1c48711c3c59cd Shirong Wu Add extra logging for layer norm eb2591231195cc0ab6780f345f095807a7d45f7c Callum Ryan Make GPU test run in bundled mode f63d3834e87a819f8335c50b351e48f60573d474 Sarunya Pumma Back out "[T143761882] Migrate CUTLASS group gemm operators to OSS" a9f489c1c3a182698385053c0a94b792c4e310ba Shirong Wu Change opt_profile_replica to 3 b8bdde86f0bae6010062c33aec03a4e13a87a6ab Brian Hirsh forward fix for new _native_batch_norm_legit_no_training op e8f4cbd46402e5603cc48d24395db3f0e010581a Shirong Wu Fix reshape op b860725bfaf74a0043190d1140ddee987dd82d0c generatedunixname89002005232357 Revert D43278214: Multisect successfully blamed D43278214 for test or build failures d4ea365cf8aa56d752912f7878b8046e89c804c2 Chunxing Yin [mtia] Add sigmoid_backward kernel binding in Glow a768c82a51a058e56a64ff82f90e619795611b66 Mor Tzur lower to ait 8eb52426aaca586ae50fde75cccca6a0827a8328 Wei Wei [hstu][fx2ait] op support 55d95ffa096d9de7952a6a1c4628efd67e554d82 Wei Wei [fx2ait] temp solution to set correct dynamic batch size for jagged tensor 0a42e2f0874c48e9b60503a25705f0fc6319ff87 Jia Jiunn Ang [CMF] chunk-squeeze-cat op fusion when split on last dimension 8bd509596a799f1270796772e12be090a6db5d39 Wei Wei [aten2trt] update comment 1761b440d646836116fdadf2b5c7b55c7d2b989b Oleg Khabinov [fx2ait] Fix a dper pass when acc_ops.squeeze doesn't have a dim 3cc405a92c9fcec886d890de87ac94e024c682a5 Jia Jiunn Ang [CMF] Fuse chunk-linear-cat into baddbmm 5f42f56c5b5d0bd4c058aa280a980e64dd89b0a9 Xiaodong Wang [cudnn] static linking 229969542a2c1e96fe8345ff7adc2fd48f6a0707 Romain Sauvestre Remove base_module from acc_tracer target a174195c484d5a25f06e4c0665bbb2e9d9dcae82 Janet Yang Support input_tensor_spec w/ multiple batch dims in TRT 0246365e6facc6dfb13843fa9854802f35c0938a Zhijing Li (Accelerator Enablement) Remove noop dropout op with acc tracer 4c287b9f6238e8bbbd80e742262a0eee6efa57de Kunming Ho Operator support for threshold_backward 71bb34c81289173b83c7e7cf544b851096d9d99d Fei Kou specialize_int_float to specialize_int from D43925225 037db53f89a7b863ef0fbaa7b94425fd9a08dc96 Wei Wei enable torchscripting 77f3dce76fd5407b08826f67213d8299d9d48542 Adnan Akhundov [fx2ait] Extend jagged tensor support e6b551e48a0c03db63fc46ff85d975b489e30079 Jordan Fix [acc_tracer] Add dont_retrace_gm option ada3cbbb3d6c3b3631496a3bceea775f45649c6c Adam Simpkins Fix a bunch of invalid Python escape warnings in torch_tensorrt 98254d631e8748a85b05851c97fb74f3e3922cfe Brandon Tran (Realtime Integrity) Add torch.nn.functional.normalize to TensorRT fce21e2248ad0fddfcc12dbe2e3a9a6ac9ea2a5f Shirong Wu Fix trt input spec a08bad1ac74a6d1409bb3f2e96953ed0c149d006 Wei Wei [fx2ait] changes to improve jagged tensor and add b2b bmm 7745d70a17677777dcb5806e1e8008532f961f5d generatedunixname485339166882981 [Codemod][[pyunit][static_listing] Convert python unit test dynamic listing to static listing] oncall+gpu_enablement_0 ba33951ae2d2ebc99794aff8026a01a31f9ad8da Shirong Wu Add ait full op converters b3bfd69f15fc4e32f27217a3efa8204a2f062af8 Chao Gu [FX] support index_add in acc ops and tracer a965bafc517afc81591052e355fd34062b028a89 Shirong Wu Make fill op read dtype from input/kwarg 72f9b0925eceffc12dfa51769c1bd0cb38a3e50c generatedunixname485339166882981 [Codemod][[pyunit][static_listing] Convert python unit test dynamic listing to static listing] oncall+gpu_enablement 2e7feece191d6178ff6ec750d8fe481175bb27b9 Max Podkorytov [fx2ait] enable lowering to bfloat16 94607911ffb11e78082e061a670b5140e9a55d72 Archie Sravankumar Add support for nan_to_num 42fddd20d303dbbc3355a8c09a86d4a74317be97 Max Podkorytov [AITemplate] feed_lower_benchmark cli argument tweak for choosing precision 648ec682f2214e67912fe7c800f7ca059195cf4e Huamin Li Re-enable previous disabled TRT unit tests 3e5c2aac8a7b9e50efe04fcae361a3c0ee1777a7 Janet Yang Skip acc normalization of repeat_interleave if input dims aren't integral f412f35baeee9a1b17f67b7749ca1f9b8cbbe77b Janet Yang Skip acc normalization of repeat if dims aren't ints 5b9cfe428f29e27da76b19029bda03a8b43c17d1 Huamin Li add import into generate_standalone_repro 9f88965e87e72658aa6a4973dc870d50b8a22ca4 Fei Kou lowering with bf16 7f761df34d672c87c40b18369b28bc593374122c Fei Kou [benchmark] Support bfloat16 in mts_gpu_benchmark fa9b09e11ba8f888d761e1398367973d30e0aa1e Wei Wei [fx2ait] add a simple eager run to verify the input generatation is correct 4f8ca36dbdc72dfa60e667c3592d0a2bc466b994 Max Podkorytov [AITemplate] implement per op benchmark backbone 9873be1e82f2dd4a8a768497ac9cdb3b9b95cfe9 Thomas Orozco buck2: tag more long running tests 0d6827c464aa2141a48a8d768a8c7facd65c0bc4 generatedunixname485339166882981 [Codemod][[pyunit][static_listing] Convert python unit test dynamic listing to static listing] oncall+gpu_enablement_0_2ea3 04f9c1105a2a6a711d025d5c85b95147343d0ecd Zhijing Li (Accelerator Enablement) [fx2ait] Fix acc_ops converter on std when keepdim=False 906bad1deebb235a9c80d0f0d46145da08afa091 Danylo Baibak Forward fix to switch calling convention back to real tensors 48ffa2ab3dd66487922f9f0bf9a145db6eaf3fe2 Kefei Lu Lowering: validate inference with alternative batch sizes ca5dc1a2896bd476e3a327db834df859a3fcc11f Jordan Fix [fba_pass_manager_builder][BE] General cleanup/refactor afb4df5e84571f466b0f385472493aefb89344cc Shirong Wu Mask select converter 25e8afb1f8be19ec6c4ef4bc74ea48e64017cde2 Janet Yang Fix lowering FusedTsEncodingModule for coffee model 7fdf06ecfc6b4efb7008ce399dcd0c32ef1f1f75 generatedunixname485339166882981 [Codemod][[pyunit][static_listing] Convert python unit test dynamic listing to static listing] oncall+deprecated_techdebt_do_not_use_4_9c34 a58c5e454412585c4cc48ced1798dbf234cc13b6 Michael Liu Initialize `output_count` in `get_model_info_str` 2c6f13ddcc52e8f833fcd164d0c479ca3398322e Wei Wei jagged SHA and MHA module support 2fe5c7cd3b763b839af3d1b05eecc73f1df05286 Shirong Wu Add BF16 support for ads model 2486edbe5013f3b7e5807503538f3164bdd4ee19 Shirong Wu Add low_level_module conversion pass ca7c51407ab0410d311c984b31aeb757dd840bc2 Wei Wei [hstu] remove torch_package from RelativeBucketedTimeAndPositionBasedBias after packaged 80596e459343d5630e16a6175eafffd2c25a3123 Shirong Wu Block a pass that yield problem ded609195500a8edc5bed80ee85f41b35224c19f Huamin Li Do not test test_implicit_batch_dim if >= 8.6 8e8e736e14d23e77fa2bd5e72123d66943f7716f Huamin Li Speed up TRT compile time in test env 2db82572e509cfe827c34a4060c058ae44b5547a Jordan Fix [acc_tracer] Add in use_concrete_args option to enable pytree flatten/unflatten 946f957b6636c6b4f64e52148c9baf6e0351fb5e Wei Wei [hstu] changes to bias module and sha/mha pass to adapt to removing presences d904b26386c2ef98f292edae7c5e98c27119f9d9 Oleg Khabinov [fx2ait] Rename split_op_to_reshape to chunk_op_to_reshape ca36733f0ea67aeeb38a3740f795bbf99b24037b Oleg Khabinov [fx2ait] Rewrite chunk_op_to_reshape() to use while loop instead of recursion 4361feb4399eec3816b534991020703d099d2896 Oleg Khabinov [fx2ait] Optimize chunk_op_to_reshape() 071b84e3cda4f0175b37ae62c37b2d4f2de7925f Huamin Li Disable libkineto for TRT unit tests 92f9acaac8f9a8f0fc2e1382bf4c79d0b94cbea5 Wei Wei [fx2ait] improve bf16 support 8b92e8356278eb9676a5299373841593af942fb4 Jongsoo Park [acc_tracer] skip None module in rewriting 0d1d644bad22c86efec12009ca1464587d1e7d38 Kefei Lu Remove non-existent argument doc string 2efe5e78bc8627a30ba132e5b8e14e06538d463f shirong Temp fix a15a564a567eb689604d27ca814553e38c287698 shirong Temporary commit at 4/24/2023, 2:32:22 PM 78825462243c09760ebb73156a4c18bbc9ddee75 shirong Temporary commit at 4/24/2023, 2:32:37 PM 9bfea274462fd77cb04c38c17bc237541af87c55 laksrm [DNR] onboard ctr to aimp with lowering fixes 8bb482b10f7f63270c329c88d5ac028b40f6b757 shirong Reenable pass --- .../fx/converters/acc_ops_converters.py | 7 +- py/torch_tensorrt/fx/diagnostics.py | 8 + py/torch_tensorrt/fx/fx2trt.py | 6 + py/torch_tensorrt/fx/input_tensor_spec.py | 73 +++- py/torch_tensorrt/fx/lower.py | 4 + .../fx/passes/lower_basic_pass.py | 46 ++- .../fx/passes/lower_pass_manager_builder.py | 7 +- py/torch_tensorrt/fx/passes/pass_utils.py | 251 ++++++++++++- .../test/converters/aten_op/test_cat_aten.py | 4 +- .../vanilla/test_convolution_vanilla.py | 3 +- .../fx/test/core/test_input_tensor_spec.py | 33 ++ .../passes/test_fuse_permute_linear_trt.py | 2 + .../fx/test/passes/test_pass_utils.py | 97 +++++ .../fx/test/passes/test_setitem_trt.py | 175 +++++---- .../fx/test/tracer/test_acc_tracer.py | 95 +++++ .../fx/test/tracer/test_dispatch_tracer.py | 93 ++++- .../fx/test/trt_lower/trt_splitter_test.py | 50 ++- py/torch_tensorrt/fx/tools/common_fx2trt.py | 4 + py/torch_tensorrt/fx/tools/model_packager.py | 23 +- py/torch_tensorrt/fx/tools/trt_splitter.py | 3 +- .../fx/tracer/acc_tracer/acc_normalizer.py | 11 + .../fx/tracer/acc_tracer/acc_ops.py | 340 ++++++++++++++---- .../fx/tracer/acc_tracer/acc_tracer.py | 21 +- .../fx/tracer/acc_tracer/acc_utils.py | 14 +- .../fx/tracer/dispatch_tracer/aten_tracer.py | 4 + py/torch_tensorrt/fx/utils.py | 16 +- 26 files changed, 1193 insertions(+), 197 deletions(-) create mode 100644 py/torch_tensorrt/fx/test/passes/test_pass_utils.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 51b5d899eb..8119f3bac7 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -691,10 +691,13 @@ def acc_ops_layer_norm(network, target, args, kwargs, name): eps_field = trt.PluginField( "eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32 ) + normalized_shape = kwargs["normalized_shape"] try: - normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32) + normalized_shape = np.array(normalized_shape, dtype=np.int32) except TypeError: - _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") + _LOGGER.error( + f"Unable to convert normalized_shape with value {normalized_shape} to a field, fall back to []" + ) normalized_shape = np.array([], dtype=np.int32) normalized_shape_filed = trt.PluginField( diff --git a/py/torch_tensorrt/fx/diagnostics.py b/py/torch_tensorrt/fx/diagnostics.py index 0ba2a30652..0d78513a81 100644 --- a/py/torch_tensorrt/fx/diagnostics.py +++ b/py/torch_tensorrt/fx/diagnostics.py @@ -87,12 +87,14 @@ class DiagnosticsWriter: def __init__(self): self._root_dir = tempfile.mkdtemp(prefix="fx2trt.") + self._data = "" _LOGGER.info(f"Initializing DiagnosticsWriter with root_dir: {self._root_dir}") def write(self, file_name: str, data: WriteObj): """ TODO: Can be disabled by regex on file_name """ + self._data = data # Only write if we are inside a collect_when() context. if not _IS_IN_COLLECT_CONTEXT.get(False): return @@ -117,6 +119,9 @@ def write(self, file_name: str, data: WriteObj): def root_dir(self) -> str: return self._root_dir + def data(self) -> WriteObj: + return self._data + def _write(self, file_name: str, to_write: bytes): # ms granularity - no naming collash, otherwise file will be # overwritten. @@ -271,6 +276,9 @@ def collect(self) -> str: finally: os.remove(fp) + def data(self) -> WriteObj: + return self._write.data() + def _res_or_err(data: WriteObj) -> t.Tuple[TWrite, str]: if isinstance(data, (str, bytes)): diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d0a6bdf0a1..846c90bdd5 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -1,4 +1,5 @@ import logging +import os import warnings from datetime import datetime from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence @@ -211,6 +212,11 @@ def run( builder_config = self.builder.create_builder_config() builder_config.max_workspace_size = max_workspace_size + # Speed up TRT build time in the test environment + if trt.__version__ >= "8.6" and os.environ.get("TRT_TEST_ENV", "0") == "1": + _LOGGER.info("Set TRT optimization level to 0") + builder_config.builder_optimization_level = 0 + cache = None if timing_cache: cache_file = numpy.array(timing_cache) diff --git a/py/torch_tensorrt/fx/input_tensor_spec.py b/py/torch_tensorrt/fx/input_tensor_spec.py index 781c11f32c..8128fc1760 100644 --- a/py/torch_tensorrt/fx/input_tensor_spec.py +++ b/py/torch_tensorrt/fx/input_tensor_spec.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Iterable, List, NamedTuple, Optional, Sequence, Tuple import torch @@ -18,6 +18,12 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None): # is the dynamic batch dimension. Otherwise, we use the additional # inputs to determine the batch dimension. if additional_inputs is None: + batch_dims = None + if not isinstance(inputs, torch.Tensor) and len(inputs) > 1: + bs = inputs[0].size(0) + batch_dims = None + if not all(x.size(0) == bs for x in inputs): + batch_dims = InputTensorSpec.find_batch_size_dim(inputs) return InputTensorSpec.from_tensors_with_dynamic_batch_size( inputs, ( @@ -26,6 +32,7 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None): lower_setting.max_batch_size, ), lower_setting.opt_profile_replica, + batch_dims, ) else: batch_dims = [] @@ -147,25 +154,69 @@ def from_tensors_with_dynamic_batch_size( A list of InputTensorSpec named tuples with dynamic ranges. """ if batch_dims is None: - batch_dims = [0] * len(tensors) + batch_dims = cls.find_batch_size_dim(tensors) input_specs = [] batch_size = tensors[0].size(batch_dims[0]) for i, tensor in enumerate(tensors): batch_dim = batch_dims[i] - assert batch_size == tensor.size( - batch_dim - ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." - shape = list(tensor.shape) - shape[batch_dim] = -1 - shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] - input_specs.append( - cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) - ) + if batch_dim == -1: + input_specs.append(cls.from_tensor(tensor)) + else: + shape = list(tensor.shape) + assert batch_size == tensor.size( + batch_dim + ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." + shape[batch_dim] = -1 + shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] + input_specs.append( + cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) + ) return input_specs + @classmethod + # pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any` + def find_batch_size_dim(cls, inputs: Any) -> []: + if isinstance(inputs, torch.Tensor) or len(inputs) <= 1: + return [0] + shapes = [i.shape for i in inputs] + frequency_map = {} + first_dims = set() + for shape in shapes: + if len(shape) < 2: + # By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info + continue + # Dedup shape value for single tensor + first_dims.add(shape[0]) + shape = set(shape) + for i in shape: + frequency_map[i] = frequency_map.get(i, 0) + 1 + + if len(first_dims) == 1: + # first dim is the same in every input: we use it as batch_size + batch_size = first_dims.pop() + elif frequency_map: + # first dims are different: we use the most frequent dim as batch_size + sorted_frequency = sorted(frequency_map.items(), key=lambda x: -x[1]) + batch_size = sorted_frequency[0][0] + else: + # no dims to sort: no batch_size + batch_size = -1 + + bs_dim = [] + for i in inputs: + # Default batch size dim = -1, indicate no batch_size + dim = -1 + for index, val in enumerate(i.shape): + if val == batch_size: + dim = index + break + bs_dim.append(dim) + + return bs_dim + def to_random_tensor(self, id=1): shape = tuple(self.shape) if len(get_dynamic_dims(shape)): diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index f96f1db6b9..6572fe9588 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -41,6 +41,8 @@ def compile( dynamic_batch=True, is_aten=False, use_experimental_fx_rt=False, + correctness_atol=1e-1, + correctness_rtol=1e-1, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module @@ -81,6 +83,8 @@ def compile( dynamic_batch=dynamic_batch, is_aten=is_aten, use_experimental_rt=use_experimental_fx_rt, + correctness_atol=correctness_atol, + correctness_rtol=correctness_rtol, ) lowerer = Lowerer.create(lower_setting=lower_setting) return lowerer(module, input) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index e753d6e227..e98a9371c5 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -54,10 +54,14 @@ def fill_with_mul_zero_and_add(*args): def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule: - # Now we do constant folding on traced module. We want to skip pattern like - # weights -> quant -> dequant -> op during constant folding when the model is - # a quantized int8 model. - def skip_folding_quant_dequant(node: torch.fx.Node): + def skip_folding_ops(node: torch.fx.Node): + # dtype op + if node.target == acc_ops.dtype: + return True + # Now we do constant folding on traced module. We want to skip pattern like + # weights -> quant -> dequant -> op during constant folding when the model is + # a quantized int8 model. + # quant_dequant if node.target != acc_ops.quantize_per_tensor: return False # If quantize_per_node -> dequantize, then skip folding. @@ -66,7 +70,7 @@ def skip_folding_quant_dequant(node: torch.fx.Node): return True return False - const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant) + const_split_mod = split_const_subgraphs(traced_mod, skip_folding_ops) const_split_mod.run_folding() return const_split_mod @@ -630,3 +634,35 @@ def fix_clamp_numerical_limits_to_fp16( mod.recompile() return mod + + +@log_before_after +@validate_inference(atol=1e-3, rtol=1e-2) +def remove_dtype_and_to_pattern( + mod: torch.fx.GraphModule, input: Input +) -> torch.fx.GraphModule: + """ + Remove this pattern since it is unnecessary to cast to dtype + %dtype : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.dtype](args = (), kwargs = {input: %_attention_layers_0__uva}) + %to_18 : [#users=2] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.to_dtype](args = (), kwargs = {input: %x}) + """ + for node in mod.graph.nodes: + if node.op == "call_function" and node.target == acc_ops.dtype: + # find its first user + next_node = next(iter(node.users)) + # acc_op or pt op is treated differently + input = ( + next_node.kwargs["input"] + if "input" in next_node.kwargs + else next_node.args[0] + ) + if len(node.users) == 1 and ( + next_node.target == acc_ops.to_dtype or next_node.target == "to" + ): + next_node.replace_all_uses_with(input) + mod.graph.erase_node(next_node) + mod.graph.erase_node(node) + + mod.graph.eliminate_dead_code() + mod.recompile() + return mod diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 61052b21af..6e6b40d42f 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -8,6 +8,7 @@ from torch.fx.passes.pass_manager import inplace_wrapper, PassManager from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult +from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion from torch_tensorrt.fx.utils import LowerPrecision from ..input_tensor_spec import generate_input_specs @@ -229,10 +230,9 @@ def lower_func(split_result: SplitResult) -> nn.Module: submod = getattr(split_result.split_module, submod_name) LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) - # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - _LOGGER.info(f"Now lowering submodule {submod_name}") + _LOGGER.info(f"ACC submodule graph: {submod.graph}") lowering_start_time = datetime.datetime.now() self.lower_setting.additional_inputs = ( @@ -251,6 +251,9 @@ def lower_func(split_result: SplitResult) -> nn.Module: _LOGGER.info( f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" ) + else: + _LOGGER.info(f"GPU submodule graph: {submod.graph}") + apply_bfloat_float_conversion(submod, submod_inputs, submod_name) return split_result.split_module diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index fabc92881d..0b8578ffba 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -1,12 +1,17 @@ +import contextlib import io +import json import logging import tempfile from datetime import datetime from functools import wraps +from traceback import TracebackException from typing import Any, Callable, List, Optional import torch +import torch_tensorrt.fx.diagnostics as diagnostics from torch import fx +from torch.fx.node import Node from torch.fx.passes.shape_prop import ShapeProp # Create an alias for module input type to avoid littering pyre-ignore for Any @@ -20,6 +25,11 @@ FINAL_CHECK_ATOL_MULTIPLIER: float = 10 FINAL_CHECK_RTOL_MULTIPLIER: float = 10 +# A global override of the alternative batch size used in validate_variable_batch_sizes +ALTERNATIVE_BATCH_SIZE_OVERRIDE: Optional[int] = None +# If exception during validate_variable_batch_sizes should be thrown +ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW: bool = False + class RelaxAccuracyCheckMode: """ @@ -83,6 +93,46 @@ def __exit__(self, type, value, traceback): ) +@contextlib.contextmanager +def override_alternative_batch_size(alternative_batch_size: int = -1): + """ + A context manager to override alternative_batch_size + + Example: + + >>> # disables run_alternative_batch_size verification + >>> with override_alternative_batch_size(-1): + >>> fx2ait() + """ + + global ALTERNATIVE_BATCH_SIZE_OVERRIDE + old_value = ALTERNATIVE_BATCH_SIZE_OVERRIDE + ALTERNATIVE_BATCH_SIZE_OVERRIDE = alternative_batch_size + _LOGGER.info(f"Override {ALTERNATIVE_BATCH_SIZE_OVERRIDE=} ({old_value=})") + try: + yield + finally: + ALTERNATIVE_BATCH_SIZE_OVERRIDE = old_value + _LOGGER.info(f"Restored old value: {ALTERNATIVE_BATCH_SIZE_OVERRIDE=})") + + +@contextlib.contextmanager +def override_alternative_batch_size_exception_should_throw( + exception_should_throw: bool, +): + """ + A context manager to set if exception during alternative batch size verification + should be thrown. + """ + global ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW + old_value = ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW + ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW = exception_should_throw + try: + yield + finally: + ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW = old_value + + def chain_passes(*passes: PassFunc) -> PassFunc: """ Chains a sequence of pass functions to form a single pass function @@ -100,11 +150,28 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. -def validate_inference(rtol=None, atol=None): +def validate_inference( + rtol=None, atol=None, run_alternative_batch_size: int = -1 +) -> "Decorator": + """ + Returns a decorator on a PassFunc to sanity check the model outputs + difference before/after the transformation is within tolerance. + + Args: + rtol: reletive tolerance + atol: absoluate tolerance + run_alternative_batch_size (int): + In addition to running inference at original batch size in the + input, also run at an alternative batch size. If set to -1, do not + run at alternative batch size. It must be smaller than the original + batch size. This is useful to check the model can run at different + batch sizes. Usually we can set this to 1. + """ + def _validate_inference(pass_: PassFunc) -> PassFunc: """ - Wraps a pass function to validate that its inference results before and - after the pass run should be `close`. + A decorator to wrap a pass function to validate that its inference + results before and after the pass run should be `close`. """ @wraps(pass_) @@ -162,6 +229,120 @@ def pass_with_validation( return _validate_inference +def validate_variable_batch_sizes(run_alternative_batch_size: int = -1) -> "Decorator": + """ + Returns a decorator on a PassFunc to verify the model can run with + different batch sizes before/after the transformation is within tolerance. + + Args: + run_alternative_batch_size (int): + In addition to running inference at original batch size in the + input, also run at an alternative batch size. If set to -1, do not + run at alternative batch size. It must be smaller than the original + batch size. This is useful to check the model can run at different + batch sizes. Usually we can set this to 1. + + If the global variable `ALTERNATIVE_BATCH_SIZE_OVERRIDE` is set, it + overrides `run_alternative_batch_size`. + `ALTERNATIVE_BATCH_SIZE_OVERRIDE` can be set via: + + with override_alternative_batch_size(...): ... + """ + + def _run_alternative_batch_size(pass_: PassFunc) -> PassFunc: + """ + A decorator for PassFunc to check that the model (both before and after + the transformation by pass func) can run at alternative batch size. + """ + + @wraps(pass_) + def pass_with_validation( + module: fx.GraphModule, + input: Input, + *args, + **kwargs, + ) -> fx.GraphModule: + _run_alternative_batch_size = ( + ALTERNATIVE_BATCH_SIZE_OVERRIDE + if ALTERNATIVE_BATCH_SIZE_OVERRIDE is not None + else run_alternative_batch_size + ) + + if _run_alternative_batch_size < 0: + return pass_(module, input, *args, **kwargs) + + if not isinstance(input, (list, tuple)): + _LOGGER.info( + f"Skip run_alternative_batch_size: input must be list, tuple. Actual: {type(input)}" + ) + return pass_(module, input, *args, **kwargs) + + if not all(isinstance(x, torch.Tensor) for x in input): + _LOGGER.info( + "Skip run_alternative_batch_size: input elements must all be tensors" + ) + return pass_(module, input, *args, **kwargs) + + if not all(len(x.shape) > 0 for x in input): + _LOGGER.info( + "Skip run_alternative_batch_size: some input tensor(s) are scalar" + ) + return pass_(module, input, *args, **kwargs) + + batch_size_candidates = {x.shape[0] for x in input} + if len(batch_size_candidates) > 1: + _LOGGER.info( + f"Skip run_alternative_batch_size: input tensors' first dim must be the same, actual: {batch_size_candidates}" + ) + return pass_(module, input, *args, **kwargs) + + batch_size = next(iter(batch_size_candidates)) + assert ( + _run_alternative_batch_size <= batch_size + ), f"{_run_alternative_batch_size=} must be smaller or equal to {batch_size=}" + + input_alt_bs = [x[:_run_alternative_batch_size, ...] for x in input] + + def run_module(mod, stage: str): + """Run module with full bs and alternative bs""" + _LOGGER.info( + f"Running {stage} model at alternative batch size: {_run_alternative_batch_size}" + ) + try: + mod(*input) + mod(*input_alt_bs) + except Exception as e: + _LOGGER.warning( + f"Failed running {stage} module at full or alternative batch size: {e}" + ) + diagnostics.write( + "lowering_diagnostics", + json.dumps( + { + "validate_variable_batch_sizes_exception": repr(e), + "validate_variable_batch_sizes_exception_type": type( + e + ).__name__, + "validate_variable_batch_sizes_exception_traceback": "".join( + TracebackException.from_exception(e).format() + ), + } + ), + ) + if ALTERNATIVE_BATCH_SIZE_EXCEPTION_SHOULD_THROW: + raise + + run_module(module, "original") + module_after = pass_(module, input, *args, **kwargs) + run_module(module_after, "transformed") + + return module_after + + return pass_with_validation + + return _run_alternative_batch_size + + Decorator = Callable[[Callable], Callable] @@ -269,3 +450,67 @@ def collect(x: fx.node.Argument) -> fx.node.Argument: fx.node.map_aggregate(arg, collect) return res + + +class InputOutputDtypeInferInterpreter(torch.fx.Interpreter): + """ + Interprete a graph to propagate the output tensor dtype from its inputs, extracing + input and output graph node that need dtype cast to float32/bfloat16. + """ + + def __init__(self, module: torch.fx.GraphModule): + super().__init__(module) + self.need_cast_to_float32 = [] + self.need_cast_to_bfloat = [] + + def _need_cast(self, node: Node, run_result) -> None: + if node.op == "placeholder" and ( + run_result.dtype not in (torch.int32, torch.int64) + ): + _LOGGER.info( + f"Encountered node: {node.format_node()} need dtype cast to float32." + ) + self.need_cast_to_float32.append(node) + # Process node that will be used as final output + elif "output" in set(i.name for i in node.users.keys()): + if run_result.dtype not in (torch.int32, torch.int64): + _LOGGER.info( + f"Encountered node: {node.format_node()} need dtype cast to bfloat16." + ) + self.need_cast_to_bfloat.append(node) + + def run_node(self, n: Node) -> Any: + run_result = super().run_node(n) + + if torch.is_tensor(run_result): + n.meta["tensor_dtype"] = run_result.dtype + self._need_cast(n, run_result) + return run_result + + +def apply_bfloat_float_conversion( + gm: torch.fx.GraphModule, inputs: Any, name: str +) -> None: + _LOGGER.info("Apply bfloat-float32 conversion on {name}") + interpreter = InputOutputDtypeInferInterpreter(gm) + interpreter.run(*inputs) + + def to_bfloat(x): + return x.to(torch.bfloat16) + + def to_float(x): + return x.to(torch.float32) + + for node in interpreter.need_cast_to_float32: + with gm.graph.inserting_after(node): + cast = gm.graph.call_function( + to_float, + (node,), + {}, + ) + node.replace_all_uses_with(cast) + + for node in interpreter.need_cast_to_bfloat: + with gm.graph.inserting_after(node): + cast = gm.graph.call_function(to_bfloat, (node,), {}) + node.replace_all_uses_with(cast) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py index cfeb235af3..55bd7b1e8b 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py @@ -9,7 +9,7 @@ class TestCatConverter(DispatchTestCase): @parameterized.expand( [ ("pos", 1), - # ("neg", -2), #Dynamo tracer issue + # ("neg", -2), #dim can not have dynamic input ] ) def test_cat(self, _, dim): @@ -27,7 +27,7 @@ def forward(self, x, y, z): @parameterized.expand( [ ("pos", 1), - # ("neg", -2), #Dynamo tracer issue + # ("neg", -2), #dim can not have dynamic input ] ) def test_cat_dynamic_shape(self, _, dim): diff --git a/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py b/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py index a604f4b75a..384d55d44e 100644 --- a/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py +++ b/py/torch_tensorrt/fx/test/converters/vanilla/test_convolution_vanilla.py @@ -81,8 +81,7 @@ def forward(self, x): ("tuple_parameters", 1, (1, 1, 1), (0, 0, 0)), param("non_zero_padding", 1, padding=1), param("dilation", 1, dilation=2), - # TODO: Enable this when TRT fixes https://github.com/pytorch/TensorRT/issues/1445 - # param("groups", 1, groups=3), + param("groups", 1, groups=3), ] ) def test_conv3d( diff --git a/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py index db848eaf1c..0443278460 100644 --- a/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/fx/test/core/test_input_tensor_spec.py @@ -47,6 +47,23 @@ def test_from_tensors_with_dynamic_batch_size(self): self.assertEqual(batch_size, shape[0]) self.assertSequenceEqual(tensor.shape[1:], shape[1:]) + def test_from_tensors_with_dynamic_batch_size_no_bs_input(self): + tensors = [torch.randn(1, 2, 3), torch.randn(1, 4), torch.randn(72, 16)] + batch_size_range = [2, 3, 4] + specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( + tensors, batch_size_range + ) + for index, (spec, tensor) in enumerate(zip(specs, tensors)): + if index == 2: + for a, b in zip(spec.shape, tensor.shape): + self.assertEqual(a, b) + else: + self._validate_spec(spec, tensor, dynamic_dims=[0]) + + for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]): + self.assertEqual(batch_size, shape[0]) + self.assertSequenceEqual(tensor.shape[1:], shape[1:]) + def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self): tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)] batch_size_range = [2, 3, 4] @@ -88,6 +105,22 @@ def test_generate_input_specs(self): self._validate_spec(spec, tensor, dynamic_dims=[1]) self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) + # Explicit batch dim with inputs w/ different batch dims. + bs = 10 + inputs = [ + torch.randn(bs, 1, 2), + torch.randn(bs, 10, 3), + torch.randn(4, bs, 5), + torch.randn(bs, 2, 5), + ] + specs = generate_input_specs(inputs, lower_setting) + for idx, (spec, tensor) in enumerate(zip(specs, inputs)): + if idx == 2: + self._validate_spec(spec, tensor, dynamic_dims=[1]) + else: + self._validate_spec(spec, tensor, dynamic_dims=[0]) + self.assertEqual(len(spec.shape_ranges), lower_setting.opt_profile_replica) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py index 36420375f8..4edc2ef706 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py @@ -76,6 +76,8 @@ def forward(self, x): inputs, {trt_transposed_linear}, apply_passes=[fuse_permute_linear], + rtol=5e-3, + atol=2e-3, ) diff --git a/py/torch_tensorrt/fx/test/passes/test_pass_utils.py b/py/torch_tensorrt/fx/test/passes/test_pass_utils.py new file mode 100644 index 0000000000..6f5edde004 --- /dev/null +++ b/py/torch_tensorrt/fx/test/passes/test_pass_utils.py @@ -0,0 +1,97 @@ +import logging +import unittest +from typing import Optional + +import torch +import torch_tensorrt.fx.diagnostics as diagnostics +from torch_tensorrt.fx.passes.pass_utils import ( + override_alternative_batch_size, + override_alternative_batch_size_exception_should_throw, + validate_variable_batch_sizes, +) + +diagnostics.set_current_collector( + diagnostics.ZipDiagnosticsCollector(writer=diagnostics.get_current_writer()) +) + + +_LOGGER: logging.Logger = logging.getLogger(__name__) +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) # configure root logger + + +class BatchSizeError(Exception): + pass + + +class PassUtilsTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + + def test_run_alternative_batch_size(self): + class TestModule(torch.nn.Module): + should_fail_at_bs: Optional[int] = None + + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + if x.shape[0] == self.should_fail_at_bs: + raise BatchSizeError(self.should_fail_at_bs) + + return x + y + z + + def gen_input(bs: int): + return [ + torch.rand(bs, 64), + torch.rand(bs, 64), + torch.rand(bs, 64), + ] + + @validate_variable_batch_sizes(1) + def model_transform_pass_good(model, input): + """ + This is a good transformation. Meaning that the model it + produces will not fail at any batch sizes + """ + model.should_fail_at_bs = None + return model + + @validate_variable_batch_sizes(1) + def model_transform_pass_bad(model, input): + """ + This is a bad transformation. Meaning that the model it produces + will fail when the given input batch size is 1 + """ + model.should_fail_at_bs = 1 + return model + + model = TestModule() + input = gen_input(bs=10) + + with diagnostics.collect_when(diagnostics.CollectionConditions.always()): + + with override_alternative_batch_size_exception_should_throw(True): + # This should succeed: the validate_inference decorator will + # run both bs=10 and bs=1 successfully + model_transform_pass_good(model, input) + + # This should fail: the validate_inference decorator will run the + # model (post transform) at bs=1. + model.should_fail_at_bs = None # reset + self.assertRaises( + BatchSizeError, lambda: model_transform_pass_bad(model, input) + ) + + # Test override_alternative_batch_size can disable run alt bs: + # This should success: the validate_inference decorator will + # NOT run alternative batch size, because it is disabled via + # override_alternative_batch_size. + model.should_fail_at_bs = None # reset + with override_alternative_batch_size(alternative_batch_size=-1): + model_transform_pass_bad(model, input) + + # Test that by default alt bs failures won't cause exception + # thrown, because of no + # `override_alternative_batch_size_exception_should_throw` + model_transform_pass_bad(model, input) diff --git a/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py b/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py index 8f9c1a887f..cb7ff8f906 100644 --- a/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py @@ -1,7 +1,6 @@ import torch import torch._dynamo as torchdynamo from parameterized import parameterized -from torch._dynamo.optimizations import backends from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase @@ -507,93 +506,93 @@ def transform_fx(gm, example_inputs): optimize_mod(*inputs) # test with torchdynamo - def test_setitem1d_trt(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[1] = x - return y - - inputs = [torch.randn(1), torch.randn(3)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - ref_output = m(*inputs) - - optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) - - output = optimize_mod(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) - - @parameterized.expand( - [ - ("c1", (4, 2), (4, 5), 0, 2), - ("c2", (4, 2), (4, 5), 1, 3), - ] - ) - def test_setitem2d_1v_trt(self, name, x_shape, y_shape, y_start, y_end): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[:, y_start:y_end] = x - return y - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - ref_output = m(*inputs) - optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) - output = optimize_mod(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) - - @parameterized.expand( - [ - ("c1", (2, 3, 4, 5), (4, 5, 6, 7), 0, 2, 0, 3, 0, 4, 0, 5), - ("c2", (2, 3, 4, 5), (4, 5, 6, 7), 1, 3, 1, 4, 1, 5, 1, 6), - ] - ) - def test_setitem4d_4v_trt( - self, - name, - x_shape, - y_shape, - start_0, - end_0, - start_1, - end_1, - start_2, - end_2, - start_3, - end_3, - ): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] = x - y = y + 3 - x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] - return x - - inputs = [torch.randn(x_shape), torch.randn(y_shape)] - m = TestModule() - - inputs = [i.cuda() for i in inputs] - m.cuda() - - ref_output = m(*inputs) - optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) - output = optimize_mod(*inputs) - self.assertTrue(torch.allclose(ref_output, output)) + # def test_setitem1d_trt(self): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # def forward(self, x, y): + # y[1] = x + # return y + + # inputs = [torch.randn(1), torch.randn(3)] + # m = TestModule() + + # inputs = [i.cuda() for i in inputs] + # m.cuda() + # ref_output = m(*inputs) + + # optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + + # output = optimize_mod(*inputs) + # self.assertTrue(torch.allclose(ref_output, output)) + + # @parameterized.expand( + # [ + # ("c1", (4, 2), (4, 5), 0, 2), + # ("c2", (4, 2), (4, 5), 1, 3), + # ] + # ) + # def test_setitem2d_1v_trt(self, name, x_shape, y_shape, y_start, y_end): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # def forward(self, x, y): + # y[:, y_start:y_end] = x + # return y + + # inputs = [torch.randn(x_shape), torch.randn(y_shape)] + # m = TestModule() + + # inputs = [i.cuda() for i in inputs] + # m.cuda() + + # ref_output = m(*inputs) + # optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + # output = optimize_mod(*inputs) + # self.assertTrue(torch.allclose(ref_output, output)) + + # @parameterized.expand( + # [ + # ("c1", (2, 3, 4, 5), (4, 5, 6, 7), 0, 2, 0, 3, 0, 4, 0, 5), + # ("c2", (2, 3, 4, 5), (4, 5, 6, 7), 1, 3, 1, 4, 1, 5, 1, 6), + # ] + # ) + # def test_setitem4d_4v_trt( + # self, + # name, + # x_shape, + # y_shape, + # start_0, + # end_0, + # start_1, + # end_1, + # start_2, + # end_2, + # start_3, + # end_3, + # ): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # def forward(self, x, y): + # y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] = x + # y = y + 3 + # x = y[start_0:end_0, start_1:end_1, start_2:end_2, start_3:end_3] + # return x + + # inputs = [torch.randn(x_shape), torch.randn(y_shape)] + # m = TestModule() + + # inputs = [i.cuda() for i in inputs] + # m.cuda() + + # ref_output = m(*inputs) + # optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + # output = optimize_mod(*inputs) + # self.assertTrue(torch.allclose(ref_output, output)) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index 633359127f..74715d6030 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -1494,6 +1494,21 @@ def test_dropout(self): lambda x: nn.functional.dropout(x, training=False), input_shape=(1, 2, 3), ) + self._make_acc_op_function_test( + None, + lambda x: nn.functional.dropout1d(x, training=False), + input_shape=(4, 2, 3), + ) + self._make_acc_op_function_test( + None, + lambda x: nn.functional.dropout2d(x, training=False), + input_shape=(4, 2, 3), + ) + self._make_acc_op_function_test( + None, + lambda x: nn.functional.dropout3d(x, training=False), + input_shape=(4, 2, 3), + ) def test_stochastic_depth(self): self._make_acc_op_function_test( @@ -1727,6 +1742,11 @@ def test_ceil(self): def test_softmax(self): self._make_acc_op_function_test(acc_ops.softmax, torch.nn.functional.softmax) + def test_normalize(self): + self._make_acc_op_function_test( + acc_ops.normalize, torch.nn.functional.normalize + ) + def test_tensor_squeeze(self): self._make_acc_op_function_test(acc_ops.squeeze, lambda x: x.squeeze()) @@ -2628,6 +2648,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertIsNotNone(getitem) self.assertTrue(torch.equal(m(x), traced(x))) + def test_skip_normalization_if_none_repeat_interleave(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + repeats = y[0] + return torch.repeat_interleave(x, repeats, 1) + + # TODO: finish test later + m = TestModule() + inputs = (torch.randn(3, 4), torch.tensor([1])) + traced = acc_tracer.trace(m, inputs) + # Make sure repeat_interleave wasn't mapped into tiles + self.assertTrue("torch.repeat_interleave" in str(traced.graph)) + self.assertFalse("tile" in str(traced.graph)) + + def test_skip_normalization_if_none_repeat(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + repeats = [y[0], y[2], 3] + return x.repeat(repeats) + + # TODO: finish test later + m = TestModule() + inputs = (torch.randn(3, 4, 5), torch.tensor([1, 2, 3])) + traced = acc_tracer.trace(m, inputs) + # Make sure repeat wasn't mapped into tiles + self.assertTrue("repeat" in str(traced.graph)) + self.assertFalse("tile" in str(traced.graph)) + def test_acc_normalization_block_list(self): class TestModule(nn.Module): def forward(self, x: List[torch.Tensor]) -> torch.Tensor: @@ -2668,6 +2722,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertTrue(torch.equal(m(*sample_inputs), traced(*sample_inputs))) + def test_threshold_bwd(self): + class TestModule(nn.Module): + def __init__(self, threshold): + super().__init__() + self._threshold = threshold + + def forward(self, grad: torch.Tensor, input: torch.Tensor) -> torch.Tensor: + return torch.ops.aten.threshold_backward.default( + grad, input, self._threshold + ) + + m = TestModule(0.0) + grad = torch.randn(4096) + sample_inputs = torch.randn(4096) + traced = acc_tracer.trace(m, [grad, sample_inputs]) + + output = None + for node in traced.graph.nodes: + if node.op == "output": + assert output is None + output = node + + ref = m(grad, sample_inputs) + res = traced(grad, sample_inputs) + self.assertTrue(torch.equal(ref, res)) + def test_all_acc_ops_registered(self): self.assertEqual( acc_normalizer._acc_ops, @@ -2689,6 +2769,7 @@ def test_all_acc_ops_registered(self): acc_ops.minimum, acc_ops.cat, acc_ops.softmax, + acc_ops.normalize, acc_ops.sign, acc_ops.permute, acc_ops.matmul, @@ -2713,6 +2794,8 @@ def test_all_acc_ops_registered(self): acc_ops.tuple_construct, acc_ops.unsqueeze, acc_ops.sigmoid, + acc_ops.sigmoid_backward, + acc_ops.threshold_backward, acc_ops.sum, acc_ops.prod, acc_ops.max_full_reduce, @@ -2726,6 +2809,7 @@ def test_all_acc_ops_registered(self): acc_ops.atan, acc_ops.exp, acc_ops.log, + acc_ops.log_softmax, acc_ops.sqrt, acc_ops.reciprocal, acc_ops.abs, @@ -2797,5 +2881,16 @@ def test_all_acc_ops_registered(self): acc_ops.var, acc_ops.grid_sample, acc_ops.xl_weight, + acc_ops.clone, + acc_ops.unbind, + acc_ops.group_norm, + acc_ops.long, + acc_ops.full_like, + acc_ops.new_full, + acc_ops.ones_like, + acc_ops.zeros_like, + acc_ops.new_zeros, + acc_ops.index_add, + acc_ops.masked_select, }, ) diff --git a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py index e160626cf2..b5db157663 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py @@ -7,17 +7,12 @@ import torch._dynamo.config import torchvision from functorch.experimental import functionalize -from torch._dynamo.optimizations import backends -from torch._dynamo.optimizations.normalize import normalize_ir from torch.library import Library from torch_tensorrt.fx.lower import compile from torch_tensorrt.fx.tracer.dispatch_tracer.tracer import make_fx from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace -# TODO(ezyang): remove this after we properly support fake example inputs -torch._dynamo.config.DO_NOT_USE_legacy_non_fake_example_inputs = True - torch.manual_seed(0) wrap_lib = Library("wrap", "DEF") @@ -65,19 +60,93 @@ def forward(self, x, y): ref_output = mod(*inputs_new) torch.testing.assert_close(output, ref_output) - def test_resnet18_dynamo(self): + def test_simple(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x, y): + y = y + x + y = y.mul(x) + y = y + x + y = y + x + y = y / x + y = y + x + y = y + x + y = y / x + y = y + x + y = self.relu(y) + return y + + mod = TestModule() + mod = mod.cuda().half().eval() + + def f(x, y): + return mod(x, y) + + inputs = [torch.randn(2, 5), torch.ones(2, 5)] + inputs = [i.cuda().half() for i in inputs] + ref_output = f(*inputs) + + mod = compile( + mod, + inputs, + max_batch_size=100, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=True, + is_aten=True, + ) + output = mod(*inputs) + torch.testing.assert_close(output, ref_output) + + def test_resnet18_aten(self): mod = torchvision.models.resnet18() mod = mod.cuda().half().eval() inputs = [torch.ones(32, 3, 224, 224)] inputs = [i.cuda().half() for i in inputs] - ref_output = mod(*inputs) - torchdynamo.reset() - dynamo_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod) - dynamo_output = dynamo_mod(*inputs) + aten_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=False, + is_aten=True, + ) + aten_output = aten_mod(*inputs) + fx_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=False, + is_aten=False, + ) + fx_output = fx_mod(*inputs) + # Kernel selection is tricky in TRT with big variance as shown below: + # Mismatched elements: 30816 / 32000 (96.3%) + # Greatest absolute difference: 0.05859375 at index (0, 499) (up to 1e-05 allowed) + # Greatest relative difference: 3.293713681986265 at index (0, 142) (up to 0.001 allowed) + # so we choose to use cosine similarity cos_val = torch.nn.functional.cosine_similarity( - dynamo_output.flatten(), ref_output.flatten(), dim=0, eps=1e-4 + aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 ) self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) @@ -224,8 +293,6 @@ def f(x, y): ref_output = f(*inputs) def compile_dispatch(gm, example_inputs): - # after normalization, relu in-place is removed - gm = normalize_ir(gm, example_inputs) # dispatch tracer nargs = len(example_inputs) diff --git a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py index 916394e944..96584c59bd 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py +++ b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py @@ -10,7 +10,11 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.fx.passes import splitter_base from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting +from torch_tensorrt.fx.tools.trt_splitter import ( + create_trt_operator_support, + TRTSplitter, + TRTSplitterSetting, +) from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer ERROR_MSG_NO_ACC_MODULE = "FX split failed: Did not find any ACC submodule!" @@ -625,6 +629,50 @@ def test_splitter(splitter): test_splitter(splitter) + def test_decline_if_input_dtype(self): + operator_support = create_trt_operator_support() + + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.relu(a) + return b + + test_mod = TestModule().cuda().eval() + x = torch.randn(2, 3) + mod = acc_tracer.trace(test_mod, [x]) + settings = TRTSplitterSetting() + settings.min_acc_module_size = 0 + # nodes w/ float16 input should be lowered + splitter = TRTSplitter( + mod, + (x.half().cuda(),), + operator_support, + settings, + ) + split_results_half = splitter.generate_split_results() + self.assertTrue(len(split_results_half), 1) + self.assertEqual( + dict(split_results_half.split_module.named_children()).keys(), + {"_run_on_acc_0"}, + ) + + # nodes w/ float64 input should not be lowered + mod = acc_tracer.trace(test_mod, [x]) + splitter = TRTSplitter( + mod, + (x.double().cuda(),), + operator_support, + settings, + ) + + split_results_double = splitter.generate_split_results() + + self.assertTrue(len(split_results_double), 1) + self.assertEqual( + dict(split_results_double.split_module.named_children()).keys(), + {"_run_on_gpu_0"}, + ) + class TestSplitComplexGraph(TestCase): """ diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index 30d6dc96c9..6d883a4f62 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -3,6 +3,8 @@ import unittest from typing import Callable, List, Optional, Set, Tuple +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt import torch import torch.fx @@ -257,6 +259,8 @@ def run_test( pass_tracer = chain_passes(*apply_passes) mod = pass_tracer(mod, inputs) + if trt.__version__ >= "8.6": + test_implicit_batch_dim = False if test_implicit_batch_dim: interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) super().run_test( diff --git a/py/torch_tensorrt/fx/tools/model_packager.py b/py/torch_tensorrt/fx/tools/model_packager.py index 0ef0ff05a4..b86c21e809 100644 --- a/py/torch_tensorrt/fx/tools/model_packager.py +++ b/py/torch_tensorrt/fx/tools/model_packager.py @@ -51,12 +51,34 @@ def generate_standalone_repro( "", "import torch", "from torch import nn", + ] + code = str(model.code) + + import_modules = set() + import_map = { + "torch_tensorrt_fx_tracer_acc_tracer_acc_ops": "torch_tensorrt.fx.tracer.acc_tracer.acc_ops", + "torch_tensorrt_fx_passes_lower_basic_pass": "torch_tensorrt.fx.passes.lower_basic_pass", + } + for line in code.split("\n"): + for k, v in import_map.items(): + if k in line: + sub_string = line.split("(")[0].split()[-1] + if sub_string.startswith(k): + mod = sub_string.replace(k + "_", "") + import_modules.add( + "from " + v + " import " + mod + " as " + sub_string + ) + for mod in sorted(import_modules): + lines.append(mod) + + lines += [ "", "", "class ExportedModule(nn.Module):", f"{INDENT}def __init__(self):", f"{INDENT * 2}super().__init__()", ] + for k, v in model._holder.named_parameters(): shape = ", ".join([str(i) for i in v.shape]) rand_func = "randn" if torch.is_floating_point(v) else "randint" @@ -64,7 +86,6 @@ def generate_standalone_repro( lines.append( f"{INDENT * 2}self.{k} = nn.Parameter(torch.{rand_func}({int_range}{shape}, dtype={v.dtype}))" ) - code = str(model.code) def dump(f): f.write(prelude) diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index bea925453f..aa3d930bfb 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -34,8 +34,9 @@ def create_trt_operator_support( return ops.chain( ops.OpSupports.decline_if_node_in_names(exclude_support_node_name), - # 1. Node is not supported if it has args with int64 dtype: + # 1. Node is not supported if it has args with int64 or float64 dtype: ops.OpSupports.decline_if_input_dtype(torch.int64), + ops.OpSupports.decline_if_input_dtype(torch.float64), # 2. Node is supported if it has TRT converter: supported_if_converter_registered, ) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py index 55cb39d4a5..1271b6f30c 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py @@ -69,6 +69,7 @@ class NormalizationInfo(NamedTuple): List[Union[Tuple[str, str, bool], Tuple[str, str]]] ] needs_shapes_for_normalization: bool + skip_normalization_if_none: bool # Dict from (op, target) to NormalizationInfo for that op. @@ -88,6 +89,7 @@ def _insert_fun( ] = None, needs_shapes_for_normalization=False, allow_normalize_from_torch_package=False, + skip_normalization_if_none=False, ): if op_and_target[0] == "call_function": assert callable(op_and_target[1]) @@ -129,6 +131,7 @@ def _insert_fun( custom_mapping_fn=custom_mapping_fn, kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty, needs_shapes_for_normalization=needs_shapes_for_normalization, + skip_normalization_if_none=skip_normalization_if_none, ) _normalization_dict[op_and_target] = norm_info @@ -217,6 +220,7 @@ def register_custom_acc_mapper_fn( ], needs_shapes_for_normalization=False, allow_normalize_from_torch_package=False, + skip_normalization_if_none=False, ): def insert(custom_mapping_fn: Callable): _insert_fun( @@ -225,6 +229,7 @@ def insert(custom_mapping_fn: Callable): arg_replacement_tuples=arg_replacement_tuples, # type: ignore[arg-type] needs_shapes_for_normalization=needs_shapes_for_normalization, allow_normalize_from_torch_package=allow_normalize_from_torch_package, + skip_normalization_if_none=skip_normalization_if_none, ) return custom_mapping_fn @@ -363,12 +368,18 @@ def normalize_to_acc_op( if normalization_info.custom_mapping_fn is not None: # For custom mapping, the normalized_kwargs are used for the original op, # i.e. *before* custom acc_ops normalization. Do that now. + if normalization_info.skip_normalization_if_none: + original_args = node.args + original_kwargs = node.kwargs node.args = normalized_args node.kwargs = normalized_kwargs new_node = normalization_info.custom_mapping_fn(node, mod) # If a new node is returned then use it to replace the old node. Otherwise # the custom mapping function did its own replacement, so return early. if new_node is None: + if normalization_info.skip_normalization_if_none: + node.args = original_args + node.kwargs = original_kwargs return else: # If there's kwargs_to_move_to_acc_out_ty then use it to setup acc_out_ty in diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 8abad9c509..1ed25d66f1 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -1,9 +1,10 @@ # encoding: utf-8 +import logging import operator import warnings import torch # isort:skip -from typing import cast, Iterable, List, Sequence +from typing import cast, Iterable, List, Optional, Sequence import torch.nn as nn from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata @@ -16,6 +17,8 @@ ) from .acc_op_properties import AccOpProperty, register_acc_op_properties +logger: logging.Logger = logging.getLogger(__name__) + this_arg_is_optional = True move_to_qparams = True dont_move_to_qparams = False @@ -161,6 +164,18 @@ def max_pool3d( ) +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.normalize)) +@register_acc_op +def normalize(*, input, p, dim, eps, out): + return nn.functional.normalize( + input=input, + p=p, + dim=dim, + eps=eps, + out=out, + ) + + @register_acc_op_mapping( op_and_target=("call_function", nn.functional.adaptive_avg_pool2d) ) @@ -364,9 +379,10 @@ def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: getitem_node.meta = node.meta.copy() return getitem_node - assert ( - input_obj_type == torch.Tensor - ), f"Expected torch.Tensor type for {input_obj_type}" + assert input_obj_type in [ + torch.Tensor, + torch.nn.parameter.Parameter, + ], f"Expected torch.Tensor type for {input_obj_type}" assert ( attr_name == "shape" or attr_name == "device" or attr_name == "dtype" ), f"Only supporting shape, device and dtype getattr for now, not {attr_name}" @@ -417,7 +433,10 @@ def tensor_size_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: @register_acc_op_mapping(op_and_target=("call_method", "add")) @register_acc_op def add(*, input, other): - return input + other + if not (isinstance(input, torch.Tensor) or isinstance(other, torch.Tensor)): + return operator.add(input, other) + else: + return input + other @register_acc_op_properties(AccOpProperty.unary) @@ -442,14 +461,27 @@ def tile(*, input, dims): ("input", "input"), ("*", "sizes"), ], + skip_normalization_if_none=True, ) -def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: +def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> Optional[torch.fx.Node]: """ Map repeat to tile. """ with node.graph.inserting_before(node): inputs = node.kwargs["input"] dims = node.kwargs["sizes"] + # Skip repeat mapping when the list of dims is not all ints (ie. contains + # some calculated value). torch.tile cannot support cases where dims + # are Proxy nodes + if ( + isinstance(dims, (list, tuple)) + and len(dims) > 0 + and not all(isinstance(x, int) for x in dims) + ): + logger.info( + "Not mapping repeat to an acc op. We can't handle variable dims." + ) + return new_node = node.graph.create_node( "call_function", tile, @@ -468,6 +500,7 @@ def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: ("dim", "dim", this_arg_is_optional), ("output_size", "output_size", this_arg_is_optional), ], + skip_normalization_if_none=True, ) @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.repeat_interleave), @@ -477,14 +510,17 @@ def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: ("dim", "dim", this_arg_is_optional), ("output_size", "output_size", this_arg_is_optional), ], + skip_normalization_if_none=True, ) def repeat_interleave_mapper(node: torch.fx.Node, _: nn.Module): input_node = node.kwargs["input"] repeats = cast(int, node.kwargs["repeats"]) dim = node.kwargs["dim"] - assert ( - type(repeats) is int - ), "We currently only support `repeat_interleave` with int repeats" + if not (type(repeats) is int): + logger.info( + "Not mapping repeat_interleave to an acc op. We currently only support `repeat_interleave` with int repeats" + ) + return rank = node.meta["tensor_rank"] if dim is None: repeat_dim = rank - 1 @@ -825,6 +861,18 @@ def matmul(*, input, other): op_and_target=("call_function", nn.functional.dropout), arg_replacement_tuples=[("input", "input")], ) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", nn.functional.dropout1d), + arg_replacement_tuples=[("input", "input")], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", nn.functional.dropout2d), + arg_replacement_tuples=[("input", "input")], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", nn.functional.dropout3d), + arg_replacement_tuples=[("input", "input")], +) @register_custom_acc_mapper_fn( op_and_target=("call_method", "detach"), arg_replacement_tuples=[("input", "input")] ) @@ -1055,7 +1103,10 @@ def rescale_quantize_per_channel(*, input, acc_out_ty=None): @register_acc_op_mapping(op_and_target=("call_method", "sub")) @register_acc_op def sub(*, input, other): - return input - other + if not (isinstance(input, torch.Tensor) or isinstance(other, torch.Tensor)): + return operator.sub(input, other) + else: + return input - other @register_acc_op_properties(AccOpProperty.pointwise) @@ -1067,6 +1118,19 @@ def mul(*, input, other): return input * other +@register_acc_op_mapping( + op_and_target=("call_function", torch.ops.aten.threshold_backward.default), + arg_replacement_tuples=[ + ("grad", "grad"), + ("self", "input"), + ("threshold", "threshold"), + ], +) +@register_acc_op +def threshold_backward(*, grad, input, threshold): + return torch.ops.aten.threshold_backward.default(grad, input, threshold) + + @register_custom_acc_mapper_fn( op_and_target=("call_method", "div"), arg_replacement_tuples=[ @@ -1367,7 +1431,7 @@ def std_mapper(node, mod): mean_kwargs = { "input": input_node, "dim": dim, - "keepdim": keepdim, + "keepdim": True, } mean_node = node.graph.call_function(mean, kwargs=mean_kwargs) mean_node.meta["type"] = torch.Tensor @@ -1385,7 +1449,7 @@ def std_mapper(node, mod): } pow_node = node.graph.call_function(pow, kwargs=pow_kwargs) pow_node.meta["type"] = torch.Tensor - # sum(pow(X-mean(X))))/N + # mean(pow(X-mean(X))) post_mean_kwargs = { "input": pow_node, "dim": dim, @@ -1393,7 +1457,7 @@ def std_mapper(node, mod): } post_mean_node = node.graph.call_function(mean, kwargs=post_mean_kwargs) post_mean_node.meta["type"] = torch.Tensor - # sqrt(sum(pow(X-mean(X))))/N) + # sqrt( mean(pow(X-mean(X))) ) sqrt_kwargs = { "input": post_mean_node, } @@ -1653,12 +1717,26 @@ def fmod(*, input, other): @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sigmoid)) +@register_acc_op_mapping( + op_and_target=("call_function", torch.ops.aten.sigmoid.default) +) @register_acc_op_mapping(op_and_target=("call_method", "sigmoid")) @register_acc_op def sigmoid(*, input): return torch.sigmoid(input=input) +@register_acc_op_properties(AccOpProperty.pointwise) +@register_acc_op_mapping( + op_and_target=("call_function", torch.ops.aten.sigmoid_backward.default) +) +@register_acc_op_mapping(op_and_target=("call_method", "sigmoid_backward")) +@register_acc_op +# first argument's name needs to be input to use same_shape_and_dtype_as_input +def sigmoid_backward(*, input, dest): + return torch.ops.aten.sigmoid_backward(grad_output=input, output=dest) + + @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sinh)) @register_acc_op @@ -1716,6 +1794,23 @@ def log(*, input): return torch.log(input=input) +@register_acc_op_properties(AccOpProperty.unary) +@register_acc_op_mapping( + op_and_target=("call_function", torch.nn.functional.log_softmax), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim"), + ("dtype", "dtype", this_arg_is_optional), + ], +) +@register_acc_op +def log_softmax(*, input, dim, dtype=None): + """ + _stacklevel are ignored here. + """ + return torch.nn.functional.log_softmax(input=input, dim=dim, dtype=dtype) + + @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sqrt)) @register_acc_op_mapping(op_and_target=("call_method", "sqrt")) @@ -1773,7 +1868,10 @@ def abs(*, input): @register_acc_op_mapping(op_and_target=("call_function", torch.neg)) @register_acc_op def neg(*, input): - return torch.neg(input=input) + if not isinstance(input, torch.Tensor): + return operator.neg(input) + else: + return torch.neg(input=input) @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @@ -2282,6 +2380,7 @@ def embedding_bag_4bit_rowwise_offsets( @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sin)) +@register_acc_op_mapping(op_and_target=("call_method", "sin")) @register_acc_op def sin(*, input): return torch.sin(input=input) @@ -2289,6 +2388,7 @@ def sin(*, input): @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.cos)) +@register_acc_op_mapping(op_and_target=("call_method", "cos")) @register_acc_op def cos(*, input): return torch.cos(input=input) @@ -2314,13 +2414,53 @@ def getitem(*, input, idx): return input[idx] -@register_acc_op_mapping(op_and_target=("call_function", torch.nan_to_num)) -@register_acc_op_mapping(op_and_target=("call_method", "nan_to_num")) @register_acc_op -def nan_to_num(*, input, nan=0.0, posinf=None, neginf=None): +def nan_to_num(*, input, nan=None, posinf=None, neginf=None): return torch.nan_to_num(input, nan=nan, posinf=posinf, neginf=neginf) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.nan_to_num), + arg_replacement_tuples=[ + ("input", "input"), + ("nan", "nan"), + ("posinf", "posinf"), + ("neginf", "neginf"), + ], +) +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "nan_to_num"), + arg_replacement_tuples=[ + ("input", "input"), + ("nan", "nan"), + ("posinf", "posinf"), + ("neginf", "neginf"), + ], +) +def custom_nan_to_num_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: + nan_val, posinf, neginf = ( + node.kwargs["nan"], + node.kwargs["posinf"], + node.kwargs["neginf"], + ) + if nan_val is None: + nan_val = 0 + if posinf is None: + posinf = torch.finfo(torch.float16).max + if neginf is None: + neginf = torch.finfo(torch.float16).min + kwargs = { + "input": node.kwargs["input"], + "nan": nan_val, + "posinf": posinf, + "neginf": neginf, + } + with node.graph.inserting_before(node): + new_node = node.graph.call_function(nan_to_num, kwargs=kwargs) + new_node.meta = node.meta.copy() + return new_node + + @register_acc_op_properties(AccOpProperty.unary) @register_acc_op_mapping( op_and_target=("call_method", "expand"), @@ -2422,7 +2562,10 @@ def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: @register_acc_op def reshape(*, input, acc_out_ty=None): assert acc_out_ty is not None - return input.reshape(acc_out_ty.shape) + shape = acc_out_ty.shape + if len(shape) == 1 and not isinstance(shape[0], int): + return input.reshape(shape[0]) + return input.reshape(shape) @register_custom_acc_mapper_fn( @@ -2977,22 +3120,6 @@ def tensor_split(*, input, indices_or_sections, dim=0): ) -@register_acc_op_mapping( - op_and_target=("call_method", "new_ones"), - arg_replacement_tuples=[ - ("input", "input"), - ("size", "size"), - ("dtype", "dtype", this_arg_is_optional), - ("device", "device", this_arg_is_optional), - ("requires_grad", "requires_grad", this_arg_is_optional), - ], -) -@register_acc_op -def new_ones(*, input, size, dtype=None, device=None, requires_grad=False): - assert requires_grad is False, f"requires_grad != False, it is {requires_grad}" - return input.new_ones(size, dtype=dtype, device=device) - - @register_acc_op_mapping( op_and_target=("call_method", "new_empty"), arg_replacement_tuples=[ @@ -3080,33 +3207,6 @@ def xl_weight(weight_id: str, metadata: TensorMetadata, proxy_shape, dtype): return torch.zeros(proxy_shape, dtype=dtype) -@register_custom_acc_mapper_fn( - op_and_target=("call_function", torch.nn.functional.log_softmax), - arg_replacement_tuples=[ - ("input", "input"), - ("dim", "dim"), - ("dtype", "dtype"), - ], -) -def log_softmax_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: - with node.graph.inserting_after(node): - - softmax_kwargs = { - "input": node.kwargs["input"], - "dim": node.kwargs["dim"], - "dtype": node.kwargs["dtype"], - } - softmax_node = node.graph.call_function(softmax, kwargs=softmax_kwargs) - softmax_node.meta = node.meta.copy() - - with softmax_node.graph.inserting_after(softmax_node): - log_kwargs = {"input": softmax_node} - log_node = node.graph.call_function(log, kwargs=log_kwargs) - log_node.meta = node.meta.copy() - - return log_node - - @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.nn.functional.softplus), arg_replacement_tuples=[ @@ -3256,6 +3356,124 @@ def baddbmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: return add_node +@register_acc_op_mapping(op_and_target=("call_function", torch.clone)) +@register_acc_op_mapping(op_and_target=("call_method", "clone")) +@register_acc_op +def clone(*, input): + return torch.clone(input) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.unbind)) +@register_acc_op +def unbind(*, input, dim=0): + return torch.unbind(input, dim=dim) + + +@register_acc_op_mapping( + op_and_target=("call_function", torch.nn.functional.group_norm), + arg_replacement_tuples=[ + ("input", "input"), + ("num_groups", "num_groups"), + ("weight", "weight"), + ("bias", "bias"), + ("eps", "eps"), + ], +) +@register_acc_op +def group_norm(*, input, num_groups, weight=None, bias=None, eps=1e-05): + return torch.nn.functional.group_norm( + input, num_groups, weight=weight, bias=bias, eps=eps + ) + + +@register_acc_op_mapping(op_and_target=("call_method", "long")) +@register_acc_op +def long(*, input): + return input.long() + + +@register_acc_op_mapping( + op_and_target=("call_method", "new_full"), + arg_replacement_tuples=[ + ("input", "input"), + ("size", "size"), + ("fill_value", "fill_value"), + ("dtype", "dtype", this_arg_is_optional), + ("device", "device", this_arg_is_optional), + ("requires_grad", "requires_grad", this_arg_is_optional), + ], +) +@register_acc_op +def new_full(*, input, size, fill_value, dtype=None, device=None, requires_grad=False): + return input.new_full(size, fill_value=fill_value, dtype=dtype, device=device) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.full_like)) +@register_acc_op +def full_like(*, input, fill_value, dtype=None, device=None): + return torch.full_like( + input=input, fill_value=fill_value, dtype=dtype, device=device + ) + + +@register_acc_op_mapping( + op_and_target=("call_method", "new_ones"), + arg_replacement_tuples=[ + ("input", "input"), + ("size", "size"), + ("dtype", "dtype", this_arg_is_optional), + ("device", "device", this_arg_is_optional), + ("requires_grad", "requires_grad", this_arg_is_optional), + ], +) +@register_acc_op +def new_ones(*, input, size, dtype=None, device=None, requires_grad=False): + assert requires_grad is False, f"requires_grad != False, it is {requires_grad}" + return input.new_ones(size, dtype=dtype, device=device) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.ones_like)) +@register_acc_op +def ones_like(*, input, dtype=None, device=None): + return torch.ones_like(input=input, dtype=dtype, device=device) + + +@register_acc_op_mapping( + op_and_target=("call_method", "new_zeros"), + arg_replacement_tuples=[ + ("input", "input"), + ("size", "size"), + ("dtype", "dtype", this_arg_is_optional), + ("device", "device", this_arg_is_optional), + ("requires_grad", "requires_grad", this_arg_is_optional), + ], +) +@register_acc_op +def new_zeros(*, input, size, dtype=None, device=None, requires_grad=False): + return input.new_zeros(size, dtype=dtype, device=device) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.zeros_like)) +@register_acc_op +def zeros_like(*, input, dtype=None, device=None): + return torch.zeros_like(input=input, dtype=dtype, device=device) + + +@register_acc_op_mapping( + op_and_target=("call_method", "index_add_"), +) +@register_acc_op_mapping(op_and_target=("call_function", torch.index_add)) +@register_acc_op +def index_add(*, input, dim, index, source, alpha=1): + return torch.index_add(input, dim, index, source, alpha=alpha) + + +@register_acc_op_mapping(op_and_target=("call_function", torch.masked_select)) +@register_acc_op +def masked_select(*, input, mask): + return torch.masked_select(input=input, mask=mask) + + ############################################################################### # Set ops as side-effectul, this prevents them from being optimized away or diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index c3a5ad850e..bc8c613fee 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -581,9 +581,15 @@ def _replace_transpose_last_dims(gm: torch.fx.GraphModule): gm.recompile() -def rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list): +def rewriter_base_trace( + mod, + ast_rewriter_allow_list, + leaf_module_list, + concrete_args: Optional[Dict[str, Any]] = None, +): rewritten_graph, rewritten_mod = AccRewritingTracer().trace( mod, + concrete_args, ast_rewriter_allow_list=ast_rewriter_allow_list, leaf_module_list=leaf_module_list, ) @@ -605,6 +611,8 @@ def trace( acc_normalization_block_list: Optional[ Set[Tuple[str, Union[str, Callable]]] ] = None, + dont_retrace_gm: bool = False, + concrete_args: Optional[Dict[str, Any]] = None, ) -> torch.fx.GraphModule: """ Performs tracing and arg normalization specialized for accelerator lowering. @@ -653,6 +661,10 @@ def trace( normalization to. Just like the register_acc_op decarators, the target can either be a string (e.g. for op == "call_method") or a callable (e.g. for op == "call_function"). + + dont_retrace_gm (bool): Optional bool for whether to re-trace the provided + module if it's a graph module already. + """ if mod.training: warnings.warn( @@ -664,7 +676,12 @@ def trace( assert isinstance(sample_inputs, (list, tuple)) # Rewrite the module to make it symbolic traceable, and then trace it. - traced = rewriter_base_trace(mod, ast_rewriter_allow_list, leaf_module_list) + if dont_retrace_gm and isinstance(mod, torch.fx.GraphModule): + traced = mod + else: + traced = rewriter_base_trace( + mod, ast_rewriter_allow_list, leaf_module_list, concrete_args + ) # Now remove all assertions and exceptions if requested. if remove_assertions: diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py index ab3207925f..75418034cb 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py @@ -110,7 +110,8 @@ def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None): If `header` is provided then it's included in the printed string. """ ops_and_counts: Dict[Callable, int] = {} - placeholder_count = get_attr_count = call_method_count = call_module_count = 0 + placeholder_count = get_attr_count = 0 + call_method_count = call_module_count = output_count = 0 for node in gm.graph.nodes: if node.op == "call_function": ops_and_counts[node.target] = ops_and_counts.get(node.target, 0) + 1 @@ -141,7 +142,8 @@ def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None): # easier to parse. pretty_ops_and_counts: List[Tuple[str, int]] = [] for op, count in ops_and_counts.items(): - pretty_ops_and_counts.append((_get_qualified_name(op), count)) + name = strip_module_prefixes(_get_qualified_name(op)) + pretty_ops_and_counts.append((name, count)) pretty_ops_and_counts.sort() for op_str, count in pretty_ops_and_counts: model_info_str += f"> {op_str}: {count}\n" @@ -149,6 +151,14 @@ def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None): return model_info_str +def strip_module_prefixes(op_name): + return ( + op_name.replace("torch_tensorrt.fx.tracer.acc_tracer.", "") + .replace("glow.fb.fx.acc_tracer.", "") + .replace("glow.fb.fx.", "") + ) + + def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: """ Make sure the name is unique (in a module) and can represents an attr. diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index e60c8f8d13..edcce20d65 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -42,24 +42,28 @@ def __init__( capture_scalar_outputs: bool = True, guard_nn_modules: bool = True, dynamic_shapes: bool = True, + specialize_int: bool = True, verbose: bool = True, ) -> None: self.capture_scalar_outputs = capture_scalar_outputs self.guard_nn_modules = guard_nn_modules self.dynamic_shapes = dynamic_shapes + self.specialize_int = specialize_int self.verbose = verbose def activate(self) -> None: torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs torchdynamo.config.guard_nn_modules = self.guard_nn_modules torchdynamo.config.dynamic_shapes = self.dynamic_shapes + torchdynamo.config.specialize_int = self.specialize_int torchdynamo.config.verbose = self.verbose def deactivate(self) -> None: torchdynamo.config.capture_scalar_outputs = True torchdynamo.config.guard_nn_modules = True torchdynamo.config.dynamic_shapes = True + torchdynamo.config.specialize_int = True torchdynamo.config.verbose = True diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 79779f604e..a8a3851655 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Callable +from typing import List, Optional, Callable from packaging import version # @manual=//deeplearning/trt/python:py_tensorrt @@ -19,6 +19,20 @@ class LowerPrecision(Enum): FP32 = "fp32" FP16 = "fp16" INT8 = "int8" + BF16 = "bf16" + + @staticmethod + def from_str(label: str) -> Optional["LowerPrecision"]: + if label in ("fp32", "float32", "float", "torch.float32"): + return LowerPrecision.FP32 + elif label in ("fp16", "float16", "half", "torch.half", "torch.float16"): + return LowerPrecision.FP16 + elif label in ("int8"): + return LowerPrecision.INT8 + elif label in ("bf16", "bfloat16", "torch.bfloat16"): + return LowerPrecision.BF16 + else: + return None def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType: