Skip to content

Commit

Permalink
Fix LayerNorm issue for undefined grad_input (#4274) (#4317)
Browse files Browse the repository at this point in the history
* Fix LayerNorm issue for undefined grad_input

* Change if condition order

* Add ut

Signed-off-by: majing <Jing1.Ma@intel.com>
  • Loading branch information
majing921201 authored Jun 3, 2024
1 parent 5c252a1 commit 619cd9f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
7 changes: 4 additions & 3 deletions csrc/gpu/aten/operators/LayerNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,9 @@ std::tuple<Tensor, Tensor, Tensor> native_layer_norm_backward(
grad_output_, input_, mean, rstd, weight_, 1e-5);
} else {
Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input;
Tensor grad_input_ =
(grad_input.dim() == 1) ? grad_input.reshape({M, N}) : grad_input;
Tensor grad_input_ = grad_input_mask[0] && (grad_input.dim() == 1)
? grad_input.reshape({M, N})
: grad_input;
Tensor grad_output_ = (grad_output.dim() == 1)
? grad_output.reshape({M, N})
: grad_output;
Expand Down Expand Up @@ -850,7 +851,7 @@ std::tuple<Tensor, Tensor, Tensor> native_layer_norm_backward(
}
}
return std::make_tuple(
grad_input.reshape(input.sizes()),
grad_input_mask[0] ? grad_input.reshape(input.sizes()) : grad_input,
grad_input_mask[1] ? grad_weight.reshape(weight.sizes()) : grad_weight,
grad_input_mask[2] ? grad_bias.reshape(bias.sizes()) : grad_bias);
}
Expand Down
19 changes: 19 additions & 0 deletions tests/gpu/examples/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,22 @@ def test_layer_norm_no_bias(self, dtype=torch.float):
y_dpcpp = layernorm(x_dpcpp_i)
z_dpcpp = y_dpcpp.mean().backward()
self.assertEqual(z_cpu, z_dpcpp, atol=1e-5, rtol=1e-5)

def test_layer_norm_leaf(self, dtype=torch.float):
x_i = torch.randn(64, dtype=dtype, device=cpu_device)
x_i.requires_grad_(False)
x_dpcpp_i = x_i.to(dpcpp_device).to(dtype)
x_dpcpp_i.requires_grad_(False)

layernorm = torch.nn.LayerNorm(
64, elementwise_affine=True, bias=False, eps=1e-6
)
y_cpu = layernorm(x_i)
y_cpu.mean().backward()
grad_wei = layernorm.weight.grad.clone()
layernorm.zero_grad()
layernorm.to(dpcpp_device).to(dtype)
y_dpcpp = layernorm(x_dpcpp_i)
y_dpcpp.mean().backward()
grad_wei_dpcpp = layernorm.weight.grad.clone()
self.assertEqual(grad_wei, grad_wei_dpcpp.cpu(), atol=1e-5, rtol=1e-5)

0 comments on commit 619cd9f

Please sign in to comment.