Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cherry-pick: Dynamo upgrades and bugfixes (release/1.4) #1956

Merged
merged 4 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DEBUG,
MAX_WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
)


Expand Down Expand Up @@ -46,11 +47,14 @@ def compile(
torch_executed_modules=[],
**kwargs,
):
if debug:
logger.setLevel(logging.DEBUG)

logger.warn(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}"
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
+ "torch_executed_ops, pass_through_build_failures}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -104,6 +108,7 @@ def create_backend(
workspace_size: int = MAX_WORKSPACE_SIZE,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
**kwargs,
):
"""Create torch.compile backend given specified arguments
Expand All @@ -116,12 +121,16 @@ def create_backend(
Returns:
Backend for torch.compile
"""
if debug:
logger.setLevel(logging.DEBUG)

settings = CompilationSettings(
debug=debug,
precision=precision,
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
pass_through_build_failures=pass_through_build_failures,
)

return partial(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DEBUG = False
MAX_WORKSPACE_SIZE = 20 << 30
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DEBUG,
MAX_WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
)


Expand All @@ -17,3 +18,4 @@ class CompilationSettings:
workspace_size: int = MAX_WORKSPACE_SIZE
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
28 changes: 20 additions & 8 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Sequence
import torch
import traceback
from functools import partial
import torch._dynamo as td

Expand All @@ -19,6 +19,9 @@
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler


logger = logging.getLogger(__name__)


@td.register_backend(name="torch_tensorrt")
@fake_tensor_unsupported
def torch_tensorrt_backend(
Expand Down Expand Up @@ -52,6 +55,7 @@ def aot_torch_tensorrt_aten_backend(
)


@fake_tensor_unsupported
def _pretraced_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
Expand All @@ -74,12 +78,22 @@ def _pretraced_backend(
)
return trt_compiled
except:
traceback.print_exc()
print(
logger.error(
"FX2TRT conversion failed on the subgraph. See trace above. "
+ "Returning GraphModule forward instead."
+ "Returning GraphModule forward instead.",
exc_info=True,
)
return gm.forward

if not settings.pass_through_build_failures:
return gm.forward
else:
raise AssertionError(
"Halting compilation on build failure since "
+ "pass_through_build_failures was specified as True. "
+ "To return the default Torch implementation and avoid "
+ "halting compilation on engine build failures, "
+ "specify pass_through_build_failures=False."
)


def _compile_module(
Expand Down Expand Up @@ -120,9 +134,7 @@ def _compile_module(
trt_mod = convert_module(
submodule,
submodule_inputs,
debug=settings.debug,
workspace_size=settings.workspace_size,
precision=settings.precision,
settings=settings,
)

# Replace FX Module with TRT Module
Expand Down
18 changes: 7 additions & 11 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,41 @@
import torch
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt import TRTModuleNext
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.fx.fx2trt import (
InputTensorSpec,
TRTInterpreter,
)
from torch_tensorrt.fx.utils import LowerPrecision

import tensorrt as trt


def convert_module(
module: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
debug: bool = False,
workspace_size: int = 20 << 30,
precision: LowerPrecision = LowerPrecision.FP32,
settings: CompilationSettings = CompilationSettings(),
) -> Union[TRTModuleNext, TRTModule]:
"""Convert an FX module to a TRT module
Args:
module: FX GraphModule to convert
inputs: Sequence of Tensors representing inputs to the module
debug: Whether to print out verbose debugging information
workspace_size: Maximum workspace TRT is allowed to use for the module
precision: Model Layer precision
settings: Compilation settings
Returns:
TRTModule or TRTModuleNext
"""
interp = TRTInterpreter(
module,
InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING),
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
)

r = interp.run(
max_workspace_size=workspace_size,
lower_precision=precision,
max_workspace_size=settings.workspace_size,
lower_precision=settings.precision,
profiling_verbosity=(
trt.ProfilingVerbosity.VERBOSE
if debug
if settings.debug
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
),
)
Expand Down
13 changes: 8 additions & 5 deletions py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,18 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
f"\nNumber of TensorRT-Accelerated Engines Generated: {num_trt_blocks}"
)

logger.debug("\nSupported Nodes:")
# Reformat support messages for debugger to print node overview as a single string
supported_nodes_str = "\nSupported Nodes:\n"
for node_name in self.supported_operators:
logger.debug("-", node_name)
supported_nodes_str += f"- {node_name}\n"

logger.debug(supported_nodes_str)

if len(self.unsupported_operators) != 0:
logger.debug("\nUnsupported or Excluded Nodes:")
unsupported_nodes_str = "\nUnsupported or Excluded Nodes:\n"
for node_name in self.unsupported_operators:
logger.debug("-", node_name)
logger.debug("\n")
unsupported_nodes_str += f"- {node_name}\n"
logger.debug(unsupported_nodes_str)
else:
logger.debug("\nAll Nodes Supported\n")

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def lower_graph_testing(
torch_executed_ops: Sequence[str] = set(),
testing_partitioning: bool = False,
):
"""Helper function to assist with graph lowering for testing of Dynamo torch_compile
"""Helper function to assist with graph lowering for testing of Dynamo compile

Args:
fx_graph: Graph to lower
Expand Down
Empty file.
9 changes: 6 additions & 3 deletions py/torch_tensorrt/dynamo/test/test_dynamo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from transformers import BertModel

from utils import COSINE_THRESHOLD, cosine_similarity
from torch_tensorrt.dynamo.common_utils.test_utils import (
COSINE_THRESHOLD,
cosine_similarity,
)


@pytest.mark.unit
Expand All @@ -30,7 +33,7 @@ def test_resnet18(ir):
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
cos_sim > COSINE_THRESHOLD,
f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
Expand Down Expand Up @@ -163,7 +166,7 @@ def test_resnet18_half(ir):
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
cos_sim > COSINE_THRESHOLD,
f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def aten_ops_cat(
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"tensors": args[0],
"dim": args[1],
"dim": args[1] if len(args) >= 2 else 0,
}
return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name)

Expand Down
35 changes: 35 additions & 0 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,41 @@ def forward(self, x, y):
expected_ops={torch.ops.aten.cat.default},
)

def test_cat_no_dim(self):
class Cat(nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y, z))

inputs = [torch.randn(2, 1, 3), torch.randn(1, 1, 3), torch.randn(3, 1, 3)]
self.run_test(
Cat(),
inputs,
expected_ops={torch.ops.aten.cat.default},
)

def test_cat_dynamic_shape_no_dim(self):
class Cat(nn.Module):
def forward(self, x, y):
return torch.cat((x, y))

input_specs = [
InputTensorSpec(
shape=(-1, 16, 3),
dtype=torch.float32,
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
),
InputTensorSpec(
shape=(-1, 16, 3),
dtype=torch.float32,
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
),
]
self.run_test_with_dynamic_shape(
Cat(),
input_specs,
expected_ops={torch.ops.aten.cat.default},
)


if __name__ == "__main__":
run_tests()