Skip to content

Commit

Permalink
Still fail on stabel diffusion 1.5. Revert back to decomposed ops ins…
Browse files Browse the repository at this point in the history
…tead of using INormalization Layer
  • Loading branch information
cehongwang committed Jul 17, 2024
1 parent a8e2618 commit e8d942f
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 48 deletions.
250 changes: 223 additions & 27 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,143 @@ def layer_norm(
# return reshaped_output, None, None
# return reshaped_output

# def native_group_norm(
# ctx: ConversionContext,
# target: Target,
# source_ir: Optional[SourceIR],
# name: str,
# input: TRTTensor,
# weight: Optional[Union[torch.Tensor, np.ndarray]],
# bias: Optional[Union[torch.Tensor, np.ndarray]],
# N: int,
# C: int,
# HxW: int,
# group: int,
# eps: float,
# return_mean_rstd: bool = True,
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
# assert (
# len(input.shape) >= 3
# ), f"The input dimension should not be less than 3, got {len(input.shape)}!"

# B = input.shape[0]
# # if C is provided, it must be as same as the channel from the input shape,
# # else if C is zero, we should get the channel from the input shape
# if C == 0:
# C = input.shape[1]
# assert (
# C == input.shape[1]
# ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}"
# # Groups are a subdivision of the channel dimension.
# assert (
# C % group == 0
# ), f"The num of channels ({C}) should be divisible by num_groups ({group})!"
# input = get_trt_tensor(ctx, input, f"{name}_input")

# shape = list(input.shape)

# for i, s in enumerate(shape):
# if i == 0 and s > 0:
# shape[i] = B * group
# elif i == 1:
# shape[i] = C // group
# elif i > 1 and s == -1:
# shape[i] = 0

# # Normalize every group.
# reshaped_input = impl.shuffle.reshape(
# ctx,
# target,
# source_ir,
# f"{name}_reshape_input",
# input,
# shape,
# )

# weight = get_trt_tensor(ctx, weight, f"{name}_weight")
# bias = get_trt_tensor(ctx, bias, f"{name}_bias")
# weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)

# dims = list(range(1, len(input.shape)))
# axes = get_axes_for_reduce_op(dims)
# dummy_weight = get_trt_tensor(ctx, np.array([1.0]), f"{name}_dummy_weight")
# dummy_weight = impl.slice.expand(
# ctx, target, source_ir, f"{name}_expand_dummy_weight", dummy_weight, reshaped_input.shape
# )
# dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias")
# dummy_bias = impl.slice.expand(
# ctx, target, source_ir, f"{name}_expand_dummy_bias", dummy_bias, reshaped_input.shape
# )
# group_norm = ctx.net.add_normalization(reshaped_input, dummy_weight, dummy_bias, axes)
# group_norm.epsilon = eps
# group_norm.compute_precision = input.dtype
# set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
# output = group_norm.get_output(0)

# shape = list(output.shape)
# for i, s in enumerate(shape):
# if i == 0 and s > 0:
# shape[i] = B
# elif i == 1:
# shape[i] = C
# elif i > 1 and s == -1:
# shape[i] = 0

# reshaped_output = impl.shuffle.reshape(
# ctx, target, source_ir, f"{name}_reshape_output", output, shape
# )


# # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze1", weight, (0))
# # weight = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_weight_unsqueeze2", weight, (2))
# # weight = impl.slice.expand(
# # ctx, target, source_ir, f"{name}_expand_weight", weight, shape
# # )
# # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze1", bias, (0))
# # bias = impl.unsqueeze.unsqueeze(ctx, target, source_ir, f"{name}_bias_unsqueeze2", bias, (2))
# # bias = impl.slice.expand(
# # ctx, target, source_ir, f"{name}_expand_bias", bias, shape
# # )

# reshaped_gamma = impl.shuffle.reshape(
# ctx,
# target,
# source_ir,
# f"{name}_reshape_gamma",
# weight,
# weight_bias_shape,
# )

# reshaped_output = impl.elementwise.mul(
# ctx,
# target,
# source_ir,
# f"{name}_mul_gamma",
# reshaped_output,
# reshaped_gamma,
# )

# reshaped_bias = impl.shuffle.reshape(
# ctx,
# target,
# source_ir,
# f"{name}_reshape_beta",
# bias,
# weight_bias_shape,
# )
# reshaped_output = impl.elementwise.add(
# ctx,
# target,
# source_ir,
# f"{name}_add_beta",
# reshaped_output,
# reshaped_bias,
# )
# if return_mean_rstd:
# # return fake mean and rstd for now
# return reshaped_output, None, None
# return reshaped_output


def native_group_norm(
ctx: ConversionContext,
Expand All @@ -243,6 +380,8 @@ def native_group_norm(
eps: float,
return_mean_rstd: bool = True,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
# with INormalization Layer
assert (
len(input.shape) >= 3
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
Expand Down Expand Up @@ -286,36 +425,94 @@ def native_group_norm(
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)

dims = list(range(1, len(input.shape)))
axes = get_axes_for_reduce_op(dims)
# Use dummy weight since the normalization layer cannot well handle the scale and shift of group norm due to shape mismatch
# TODO: check with TRT the correct way to use 'num_groups' to implement group norm
dummy_weight = get_trt_tensor(
ctx, np.array([1.0]), f"{name}_dummy_weight", input.dtype

# E[X]
mean_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean",
reshaped_input,
dims,
True,
)
dummy_weight = impl.slice.expand(

mean_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_dummy_weight",
dummy_weight,
f"{name}_expand_mean_trt",
mean_trt,
reshaped_input.shape,
)
dummy_bias = get_trt_tensor(ctx, np.array([0.0]), f"{name}_dummy_bias", input.dtype)
dummy_bias = impl.slice.expand(

# X - E[X]
sub_trt = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_sub",
reshaped_input,
mean_trt,
)

# variance
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
pow_var = impl.elementwise.pow(
ctx,
target,
source_ir,
f"{name}_pow",
sub_trt,
pow_trt,
)

var_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean_var",
pow_var,
dims,
True,
)

var_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_dummy_bias",
dummy_bias,
f"{name}_expand_var_trt",
var_trt,
reshaped_input.shape,
)
group_norm = ctx.net.add_normalization(
reshaped_input, dummy_weight, dummy_bias, axes

eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
add_trt = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add",
var_trt,
eps_trt,
)

sqrt_trt = impl.unary.sqrt(
ctx,
target,
source_ir,
f"{name}_sqrt",
add_trt,
)

# y = (X - E[X]) / sqrt((var + eps))
output = impl.elementwise.div(
ctx,
target,
source_ir,
f"{name}_div",
sub_trt,
sqrt_trt,
)
group_norm.epsilon = eps
group_norm.compute_precision = input.dtype
set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
output = group_norm.get_output(0)

shape = list(output.shape)
for i, s in enumerate(shape):
Expand All @@ -329,12 +526,11 @@ def native_group_norm(
reshaped_output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_output", output, shape
)

weight = impl.shuffle.reshape(
reshaped_gamma = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_weight",
f"{name}_reshape_gamma",
weight,
weight_bias_shape,
)
Expand All @@ -343,26 +539,26 @@ def native_group_norm(
ctx,
target,
source_ir,
f"{name}_mul_weight",
f"{name}_mul_gamma",
reshaped_output,
weight,
reshaped_gamma,
)

bias = impl.shuffle.reshape(
reshaped_bias = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_bias",
f"{name}_reshape_beta",
bias,
weight_bias_shape,
)
reshaped_output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add_bias",
f"{name}_add_beta",
reshaped_output,
bias,
reshaped_bias,
)
if return_mean_rstd:
# return fake mean and rstd for now
Expand Down
41 changes: 20 additions & 21 deletions tests/py/dynamo/conversion/test_group_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,26 @@ def forward(self, x):
inputs,
)

# TODO: Half precision has accuracy issue for now.
# def test_groupnorm_sd(self):
# class GroupNorm(torch.nn.Module):
# def forward(self, x):
# return torch.ops.aten.native_group_norm.default(
# x,
# torch.randn((320,)).half(),
# torch.randn((320,)).half(),
# 2,
# 320,
# 4096,
# 32,
# 1e-05,
# )[0]

# inputs = [torch.randn(2, 320, 64, 64).half()]
# with torch.no_grad():
# self.run_test(
# GroupNorm(),
# inputs,
# )
def test_groupnorm_sd(self):
class GroupNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_group_norm.default(
x,
torch.randn((320,)).half(),
torch.randn((320,)).half(),
2,
320,
4096,
32,
1e-05,
)[0]

inputs = [torch.randn(2, 320, 64, 64).half()]
with torch.no_grad():
self.run_test(
GroupNorm(),
inputs,
)

@parameterized.expand(
[
Expand Down

0 comments on commit e8d942f

Please sign in to comment.