Skip to content

Commit

Permalink
[FX] refactor the fx path in compile function (#1141)
Browse files Browse the repository at this point in the history
* compile interface

* add compile method

* update

* update

* Update lower_setting.py

* update fx2trt_example

* add docstring

* update dynamic_batch default to False

* add docstring

* add save/load module
  • Loading branch information
Wei authored Jun 28, 2022
1 parent 5b03083 commit 3c87214
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 161 deletions.
84 changes: 11 additions & 73 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import torch_tensorrt.ts
from torch_tensorrt import logging
import torch
from torch import fx
import torch.fx
from enum import Enum
from torch_tensorrt import fx
import torch_tensorrt.fx
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.utils import LowerPrecision

class _IRType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout
Expand Down Expand Up @@ -108,78 +110,14 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
ts_mod = torch.jit.script(module)
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
elif target_ir == _IRType.fx:
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx import TRTInterpreter
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt.fx.utils import LowerPrecision
acc_model = acc_tracer.trace(module, inputs)

splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = False
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
splitter.node_support_preview()
split_mod = splitter()
num_piece = 0
for name, _ in split_mod.named_children():
print(f"graph is split into {name}")
num_piece += 1

# if the graph module is split into pieces larger than 8, we consider its perf
# is not good and fall back to non-TRT
if num_piece > 8:
print(
f"The graph module is split into {num_piece} which is large than the \
threshold=8. Fall back to non-TRT module."
)
return None

if torch.float16 in enabled_precisions or torch.half in enabled_precisions:
precision = LowerPrecision.FP16
if torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions:
lower_precision = LowerPrecision.FP16
elif torch.float32 in enabled_precisions or torch_tensorrt.dtype.float in enabled_precisions:
lower_precision = LowerPrecision.FP32
else:
precision = LowerPrecision.FP32

def get_submod_inputs(mod, submod, inputs):
acc_inputs = None

def get_input(self, inputs):
nonlocal acc_inputs
acc_inputs = inputs

handle = submod.register_forward_pre_hook(get_input)
mod(*inputs)
handle.remove()
return acc_inputs

for name, _ in split_mod.named_children():
if "_run_on_acc" in name:
submod = getattr(split_mod, name)
# Get submodule inputs for fx2trt
acc_inputs = get_submod_inputs(split_mod, submod, inputs)

# fx2trt replacement
interp = TRTInterpreter(
submod,
InputTensorSpec.from_tensors(acc_inputs),
explicit_batch_dimension=True,
)
r = interp.run(
max_workspace_size=20 << 30,
lower_precision=precision,
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
)
# For profile
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
# profile_trt_module("", trt_mod, acc_inputs)
trt_mod = TRTModule(*r)

setattr(split_mod, name, trt_mod)
else:
submod = getattr(split_mod, name)
return split_mod
raise ValueError(f"Precision {enabled_precisions} not supported on FX")

return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True)
else:
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

Expand Down
81 changes: 52 additions & 29 deletions py/torch_tensorrt/fx/example/fx2trt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch
import torch.fx
import torch.nn as nn
from torch_tensorrt.fx.utils import LowerPrecision
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter


# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
# model to TensorRT via FX with existing FX based tooling. The general lowering flow
# would be like:
Expand All @@ -30,11 +30,12 @@ def forward(self, x):
x = self.linear(x)
x = self.relu(x)
x = torch.linalg.norm(x, ord=2, dim=1)
x = self.relu(x)
return x


inputs = [torch.randn(1, 10)]
model = Model().eval()
inputs = [torch.randn((1, 10), device=torch.device('cuda'))]
model = Model().cuda().eval()

# acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators
# to acc ops.
Expand Down Expand Up @@ -64,20 +65,23 @@ def forward(self, x):
# Split.
split_mod = splitter()

# After split we have two submodules, _run_on_acc_0 and _run_on_gpu_1.
# After split we have three submodules, _run_on_acc_0 and _run_on_gpu_1.
print(split_mod.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
%_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {})
return _run_on_gpu_1
%_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {})
return _run_on_acc_2
"""

# Take a look at what inside each submodule. _run_on_acc_0 contains linear and relu while
# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt.
# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt. _run_on_acc_3
# is the another submodule supported.
print(split_mod._run_on_acc_0.graph)
print(split_mod._run_on_gpu_1.graph)
print(split_mod._run_on_acc_2.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
Expand All @@ -90,32 +94,51 @@ def forward(self, x):
%relu_1 : [#users=1] = placeholder[target=relu_1]
%linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
return linalg_norm_1
graph():
%linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1]
%relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False})
return relu_3
"""

# Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered,
# we can skip the splitter part.
interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs))
r = interp.run()
trt_mod = TRTModule(r.engine, r.input_names, r.output_names)
split_mod._run_on_acc_0 = trt_mod

cuda_inputs = [input.cuda() for input in inputs]
split_mod.cuda()
lowered_model_output = split_mod(*cuda_inputs)
def get_submod_inputs(mod, submod, inputs):
acc_inputs = None

def get_input(self, inputs):
nonlocal acc_inputs
acc_inputs = inputs

handle = submod.register_forward_pre_hook(get_input)
mod(*inputs)
handle.remove()
return acc_inputs

# Since the model is splitted into three segments. We need to lower each TRT eligible segment.
# If we know the model can be fully lowered, we can skip the splitter part.
for name, _ in split_mod.named_children():
if "_run_on_acc" in name:
submod = getattr(split_mod, name)
# Get submodule inputs for fx2trt
acc_inputs = get_submod_inputs(split_mod, submod, inputs)

# fx2trt replacement
interp = TRTInterpreter(
submod,
InputTensorSpec.from_tensors(acc_inputs),
explicit_batch_dimension=True,
)
r = interp.run(lower_precision=LowerPrecision.FP32)
trt_mod = TRTModule(*r)
setattr(split_mod, name, trt_mod)

lowered_model_output = split_mod(*inputs)

# Save and load model
torch.save(split_mod, "trt.pt")
reload_trt_mod = torch.load("trt.pt")
reload_model_output = reload_trt_mod(*inputs)

# Make sure the results match
model.cuda()
regular_model_output = model(*cuda_inputs)
regular_model_output = model(*inputs)
torch.testing.assert_close(
lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2
reload_model_output, regular_model_output, atol=3e-3, rtol=1e-2
)

# We can utilize the trt profiler to print out the time spend on each layer.
trt_mod.enable_profiling()
trt_mod(*cuda_inputs)
"""
Reformatting CopyNode for Input Tensor 0 to LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.027392ms
LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.023072ms
PWN(ActivationType.RELU_acc_ops.relu_relu_1): 0.008928ms
"""
trt_mod.disable_profiling()
6 changes: 3 additions & 3 deletions py/torch_tensorrt/fx/example/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,6 @@ def run_configuration_benchmark(


if __name__ == "__main__":
test_model = torchvision.models.resnet101()
input = [torch.cuda.FloatTensor(1024, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 100, 1024)
test_model = torchvision.models.resnet18(pretrained=True)
input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 50, 128)
54 changes: 0 additions & 54 deletions py/torch_tensorrt/fx/example/test_fx2trt.py

This file was deleted.

57 changes: 57 additions & 0 deletions py/torch_tensorrt/fx/example/torch_trt_simple_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import copy
import torchvision
import torch_tensorrt
from torch_tensorrt.fx import InputTensorSpec


def test_torch_tensorrt(model, inputs):
# torchscript path
model_ts = copy.deepcopy(model)
inputs_ts = copy.deepcopy(inputs)
# fp32 test
with torch.inference_mode():
ref_fp32 = model_ts(*inputs_ts)
trt_ts_module = torch_tensorrt.compile(
model_ts, inputs=inputs_ts, enabled_precisions={torch.float32}
)
result_fp32 = trt_ts_module(*inputs_ts)
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999)
# fp16 test
model_ts = model_ts.half()
inputs_ts = [i.cuda().half() for i in inputs_ts]
with torch.inference_mode():
ref_fp16 = model_ts(*inputs_ts)
trt_ts_module = torch_tensorrt.compile(
model_ts, inputs=inputs_ts, enabled_precisions={torch.float16}
)
result_fp16 = trt_ts_module(*inputs_ts)
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99)

# FX path
model_fx = copy.deepcopy(model)
inputs_fx = copy.deepcopy(inputs)
# fp32 test
with torch.inference_mode():
ref_fp32 = model_fx(*inputs_fx)
trt_fx_module = torch_tensorrt.compile(
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float32}
)
result_fp32 = trt_fx_module(*inputs_fx)
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999)
# fp16 test
model_fx = model_fx.cuda().half()
inputs_fx = [i.cuda().half() for i in inputs_fx]
with torch.inference_mode():
ref_fp16 = model_fx(*inputs_fx)
trt_fx_module = torch_tensorrt.compile(
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float16}
)
result_fp16 = trt_fx_module(*inputs_fx)
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99 )


if __name__ == "__main__":
model = torchvision.models.resnet18(pretrained=True).cuda().eval()
inputs = [torch.ones((32, 3, 224, 224), device=torch.device('cuda'))] # type: ignore[attr-defined]
test_torch_tensorrt(model, inputs)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,21 @@ def run(
timing_cache=None,
profiling_verbosity=None,
) -> TRTInterpreterResult:
"""
Build TensorRT engine with some configs.
Args:
max_batch_size: set accordingly for maximum batch size you will use.
max_workspace_size: set to the maximum size we can afford for temporary buffer
lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
force_fp32_output: force output to be fp32
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
algorithm_selector: set up algorithm selection for certain layer
timing_cache: enable timing cache for TensorRT
profiling_verbosity: TensorRT logging level
Return:
TRTInterpreterResult
"""
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)

# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
Expand Down
Loading

0 comments on commit 3c87214

Please sign in to comment.