Skip to content

Commit

Permalink
fix: Refactor code and add testing
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed May 25, 2023
1 parent 70530a3 commit c3c266d
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 86 deletions.
10 changes: 6 additions & 4 deletions py/torch_tensorrt/dynamo/backend/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
from ._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
from ._pre_aot_lowering import (
MODULE_SUBSTITUTION_REGISTRY,
module_substitution,
)
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
from .module_substitutions import *
12 changes: 10 additions & 2 deletions py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Dict, List, Optional, Sequence
from typing import Dict, List, Optional, Sequence, Set

import torch

from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name
Expand All @@ -14,6 +15,11 @@

logger = logging.getLogger(__name__)

DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
"torch.ops." + str(module.new_operator)
for module in MODULE_SUBSTITUTION_REGISTRY.values()
)


class TRTPartitioner(CapabilityBasedPartitioner):
"""Partitioner to split an FX graph into subgraphs based on operator support
Expand All @@ -35,7 +41,9 @@ def __init__(
operator_support: OperatorSupport,
*,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[
Sequence[str]
] = DEFAULT_SINGLE_NODE_PARTITIONS,
min_block_size=MIN_BLOCK_SIZE,
) -> None:
super().__init__(
Expand Down
115 changes: 35 additions & 80 deletions py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,12 @@
from dataclasses import dataclass
import traceback
from typing import Callable, Dict, Tuple
from typing import Any, Callable, Dict
import torch
from torch._custom_op import custom_op
from torch.fx.node import Argument, Target
import logging

from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

logger = logging.getLogger(__name__)


@custom_op(
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor",
ns="tensorrt",
)
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
# Defines operator schema, name, namespace, and function header
...


@maxpool1d.impl("cpu")
@maxpool1d.impl("cuda")
def maxpool1d_generic(
*args,
**kwargs,
):
# Defines a converter implementation for Autograd to use for shape analysis/propagation
return torch.nn.functional.max_pool1d(
*args,
**kwargs,
)


def maxpool1d_insertion_fn(
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
) -> torch.fx.Node:
# Defines insertion function for new node
new_node = gm.graph.call_function(
torch.ops.tensorrt.maxpool1d,
args=node.args,
kwargs={
"kernel_size": submodule.kernel_size,
"stride": submodule.stride,
"padding": submodule.padding,
"dilation": submodule.dilation,
"ceil_mode": submodule.ceil_mode,
},
)

return new_node


@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
def aten_ops_maxpool1d(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
# Defines converter replacing the default operator for this function
kwargs_new = {
"input": args[0],
"kernel_size": args[1],
"stride": args[2],
"padding": args[3],
"dilation": args[4],
"ceil_mode": False if len(args) < 6 else args[5],
}

return acc_ops_converters.acc_ops_max_pool1d(
network, target, None, kwargs_new, name
)


@dataclass(frozen=True)
class ModuleReplacement:
"""Class to store key functionality for module replacement"""
Expand All @@ -93,12 +23,37 @@ class ModuleReplacement:


# Dictionary mapping module to ModuleReplacement instance
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = {
torch.nn.MaxPool1d: ModuleReplacement(
new_operator=torch.ops.tensorrt.maxpool1d,
subgraph_insertion_fn=maxpool1d_insertion_fn,
),
}
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict()


def module_substitution(
module_to_replace: torch.nn.Module,
new_operator: torch._ops.OpOverload,
enabled: bool = True,
) -> Callable[[Any], Any]:
"""Decorator to register subgraph insertion functions
Args:
module_to_replace: nn.Module to replace
new_operator: Custom torch operator to replace with
enabled: Whether the substitution is enabled or disabled
Returns:
torch.fx.GraphModule
"""

def register_substitution(subgraph_insertion_fn):
"""Function for use if substitution is enabled"""
module_replacement = ModuleReplacement(
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
)
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement
return subgraph_insertion_fn

def disable_substitution(subgraph_insertion_fn):
"""Function for use if substitution is disabled"""
return subgraph_insertion_fn

return register_substitution if enabled else disable_substitution


def pre_aot_module_replacement(gm: torch.fx.GraphModule):
Expand Down Expand Up @@ -144,7 +99,7 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
id(type(submodule))
)

# Replace all original node uses and delete node
# Replace all original node uses and clean up graph
n.replace_all_uses_with(new_node)
gm.graph.eliminate_dead_code()
gm.recompile()
Expand All @@ -153,9 +108,9 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
# be replaced
except Exception:
logger.debug(
f"Encountered the following error while replacing {type(submodule)}"
f"Encountered error while replacing {type(submodule)}",
exc_info=True,
)
logger.debug(traceback.format_exc())
continue

# Perform cleanup and recompilation before returning module
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .maxpool1d import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Dict, Tuple
import torch
from torch._custom_op import custom_op
from torch.fx.node import Argument, Target

from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from torch_tensorrt.dynamo.backend.lowering import module_substitution


@custom_op(
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor",
ns="tensorrt",
)
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
# Defines operator schema, name, namespace, and function header
...


@maxpool1d.impl("cpu")
@maxpool1d.impl("cuda")
def maxpool1d_generic(
*args,
**kwargs,
):
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
return torch.nn.functional.max_pool1d(
*args,
**kwargs,
)


@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
def aten_ops_maxpool1d(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
# Defines converter replacing the default operator for this function
kwargs_new = {
"input": args[0],
"kernel_size": args[1],
"stride": args[2],
"padding": args[3],
"dilation": args[4],
"ceil_mode": False if len(args) < 6 else args[5],
}

return acc_ops_converters.acc_ops_max_pool1d(
network, target, None, kwargs_new, name
)


@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
def maxpool1d_insertion_fn(
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
) -> torch.fx.Node:
# Defines insertion function for new node
new_node = gm.graph.call_function(
torch.ops.tensorrt.maxpool1d,
args=node.args,
kwargs={
"kernel_size": submodule.kernel_size,
"stride": submodule.stride,
"padding": submodule.padding,
"dilation": submodule.dilation,
"ceil_mode": submodule.ceil_mode,
},
)

return new_node
55 changes: 55 additions & 0 deletions py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from utils import lower_graph_testing
from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.dynamo import compile


class TestMaxPool1D(TestCase):
def test_pre_aot_lowering_maxpool1d(self):
class MaxPool1D(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.maxpool = torch.nn.MaxPool1d(2)

def forward(self, x):
return self.maxpool(x)

# Operations expected to be included in the traced graph after decompositions
expected_ops = {torch.ops.tensorrt.maxpool1d.default}

inputs = [
torch.rand(
9,
16,
2,
),
]

fx_graph = torch.fx.symbolic_trace(MaxPool1D())
_, expected_ops_unseen = lower_graph_testing(
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
self.assertAlmostEqual(
max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model."
)


if __name__ == "__main__":
run_tests()
5 changes: 5 additions & 0 deletions py/torch_tensorrt/dynamo/backend/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
)
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
pre_aot_module_replacement,
)

from torch._dynamo.backends.common import fake_tensor_unsupported

Expand All @@ -31,6 +34,8 @@ def fx_dynamo_testing_backend(
torch_executed_ops=torch_executed_ops,
)

gm = pre_aot_module_replacement(gm)

# Invoke AOTAutograd to translate operators to aten
return aot_module_simplified(
gm,
Expand Down

0 comments on commit c3c266d

Please sign in to comment.