Skip to content

Commit

Permalink
fix: Implement aten.mean.default and aten.mean.dim converters (#1810
Browse files Browse the repository at this point in the history
)
  • Loading branch information
gs-olive authored Apr 21, 2023
1 parent 5f77f56 commit b3f433a
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 11 deletions.
35 changes: 24 additions & 11 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def aten_ops_add(
return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name)


@tensorrt_converter(torch.ops.aten.mean.dim)
@tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
@tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
def aten_ops_adaptive_avg_poolnd(
Expand All @@ -51,24 +50,38 @@ def aten_ops_adaptive_avg_poolnd(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
if target == torch.ops.aten.mean.dim:

if list(args[1]) != [-1, -2]:
raise RuntimeError(f"We do not support {target} has dim={args[1]}")
else:
output_size = [1, 1]
else:
output_size = args[1]

kwargs_new = {
"input": args[0],
"output_size": output_size,
"output_size": args[1],
}
return acc_ops_converters.acc_ops_adaptive_avg_poolnd(
network, target, None, kwargs_new, name
)


@tensorrt_converter(torch.ops.aten.mean.default)
@tensorrt_converter(torch.ops.aten.mean.dim)
def aten_ops_mean(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
# Default invocation of aten.mean only uses first argument and
# averages over all elements (all dimensions)
# aten.mean.dim invocation allows specification of dimensions to average
# over, as well at the option to keep the dimension or not
kwargs_new = {
"input": args[0],
"dim": args[1] if len(args) >= 2 else list(range(len(args[0].shape))),
"keepdim": args[2] if len(args) >= 3 else False,
}
return add_reduce_layer(
network, target, args, kwargs_new, trt.ReduceOperation.AVG, name
)


@tensorrt_converter(torch.ops.aten.batch_norm)
def aten_ops_batch_norm(
network: TRTNetwork,
Expand Down
84 changes: 84 additions & 0 deletions py/torch_tensorrt/fx/test/converters/aten_op/test_mean_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec


class TestMeanDimConverter(DispatchTestCase):
def test_mean_dim_keepdims(self):
class TestModule(nn.Module):
def forward(self, x):
return torch.mean(x, dim=[0, 1], keepdim=True)

inputs = [torch.randn(1, 10)]
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim})

def test_mean_dim_keepdims_with_dynamic_shape(self):
class TestModule(nn.Module):
def forward(self, x):
return torch.mean(x, dim=[0, 1, 2], keepdim=True)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim}
)

def test_mean_dim_keepdims_false(self):
class TestModule(nn.Module):
def forward(self, x):
return torch.mean(x, dim=0, keepdim=False)

inputs = [torch.randn(3, 5, 7)]
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim})

def test_mean_dim_keepdims_false_with_dynamic_shape(self):
class TestModule(nn.Module):
def forward(self, x):
return torch.mean(x, dim=-1, keepdim=False)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim}
)


class TestMeanConverter(DispatchTestCase):
def test_mean(self):
class TestModule(nn.Module):
def forward(self, x):
return torch.mean(x)

inputs = [torch.randn(3, 8, 5, 7, 1)]
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.default})

def test_mean_with_dynamic_shape(self):
class TestModule(nn.Module):
def forward(self, x):
return torch.mean(x)

input_specs = [
InputTensorSpec(
shape=(-1, -1, -1),
dtype=torch.float32,
shape_ranges=[((1, 1, 1), (1, 5, 8), (3, 10, 10))],
),
]
self.run_test_with_dynamic_shape(
TestModule(), input_specs, expected_ops={torch.ops.aten.mean.default}
)


if __name__ == "__main__":
run_tests()

0 comments on commit b3f433a

Please sign in to comment.