From 57c66bfc7c25596579a392a353911dcd5519db2a Mon Sep 17 00:00:00 2001 From: HolyWu Date: Mon, 4 Nov 2024 19:40:09 +0800 Subject: [PATCH] Use INormalizationLayer layer for GroupNorm --- .../dynamo/conversion/aten_ops_converters.py | 38 +-- .../conversion/impl/normalization/ops.py | 257 ++++-------------- .../dynamo/conversion/test_group_norm_aten.py | 147 ++-------- 3 files changed, 73 insertions(+), 369 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 07c8c03697..778cb16fbd 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -175,14 +175,14 @@ def aten_ops_layer_norm( 0: (TRTTensor,), } ) -def aten_ops_native_group_norm( +def aten_ops_group_norm( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.normalization.native_group_norm( + return impl.normalization.group_norm( ctx, target, SourceIR.ATEN, @@ -198,40 +198,6 @@ def aten_ops_native_group_norm( ) -@dynamo_tensorrt_converter( - torch.ops.aten.group_norm.default, - supports_dynamic_shapes=True, -) -@dynamo_tensorrt_converter( - torch.ops.aten.group_norm, - supports_dynamic_shapes=True, -) -@enforce_tensor_types( - { - 0: (TRTTensor,), - } -) -def aten_ops_group_norm( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.normalization.group_norm( - ctx, - target, - SourceIR.ATEN, - name, - input=args[0], - num_groups=args[1], - weight=args_bounds_check(args, 2, None), - bias=args_bounds_check(args, 3, None), - eps=args_bounds_check(args, 4, 1e-05), - cudnn_enabled=args_bounds_check(args, 5, True), - ) - - @dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True) def aten_ops_cat( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 4f39a6d5d9..58fdf1bca6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -1,5 +1,5 @@ import logging -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from typing import List, Optional, Sequence, Tuple, Union import numpy as np import tensorrt as trt @@ -16,7 +16,6 @@ get_trt_tensor, has_dynamic_shape, set_layer_name, - to_numpy, ) from torch_tensorrt.dynamo.conversion.impl.cat import cat from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge @@ -204,240 +203,80 @@ def layer_norm( return layer_norm.get_output(0) -def native_group_norm( +def group_norm( ctx: ConversionContext, target: Target, source_ir: Optional[SourceIR], name: str, input: TRTTensor, - weight: Optional[Union[torch.Tensor, np.ndarray]], - bias: Optional[Union[torch.Tensor, np.ndarray]], + weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], N: int, C: int, HxW: int, group: int, eps: float, - return_mean_rstd: bool = True, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - # TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation - # with INormalization Layer +) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]: assert ( len(input.shape) >= 3 - ), f"The input dimension should not be less than 3, got {len(input.shape)}!" + ), f"Expected at least 3 dimensions for input tensor but got {len(input.shape)}" - B = input.shape[0] - # if C is provided, it must be as same as the channel from the input shape, - # else if C is zero, we should get the channel from the input shape - if C == 0: - C = input.shape[1] assert ( C == input.shape[1] - ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}" - # Groups are a subdivision of the channel dimension. - assert ( - C % group == 0 - ), f"The num of channels ({C}) should be divisible by num_groups ({group})!" - input = get_trt_tensor(ctx, input, f"{name}_input") - - shape = list(input.shape) - - for i, s in enumerate(shape): - if i == 0 and s > 0: - shape[i] = B * group - elif i == 1: - shape[i] = C // group - elif i > 1 and s == -1: - shape[i] = 0 - - # Normalize every group. - reshaped_input = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_input", - input, - shape, - ) - - if weight is None: - weight = to_numpy(1.0) + ), f"num_channels ({C}) must be equal to number of channels in input ({input.shape[1]})" - if bias is None: - bias = to_numpy(0.0) + weight_one = get_trt_tensor(ctx, 1.0, f"{name}_weight_one", input.dtype) + bias_zero = get_trt_tensor(ctx, 0.0, f"{name}_bias_zero", input.dtype) - weight = get_trt_tensor(ctx, weight, f"{name}_weight") - bias = get_trt_tensor(ctx, bias, f"{name}_bias") - weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) - - dims = list(range(1, len(input.shape))) - - # E[X] - mean_trt = impl.reduce.mean( - ctx, - target, - source_ir, - f"{name}_mean", - reshaped_input, - dims, - True, - ) - - mean_trt = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_mean_trt", - mean_trt, - reshaped_input.shape, - ) - - # X - E[X] - sub_trt = impl.elementwise.sub( - ctx, - target, - source_ir, - f"{name}_sub", - reshaped_input, - mean_trt, - ) - - # variance - pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) - pow_var = impl.elementwise.pow( - ctx, - target, - source_ir, - f"{name}_pow", - sub_trt, - pow_trt, - ) - - var_trt = impl.reduce.mean( - ctx, - target, - source_ir, - f"{name}_mean_var", - pow_var, - dims, - True, - ) - - var_trt = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_var_trt", - var_trt, - reshaped_input.shape, - ) - - eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) - add_trt = impl.elementwise.add( - ctx, - target, - source_ir, - f"{name}_add", - var_trt, - eps_trt, - ) + shape = [1, group] + [1] * (len(input.shape) - 2) - sqrt_trt = impl.unary.sqrt( - ctx, - target, - source_ir, - f"{name}_sqrt", - add_trt, + expanded_weight_one = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_weight_one", weight_one, shape ) - - # y = (X - E[X]) / sqrt((var + eps)) - output = impl.elementwise.div( - ctx, - target, - source_ir, - f"{name}_div", - sub_trt, - sqrt_trt, + expanded_bias_zero = impl.slice.expand( + ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape ) - shape = list(output.shape) - for i, s in enumerate(shape): - if i == 0 and s > 0: - shape[i] = B - elif i == 1: - shape[i] = C - elif i > 1 and s == -1: - shape[i] = 0 + axes = get_axes_for_reduce_op([i for i in range(2, len(input.shape))]) - reshaped_output = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_reshape_output", output, shape - ) - reshaped_gamma = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_gamma", - weight, - weight_bias_shape, + # INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel, + # hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later + layer = ctx.net.add_normalization( + input, expanded_weight_one, expanded_bias_zero, axes ) + layer.epsilon = eps + layer.num_groups = group + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) - reshaped_output = impl.elementwise.mul( - ctx, - target, - source_ir, - f"{name}_mul_gamma", - reshaped_output, - reshaped_gamma, - ) + shape[1] = C - reshaped_bias = impl.shuffle.reshape( - ctx, - target, - source_ir, - f"{name}_reshape_beta", - bias, - weight_bias_shape, - ) - reshaped_output = impl.elementwise.add( - ctx, - target, - source_ir, - f"{name}_add_beta", - reshaped_output, - reshaped_bias, - ) - if return_mean_rstd: - # return fake mean and rstd for now - return reshaped_output, None, None - return reshaped_output + if weight is not None: + weight = get_trt_tensor(ctx, weight, f"{name}_weight") + weight = cast_trt_tensor( + ctx, weight, input.dtype, f"{name}_cast_weight", target, source_ir + ) + weight = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_weight", weight, shape + ) + output = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_mul_weight", output, weight + ) + if bias is not None: + bias = get_trt_tensor(ctx, bias, f"{name}_bias") + bias = cast_trt_tensor( + ctx, bias, input.dtype, f"{name}_cast_bias", target, source_ir + ) + bias = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_bias", bias, shape + ) + output = impl.elementwise.add( + ctx, target, source_ir, f"{name}_add_bias", output, bias + ) -def group_norm( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - num_groups: int, - weight: Optional[Union[torch.Tensor, np.ndarray]], - bias: Optional[Union[torch.Tensor, np.ndarray]], - eps: float, - cudnn_enabled: bool, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return native_group_norm( - ctx, - target, - source_ir, - name, - input, - weight, - bias, - 0, - 0, - 0, - num_groups, - eps, - return_mean_rstd=False, - ) + # return fake mean and rstd for now + return output, None, None def softmax( diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py index 617166d0c4..2caf17c8d4 100644 --- a/tests/py/dynamo/conversion/test_group_norm_aten.py +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -6,155 +6,56 @@ from .harness import DispatchTestCase -class TestGroupNormConverter(DispatchTestCase): - def test_groupnorm1d(self): - class GroupNorm(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.group_norm.default( - x, - 2, - torch.ones((6,)), - torch.zeros((6,)), - 1e-05, - True, - ) - - inputs = [torch.randn(3, 6, 224)] - self.run_test( - GroupNorm(), - inputs, - ) - - def test_groupnorm2d(self): - class GroupNorm(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.group_norm.default( - x, - 2, - torch.randn((6,)), - torch.randn((6,)), - 1e-05, - True, - ) - - inputs = [torch.randn(3, 6, 224, 224)] - with torch.no_grad(): - self.run_test( - GroupNorm(), - inputs, - ) - - def test_groupnorm_with_dynamic_shape(self): - class GroupNorm(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.group_norm.default( - x, - 2, - torch.randn((6,)), - torch.randn((6,)), - 1e-05, - True, - ) - - input_specs = [ - Input( - dtype=torch.float32, - min_shape=(3, 6, 24, 24), - opt_shape=(5, 6, 24, 24), - max_shape=(8, 6, 48, 24), - ), - ] - self.run_test_with_dynamic_shape( - GroupNorm(), - input_specs, - ) - - class TestNativeGroupNormConverter(DispatchTestCase): - def test_groupnorm1d(self): + def test_groupnorm_1d(self): class GroupNorm(torch.nn.Module): def forward(self, x): return torch.ops.aten.native_group_norm.default( - x, - torch.ones((6,)), - torch.zeros((6,)), - 3, - 6, - 224, - 2, - 1e-05, + x, None, None, 3, 6, 224, 2, 1e-05 )[0] inputs = [torch.randn(3, 6, 224)] - self.run_test( - GroupNorm(), - inputs, - ) + self.run_test(GroupNorm(), inputs, use_dynamo_tracer=True) - def test_groupnorm2d(self): + def test_groupnorm_2d(self): class GroupNorm(torch.nn.Module): - def forward(self, x): + def forward(self, x, weight, bias): return torch.ops.aten.native_group_norm.default( - x, - torch.ones((6,)), - torch.zeros((6,)), - 3, - 6, - 224 * 224, - 2, - 1e-05, + x, weight, bias, 3, 6, 224 * 224, 2, 1e-05 )[0] - inputs = [torch.randn(3, 6, 224, 224)] - with torch.no_grad(): - self.run_test( - GroupNorm(), - inputs, - ) + inputs = [torch.randn(3, 6, 224, 224), torch.ones(6), torch.zeros(6)] + self.run_test(GroupNorm(), inputs, use_dynamo_tracer=True) def test_groupnorm_sd(self): class GroupNorm(torch.nn.Module): - def forward(self, x): + def forward(self, x, weight, bias): return torch.ops.aten.native_group_norm.default( - x, - torch.randn((320,)).half(), - torch.randn((320,)).half(), - 2, - 320, - 4096, - 32, - 1e-05, + x, weight, bias, 2, 320, 64 * 64, 32, 1e-05 )[0] - inputs = [torch.randn(2, 320, 64, 64).half()] - with torch.no_grad(): - self.run_test( - GroupNorm(), - inputs, - ) + inputs = [ + torch.randn(2, 320, 64, 64, dtype=torch.half), + torch.randn(320, dtype=torch.half), + torch.randn(320, dtype=torch.half), + ] + self.run_test(GroupNorm(), inputs, precision=torch.half, use_dynamo_tracer=True) @parameterized.expand( [ (5, 4, 4, 2, (2, 4, 2), (3, 4, 2), (5, 4, 4)), (5, 4, 2 * 2, 2, (2, 4, 2, 2), (3, 4, 2, 2), (5, 4, 2, 2)), (5, 9, 6 * 3, 3, (3, 9, 3, 3), (4, 9, 3, 3), (5, 9, 6, 3)), - (8, 9, 6 * 6, 3, (3, 9, 2, 3, 2), (5, 9, 3, 3, 2), (8, 9, 6, 3, 2)), + (8, 9, 6 * 3 * 2, 3, (3, 9, 2, 3, 2), (5, 9, 3, 3, 2), (8, 9, 6, 3, 2)), ] ) def test_groupnorm_with_dynamic_shape( - self, N, C, HxW, groups, min_shape, opt_shape, max_shape + self, N, C, HxW, group, min_shape, opt_shape, max_shape ): class GroupNorm(torch.nn.Module): - def forward(self, x): + def forward(self, x, weight, bias): return torch.ops.aten.native_group_norm.default( - x, - torch.ones((C,)), - torch.zeros((C,)), - N, - C, - HxW, - groups, - 1e-5, + x, weight, bias, N, C, HxW, group, 1e-05 )[0] input_specs = [ @@ -164,12 +65,10 @@ def forward(self, x): opt_shape=opt_shape, max_shape=max_shape, ), + Input(dtype=torch.float32, shape=(C,)), + Input(dtype=torch.float32, shape=(C,)), ] - self.run_test_with_dynamic_shape( - GroupNorm(), - input_specs, - check_dtype=False, - ) + self.run_test_with_dynamic_shape(GroupNorm(), input_specs, check_dtype=False) if __name__ == "__main__":