-
Notifications
You must be signed in to change notification settings - Fork 356
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
187 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from torch_tensorrt.dynamo.backend.lowering._decompositions import ( | ||
from ._decompositions import ( | ||
get_decompositions, | ||
) | ||
from torch_tensorrt.dynamo.backend.lowering._partition import ( | ||
partition, | ||
get_submod_inputs, | ||
from ._pre_aot_lowering import ( | ||
MODULE_SUBSTITUTION_REGISTRY, | ||
module_substitution, | ||
) | ||
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS | ||
from .module_substitutions import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .maxpool1d import * |
75 changes: 75 additions & 0 deletions
75
py/torch_tensorrt/dynamo/backend/lowering/module_substitutions/maxpool1d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import Dict, Tuple | ||
import torch | ||
from torch._custom_op import custom_op | ||
from torch.fx.node import Argument, Target | ||
|
||
from torch_tensorrt.fx.converter_registry import tensorrt_converter | ||
from torch_tensorrt.fx.converters import acc_ops_converters | ||
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor | ||
|
||
from torch_tensorrt.dynamo.backend.lowering import module_substitution | ||
|
||
|
||
@custom_op( | ||
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor", | ||
ns="tensorrt", | ||
) | ||
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): | ||
# Defines operator schema, name, namespace, and function header | ||
... | ||
|
||
|
||
@maxpool1d.impl("cpu") | ||
@maxpool1d.impl("cuda") | ||
def maxpool1d_generic( | ||
*args, | ||
**kwargs, | ||
): | ||
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation | ||
return torch.nn.functional.max_pool1d( | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
|
||
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) | ||
def aten_ops_maxpool1d( | ||
network: TRTNetwork, | ||
target: Target, | ||
args: Tuple[Argument, ...], | ||
kwargs: Dict[str, Argument], | ||
name: str, | ||
) -> TRTTensor: | ||
# Defines converter replacing the default operator for this function | ||
kwargs_new = { | ||
"input": args[0], | ||
"kernel_size": args[1], | ||
"stride": args[2], | ||
"padding": args[3], | ||
"dilation": args[4], | ||
"ceil_mode": False if len(args) < 6 else args[5], | ||
} | ||
|
||
return acc_ops_converters.acc_ops_max_pool1d( | ||
network, target, None, kwargs_new, name | ||
) | ||
|
||
|
||
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) | ||
def maxpool1d_insertion_fn( | ||
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node | ||
) -> torch.fx.Node: | ||
# Defines insertion function for new node | ||
new_node = gm.graph.call_function( | ||
torch.ops.tensorrt.maxpool1d, | ||
args=node.args, | ||
kwargs={ | ||
"kernel_size": submodule.kernel_size, | ||
"stride": submodule.stride, | ||
"padding": submodule.padding, | ||
"dilation": submodule.dilation, | ||
"ceil_mode": submodule.ceil_mode, | ||
}, | ||
) | ||
|
||
return new_node |
55 changes: 55 additions & 0 deletions
55
py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import torch | ||
from utils import lower_graph_testing | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
from torch_tensorrt.dynamo import compile | ||
|
||
|
||
class TestMaxPool1D(TestCase): | ||
def test_pre_aot_lowering_maxpool1d(self): | ||
class MaxPool1D(torch.nn.Module): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.maxpool = torch.nn.MaxPool1d(2) | ||
|
||
def forward(self, x): | ||
return self.maxpool(x) | ||
|
||
# Operations expected to be included in the traced graph after decompositions | ||
expected_ops = {torch.ops.tensorrt.maxpool1d.default} | ||
|
||
inputs = [ | ||
torch.rand( | ||
9, | ||
16, | ||
2, | ||
), | ||
] | ||
|
||
fx_graph = torch.fx.symbolic_trace(MaxPool1D()) | ||
_, expected_ops_unseen = lower_graph_testing( | ||
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1 | ||
) | ||
|
||
self.assertEquals( | ||
len(expected_ops_unseen), | ||
0, | ||
f"The following expected ops were not encountered: {expected_ops_unseen}", | ||
) | ||
|
||
torch._dynamo.reset() | ||
|
||
# Validate that the results between Torch and Torch-TRT are similar | ||
optimized_model = compile( | ||
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True | ||
) | ||
optimized_model_results = optimized_model(*inputs).detach().cpu() | ||
torch_model_results = fx_graph(*inputs).detach().cpu() | ||
|
||
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) | ||
self.assertAlmostEqual( | ||
max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model." | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters