From 07238c890e511a74fdc0bafe10c0fd342b3d159b Mon Sep 17 00:00:00 2001 From: Wei Date: Wed, 3 Aug 2022 21:48:05 -0700 Subject: [PATCH] Changes done internally at Facebook (#1221) 6703b98dff0695d91026f057b951dba1355825fa Shreyansh Prajapati Test dynamic shape support for acc_ops.prod c822345d6d673e1653c2208435e34ab400bada3d Jason Park Add support for generic torch ops to be used in training. e5758602a0592d6c2b71d6d66a0398c4dd9b5e20 Shreyansh Prajapati Test dynamic shape support for repeat interleave c13c633f04df162500eed477c0569eb2b81eb070 Shreyansh Prajapati Test dynamic shape support for reduce ops 863476cf43b210922b88585b8f196dd84fbebb56 Shreyansh Prajapati Test dynamic shape support for acc_op.convolution 68dff39793e5c30c20010919a855bb3d984015d7 Ruichao Xiao [fbcode][GPU][DHEN]fuse split squeeze cat as reshape f8b920769507ebd2ff02419b4aece25451298a95 Ruichao Xiao [fbcode][DHEN][GPU] reorder and merge cats whose input is a sublist of another cat 5b6a8d2d6be979983a52ac96225fefb510c3817c Andrew Or [Quant][fx] Rename convert_to_reference to convert_to_reference_fx 996a0e080b8a8bc0b292a7c2ac92f41f6db33a2e Shreyansh Prajapati Test dynamic shape support for acc_op.expand 084631fe74b304fbb9481ca15fd452a3714fb1b8 Shreyansh Prajapati Test dynamic shape support for acc_op.to_dtype b3195e76329ccddbb5c4640cfa884d0e457d2d34 Shreyansh Prajapati Test dynamic shape support for std a5d964e62bdf769cf8c2e67321138b33e1f524a7 Shreyansh Prajapati Test dynamic shape support for acc_op.tile 3d33d45b2fc7f10f25c22946ba474b227e4b6529 Shreyansh Prajapati Test dynamic shape support for squeeze 09085abf63d7e7732e2cd66e600e8afc6d58964f Shreyansh Prajapati Test dynamic shape support for acc_op.topk 65edc7ea12899e9bd2af42c890a64de853d9b7fe Huamin Li temporarily skip gelu tests d11e521f9b90554ca86912a49920afa4406bb40d Shirong Wu Suppress accuracy check for remove_reshape_with_batch_size_change 6d948298b2327d229e010a34f1c221b11d2eb504 Ankur Singla [GPULowering] Suppress accuracy check for fuse_unsqueeze_cat_sum e780b647fc9571b77d9f41c963041a6ac3d66f33 Janet Yang Lower xrayvideo2022 to fx2trt 433c7207fef16b1fdff985546ea969c39fa83e7c generatedunixname89002005287564 [Codemod][Remove @noautodeps and @autodeps-skip tags] deeplearning/trt 1/2 66fdb65cffa925660c77b4758388399db3cbfe48 Scott Wolchok [fx2ait] Minor Python cleanup in acc_ops_getitem 188132ecb2c19bcbf83cb2dc381f6e3798629f87 generatedunixname89002005324833 [AutoAccept][Codemod][FBSourceBuckFormatLinter] Daily `arc lint --take BUCKFORMAT` 4536bae4686dd01f2149541ea7fb330e178a4969 Wei Wei [fx2trt] support sub 064602e666f86c110d931cd90a8536112a19b4ad Shreyansh Prajapati Test dynamic shape support for acc_ops.interpolate 9dfd0ee0cecb1975e3f53c44de237d67ca443ec5 Shreyansh Prajapati Test dynamic shape support for unary_ops 39b9efad8d5d82463a2016d135c0cf277de1c3c6 Shreyansh Prajapati Test dynamic shape support for unsqueeze 2bb17667d1dabc95391950426fc1f921eb3d0959 Shreyansh Prajapati Test dynamic shape support for acc_ops.split 64dfb7b096686cb2fd33197340dc72f30d525456 Shirong Wu Group LN trt plugin 438f670e28df59b0734baa092a514fba3d75eb4f Shreyansh Prajapati Test dynamic shape support for acc_ops.avgpool df0fe32dae4343827bd9b37b72daae761b02f228 Shreyansh Prajapati Test dynamic shape support for acc_ops masked fill 44fe735d3493ea2d05a56b49093e4a23dd63a98e Shreyansh Prajapati Test dynamic shaope support for acc_ops.pad 4f931acca706d8ce79045ceafef2ea0486609149 Wei Wei [fx2trt] torch.max dynamic shape test bf6f6cbe217d26a95ca9122574adf7de3966db9e Shreyansh Prajapati Change the name of the test from full_reduce to dim_reduce 1c5680ed107d9206f3514eff4069a3f6c870ba8c Shreyansh Prajapati Test dynamic shape support for acc_ops.type_as 33e4c175a4f5fec78ac0b1c8eb262ca777c7aaba Shreyansh Prajapati Test dynamic shape support for acc_ops.min f37be34bcef9716080b8bafbd1f4ad72e412c44c Wei Wei [fx2trt] plugin for grid_sample 57b5cc6a0f4839686ae360361a3a13b424794ee7 generatedunixname89002005367269 [AutoAccept][Codemod][FBSourceBlackLinter] Daily `arc lint --take BLACK` eb741cc5e5a7babdc94e72d411670905f54da3e0 Shreyansh Prajapati Updated the dynamic shape support for narrow op 521c36b96a14741ae89d7af6cbb658120bcec2ea Shreyansh Prajapati Removing the comment for 4 dims dynamic shape support after analysis e947343375967fe9efb0a16fdb9f63bff1449328 Shreyansh Prajapati Updated the pad test for dynamic batch for analysis 3d64087014e91bc301a315eae43683b1aa2b66bc Oleg Khabinov [trt_bc] Some improvements dfd937a56fa01aca88a89b46176befdac4c202c4 Shreyansh Prajapati Updated the test for as_strided op for analysis 11d76d0420dcaa4bb8890dcdeb86b6e534af831c Bangsheng Tang [gpu][infer] replace fx2trt_layer_norm with fbgemm layer_norm 932046ff6ea6dead114c0222b23ca3854690cffa Wei Wei [fx2trt] bridge the dynamic batch and fixed shape f911463393d8a671cfee6de6d1b5ef4d4f3991a6 Shirong Wu group swish LN plugin ea65970f23dd7a468e5bc43240f2a9bfa07c9b3b Shirong Wu Create backend specific lower pass 38183e4a724e5514db2be7193cf4897b59759252 Alex Beloi [fx] run acc_linter.lint in acc_tracer.trace 088abb6a790a62ca9f8515298a54117cc7fa31d4 Alex Beloi [fx] re-add pointwise property to acc_ops.clamp 9905c34f2bd28e9b64f10336f9ac326cc39eb60d Oleg Khabinov [trt] Comment out torch.ops.fbgemm dependency in TRT converters 8252e779476d2ff22ad78185af97a526b2f70fe3 Alex Beloi [fx] add operator test suite to test_acc_tracer.py 7b93a89c903bc0b6c59efb73a510c3dce8ef793a Shirong Wu Add option for lower and trt_splitter e08dabcbcd8c3e8ae92484e14cf07bb26993a8d6 Wei Wei [fx2trt] convert print to logging 3d61dc169b8a7dd1aecad35891a628e44e2c5a02 Shreyansh Prajapati Readme.md file for dynamic shape support 6337f62a38d73b08aa68762a4583d974d60a21b4 Ying Zhang fx2ait benchmark fixes dea33c518cae303e77ce80fdb5bf6d1fafd82ff9 Jason Park Prepare for group linear. d259d5a16d6203f0e4975ee5ded166e8b24d39d8 Dmitry Kulikovsky [re] migrate from gpu-remote-execution-1 to gpu-remote-execution 339f34a3951908d744a2c185b23c83ed427947bf Yinghai Lu [fx2trt] Add an option for tactic sources 18c65779e10d22e83f8b23683adb5c6f58730bef Shirong Wu integrate plugin 300b0fdd448e3c4637ff8454382809d71123c149 Wei Wei [fx2trt] fix bug in log format --- py/torch_tensorrt/fx/fx2trt.py | 4 ++++ py/torch_tensorrt/fx/lower.py | 1 + py/torch_tensorrt/fx/lower_setting.py | 5 ++++- .../fx/passes/lower_pass_manager_builder.py | 10 ++++------ py/torch_tensorrt/fx/tools/trt_splitter.py | 2 +- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 7deed3e470..ca16e2ad9b 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -163,6 +163,7 @@ def run( algorithm_selector=None, timing_cache=None, profiling_verbosity=None, + tactic_sources=None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -245,6 +246,9 @@ def run( builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE) builder_config.algorithm_selector = algorithm_selector + if tactic_sources is not None: + builder_config.set_tactic_sources(tactic_sources=tactic_sources) + engine = self.builder.build_engine(self.network, builder_config) assert engine diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 470f78c407..387b4db841 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -120,6 +120,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: profiling_verbosity=trt.ProfilingVerbosity.DETAILED if self.lower_setting.verbose_profile else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, + tactic_sources=self.lower_setting.tactic_sources, ) # Update timing cache file if needed diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index d3f2cc9a14..c1d02229e3 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -1,5 +1,5 @@ import dataclasses as dc -from typing import List, Optional, Sequence, Set, Type +from typing import List, Optional, Set, Type from torch import nn from torch.fx.passes.pass_manager import PassManager @@ -68,6 +68,8 @@ class LowerSetting(LowerSettingBasic): opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is only used by explicit batch dim with dynamic shape mode. dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension. + tactic_sources: tactic sources for TensorRT kernel selection. Default to None, + meaning all possible tactic sources. """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -87,3 +89,4 @@ class LowerSetting(LowerSettingBasic): preset_lowerer: str = "" opt_profile_replica: int = 1 dynamic_batch: bool = True + tactic_sources: Optional[int] = None diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 937737b60d..047ceb3ad2 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -148,7 +148,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - _LOGGER.info("Now lowering submodule", submod_name) + _LOGGER.info(f"Now lowering submodule {submod_name}") lowering_start_time = datetime.datetime.now() self.lower_setting.input_specs = generate_input_specs( @@ -166,8 +166,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: submod_name, lowered_module, submod_inputs ) _LOGGER.info( - f"Lowering submodule {submod_name} elapsed time", - datetime.datetime.now() - lowering_start_time, + f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" ) return split_result.split_module @@ -184,7 +183,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - _LOGGER.info("Now lowering submodule", submod_name) + _LOGGER.info(f"Now lowering submodule {submod_name}") lowering_start_time = datetime.datetime.now() lowered_module = self._lower_func( @@ -195,8 +194,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: submod_name, lowered_module, submod_inputs ) _LOGGER.info( - f"Lowering submodule {submod_name} elapsed time", - datetime.datetime.now() - lowering_start_time, + f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" ) return split_result.split_module diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index 7fbca8d99a..28279a117d 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -49,7 +49,7 @@ def __init__(self): # During split, we'll split out the operators that # don't support the batch dim. self.use_implicit_batch_dim: bool = True - self.exclude_support_node_name: set = set() + self.exclude_support_node_name: set = set(self.op_lowering_disallow_list) class TRTSplitter(splitter_base._SplitterBase):