From b547a33de1d5e018db19fde388244f7f4395f896 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Wed, 22 Jun 2022 14:08:57 -0700 Subject: [PATCH 01/10] compile interface --- py/torch_tensorrt/_compile.py | 145 +++++++++--------- py/torch_tensorrt/fx/example/lower_example.py | 6 +- py/torch_tensorrt/fx/utils.py | 95 +++++++++++- 3 files changed, 170 insertions(+), 76 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 52ca551142..13198398c8 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -108,78 +108,79 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums ts_mod = torch.jit.script(module) return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) elif target_ir == _IRType.fx: - from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer - from torch_tensorrt.fx import InputTensorSpec - from torch_tensorrt.fx import TRTInterpreter - from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem - from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter - from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting - from torch_tensorrt.fx.trt_module import TRTModule - from torch_tensorrt.fx.utils import LowerPrecision - acc_model = acc_tracer.trace(module, inputs) - - splitter_setting = TRTSplitterSetting() - splitter_setting.use_implicit_batch_dim = False - splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) - splitter.node_support_preview() - split_mod = splitter() - num_piece = 0 - for name, _ in split_mod.named_children(): - print(f"graph is split into {name}") - num_piece += 1 - - # if the graph module is split into pieces larger than 8, we consider its perf - # is not good and fall back to non-TRT - if num_piece > 8: - print( - f"The graph module is split into {num_piece} which is large than the \ - threshold=8. Fall back to non-TRT module." - ) - return None - - if torch.float16 in enabled_precisions or torch.half in enabled_precisions: - precision = LowerPrecision.FP16 - else: - precision = LowerPrecision.FP32 - - def get_submod_inputs(mod, submod, inputs): - acc_inputs = None - - def get_input(self, inputs): - nonlocal acc_inputs - acc_inputs = inputs - - handle = submod.register_forward_pre_hook(get_input) - mod(*inputs) - handle.remove() - return acc_inputs - - for name, _ in split_mod.named_children(): - if "_run_on_acc" in name: - submod = getattr(split_mod, name) - # Get submodule inputs for fx2trt - acc_inputs = get_submod_inputs(split_mod, submod, inputs) - - # fx2trt replacement - interp = TRTInterpreter( - submod, - InputTensorSpec.from_tensors(acc_inputs), - explicit_batch_dimension=True, - ) - r = interp.run( - max_workspace_size=20 << 30, - lower_precision=precision, - # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile - ) - # For profile - # from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module - # profile_trt_module("", trt_mod, acc_inputs) - trt_mod = TRTModule(*r) - - setattr(split_mod, name, trt_mod) - else: - submod = getattr(split_mod, name) - return split_mod + return torch_tensorrt.fx.util.compile(module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) + # from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer + # from torch_tensorrt.fx import InputTensorSpec + # from torch_tensorrt.fx import TRTInterpreter + # from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem + # from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter + # from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting + # from torch_tensorrt.fx.trt_module import TRTModule + # from torch_tensorrt.fx.utils import LowerPrecision + # acc_model = acc_tracer.trace(module, inputs) + + # splitter_setting = TRTSplitterSetting() + # splitter_setting.use_implicit_batch_dim = False + # splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) + # splitter.node_support_preview() + # split_mod = splitter() + # num_piece = 0 + # for name, _ in split_mod.named_children(): + # print(f"graph is split into {name}") + # num_piece += 1 + + # # if the graph module is split into pieces larger than 8, we consider its perf + # # is not good and fall back to non-TRT + # if num_piece > 8: + # print( + # f"The graph module is split into {num_piece} which is large than the \ + # threshold=8. Fall back to non-TRT module." + # ) + # return None + + # if torch.float16 in enabled_precisions or torch.half in enabled_precisions: + # precision = LowerPrecision.FP16 + # else: + # precision = LowerPrecision.FP32 + + # def get_submod_inputs(mod, submod, inputs): + # acc_inputs = None + + # def get_input(self, inputs): + # nonlocal acc_inputs + # acc_inputs = inputs + + # handle = submod.register_forward_pre_hook(get_input) + # mod(*inputs) + # handle.remove() + # return acc_inputs + + # for name, _ in split_mod.named_children(): + # if "_run_on_acc" in name: + # submod = getattr(split_mod, name) + # # Get submodule inputs for fx2trt + # acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + # # fx2trt replacement + # interp = TRTInterpreter( + # submod, + # InputTensorSpec.from_tensors(acc_inputs), + # explicit_batch_dimension=True, + # ) + # r = interp.run( + # max_workspace_size=20 << 30, + # lower_precision=precision, + # # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile + # ) + # # For profile + # # from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module + # # profile_trt_module("", trt_mod, acc_inputs) + # trt_mod = TRTModule(*r) + + # setattr(split_mod, name, trt_mod) + # else: + # submod = getattr(split_mod, name) + # return split_mod else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/fx/example/lower_example.py b/py/torch_tensorrt/fx/example/lower_example.py index b93e93598e..7990315e5f 100644 --- a/py/torch_tensorrt/fx/example/lower_example.py +++ b/py/torch_tensorrt/fx/example/lower_example.py @@ -198,6 +198,6 @@ def run_configuration_benchmark( if __name__ == "__main__": - test_model = torchvision.models.resnet101() - input = [torch.cuda.FloatTensor(1024, 3, 224, 224)] # type: ignore[attr-defined] - benchmark(test_model, input, 100, 1024) + test_model = torchvision.models.resnet18() + input = [torch.cuda.FloatTensor(32, 3, 224, 224)] # type: ignore[attr-defined] + benchmark(test_model, input, 30, 32) diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 863e4b3f85..86d3fe4877 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -6,7 +6,14 @@ import torch from .types import Shape, TRTDataType - +from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer +from torch_tensorrt.fx import InputTensorSpec +from torch_tensorrt.fx import TRTInterpreter +from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision class LowerPrecision(Enum): FP32 = "fp32" @@ -82,3 +89,89 @@ def get_dynamic_dims(shape: Shape) -> List[int]: dynamic_dims.append(i) return dynamic_dims + +def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums.dtype.float]), **kwargs): + """Compile a PyTorch module through fx + + Takes a existing PyTorch module and a set of settings to configure the compiler + and using the path specified in ``ir`` lower and compile the module to TensorRT + returning a PyTorch Module back + + Converts specifically the forward method of a Module + + Arguments: + module (torch.nn.Module): Source module + + Keyword Arguments: + inputs (List[torch.Tensor]): for fixed shape scenario, inputs shapes can not change + enabled_precision (torch.dtype): The datatype that TensorRT can use when selecting kernels. If torch.float is chosen, the kernel is running with fp32; If torch.float16 is chosen, the kernel is running with fp16 or fp32 which selected by TensorRT + ir (str): The requested strategy to compile. (default is ts - TorchScript with scripting path, fx is FX based path) + **kwargs: Additional settings for the specific requested strategy (See submodules for more info) + + Returns: + torch.nn.Module: Compiled Module, when run it will execute via TensorRT + """ + acc_model = acc_tracer.trace(module, inputs) + + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = False + splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) + splitter.node_support_preview() + split_mod = splitter() + num_piece = 0 + for name, _ in split_mod.named_children(): + print(f"graph is split into {name}") + num_piece += 1 + + # if the graph module is split into pieces larger than 8, we consider its perf + # is not good and fall back to non-TRT + if num_piece > 8: + print( + f"The graph module is split into {num_piece} which is large than the \ + threshold=8. Fall back to non-TRT module." + ) + return None + + if torch.float16 in enabled_precisions or torch.half in enabled_precisions: + precision = LowerPrecision.FP16 + else: + precision = LowerPrecision.FP32 + + def get_submod_inputs(mod, submod, inputs): + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + + for name, _ in split_mod.named_children(): + if "_run_on_acc" in name: + submod = getattr(split_mod, name) + # Get submodule inputs for fx2trt + acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + # fx2trt replacement + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + ) + r = interp.run( + max_workspace_size=20 << 30, + lower_precision=precision, + # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile + ) + # For profile + # from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module + # profile_trt_module("", trt_mod, acc_inputs) + trt_mod = TRTModule(*r) + + setattr(split_mod, name, trt_mod) + else: + submod = getattr(split_mod, name) + return split_mod From 0f5ef06dcf11cd5b5372f29581e2c2c6a05cfba4 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Wed, 22 Jun 2022 16:33:38 -0700 Subject: [PATCH 02/10] add compile method --- py/torch_tensorrt/_compile.py | 78 +----------------------- py/torch_tensorrt/fx/__init__.py | 1 + py/torch_tensorrt/fx/compile.py | 101 +++++++++++++++++++++++++++++++ py/torch_tensorrt/fx/utils.py | 94 ---------------------------- 4 files changed, 105 insertions(+), 169 deletions(-) create mode 100644 py/torch_tensorrt/fx/compile.py diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 13198398c8..ef727211c2 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -3,9 +3,9 @@ import torch_tensorrt.ts from torch_tensorrt import logging import torch -from torch import fx +import torch.fx from enum import Enum -from torch_tensorrt import fx +import torch_tensorrt.fx class _IRType(Enum): """Enum to set the minimum required logging level to print a message to stdout @@ -108,79 +108,7 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums ts_mod = torch.jit.script(module) return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) elif target_ir == _IRType.fx: - return torch_tensorrt.fx.util.compile(module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) - # from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer - # from torch_tensorrt.fx import InputTensorSpec - # from torch_tensorrt.fx import TRTInterpreter - # from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem - # from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter - # from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting - # from torch_tensorrt.fx.trt_module import TRTModule - # from torch_tensorrt.fx.utils import LowerPrecision - # acc_model = acc_tracer.trace(module, inputs) - - # splitter_setting = TRTSplitterSetting() - # splitter_setting.use_implicit_batch_dim = False - # splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) - # splitter.node_support_preview() - # split_mod = splitter() - # num_piece = 0 - # for name, _ in split_mod.named_children(): - # print(f"graph is split into {name}") - # num_piece += 1 - - # # if the graph module is split into pieces larger than 8, we consider its perf - # # is not good and fall back to non-TRT - # if num_piece > 8: - # print( - # f"The graph module is split into {num_piece} which is large than the \ - # threshold=8. Fall back to non-TRT module." - # ) - # return None - - # if torch.float16 in enabled_precisions or torch.half in enabled_precisions: - # precision = LowerPrecision.FP16 - # else: - # precision = LowerPrecision.FP32 - - # def get_submod_inputs(mod, submod, inputs): - # acc_inputs = None - - # def get_input(self, inputs): - # nonlocal acc_inputs - # acc_inputs = inputs - - # handle = submod.register_forward_pre_hook(get_input) - # mod(*inputs) - # handle.remove() - # return acc_inputs - - # for name, _ in split_mod.named_children(): - # if "_run_on_acc" in name: - # submod = getattr(split_mod, name) - # # Get submodule inputs for fx2trt - # acc_inputs = get_submod_inputs(split_mod, submod, inputs) - - # # fx2trt replacement - # interp = TRTInterpreter( - # submod, - # InputTensorSpec.from_tensors(acc_inputs), - # explicit_batch_dimension=True, - # ) - # r = interp.run( - # max_workspace_size=20 << 30, - # lower_precision=precision, - # # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile - # ) - # # For profile - # # from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module - # # profile_trt_module("", trt_mod, acc_inputs) - # trt_mod = TRTModule(*r) - - # setattr(split_mod, name, trt_mod) - # else: - # submod = getattr(split_mod, name) - # return split_mod + return torch_tensorrt.fx.compile(module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/fx/__init__.py b/py/torch_tensorrt/fx/__init__.py index fa0afc33d1..d0b0657c74 100644 --- a/py/torch_tensorrt/fx/__init__.py +++ b/py/torch_tensorrt/fx/__init__.py @@ -8,3 +8,4 @@ from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa from .input_tensor_spec import InputTensorSpec # noqa from .trt_module import TRTModule # noqa +from .compile import compile # noqa diff --git a/py/torch_tensorrt/fx/compile.py b/py/torch_tensorrt/fx/compile.py new file mode 100644 index 0000000000..10c1b96948 --- /dev/null +++ b/py/torch_tensorrt/fx/compile.py @@ -0,0 +1,101 @@ + + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch + + +from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer +from torch_tensorrt.fx import InputTensorSpec +from torch_tensorrt.fx import TRTInterpreter +from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision + +def compile(module: torch.nn.Module, ir="default", inputs=[], enabled_precisions=torch.dtype, **kwargs): + """Compile a PyTorch module through fx + + Takes a existing PyTorch module and a set of settings to configure the compiler + and using the path specified in ``ir`` lower and compile the module to TensorRT + returning a PyTorch Module back + + Converts specifically the forward method of a Module + + Arguments: + module (torch.nn.Module): Source module + + Keyword Arguments: + inputs (List[torch.Tensor]): for fixed shape scenario, inputs shapes can not change + enabled_precision (torch.dtype): The datatype that TensorRT can use when selecting kernels. If torch.float is chosen, the kernel is running with fp32; If torch.float16 is chosen, the kernel is running with fp16 or fp32 which selected by TensorRT + ir (str): The requested strategy to compile. (default is ts - TorchScript with scripting path, fx is FX based path) + **kwargs: Additional settings for the specific requested strategy (See submodules for more info) + + Returns: + torch.nn.Module: Compiled Module, when run it will execute via TensorRT + """ + acc_model = acc_tracer.trace(module, inputs) + + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = False + splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) + splitter.node_support_preview() + split_mod = splitter() + num_piece = 0 + for name, _ in split_mod.named_children(): + print(f"graph is split into {name}") + num_piece += 1 + + # if the graph module is split into pieces larger than 8, we consider its perf + # is not good and fall back to non-TRT + if num_piece > 8: + print( + f"The graph module is split into {num_piece} which is large than the \ + threshold=8. Fall back to non-TRT module." + ) + return None + + if torch.float16 in enabled_precisions or torch.half in enabled_precisions: + precision = LowerPrecision.FP16 + else: + precision = LowerPrecision.FP32 + + def get_submod_inputs(mod, submod, inputs): + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + + for name, _ in split_mod.named_children(): + if "_run_on_acc" in name: + submod = getattr(split_mod, name) + # Get submodule inputs for fx2trt + acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + # fx2trt replacement + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + ) + r = interp.run( + max_workspace_size=20 << 30, + lower_precision=precision, + # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile + ) + # For profile + # from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module + # profile_trt_module("", trt_mod, acc_inputs) + trt_mod = TRTModule(*r) + + setattr(split_mod, name, trt_mod) + else: + submod = getattr(split_mod, name) + return split_mod diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 86d3fe4877..826112dcc8 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -6,14 +6,6 @@ import torch from .types import Shape, TRTDataType -from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer -from torch_tensorrt.fx import InputTensorSpec -from torch_tensorrt.fx import TRTInterpreter -from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting -from torch_tensorrt.fx.trt_module import TRTModule -from torch_tensorrt.fx.utils import LowerPrecision class LowerPrecision(Enum): FP32 = "fp32" @@ -89,89 +81,3 @@ def get_dynamic_dims(shape: Shape) -> List[int]: dynamic_dims.append(i) return dynamic_dims - -def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums.dtype.float]), **kwargs): - """Compile a PyTorch module through fx - - Takes a existing PyTorch module and a set of settings to configure the compiler - and using the path specified in ``ir`` lower and compile the module to TensorRT - returning a PyTorch Module back - - Converts specifically the forward method of a Module - - Arguments: - module (torch.nn.Module): Source module - - Keyword Arguments: - inputs (List[torch.Tensor]): for fixed shape scenario, inputs shapes can not change - enabled_precision (torch.dtype): The datatype that TensorRT can use when selecting kernels. If torch.float is chosen, the kernel is running with fp32; If torch.float16 is chosen, the kernel is running with fp16 or fp32 which selected by TensorRT - ir (str): The requested strategy to compile. (default is ts - TorchScript with scripting path, fx is FX based path) - **kwargs: Additional settings for the specific requested strategy (See submodules for more info) - - Returns: - torch.nn.Module: Compiled Module, when run it will execute via TensorRT - """ - acc_model = acc_tracer.trace(module, inputs) - - splitter_setting = TRTSplitterSetting() - splitter_setting.use_implicit_batch_dim = False - splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) - splitter.node_support_preview() - split_mod = splitter() - num_piece = 0 - for name, _ in split_mod.named_children(): - print(f"graph is split into {name}") - num_piece += 1 - - # if the graph module is split into pieces larger than 8, we consider its perf - # is not good and fall back to non-TRT - if num_piece > 8: - print( - f"The graph module is split into {num_piece} which is large than the \ - threshold=8. Fall back to non-TRT module." - ) - return None - - if torch.float16 in enabled_precisions or torch.half in enabled_precisions: - precision = LowerPrecision.FP16 - else: - precision = LowerPrecision.FP32 - - def get_submod_inputs(mod, submod, inputs): - acc_inputs = None - - def get_input(self, inputs): - nonlocal acc_inputs - acc_inputs = inputs - - handle = submod.register_forward_pre_hook(get_input) - mod(*inputs) - handle.remove() - return acc_inputs - - for name, _ in split_mod.named_children(): - if "_run_on_acc" in name: - submod = getattr(split_mod, name) - # Get submodule inputs for fx2trt - acc_inputs = get_submod_inputs(split_mod, submod, inputs) - - # fx2trt replacement - interp = TRTInterpreter( - submod, - InputTensorSpec.from_tensors(acc_inputs), - explicit_batch_dimension=True, - ) - r = interp.run( - max_workspace_size=20 << 30, - lower_precision=precision, - # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile - ) - # For profile - # from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module - # profile_trt_module("", trt_mod, acc_inputs) - trt_mod = TRTModule(*r) - - setattr(split_mod, name, trt_mod) - else: - submod = getattr(split_mod, name) - return split_mod From 5f99c1107132132ca1a868c734f7db5a2c45bb2d Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Wed, 22 Jun 2022 16:36:10 -0700 Subject: [PATCH 03/10] update --- py/torch_tensorrt/fx/compile.py | 3 --- py/torch_tensorrt/fx/utils.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/py/torch_tensorrt/fx/compile.py b/py/torch_tensorrt/fx/compile.py index 10c1b96948..c824f106b1 100644 --- a/py/torch_tensorrt/fx/compile.py +++ b/py/torch_tensorrt/fx/compile.py @@ -1,10 +1,7 @@ - - # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch - from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer from torch_tensorrt.fx import InputTensorSpec from torch_tensorrt.fx import TRTInterpreter diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 826112dcc8..863e4b3f85 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -7,6 +7,7 @@ from .types import Shape, TRTDataType + class LowerPrecision(Enum): FP32 = "fp32" FP16 = "fp16" From 7be21ca952d04ce88b8543b0633a2e73dce6e009 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Thu, 23 Jun 2022 19:41:31 -0700 Subject: [PATCH 04/10] update --- py/torch_tensorrt/_compile.py | 11 ++- py/torch_tensorrt/fx/__init__.py | 1 - py/torch_tensorrt/fx/compile.py | 98 ------------------- py/torch_tensorrt/fx/example/lower_example.py | 6 +- py/torch_tensorrt/fx/example/test_fx2trt.py | 54 ---------- .../fx/example/torch_trt_simple_example.py | 57 +++++++++++ py/torch_tensorrt/fx/lower.py | 6 +- py/torch_tensorrt/fx/lower_setting.py | 4 + 8 files changed, 78 insertions(+), 159 deletions(-) delete mode 100644 py/torch_tensorrt/fx/compile.py delete mode 100644 py/torch_tensorrt/fx/example/test_fx2trt.py create mode 100644 py/torch_tensorrt/fx/example/torch_trt_simple_example.py diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index ef727211c2..c6550ae7c7 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -6,6 +6,8 @@ import torch.fx from enum import Enum import torch_tensorrt.fx +from torch_tensorrt.fx.lower import lower_to_trt +from torch_tensorrt.fx.utils import LowerPrecision class _IRType(Enum): """Enum to set the minimum required logging level to print a message to stdout @@ -108,7 +110,14 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums ts_mod = torch.jit.script(module) return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) elif target_ir == _IRType.fx: - return torch_tensorrt.fx.compile(module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs) + if torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions: + lower_precision = LowerPrecision.FP16 + elif torch.float32 in enabled_precisions or torch_tensorrt.dtype.float in enabled_precisions: + lower_precision = LowerPrecision.FP32 + else: + raise ValueError(f"Precision {enabled_precisions} not supported on FX") + + return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True, dynamic_batch=False) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/fx/__init__.py b/py/torch_tensorrt/fx/__init__.py index d0b0657c74..fa0afc33d1 100644 --- a/py/torch_tensorrt/fx/__init__.py +++ b/py/torch_tensorrt/fx/__init__.py @@ -8,4 +8,3 @@ from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa from .input_tensor_spec import InputTensorSpec # noqa from .trt_module import TRTModule # noqa -from .compile import compile # noqa diff --git a/py/torch_tensorrt/fx/compile.py b/py/torch_tensorrt/fx/compile.py deleted file mode 100644 index c824f106b1..0000000000 --- a/py/torch_tensorrt/fx/compile.py +++ /dev/null @@ -1,98 +0,0 @@ -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt -import torch - -from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer -from torch_tensorrt.fx import InputTensorSpec -from torch_tensorrt.fx import TRTInterpreter -from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting -from torch_tensorrt.fx.trt_module import TRTModule -from torch_tensorrt.fx.utils import LowerPrecision - -def compile(module: torch.nn.Module, ir="default", inputs=[], enabled_precisions=torch.dtype, **kwargs): - """Compile a PyTorch module through fx - - Takes a existing PyTorch module and a set of settings to configure the compiler - and using the path specified in ``ir`` lower and compile the module to TensorRT - returning a PyTorch Module back - - Converts specifically the forward method of a Module - - Arguments: - module (torch.nn.Module): Source module - - Keyword Arguments: - inputs (List[torch.Tensor]): for fixed shape scenario, inputs shapes can not change - enabled_precision (torch.dtype): The datatype that TensorRT can use when selecting kernels. If torch.float is chosen, the kernel is running with fp32; If torch.float16 is chosen, the kernel is running with fp16 or fp32 which selected by TensorRT - ir (str): The requested strategy to compile. (default is ts - TorchScript with scripting path, fx is FX based path) - **kwargs: Additional settings for the specific requested strategy (See submodules for more info) - - Returns: - torch.nn.Module: Compiled Module, when run it will execute via TensorRT - """ - acc_model = acc_tracer.trace(module, inputs) - - splitter_setting = TRTSplitterSetting() - splitter_setting.use_implicit_batch_dim = False - splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) - splitter.node_support_preview() - split_mod = splitter() - num_piece = 0 - for name, _ in split_mod.named_children(): - print(f"graph is split into {name}") - num_piece += 1 - - # if the graph module is split into pieces larger than 8, we consider its perf - # is not good and fall back to non-TRT - if num_piece > 8: - print( - f"The graph module is split into {num_piece} which is large than the \ - threshold=8. Fall back to non-TRT module." - ) - return None - - if torch.float16 in enabled_precisions or torch.half in enabled_precisions: - precision = LowerPrecision.FP16 - else: - precision = LowerPrecision.FP32 - - def get_submod_inputs(mod, submod, inputs): - acc_inputs = None - - def get_input(self, inputs): - nonlocal acc_inputs - acc_inputs = inputs - - handle = submod.register_forward_pre_hook(get_input) - mod(*inputs) - handle.remove() - return acc_inputs - - for name, _ in split_mod.named_children(): - if "_run_on_acc" in name: - submod = getattr(split_mod, name) - # Get submodule inputs for fx2trt - acc_inputs = get_submod_inputs(split_mod, submod, inputs) - - # fx2trt replacement - interp = TRTInterpreter( - submod, - InputTensorSpec.from_tensors(acc_inputs), - explicit_batch_dimension=True, - ) - r = interp.run( - max_workspace_size=20 << 30, - lower_precision=precision, - # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile - ) - # For profile - # from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module - # profile_trt_module("", trt_mod, acc_inputs) - trt_mod = TRTModule(*r) - - setattr(split_mod, name, trt_mod) - else: - submod = getattr(split_mod, name) - return split_mod diff --git a/py/torch_tensorrt/fx/example/lower_example.py b/py/torch_tensorrt/fx/example/lower_example.py index 7990315e5f..71f15a2f88 100644 --- a/py/torch_tensorrt/fx/example/lower_example.py +++ b/py/torch_tensorrt/fx/example/lower_example.py @@ -198,6 +198,6 @@ def run_configuration_benchmark( if __name__ == "__main__": - test_model = torchvision.models.resnet18() - input = [torch.cuda.FloatTensor(32, 3, 224, 224)] # type: ignore[attr-defined] - benchmark(test_model, input, 30, 32) + test_model = torchvision.models.resnet18(pretrained=True) + input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined] + benchmark(test_model, input, 50, 128) diff --git a/py/torch_tensorrt/fx/example/test_fx2trt.py b/py/torch_tensorrt/fx/example/test_fx2trt.py deleted file mode 100644 index effc188e7a..0000000000 --- a/py/torch_tensorrt/fx/example/test_fx2trt.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch_tensorrt - - -class MyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(5, 3) - self.relu = torch.nn.functional.relu - - def forward(self, x): - x = self.linear(x) - x = self.relu(x) - return x - - -model = MyModel().eval() # torch module needs to be in eval (not training) mode - -# torch tensorrt -inputs = [ - torch_tensorrt.Input( - (2, 5), - dtype=torch.half, - ) -] -enabled_precisions = {torch.float, torch.half} # Run with fp16 - -trt_ts_module = torch_tensorrt.compile( - model, inputs=inputs, enabled_precisions=enabled_precisions -) - -inputs_ts = [torch.ones(2, 5)] -inputs_ts = [i.cuda().half() for i in inputs_ts] -result = trt_ts_module(*inputs_ts) -print(result) - -model.cuda().half() -ref = model(*inputs_ts) -print(ref) - -# fx2trt -inputs_fx = [torch.ones((2, 5))] - -model.cuda().half() -inputs_fx = [i.cuda().half() for i in inputs_fx] - -trt_fx_module = torch_tensorrt.compile( - model, ir="fx", inputs=inputs_fx, enabled_precisions={torch.half} -) -result = trt_fx_module(*inputs_fx) -print(result) - -ref = model(*inputs_fx) -print(ref) diff --git a/py/torch_tensorrt/fx/example/torch_trt_simple_example.py b/py/torch_tensorrt/fx/example/torch_trt_simple_example.py new file mode 100644 index 0000000000..a6dd732c84 --- /dev/null +++ b/py/torch_tensorrt/fx/example/torch_trt_simple_example.py @@ -0,0 +1,57 @@ +import torch +import copy +import torchvision +import torch_tensorrt +from torch_tensorrt.fx import InputTensorSpec + + +def test_torch_tensorrt(model, inputs): + # torchscript path + model_ts = copy.deepcopy(model) + inputs_ts = copy.deepcopy(inputs) + # fp32 test + with torch.inference_mode(): + ref_fp32 = model_ts(*inputs_ts) + trt_ts_module = torch_tensorrt.compile( + model_ts, inputs=inputs_ts, enabled_precisions={torch.float32} + ) + result_fp32 = trt_ts_module(*inputs_ts) + assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999) + # fp16 test + model_ts = model_ts.half() + inputs_ts = [i.cuda().half() for i in inputs_ts] + with torch.inference_mode(): + ref_fp16 = model_ts(*inputs_ts) + trt_ts_module = torch_tensorrt.compile( + model_ts, inputs=inputs_ts, enabled_precisions={torch.float16} + ) + result_fp16 = trt_ts_module(*inputs_ts) + assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99) + + # FX path + model_fx = copy.deepcopy(model) + inputs_fx = copy.deepcopy(inputs) + # fp32 test + with torch.inference_mode(): + ref_fp32 = model_fx(*inputs_fx) + trt_fx_module = torch_tensorrt.compile( + model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float32} + ) + result_fp32 = trt_fx_module(*inputs_fx) + assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999) + # fp16 test + model_fx = model_fx.cuda().half() + inputs_fx = [i.cuda().half() for i in inputs_fx] + with torch.inference_mode(): + ref_fp16 = model_fx(*inputs_fx) + trt_fx_module = torch_tensorrt.compile( + model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float16} + ) + result_fp16 = trt_fx_module(*inputs_fx) + assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99 ) + + +if __name__ == "__main__": + model = torchvision.models.resnet18(pretrained=True).cuda().eval() + inputs = [torch.ones((32, 3, 224, 224), device=torch.device('cuda'))] # type: ignore[attr-defined] + test_torch_tensorrt(model, inputs) diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 318136be56..9d7adff541 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -42,6 +42,7 @@ def lower_to_trt( timing_cache_prefix="", save_timing_cache=False, cuda_graph_batch_size=-1, + dynamic_batch=True, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module @@ -71,6 +72,7 @@ def lower_to_trt( timing_cache_prefix=timing_cache_prefix, save_timing_cache=save_timing_cache, cuda_graph_batch_size=cuda_graph_batch_size, + dynamic_batch=dynamic_batch, ) lowerer = Lowerer.create(lower_setting=lower_setting) return lowerer(module, input) @@ -100,12 +102,12 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: self.lower_setting.max_batch_size, self.lower_setting.max_batch_size, ), + self.lower_setting.opt_profile_replica, ) - if self.lower_setting.explicit_batch_dimension + if self.lower_setting.explicit_batch_dimension and self.lower_setting.dynamic_batch else InputTensorSpec.from_tensors(input) ) ) - # Prepare algorithm selector and timing_cache for TRTInterpreter algo_selector = None if self.lower_setting.algo_selector: diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index b1a32c2cff..98435e6fe1 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -69,6 +69,8 @@ class LowerSetting(LowerSettingBasic): how presets are applied. Refer to `caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how to add a preset. + opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is + only used by explicit batch dim with dynamic shape mode. """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -86,3 +88,5 @@ class LowerSetting(LowerSettingBasic): save_timing_cache: bool = False cuda_graph_batch_size: int = -1 preset_lowerer: str = "" + opt_profile_replica: int = 1 + dynamic_batch: bool = True From 96f9aa3993c135630b3c23f3c2f27f762d389c75 Mon Sep 17 00:00:00 2001 From: Wei Date: Thu, 23 Jun 2022 19:56:06 -0700 Subject: [PATCH 05/10] Update lower_setting.py --- py/torch_tensorrt/fx/lower_setting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index a297ca9736..1719b4efe8 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -89,4 +89,5 @@ class LowerSetting(LowerSettingBasic): cuda_graph_batch_size: int = -1 preset_lowerer: str = "" opt_profile_replica: int = 1 - dynamic_batch: bool = True \ No newline at end of file + dynamic_batch: bool = True + From 596ac14a9c7aa418d7a4afc73ea0bb7404db53a5 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Thu, 23 Jun 2022 20:33:25 -0700 Subject: [PATCH 06/10] update fx2trt_example --- .../fx/example/fx2trt_example.py | 76 ++++++++++++------- 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/py/torch_tensorrt/fx/example/fx2trt_example.py b/py/torch_tensorrt/fx/example/fx2trt_example.py index 8c648ec065..24cf7fd2b4 100644 --- a/py/torch_tensorrt/fx/example/fx2trt_example.py +++ b/py/torch_tensorrt/fx/example/fx2trt_example.py @@ -3,11 +3,11 @@ import torch import torch.fx import torch.nn as nn +from torch_tensorrt.fx.utils import LowerPrecision import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter - # The purpose of this example is to demonstrate the overall flow of lowering a PyTorch # model to TensorRT via FX with existing FX based tooling. The general lowering flow # would be like: @@ -30,11 +30,12 @@ def forward(self, x): x = self.linear(x) x = self.relu(x) x = torch.linalg.norm(x, ord=2, dim=1) + x = self.relu(x) return x -inputs = [torch.randn(1, 10)] -model = Model().eval() +inputs = [torch.randn((1, 10), device=torch.device('cuda'))] +model = Model().cuda().eval() # acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators # to acc ops. @@ -64,20 +65,23 @@ def forward(self, x): # Split. split_mod = splitter() -# After split we have two submodules, _run_on_acc_0 and _run_on_gpu_1. +# After split we have three submodules, _run_on_acc_0 and _run_on_gpu_1. print(split_mod.graph) """ graph(): %x : [#users=1] = placeholder[target=x] %_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {}) %_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {}) - return _run_on_gpu_1 + %_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {}) + return _run_on_acc_2 """ # Take a look at what inside each submodule. _run_on_acc_0 contains linear and relu while -# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt. +# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt. _run_on_acc_3 +# is the another submodule supported. print(split_mod._run_on_acc_0.graph) print(split_mod._run_on_gpu_1.graph) +print(split_mod._run_on_acc_2.graph) """ graph(): %x : [#users=1] = placeholder[target=x] @@ -90,32 +94,46 @@ def forward(self, x): %relu_1 : [#users=1] = placeholder[target=relu_1] %linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ... return linalg_norm_1 +graph(): + %linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1] + %relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False}) + return relu_3 """ -# Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered, -# we can skip the splitter part. -interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs)) -r = interp.run() -trt_mod = TRTModule(r.engine, r.input_names, r.output_names) -split_mod._run_on_acc_0 = trt_mod - -cuda_inputs = [input.cuda() for input in inputs] -split_mod.cuda() -lowered_model_output = split_mod(*cuda_inputs) +def get_submod_inputs(mod, submod, inputs): + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + +# Since the model is splitted into three segments. We need to lower each TRT eligible segment. +# If we know the model can be fully lowered, we can skip the splitter part. +for name, _ in split_mod.named_children(): + if "_run_on_acc" in name: + submod = getattr(split_mod, name) + # Get submodule inputs for fx2trt + acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + # fx2trt replacement + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + ) + r = interp.run(lower_precision=LowerPrecision.FP32) + trt_mod = TRTModule(*r) + setattr(split_mod, name, trt_mod) + +lowered_model_output = split_mod(*inputs) # Make sure the results match -model.cuda() -regular_model_output = model(*cuda_inputs) +regular_model_output = model(*inputs) torch.testing.assert_close( - lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2 + lowered_model_output, regular_model_output, atol=3e-3, rtol=1e-2 ) - -# We can utilize the trt profiler to print out the time spend on each layer. -trt_mod.enable_profiling() -trt_mod(*cuda_inputs) -""" -Reformatting CopyNode for Input Tensor 0 to LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.027392ms -LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.023072ms -PWN(ActivationType.RELU_acc_ops.relu_relu_1): 0.008928ms -""" -trt_mod.disable_profiling() From e367e11a87cb1b50088760ea25c01f13f10179e4 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Thu, 23 Jun 2022 22:33:13 -0700 Subject: [PATCH 07/10] add docstring --- py/torch_tensorrt/fx/lower_setting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 1719b4efe8..775da55e6d 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -71,6 +71,7 @@ class LowerSetting(LowerSettingBasic): to add a preset. opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is only used by explicit batch dim with dynamic shape mode. + dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension. """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -90,4 +91,3 @@ class LowerSetting(LowerSettingBasic): preset_lowerer: str = "" opt_profile_replica: int = 1 dynamic_batch: bool = True - From 834a4b07d0cf52a8b9e4cac7e7cbbba011f091bc Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Fri, 24 Jun 2022 09:55:13 -0700 Subject: [PATCH 08/10] update dynamic_batch default to False --- py/torch_tensorrt/_compile.py | 2 +- py/torch_tensorrt/fx/lower.py | 2 +- py/torch_tensorrt/fx/lower_setting.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index c6550ae7c7..c697979a43 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -117,7 +117,7 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums else: raise ValueError(f"Precision {enabled_precisions} not supported on FX") - return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True, dynamic_batch=False) + return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 9d7adff541..763ffdc653 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -42,7 +42,7 @@ def lower_to_trt( timing_cache_prefix="", save_timing_cache=False, cuda_graph_batch_size=-1, - dynamic_batch=True, + dynamic_batch=False, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 775da55e6d..6695c8ff85 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -90,4 +90,4 @@ class LowerSetting(LowerSettingBasic): cuda_graph_batch_size: int = -1 preset_lowerer: str = "" opt_profile_replica: int = 1 - dynamic_batch: bool = True + dynamic_batch: bool = False From 09babb57cbe9fefe394b7102fe72804b3d1cb9b2 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Sat, 25 Jun 2022 16:25:15 -0700 Subject: [PATCH 09/10] add docstring --- py/torch_tensorrt/fx/fx2trt.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 16da30575f..29b1490586 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -164,6 +164,21 @@ def run( timing_cache=None, profiling_verbosity=None, ) -> TRTInterpreterResult: + """ + Build TensorRT engine with some configs. + Args: + max_batch_size: set accordingly for maximum batch size you will use. + max_workspace_size: set to the maximum size we can afford for temporary buffer + lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). + sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity + force_fp32_output: force output to be fp32 + strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. + algorithm_selector: set up algorithm selection for certain layer + timing_cache: enable timing cache for TensorRT + profiling_verbosity: TensorRT logging level + Return: + TRTInterpreterResult + """ TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and From 9eb349d3925661edd52e7ff768b6b1daeff428eb Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Sat, 25 Jun 2022 16:58:23 -0700 Subject: [PATCH 10/10] add save/load module --- py/torch_tensorrt/fx/example/fx2trt_example.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/example/fx2trt_example.py b/py/torch_tensorrt/fx/example/fx2trt_example.py index 24cf7fd2b4..b9fbc05f17 100644 --- a/py/torch_tensorrt/fx/example/fx2trt_example.py +++ b/py/torch_tensorrt/fx/example/fx2trt_example.py @@ -132,8 +132,13 @@ def get_input(self, inputs): lowered_model_output = split_mod(*inputs) +# Save and load model +torch.save(split_mod, "trt.pt") +reload_trt_mod = torch.load("trt.pt") +reload_model_output = reload_trt_mod(*inputs) + # Make sure the results match regular_model_output = model(*inputs) torch.testing.assert_close( - lowered_model_output, regular_model_output, atol=3e-3, rtol=1e-2 + reload_model_output, regular_model_output, atol=3e-3, rtol=1e-2 )