Skip to content

Commit

Permalink
fix: Add support for passing through build issues
Browse files Browse the repository at this point in the history
- Add support for `pass_through_build_failures` keyword arg
- Add failure pass through testing to all e2e tests to validate feature
- Add minor typo fixes
  • Loading branch information
gs-olive committed May 25, 2023
1 parent 618615b commit 0f35954
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 10 deletions.
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DEBUG,
MAX_WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
)


Expand Down Expand Up @@ -52,7 +53,8 @@ def compile(
logger.warn(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}"
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
+ "torch_executed_ops, pass_through_build_failures}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -106,6 +108,7 @@ def create_backend(
workspace_size: int = MAX_WORKSPACE_SIZE,
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
**kwargs,
):
"""Create torch.compile backend given specified arguments
Expand All @@ -124,6 +127,7 @@ def create_backend(
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
pass_through_build_failures=pass_through_build_failures,
)

return partial(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
DEBUG = False
MAX_WORKSPACE_SIZE = 20 << 30
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DEBUG,
MAX_WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
)


Expand All @@ -17,3 +18,4 @@ class CompilationSettings:
workspace_size: int = MAX_WORKSPACE_SIZE
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
23 changes: 18 additions & 5 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Sequence
import torch
import traceback
from functools import partial
import torch._dynamo as td

Expand All @@ -19,6 +19,9 @@
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler


logger = logging.getLogger(__name__)


@td.register_backend(name="torch_tensorrt")
@fake_tensor_unsupported
def torch_tensorrt_backend(
Expand Down Expand Up @@ -75,12 +78,22 @@ def _pretraced_backend(
)
return trt_compiled
except:
traceback.print_exc()
print(
logger.error(
"FX2TRT conversion failed on the subgraph. See trace above. "
+ "Returning GraphModule forward instead."
+ "Returning GraphModule forward instead.",
exc_info=True,
)
return gm.forward

if not settings.pass_through_build_failures:
return gm.forward
else:
raise AssertionError(
"Halting compilation on build failure since "
+ "pass_through_build_failures was specified as True. "
+ "To return the default Torch implementation and avoid "
+ "halting compilation on engine build failures, "
+ "specify pass_through_build_failures=False."
)


def _compile_module(
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def lower_graph_testing(
torch_executed_ops: Sequence[str] = set(),
testing_partitioning: bool = False,
):
"""Helper function to assist with graph lowering for testing of Dynamo torch_compile
"""Helper function to assist with graph lowering for testing of Dynamo compile
Args:
fx_graph: Graph to lower
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/common_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .test_utils import *
File renamed without changes.
14 changes: 11 additions & 3 deletions py/torch_tensorrt/dynamo/test/test_dynamo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from transformers import BertModel

from utils import COSINE_THRESHOLD, cosine_similarity
from torch_tensorrt.dynamo.common_utils.test_utils import (
COSINE_THRESHOLD,
cosine_similarity,
)


@pytest.mark.unit
Expand All @@ -24,13 +27,14 @@ def test_resnet18(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
cos_sim > COSINE_THRESHOLD,
f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
Expand All @@ -54,6 +58,7 @@ def test_mobilenet_v2(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down Expand Up @@ -84,6 +89,7 @@ def test_efficientnet_b0(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
Expand Down Expand Up @@ -123,6 +129,7 @@ def test_bert_base_uncased(ir):
"enabled_precisions": {torch.float},
"truncate_long_and_double": True,
"ir": ir,
"pass_through_build_failures": True,
}
trt_mod = torchtrt.compile(model, **compile_spec)

Expand Down Expand Up @@ -157,13 +164,14 @@ def test_resnet18_half(ir):
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.half},
"ir": ir,
"pass_through_build_failures": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assert (
cos_sim > COSINE_THRESHOLD,
f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
Expand Down

0 comments on commit 0f35954

Please sign in to comment.