Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group norm bug fix #3014

Merged
merged 4 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 132 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def native_group_norm(
eps: float,
return_mean_rstd: bool = True,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
# with INormalization Layer
assert (
len(input.shape) >= 3
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
Expand Down Expand Up @@ -187,28 +189,105 @@ def native_group_norm(
shape,
)

if weight is None:
weight = to_numpy(1.0)

if bias is None:
bias = to_numpy(0.0)

weight = get_trt_tensor(ctx, weight, f"{name}_weight")
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
cehongwang marked this conversation as resolved.
Show resolved Hide resolved
if tuple(reshaped_input.shape) != tuple(weight.shape):
weight = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_weight",
weight,
reshaped_input.shape,
)
if tuple(reshaped_input.shape) != tuple(bias.shape):
bias = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape
)
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)

dims = list(range(1, len(input.shape)))
axes = get_axes_for_reduce_op(dims)
group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes)
group_norm.epsilon = eps
group_norm.compute_precision = input.dtype
set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
output = group_norm.get_output(0)

# E[X]
mean_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean",
reshaped_input,
dims,
True,
)

mean_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_mean_trt",
mean_trt,
reshaped_input.shape,
)

# X - E[X]
sub_trt = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_sub",
reshaped_input,
mean_trt,
)

# variance
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
pow_var = impl.elementwise.pow(
ctx,
target,
source_ir,
f"{name}_pow",
sub_trt,
pow_trt,
)

var_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean_var",
pow_var,
dims,
True,
)

var_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_var_trt",
var_trt,
reshaped_input.shape,
)

eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this converted to np.float32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is 1e-6 in most cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float16 cannot handle that close to 0

add_trt = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add",
var_trt,
eps_trt,
)

sqrt_trt = impl.unary.sqrt(
ctx,
target,
source_ir,
f"{name}_sqrt",
add_trt,
)

# y = (X - E[X]) / sqrt((var + eps))
output = impl.elementwise.div(
ctx,
target,
source_ir,
f"{name}_div",
sub_trt,
sqrt_trt,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to clarify if this div would require any mode, eg: trunc? Are the data types always compatible with the output types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The previous implementation from Evan did not include any mode.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just FYI: we have impl.elementwise.floor_divide and impl.elementwise.trunc_div helper functions, if needed.

shape = list(output.shape)
for i, s in enumerate(shape):
Expand All @@ -222,6 +301,40 @@ def native_group_norm(
reshaped_output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_output", output, shape
)
reshaped_gamma = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_gamma",
weight,
weight_bias_shape,
)

reshaped_output = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_mul_gamma",
reshaped_output,
reshaped_gamma,
)

reshaped_bias = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_beta",
bias,
weight_bias_shape,
)
reshaped_output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add_beta",
reshaped_output,
reshaped_bias,
)
cehongwang marked this conversation as resolved.
Show resolved Hide resolved
if return_mean_rstd:
# return fake mean and rstd for now
return reshaped_output, None, None
Expand Down
29 changes: 25 additions & 4 deletions tests/py/dynamo/conversion/test_group_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def forward(self, x):
return torch.ops.aten.group_norm.default(
x,
2,
torch.ones((6,)),
torch.zeros((6,)),
torch.randn((6,)),
torch.randn((6,)),
1e-05,
True,
)
Expand All @@ -50,8 +50,8 @@ def forward(self, x):
return torch.ops.aten.group_norm.default(
x,
2,
torch.ones((6,)),
torch.zeros((6,)),
torch.randn((6,)),
torch.randn((6,)),
1e-05,
True,
)
Expand Down Expand Up @@ -112,6 +112,27 @@ def forward(self, x):
inputs,
)

def test_groupnorm_sd(self):
class GroupNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_group_norm.default(
x,
torch.randn((320,)).half(),
torch.randn((320,)).half(),
2,
320,
4096,
32,
1e-05,
)[0]

inputs = [torch.randn(2, 320, 64, 64).half()]
with torch.no_grad():
self.run_test(
GroupNorm(),
inputs,
)

@parameterized.expand(
[
(5, 4, 4, 2, (2, 4, 2), (3, 4, 2), (5, 4, 4)),
Expand Down
Loading