Skip to content

Commit

Permalink
Use INormalizationLayer layer for GroupNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Nov 4, 2024
1 parent 8e2c82d commit 57c66bf
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 369 deletions.
38 changes: 2 additions & 36 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,14 @@ def aten_ops_layer_norm(
0: (TRTTensor,),
}
)
def aten_ops_native_group_norm(
def aten_ops_group_norm(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.native_group_norm(
return impl.normalization.group_norm(
ctx,
target,
SourceIR.ATEN,
Expand All @@ -198,40 +198,6 @@ def aten_ops_native_group_norm(
)


@dynamo_tensorrt_converter(
torch.ops.aten.group_norm.default,
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
torch.ops.aten.group_norm,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_group_norm(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.group_norm(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
num_groups=args[1],
weight=args_bounds_check(args, 2, None),
bias=args_bounds_check(args, 3, None),
eps=args_bounds_check(args, 4, 1e-05),
cudnn_enabled=args_bounds_check(args, 5, True),
)


@dynamo_tensorrt_converter(torch.ops.aten.cat.default, supports_dynamic_shapes=True)
def aten_ops_cat(
ctx: ConversionContext,
Expand Down
257 changes: 48 additions & 209 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
Expand All @@ -16,7 +16,6 @@
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
Expand Down Expand Up @@ -204,240 +203,80 @@ def layer_norm(
return layer_norm.get_output(0)


def native_group_norm(
def group_norm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
N: int,
C: int,
HxW: int,
group: int,
eps: float,
return_mean_rstd: bool = True,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
# with INormalization Layer
) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]:
assert (
len(input.shape) >= 3
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
), f"Expected at least 3 dimensions for input tensor but got {len(input.shape)}"

B = input.shape[0]
# if C is provided, it must be as same as the channel from the input shape,
# else if C is zero, we should get the channel from the input shape
if C == 0:
C = input.shape[1]
assert (
C == input.shape[1]
), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
# Groups are a subdivision of the channel dimension.
assert (
C % group == 0
), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
input = get_trt_tensor(ctx, input, f"{name}_input")

shape = list(input.shape)

for i, s in enumerate(shape):
if i == 0 and s > 0:
shape[i] = B * group
elif i == 1:
shape[i] = C // group
elif i > 1 and s == -1:
shape[i] = 0

# Normalize every group.
reshaped_input = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_input",
input,
shape,
)

if weight is None:
weight = to_numpy(1.0)
), f"num_channels ({C}) must be equal to number of channels in input ({input.shape[1]})"

if bias is None:
bias = to_numpy(0.0)
weight_one = get_trt_tensor(ctx, 1.0, f"{name}_weight_one", input.dtype)
bias_zero = get_trt_tensor(ctx, 0.0, f"{name}_bias_zero", input.dtype)

weight = get_trt_tensor(ctx, weight, f"{name}_weight")
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)

dims = list(range(1, len(input.shape)))

# E[X]
mean_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean",
reshaped_input,
dims,
True,
)

mean_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_mean_trt",
mean_trt,
reshaped_input.shape,
)

# X - E[X]
sub_trt = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_sub",
reshaped_input,
mean_trt,
)

# variance
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
pow_var = impl.elementwise.pow(
ctx,
target,
source_ir,
f"{name}_pow",
sub_trt,
pow_trt,
)

var_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean_var",
pow_var,
dims,
True,
)

var_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_var_trt",
var_trt,
reshaped_input.shape,
)

eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
add_trt = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add",
var_trt,
eps_trt,
)
shape = [1, group] + [1] * (len(input.shape) - 2)

sqrt_trt = impl.unary.sqrt(
ctx,
target,
source_ir,
f"{name}_sqrt",
add_trt,
expanded_weight_one = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_weight_one", weight_one, shape
)

# y = (X - E[X]) / sqrt((var + eps))
output = impl.elementwise.div(
ctx,
target,
source_ir,
f"{name}_div",
sub_trt,
sqrt_trt,
expanded_bias_zero = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape
)

shape = list(output.shape)
for i, s in enumerate(shape):
if i == 0 and s > 0:
shape[i] = B
elif i == 1:
shape[i] = C
elif i > 1 and s == -1:
shape[i] = 0
axes = get_axes_for_reduce_op([i for i in range(2, len(input.shape))])

reshaped_output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_output", output, shape
)
reshaped_gamma = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_gamma",
weight,
weight_bias_shape,
# INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel,
# hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later
layer = ctx.net.add_normalization(
input, expanded_weight_one, expanded_bias_zero, axes
)
layer.epsilon = eps
layer.num_groups = group
set_layer_name(layer, target, name, source_ir)
output = layer.get_output(0)

reshaped_output = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_mul_gamma",
reshaped_output,
reshaped_gamma,
)
shape[1] = C

reshaped_bias = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_beta",
bias,
weight_bias_shape,
)
reshaped_output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add_beta",
reshaped_output,
reshaped_bias,
)
if return_mean_rstd:
# return fake mean and rstd for now
return reshaped_output, None, None
return reshaped_output
if weight is not None:
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
weight = cast_trt_tensor(
ctx, weight, input.dtype, f"{name}_cast_weight", target, source_ir
)
weight = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_weight", weight, shape
)
output = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_mul_weight", output, weight
)

if bias is not None:
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
bias = cast_trt_tensor(
ctx, bias, input.dtype, f"{name}_cast_bias", target, source_ir
)
bias = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_bias", bias, shape
)
output = impl.elementwise.add(
ctx, target, source_ir, f"{name}_add_bias", output, bias
)

def group_norm(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
num_groups: int,
weight: Optional[Union[torch.Tensor, np.ndarray]],
bias: Optional[Union[torch.Tensor, np.ndarray]],
eps: float,
cudnn_enabled: bool,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return native_group_norm(
ctx,
target,
source_ir,
name,
input,
weight,
bias,
0,
0,
0,
num_groups,
eps,
return_mean_rstd=False,
)
# return fake mean and rstd for now
return output, None, None


def softmax(
Expand Down
Loading

0 comments on commit 57c66bf

Please sign in to comment.