Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] refactor the fx path in compile function #1141

Merged
merged 14 commits into from
Jun 28, 2022
77 changes: 3 additions & 74 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import torch_tensorrt.ts
from torch_tensorrt import logging
import torch
from torch import fx
import torch.fx
from enum import Enum
from torch_tensorrt import fx
import torch_tensorrt.fx

class _IRType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout
Expand Down Expand Up @@ -108,78 +108,7 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
ts_mod = torch.jit.script(module)
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
elif target_ir == _IRType.fx:
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx import TRTInterpreter
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt.fx.utils import LowerPrecision
acc_model = acc_tracer.trace(module, inputs)

splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = False
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
splitter.node_support_preview()
split_mod = splitter()
num_piece = 0
for name, _ in split_mod.named_children():
print(f"graph is split into {name}")
num_piece += 1

# if the graph module is split into pieces larger than 8, we consider its perf
# is not good and fall back to non-TRT
if num_piece > 8:
print(
f"The graph module is split into {num_piece} which is large than the \
threshold=8. Fall back to non-TRT module."
)
return None

if torch.float16 in enabled_precisions or torch.half in enabled_precisions:
precision = LowerPrecision.FP16
else:
precision = LowerPrecision.FP32

def get_submod_inputs(mod, submod, inputs):
acc_inputs = None

def get_input(self, inputs):
nonlocal acc_inputs
acc_inputs = inputs

handle = submod.register_forward_pre_hook(get_input)
mod(*inputs)
handle.remove()
return acc_inputs

for name, _ in split_mod.named_children():
if "_run_on_acc" in name:
submod = getattr(split_mod, name)
# Get submodule inputs for fx2trt
acc_inputs = get_submod_inputs(split_mod, submod, inputs)

# fx2trt replacement
interp = TRTInterpreter(
submod,
InputTensorSpec.from_tensors(acc_inputs),
explicit_batch_dimension=True,
)
r = interp.run(
max_workspace_size=20 << 30,
lower_precision=precision,
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
)
# For profile
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
# profile_trt_module("", trt_mod, acc_inputs)
trt_mod = TRTModule(*r)

setattr(split_mod, name, trt_mod)
else:
submod = getattr(split_mod, name)
return split_mod
return torch_tensorrt.fx.compile(module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
else:
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa
from .input_tensor_spec import InputTensorSpec # noqa
from .trt_module import TRTModule # noqa
from .compile import compile # noqa
98 changes: 98 additions & 0 deletions py/torch_tensorrt/fx/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch

from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx import TRTInterpreter
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt.fx.utils import LowerPrecision

def compile(module: torch.nn.Module, ir="default", inputs=[], enabled_precisions=torch.dtype, **kwargs):
"""Compile a PyTorch module through fx

Takes a existing PyTorch module and a set of settings to configure the compiler
and using the path specified in ``ir`` lower and compile the module to TensorRT
returning a PyTorch Module back

Converts specifically the forward method of a Module

Arguments:
module (torch.nn.Module): Source module

Keyword Arguments:
inputs (List[torch.Tensor]): for fixed shape scenario, inputs shapes can not change
enabled_precision (torch.dtype): The datatype that TensorRT can use when selecting kernels. If torch.float is chosen, the kernel is running with fp32; If torch.float16 is chosen, the kernel is running with fp16 or fp32 which selected by TensorRT
ir (str): The requested strategy to compile. (default is ts - TorchScript with scripting path, fx is FX based path)
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)

Returns:
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
"""
acc_model = acc_tracer.trace(module, inputs)
frank-wei marked this conversation as resolved.
Show resolved Hide resolved

splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = False
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
splitter.node_support_preview()
split_mod = splitter()
num_piece = 0
for name, _ in split_mod.named_children():
print(f"graph is split into {name}")
num_piece += 1

# if the graph module is split into pieces larger than 8, we consider its perf
# is not good and fall back to non-TRT
if num_piece > 8:
print(
f"The graph module is split into {num_piece} which is large than the \
threshold=8. Fall back to non-TRT module."
)
return None

if torch.float16 in enabled_precisions or torch.half in enabled_precisions:
precision = LowerPrecision.FP16
else:
precision = LowerPrecision.FP32

def get_submod_inputs(mod, submod, inputs):
acc_inputs = None

def get_input(self, inputs):
nonlocal acc_inputs
acc_inputs = inputs

handle = submod.register_forward_pre_hook(get_input)
mod(*inputs)
handle.remove()
return acc_inputs

for name, _ in split_mod.named_children():
if "_run_on_acc" in name:
submod = getattr(split_mod, name)
# Get submodule inputs for fx2trt
acc_inputs = get_submod_inputs(split_mod, submod, inputs)

# fx2trt replacement
interp = TRTInterpreter(
submod,
InputTensorSpec.from_tensors(acc_inputs),
explicit_batch_dimension=True,
)
r = interp.run(
max_workspace_size=20 << 30,
lower_precision=precision,
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
)
# For profile
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
# profile_trt_module("", trt_mod, acc_inputs)
trt_mod = TRTModule(*r)

setattr(split_mod, name, trt_mod)
else:
submod = getattr(split_mod, name)
return split_mod
6 changes: 3 additions & 3 deletions py/torch_tensorrt/fx/example/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,6 @@ def run_configuration_benchmark(


if __name__ == "__main__":
test_model = torchvision.models.resnet101()
input = [torch.cuda.FloatTensor(1024, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 100, 1024)
test_model = torchvision.models.resnet18()
input = [torch.cuda.FloatTensor(32, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 30, 32)