From 8649c5bcf9900e763905c7de152a912456b597b8 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 17 Jul 2024 12:34:39 -0700 Subject: [PATCH] Prvious commit still fail on stabel diffusion 1.5. Changed to use decomposed ops instead of using INormalization Layer. Supported dynamic shape --- .../conversion/impl/normalization/ops.py | 250 ++++++++++++++++-- .../dynamo/conversion/test_group_norm_aten.py | 41 ++- 2 files changed, 243 insertions(+), 48 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 204d94507f..a6b4e99b1c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -227,6 +227,143 @@ def layer_norm( # return reshaped_output, None, None # return reshaped_output +# def native_group_norm( +# ctx: ConversionContext, +# target: Target, +# source_ir: Optional[SourceIR], +# name: str, +# input: TRTTensor, +# weight: Optional[Union[torch.Tensor, np.ndarray]], +# bias: Optional[Union[torch.Tensor, np.ndarray]], +# N: int, +# C: int, +# HxW: int, +# group: int, +# eps: float, +# return_mean_rstd: bool = True, +# ) -> Union[TRTTensor, Sequence[TRTTensor]]: +# assert ( +# len(input.shape) >= 3 +# ), f"The input dimension should not be less than 3, got {len(input.shape)}!" + +# B = input.shape[0] +# # if C is provided, it must be as same as the channel from the input shape, +# # else if C is zero, we should get the channel from the input shape +# if C == 0: +# C = input.shape[1] +# assert ( +# C == input.shape[1] +# ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}" +# # Groups are a subdivision of the channel dimension. +# assert ( +# C % group == 0 +# ), f"The num of channels ({C}) should be divisible by num_groups ({group})!" +# input = get_trt_tensor(ctx, input, f"{name}_input") + +# shape = list(input.shape) + +# for i, s in enumerate(shape): +# if i == 0 and s > 0: +# shape[i] = B * group +# elif i == 1: +# shape[i] = C // group +# elif i > 1 and s == -1: +# shape[i] = 0 + +# # Normalize every group. +# reshaped_input = impl.shuffle.reshape( +# ctx, +# target, +# source_ir, +# f"{name}_reshape_input", +# input, +# shape, +# ) + +# weight = get_trt_tensor(ctx, weight, f"{name}_weight") +# bias = get_trt_tensor(ctx, bias, f"{name}_bias") +# weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) + +# dims = list(range(1, len(input.shape))) +# axes = get_axes_for_reduce_op(dims) +# dummy_weight = get_trt_tensor(ctx, np.array([1.0]), f"{name}_dummy_weight") +# dummy_weight = impl.slice.expand( +# ctx, target, source_ir, f"{name}_expand_dummy_weight", dummy_weight, reshaped_input.shape +# ) +# dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias") +# dummy_bias = impl.slice.expand( +# ctx, target, source_ir, f"{name}_expand_dummy_bias", dummy_bias, reshaped_input.shape +# ) +# group_norm = ctx.net.add_normalization(reshaped_input, dummy_weight, dummy_bias, axes) +# group_norm.epsilon = eps +# group_norm.compute_precision = input.dtype +# set_layer_name(group_norm, target, f"{name}_group_norm", source_ir) +# output = group_norm.get_output(0) + +# shape = list(output.shape) +# for i, s in enumerate(shape): +# if i == 0 and s > 0: +# shape[i] = B +# elif i == 1: +# shape[i] = C +# elif i > 1 and s == -1: +# shape[i] = 0 + +# reshaped_output = impl.shuffle.reshape( +# ctx, target, source_ir, f"{name}_reshape_output", output, shape +# ) + + +# # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze1", weight, (0)) +# # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze2", weight, (2)) +# # weight = impl.slice.expand( +# # ctx, target, source_ir, f"{name}_expand_weight", weight, shape +# # ) +# # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze1", bias, (0)) +# # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze2", bias, (2)) +# # bias = impl.slice.expand( +# # ctx, target, source_ir, f"{name}_expand_bias", bias, shape +# # ) + +# reshaped_gamma = impl.shuffle.reshape( +# ctx, +# target, +# source_ir, +# f"{name}_reshape_gamma", +# weight, +# weight_bias_shape, +# ) + +# reshaped_output = impl.elementwise.mul( +# ctx, +# target, +# source_ir, +# f"{name}_mul_gamma", +# reshaped_output, +# reshaped_gamma, +# ) + +# reshaped_bias = impl.shuffle.reshape( +# ctx, +# target, +# source_ir, +# f"{name}_reshape_beta", +# bias, +# weight_bias_shape, +# ) +# reshaped_output = impl.elementwise.add( +# ctx, +# target, +# source_ir, +# f"{name}_add_beta", +# reshaped_output, +# reshaped_bias, +# ) +# if return_mean_rstd: +# # return fake mean and rstd for now +# return reshaped_output, None, None +# return reshaped_output + def native_group_norm( ctx: ConversionContext, @@ -243,6 +380,8 @@ def native_group_norm( eps: float, return_mean_rstd: bool = True, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + # TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation + # with INormalization Layer assert ( len(input.shape) >= 3 ), f"The input dimension should not be less than 3, got {len(input.shape)}!" @@ -286,36 +425,94 @@ def native_group_norm( weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) dims = list(range(1, len(input.shape))) - axes = get_axes_for_reduce_op(dims) - # Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch - # TODO: check with TRT the correct way to use 'num_groups' to implement group norm - dummy_weight = get_trt_tensor( - ctx, np.array([1.0]), f"{name}_dummy_weight", input.dtype + + # E[X] + mean_trt = impl.reduce.mean( + ctx, + target, + source_ir, + f"{name}_mean", + reshaped_input, + dims, + True, ) - dummy_weight = impl.slice.expand( + + mean_trt = impl.slice.expand( ctx, target, source_ir, - f"{name}_expand_dummy_weight", - dummy_weight, + f"{name}_expand_mean_trt", + mean_trt, reshaped_input.shape, ) - dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias", input.dtype) - dummy_bias = impl.slice.expand( + + # X - E[X] + sub_trt = impl.elementwise.sub( + ctx, + target, + source_ir, + f"{name}_sub", + reshaped_input, + mean_trt, + ) + + # variance + pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) + pow_var = impl.elementwise.pow( + ctx, + target, + source_ir, + f"{name}_pow", + sub_trt, + pow_trt, + ) + + var_trt = impl.reduce.mean( + ctx, + target, + source_ir, + f"{name}_mean_var", + pow_var, + dims, + True, + ) + + var_trt = impl.slice.expand( ctx, target, source_ir, - f"{name}_expand_dummy_bias", - dummy_bias, + f"{name}_expand_var_trt", + var_trt, reshaped_input.shape, ) - group_norm = ctx.net.add_normalization( - reshaped_input, dummy_weight, dummy_bias, axes + + eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) + add_trt = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_add", + var_trt, + eps_trt, + ) + + sqrt_trt = impl.unary.sqrt( + ctx, + target, + source_ir, + f"{name}_sqrt", + add_trt, + ) + + # y = (X - E[X]) / sqrt((var + eps)) + output = impl.elementwise.div( + ctx, + target, + source_ir, + f"{name}_div", + sub_trt, + sqrt_trt, ) - group_norm.epsilon = eps - group_norm.compute_precision = input.dtype - set_layer_name(group_norm, target, f"{name}_group_norm", source_ir) - output = group_norm.get_output(0) shape = list(output.shape) for i, s in enumerate(shape): @@ -329,12 +526,11 @@ def native_group_norm( reshaped_output = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_output", output, shape ) - - weight = impl.shuffle.reshape( + reshaped_gamma = impl.shuffle.reshape( ctx, target, source_ir, - f"{name}_weight", + f"{name}_reshape_gamma", weight, weight_bias_shape, ) @@ -343,16 +539,16 @@ def native_group_norm( ctx, target, source_ir, - f"{name}_mul_weight", + f"{name}_mul_gamma", reshaped_output, - weight, + reshaped_gamma, ) - bias = impl.shuffle.reshape( + reshaped_bias = impl.shuffle.reshape( ctx, target, source_ir, - f"{name}_reshape_bias", + f"{name}_reshape_beta", bias, weight_bias_shape, ) @@ -360,9 +556,9 @@ def native_group_norm( ctx, target, source_ir, - f"{name}_add_bias", + f"{name}_add_beta", reshaped_output, - bias, + reshaped_bias, ) if return_mean_rstd: # return fake mean and rstd for now diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py index c1668de3f9..cf5dd48f0a 100644 --- a/tests/py/dynamo/conversion/test_group_norm_aten.py +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -112,27 +112,26 @@ def forward(self, x): inputs, ) - # TODO: Half precision has accuracy issue for now. - # def test_groupnorm_sd(self): - # class GroupNorm(torch.nn.Module): - # def forward(self, x): - # return torch.ops.aten.native_group_norm.default( - # x, - # torch.randn((320,)).half(), - # torch.randn((320,)).half(), - # 2, - # 320, - # 4096, - # 32, - # 1e-05, - # )[0] - - # inputs = [torch.randn(2, 320, 64, 64).half()] - # with torch.no_grad(): - # self.run_test( - # GroupNorm(), - # inputs, - # ) + def test_groupnorm_sd(self): + class GroupNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_group_norm.default( + x, + torch.randn((320,)).half(), + torch.randn((320,)).half(), + 2, + 320, + 4096, + 32, + 1e-05, + )[0] + + inputs = [torch.randn(2, 320, 64, 64).half()] + with torch.no_grad(): + self.run_test( + GroupNorm(), + inputs, + ) @parameterized.expand( [