Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] refactor the fx path in compile function #1141

Merged
merged 14 commits into from
Jun 28, 2022
84 changes: 11 additions & 73 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import torch_tensorrt.ts
from torch_tensorrt import logging
import torch
from torch import fx
import torch.fx
from enum import Enum
from torch_tensorrt import fx
import torch_tensorrt.fx
from torch_tensorrt.fx.lower import lower_to_trt
from torch_tensorrt.fx.utils import LowerPrecision

class _IRType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout
Expand Down Expand Up @@ -108,78 +110,14 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
ts_mod = torch.jit.script(module)
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
elif target_ir == _IRType.fx:
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx import TRTInterpreter
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt.fx.utils import LowerPrecision
acc_model = acc_tracer.trace(module, inputs)

splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = False
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
splitter.node_support_preview()
split_mod = splitter()
num_piece = 0
for name, _ in split_mod.named_children():
print(f"graph is split into {name}")
num_piece += 1

# if the graph module is split into pieces larger than 8, we consider its perf
# is not good and fall back to non-TRT
if num_piece > 8:
print(
f"The graph module is split into {num_piece} which is large than the \
threshold=8. Fall back to non-TRT module."
)
return None

if torch.float16 in enabled_precisions or torch.half in enabled_precisions:
precision = LowerPrecision.FP16
if torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions:
lower_precision = LowerPrecision.FP16
elif torch.float32 in enabled_precisions or torch_tensorrt.dtype.float in enabled_precisions:
lower_precision = LowerPrecision.FP32
else:
precision = LowerPrecision.FP32

def get_submod_inputs(mod, submod, inputs):
acc_inputs = None

def get_input(self, inputs):
nonlocal acc_inputs
acc_inputs = inputs

handle = submod.register_forward_pre_hook(get_input)
mod(*inputs)
handle.remove()
return acc_inputs

for name, _ in split_mod.named_children():
if "_run_on_acc" in name:
submod = getattr(split_mod, name)
# Get submodule inputs for fx2trt
acc_inputs = get_submod_inputs(split_mod, submod, inputs)

# fx2trt replacement
interp = TRTInterpreter(
submod,
InputTensorSpec.from_tensors(acc_inputs),
explicit_batch_dimension=True,
)
r = interp.run(
max_workspace_size=20 << 30,
lower_precision=precision,
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
)
# For profile
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
# profile_trt_module("", trt_mod, acc_inputs)
trt_mod = TRTModule(*r)

setattr(split_mod, name, trt_mod)
else:
submod = getattr(split_mod, name)
return split_mod
raise ValueError(f"Precision {enabled_precisions} not supported on FX")

return lower_to_trt(module, inputs, lower_precision=lower_precision, max_batch_size=inputs[0].size(0), explicit_batch_dimension=True)
else:
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

Expand Down
81 changes: 52 additions & 29 deletions py/torch_tensorrt/fx/example/fx2trt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch
import torch.fx
import torch.nn as nn
from torch_tensorrt.fx.utils import LowerPrecision
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter


# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
# model to TensorRT via FX with existing FX based tooling. The general lowering flow
# would be like:
Expand All @@ -30,11 +30,12 @@ def forward(self, x):
x = self.linear(x)
x = self.relu(x)
x = torch.linalg.norm(x, ord=2, dim=1)
x = self.relu(x)
return x


inputs = [torch.randn(1, 10)]
model = Model().eval()
inputs = [torch.randn((1, 10), device=torch.device('cuda'))]
model = Model().cuda().eval()

# acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators
# to acc ops.
Expand Down Expand Up @@ -64,20 +65,23 @@ def forward(self, x):
# Split.
split_mod = splitter()

# After split we have two submodules, _run_on_acc_0 and _run_on_gpu_1.
# After split we have three submodules, _run_on_acc_0 and _run_on_gpu_1.
print(split_mod.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
%_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {})
return _run_on_gpu_1
%_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {})
return _run_on_acc_2
"""

# Take a look at what inside each submodule. _run_on_acc_0 contains linear and relu while
# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt.
# _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt. _run_on_acc_3
# is the another submodule supported.
print(split_mod._run_on_acc_0.graph)
print(split_mod._run_on_gpu_1.graph)
print(split_mod._run_on_acc_2.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
Expand All @@ -90,32 +94,51 @@ def forward(self, x):
%relu_1 : [#users=1] = placeholder[target=relu_1]
%linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
return linalg_norm_1
graph():
%linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1]
%relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False})
return relu_3
"""

# Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered,
# we can skip the splitter part.
interp = TRTInterpreter(split_mod._run_on_acc_0, InputTensorSpec.from_tensors(inputs))
r = interp.run()
trt_mod = TRTModule(r.engine, r.input_names, r.output_names)
split_mod._run_on_acc_0 = trt_mod

cuda_inputs = [input.cuda() for input in inputs]
split_mod.cuda()
lowered_model_output = split_mod(*cuda_inputs)
def get_submod_inputs(mod, submod, inputs):
acc_inputs = None

def get_input(self, inputs):
nonlocal acc_inputs
acc_inputs = inputs

handle = submod.register_forward_pre_hook(get_input)
mod(*inputs)
handle.remove()
return acc_inputs

# Since the model is splitted into three segments. We need to lower each TRT eligible segment.
# If we know the model can be fully lowered, we can skip the splitter part.
for name, _ in split_mod.named_children():
if "_run_on_acc" in name:
submod = getattr(split_mod, name)
# Get submodule inputs for fx2trt
acc_inputs = get_submod_inputs(split_mod, submod, inputs)

# fx2trt replacement
interp = TRTInterpreter(
submod,
InputTensorSpec.from_tensors(acc_inputs),
explicit_batch_dimension=True,
)
r = interp.run(lower_precision=LowerPrecision.FP32)
trt_mod = TRTModule(*r)
setattr(split_mod, name, trt_mod)

lowered_model_output = split_mod(*inputs)

# Save and load model
torch.save(split_mod, "trt.pt")
reload_trt_mod = torch.load("trt.pt")
reload_model_output = reload_trt_mod(*inputs)

# Make sure the results match
model.cuda()
regular_model_output = model(*cuda_inputs)
regular_model_output = model(*inputs)
torch.testing.assert_close(
lowered_model_output, regular_model_output.to(torch.float16), atol=3e-3, rtol=1e-2
reload_model_output, regular_model_output, atol=3e-3, rtol=1e-2
)

# We can utilize the trt profiler to print out the time spend on each layer.
trt_mod.enable_profiling()
trt_mod(*cuda_inputs)
"""
Reformatting CopyNode for Input Tensor 0 to LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.027392ms
LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.023072ms
PWN(ActivationType.RELU_acc_ops.relu_relu_1): 0.008928ms
"""
trt_mod.disable_profiling()
6 changes: 3 additions & 3 deletions py/torch_tensorrt/fx/example/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,6 @@ def run_configuration_benchmark(


if __name__ == "__main__":
test_model = torchvision.models.resnet101()
input = [torch.cuda.FloatTensor(1024, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 100, 1024)
test_model = torchvision.models.resnet18(pretrained=True)
frank-wei marked this conversation as resolved.
Show resolved Hide resolved
input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 50, 128)
54 changes: 0 additions & 54 deletions py/torch_tensorrt/fx/example/test_fx2trt.py

This file was deleted.

57 changes: 57 additions & 0 deletions py/torch_tensorrt/fx/example/torch_trt_simple_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import copy
import torchvision
import torch_tensorrt
from torch_tensorrt.fx import InputTensorSpec


def test_torch_tensorrt(model, inputs):
# torchscript path
model_ts = copy.deepcopy(model)
inputs_ts = copy.deepcopy(inputs)
# fp32 test
with torch.inference_mode():
ref_fp32 = model_ts(*inputs_ts)
trt_ts_module = torch_tensorrt.compile(
model_ts, inputs=inputs_ts, enabled_precisions={torch.float32}
)
result_fp32 = trt_ts_module(*inputs_ts)
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999)
# fp16 test
model_ts = model_ts.half()
inputs_ts = [i.cuda().half() for i in inputs_ts]
with torch.inference_mode():
ref_fp16 = model_ts(*inputs_ts)
trt_ts_module = torch_tensorrt.compile(
model_ts, inputs=inputs_ts, enabled_precisions={torch.float16}
)
result_fp16 = trt_ts_module(*inputs_ts)
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99)

# FX path
model_fx = copy.deepcopy(model)
inputs_fx = copy.deepcopy(inputs)
# fp32 test
with torch.inference_mode():
ref_fp32 = model_fx(*inputs_fx)
trt_fx_module = torch_tensorrt.compile(
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float32}
)
result_fp32 = trt_fx_module(*inputs_fx)
assert(torch.nn.functional.cosine_similarity(ref_fp32.flatten(), result_fp32.flatten(), dim=0)>0.9999)
# fp16 test
model_fx = model_fx.cuda().half()
inputs_fx = [i.cuda().half() for i in inputs_fx]
with torch.inference_mode():
ref_fp16 = model_fx(*inputs_fx)
trt_fx_module = torch_tensorrt.compile(
model_fx, ir="fx", inputs=inputs_fx, enabled_precisions={torch.float16}
)
result_fp16 = trt_fx_module(*inputs_fx)
assert(torch.nn.functional.cosine_similarity(ref_fp16.flatten(), result_fp16.flatten(), dim=0)>0.99 )


if __name__ == "__main__":
model = torchvision.models.resnet18(pretrained=True).cuda().eval()
inputs = [torch.ones((32, 3, 224, 224), device=torch.device('cuda'))] # type: ignore[attr-defined]
test_torch_tensorrt(model, inputs)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,21 @@ def run(
timing_cache=None,
profiling_verbosity=None,
) -> TRTInterpreterResult:
"""
Build TensorRT engine with some configs.
Args:
max_batch_size: set accordingly for maximum batch size you will use.
max_workspace_size: set to the maximum size we can afford for temporary buffer
lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
force_fp32_output: force output to be fp32
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
algorithm_selector: set up algorithm selection for certain layer
timing_cache: enable timing cache for TensorRT
profiling_verbosity: TensorRT logging level
Return:
TRTInterpreterResult
"""
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)

# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
Expand Down
Loading