Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: a lowering pass to re-compose ops into aten.linear #2411

Merged
merged 5 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,3 +1957,31 @@ def aten_ops_argmax(
dim=args_bounds_check(args, 1),
keep_dim=args_bounds_check(args, 2, False),
)


@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc]
@enforce_tensor_types(
{
0: (TRTTensor,),
1: (np.ndarray, torch.Tensor, TRTTensor),
2: (np.ndarray, torch.Tensor, TRTTensor),
}
) # type: ignore[misc]
def aten_ops_addmm(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.addmm.addmm(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
beta=kwargs.get("beta", 1),
alpha=kwargs.get("alpha", 1),
)
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from . import (
activation,
attention,
addmm,
argmax,
attention,
cast,
cat,
condition,
Expand Down
34 changes: 34 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional, Union

import numpy as np
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.types import TRTTensor


def addmm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
mat1: Union[TRTTensor, torch.Tensor, np.ndarray],
mat2: Union[TRTTensor, torch.Tensor, np.ndarray],
*,
beta: Union[float, int],
alpha: Union[float, int],
) -> TRTTensor:
mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2)
if alpha != 1:
mm = impl.elementwise.mul(
ctx, target, SourceIR.ATEN, f"{name}_mul_alpha", mm, alpha
)
if beta != 1:
input = impl.elementwise.mul(
ctx, target, SourceIR.ATEN, f"{name}_mul_beta", input, beta
)

return impl.elementwise.add(ctx, target, source_ir, f"{name}_add", input, mm)
16 changes: 0 additions & 16 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,6 @@ def alias_replacement(x: torch.Tensor) -> torch.Tensor:
return x


@register_torch_trt_decomposition(
torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS
)
def addmm_replacement(
input_: torch.Tensor,
mat1: torch.Tensor,
mat2: torch.Tensor,
*,
beta: int = 1,
alpha: int = 1,
) -> torch.Tensor:
return torch.add(
torch.mul(input_, beta), torch.mul(torch.matmul(mat1, mat2), alpha)
)
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved


@register_torch_trt_decomposition(
torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_efficient_attention import lower_efficient_attention
from .lower_linear import lower_linear
from .pass_manager import DynamoPassManager
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
Expand All @@ -17,6 +18,7 @@
constant_fold,
repair_input_as_output,
lower_efficient_attention,
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
get_tensor_placeholders,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,34 +35,13 @@ def efficient_attention_replacement() -> (
):
"""Constructs the original and replacement functions for efficient attention"""

# Empty boilerplate function taking in three Tensors and returning one
def boilerplate(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
...

# Trace boilerplate function and extract placeholder and output nodes
orig = torch.fx.symbolic_trace(boilerplate)
q, k, v = get_tensor_placeholders(orig)
output = [node for node in orig.graph.nodes if node.op == "output"][0]

# Graph types to replace are those which use the _scaled_dot_product_efficient_attention
# function and extract only the first element
with orig.graph.inserting_before(output):
att = orig.graph.call_function(
torch.ops.aten._scaled_dot_product_efficient_attention.default,
args=(q, k, v, None, False),
# Original graph
def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
q, k, v, None, False
)
out = orig.graph.call_function(
operator.getitem,
args=(att, 0),
)

# Assign the output of the graph to be the single getitem output
output.args = (out,)

orig.graph.lint()
orig.recompile()
out = operator.getitem(outputs, 0)
return out

# Replacement graph consists of the functional version of scaled_dot_product_attention
def replacement(
Expand Down
47 changes: 47 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from typing import Callable, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def lower_linear(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
orig, replacement = linear_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after lowering linear:\n{gm.graph}")

return gm


def linear_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
):
"""Constructs the original and replacement functions for linear"""

# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
W_T = torch.ops.aten.permute.default(weight, [1, 0])
out = torch.ops.aten.addmm.default(bias, input, W_T)
return out

# Replacement graph
def replacement(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
return torch.ops.aten.linear.default(input, weight, bias)

return orig, replacement
65 changes: 65 additions & 0 deletions tests/py/dynamo/conversion/test_addmm_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestAddmmConverter(DispatchTestCase):
@parameterized.expand(
[
((2, 2), (2, 3), (3, 2)),
((4, 6), (4, 5), (5, 6)),
((2, 1), (2, 3), (3, 1)),
((4, 1), (4, 1), (1, 1)),
((1, 2), (1, 3), (3, 2)),
]
)
def test_addmm(self, input_shape, mat1_shape, mat2_shape):
class Addmm(nn.Module):
def forward(self, input, mat1, mat2):
return torch.ops.aten.addmm.default(input, mat1, mat2)

inputs = [
torch.randn(input_shape),
torch.randn(mat1_shape),
torch.randn(mat2_shape),
]

self.run_test(
Addmm(),
inputs,
)

@parameterized.expand(
[
((2, 2), (2, 3), (3, 2), 1.0, 1.0),
((4, 6), (4, 5), (5, 6), 1.2, 0.8),
((2, 1), (2, 3), (3, 1), 3, 2),
((4, 1), (4, 1), (1, 1), 1, 1),
((1, 2), (1, 3), (3, 2), 2, 1.0),
((1, 2), (1, 3), (3, 2), 1, 2.0),
]
)
def test_addmm_scale(self, input_shape, mat1_shape, mat2_shape, beta, alpha):
class Addmm(nn.Module):
def forward(self, input, mat1, mat2):
return torch.ops.aten.addmm.default(
input, mat1, mat2, beta=beta, alpha=alpha
)

inputs = [
torch.randn(input_shape),
torch.randn(mat1_shape),
torch.randn(mat2_shape),
]

self.run_test(
Addmm(),
inputs,
)


if __name__ == "__main__":
run_tests()
108 changes: 108 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,113 @@ def forward(self, q, k, v):
torch._dynamo.reset()


class TestLowerLinear(TestCase):
def test_lower_linear(self):
class Linear(torch.nn.Module):
def forward(self, input, weight, bias):
out = torch.ops.aten.linear.default(input, weight, bias)
return out

inputs = [
torch.rand((3, 32)).cuda(),
torch.rand((64, 32)).cuda(),
torch.rand((64,)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(Linear())
expected_ops = {torch.ops.aten.linear.default}
unexpected_ops = {
torch.ops.aten.permute.default,
torch.ops.aten.addmm.default,
}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Linear TRT outputs don't match with the original model.",
)
torch._dynamo.reset()

def test_lower_linear_batch(self):
class Linear(torch.nn.Module):
def forward(self, input, weight, bias):
out = torch.ops.aten.linear.default(input, weight, bias)
return out

inputs = [
torch.rand((2, 2, 32)).cuda(),
torch.rand((64, 32)).cuda(),
torch.rand((64,)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(Linear())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"Linear TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


if __name__ == "__main__":
run_tests()
Loading
Loading