From 619cd9f5c300a876455411bcacc470bd94c923be Mon Sep 17 00:00:00 2001 From: majing Date: Mon, 3 Jun 2024 09:49:57 +0800 Subject: [PATCH] Fix LayerNorm issue for undefined grad_input (#4274) (#4317) * Fix LayerNorm issue for undefined grad_input * Change if condition order * Add ut Signed-off-by: majing --- csrc/gpu/aten/operators/LayerNorm.cpp | 7 ++++--- tests/gpu/examples/test_layer_norm.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/csrc/gpu/aten/operators/LayerNorm.cpp b/csrc/gpu/aten/operators/LayerNorm.cpp index fe1b19c08..9a38101c6 100644 --- a/csrc/gpu/aten/operators/LayerNorm.cpp +++ b/csrc/gpu/aten/operators/LayerNorm.cpp @@ -821,8 +821,9 @@ std::tuple native_layer_norm_backward( grad_output_, input_, mean, rstd, weight_, 1e-5); } else { Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input; - Tensor grad_input_ = - (grad_input.dim() == 1) ? grad_input.reshape({M, N}) : grad_input; + Tensor grad_input_ = grad_input_mask[0] && (grad_input.dim() == 1) + ? grad_input.reshape({M, N}) + : grad_input; Tensor grad_output_ = (grad_output.dim() == 1) ? grad_output.reshape({M, N}) : grad_output; @@ -850,7 +851,7 @@ std::tuple native_layer_norm_backward( } } return std::make_tuple( - grad_input.reshape(input.sizes()), + grad_input_mask[0] ? grad_input.reshape(input.sizes()) : grad_input, grad_input_mask[1] ? grad_weight.reshape(weight.sizes()) : grad_weight, grad_input_mask[2] ? grad_bias.reshape(bias.sizes()) : grad_bias); } diff --git a/tests/gpu/examples/test_layer_norm.py b/tests/gpu/examples/test_layer_norm.py index 3320a76a0..561b8b113 100644 --- a/tests/gpu/examples/test_layer_norm.py +++ b/tests/gpu/examples/test_layer_norm.py @@ -304,3 +304,22 @@ def test_layer_norm_no_bias(self, dtype=torch.float): y_dpcpp = layernorm(x_dpcpp_i) z_dpcpp = y_dpcpp.mean().backward() self.assertEqual(z_cpu, z_dpcpp, atol=1e-5, rtol=1e-5) + + def test_layer_norm_leaf(self, dtype=torch.float): + x_i = torch.randn(64, dtype=dtype, device=cpu_device) + x_i.requires_grad_(False) + x_dpcpp_i = x_i.to(dpcpp_device).to(dtype) + x_dpcpp_i.requires_grad_(False) + + layernorm = torch.nn.LayerNorm( + 64, elementwise_affine=True, bias=False, eps=1e-6 + ) + y_cpu = layernorm(x_i) + y_cpu.mean().backward() + grad_wei = layernorm.weight.grad.clone() + layernorm.zero_grad() + layernorm.to(dpcpp_device).to(dtype) + y_dpcpp = layernorm(x_dpcpp_i) + y_dpcpp.mean().backward() + grad_wei_dpcpp = layernorm.weight.grad.clone() + self.assertEqual(grad_wei, grad_wei_dpcpp.cpu(), atol=1e-5, rtol=1e-5)