diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 204d94507f..3644d95752 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -134,100 +134,6 @@ def layer_norm( return layer_norm.get_output(0) -# def native_group_norm( -# ctx: ConversionContext, -# target: Target, -# source_ir: Optional[SourceIR], -# name: str, -# input: TRTTensor, -# weight: Optional[Union[torch.Tensor, np.ndarray]], -# bias: Optional[Union[torch.Tensor, np.ndarray]], -# N: int, -# C: int, -# HxW: int, -# group: int, -# eps: float, -# return_mean_rstd: bool = True, -# ) -> Union[TRTTensor, Sequence[TRTTensor]]: -# assert ( -# len(input.shape) >= 3 -# ), f"The input dimension should not be less than 3, got {len(input.shape)}!" - -# B = input.shape[0] -# # if C is provided, it must be as same as the channel from the input shape, -# # else if C is zero, we should get the channel from the input shape -# if C == 0: -# C = input.shape[1] -# assert ( -# C == input.shape[1] -# ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}" -# # Groups are a subdivision of the channel dimension. -# assert ( -# C % group == 0 -# ), f"The num of channels ({C}) should be divisible by num_groups ({group})!" -# input = get_trt_tensor(ctx, input, f"{name}_input") - -# shape = list(input.shape) - -# for i, s in enumerate(shape): -# if i == 0 and s > 0: -# shape[i] = B * group -# elif i == 1: -# shape[i] = C // group -# elif i > 1 and s == -1: -# shape[i] = 0 - -# # Normalize every group. -# reshaped_input = impl.shuffle.reshape( -# ctx, -# target, -# source_ir, -# f"{name}_reshape_input", -# input, -# shape, -# ) - -# weight = get_trt_tensor(ctx, weight, f"{name}_weight") -# bias = get_trt_tensor(ctx, bias, f"{name}_bias") -# if tuple(reshaped_input.shape) != tuple(weight.shape): -# weight = impl.slice.expand( -# ctx, -# target, -# source_ir, -# f"{name}_expand_weight", -# weight, -# reshaped_input.shape, -# ) -# if tuple(reshaped_input.shape) != tuple(bias.shape): -# bias = impl.slice.expand( -# ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape -# ) -# dims = list(range(1, len(input.shape))) -# axes = get_axes_for_reduce_op(dims) -# group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes) -# group_norm.epsilon = eps -# group_norm.compute_precision = input.dtype -# set_layer_name(group_norm, target, f"{name}_group_norm", source_ir) -# output = group_norm.get_output(0) - -# shape = list(output.shape) -# for i, s in enumerate(shape): -# if i == 0 and s > 0: -# shape[i] = B -# elif i == 1: -# shape[i] = C -# elif i > 1 and s == -1: -# shape[i] = 0 - -# reshaped_output = impl.shuffle.reshape( -# ctx, target, source_ir, f"{name}_reshape_output", output, shape -# ) -# if return_mean_rstd: -# # return fake mean and rstd for now -# return reshaped_output, None, None -# return reshaped_output - - def native_group_norm( ctx: ConversionContext, target: Target, @@ -243,6 +149,8 @@ def native_group_norm( eps: float, return_mean_rstd: bool = True, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + # TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation + # with INormalization Layer assert ( len(input.shape) >= 3 ), f"The input dimension should not be less than 3, got {len(input.shape)}!" @@ -286,36 +194,94 @@ def native_group_norm( weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) dims = list(range(1, len(input.shape))) - axes = get_axes_for_reduce_op(dims) - # Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch - # TODO: check with TRT the correct way to use 'num_groups' to implement group norm - dummy_weight = get_trt_tensor( - ctx, np.array([1.0]), f"{name}_dummy_weight", input.dtype + + # E[X] + mean_trt = impl.reduce.mean( + ctx, + target, + source_ir, + f"{name}_mean", + reshaped_input, + dims, + True, ) - dummy_weight = impl.slice.expand( + + mean_trt = impl.slice.expand( ctx, target, source_ir, - f"{name}_expand_dummy_weight", - dummy_weight, + f"{name}_expand_mean_trt", + mean_trt, reshaped_input.shape, ) - dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias", input.dtype) - dummy_bias = impl.slice.expand( + + # X - E[X] + sub_trt = impl.elementwise.sub( + ctx, + target, + source_ir, + f"{name}_sub", + reshaped_input, + mean_trt, + ) + + # variance + pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) + pow_var = impl.elementwise.pow( + ctx, + target, + source_ir, + f"{name}_pow", + sub_trt, + pow_trt, + ) + + var_trt = impl.reduce.mean( + ctx, + target, + source_ir, + f"{name}_mean_var", + pow_var, + dims, + True, + ) + + var_trt = impl.slice.expand( ctx, target, source_ir, - f"{name}_expand_dummy_bias", - dummy_bias, + f"{name}_expand_var_trt", + var_trt, reshaped_input.shape, ) - group_norm = ctx.net.add_normalization( - reshaped_input, dummy_weight, dummy_bias, axes + + eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) + add_trt = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_add", + var_trt, + eps_trt, + ) + + sqrt_trt = impl.unary.sqrt( + ctx, + target, + source_ir, + f"{name}_sqrt", + add_trt, + ) + + # y = (X - E[X]) / sqrt((var + eps)) + output = impl.elementwise.div( + ctx, + target, + source_ir, + f"{name}_div", + sub_trt, + sqrt_trt, ) - group_norm.epsilon = eps - group_norm.compute_precision = input.dtype - set_layer_name(group_norm, target, f"{name}_group_norm", source_ir) - output = group_norm.get_output(0) shape = list(output.shape) for i, s in enumerate(shape): @@ -329,12 +295,11 @@ def native_group_norm( reshaped_output = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_output", output, shape ) - - weight = impl.shuffle.reshape( + reshaped_gamma = impl.shuffle.reshape( ctx, target, source_ir, - f"{name}_weight", + f"{name}_reshape_gamma", weight, weight_bias_shape, ) @@ -343,16 +308,16 @@ def native_group_norm( ctx, target, source_ir, - f"{name}_mul_weight", + f"{name}_mul_gamma", reshaped_output, - weight, + reshaped_gamma, ) - bias = impl.shuffle.reshape( + reshaped_bias = impl.shuffle.reshape( ctx, target, source_ir, - f"{name}_reshape_bias", + f"{name}_reshape_beta", bias, weight_bias_shape, ) @@ -360,9 +325,9 @@ def native_group_norm( ctx, target, source_ir, - f"{name}_add_bias", + f"{name}_add_beta", reshaped_output, - bias, + reshaped_bias, ) if return_mean_rstd: # return fake mean and rstd for now diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py index c1668de3f9..cf5dd48f0a 100644 --- a/tests/py/dynamo/conversion/test_group_norm_aten.py +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -112,27 +112,26 @@ def forward(self, x): inputs, ) - # TODO: Half precision has accuracy issue for now. - # def test_groupnorm_sd(self): - # class GroupNorm(torch.nn.Module): - # def forward(self, x): - # return torch.ops.aten.native_group_norm.default( - # x, - # torch.randn((320,)).half(), - # torch.randn((320,)).half(), - # 2, - # 320, - # 4096, - # 32, - # 1e-05, - # )[0] - - # inputs = [torch.randn(2, 320, 64, 64).half()] - # with torch.no_grad(): - # self.run_test( - # GroupNorm(), - # inputs, - # ) + def test_groupnorm_sd(self): + class GroupNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_group_norm.default( + x, + torch.randn((320,)).half(), + torch.randn((320,)).half(), + 2, + 320, + 4096, + 32, + 1e-05, + )[0] + + inputs = [torch.randn(2, 320, 64, 64).half()] + with torch.no_grad(): + self.run_test( + GroupNorm(), + inputs, + ) @parameterized.expand( [