Skip to content

Commit

Permalink
cherry-pick fix sdp bwd page fault with no grad bias (#4439)(#4428)
Browse files Browse the repository at this point in the history
  • Loading branch information
YizhouZ authored Jul 9, 2024
1 parent bb1c6e9 commit d015f00
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 12 deletions.
28 changes: 16 additions & 12 deletions csrc/gpu/aten/operators/xetla/kernels/SDP/fmha_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,14 @@ class fmha_backward_t {
mem_desc_Bij.init(
args.B_ptr, {end_x, end_y, args.uMT}, {start_x, start_y});

start_y = gid * args.uF + startF;
end_y = start_y + kBr;
boundary_y = (gid + 1) * args.uF;
end_y = end_y > boundary_y ? boundary_y : end_y;
mem_desc_dBij.init(
args.dB_ptr, {end_x, end_y, args.uMT}, {start_x, start_y});
if (args.dB_ptr) {
start_y = gid * args.uF + startF;
end_y = start_y + kBr;
boundary_y = (gid + 1) * args.uF;
end_y = end_y > boundary_y ? boundary_y : end_y;
mem_desc_dBij.init(
args.dB_ptr, {end_x, end_y, args.uMT}, {start_x, start_y});
}
}
}

Expand Down Expand Up @@ -707,12 +709,14 @@ class fmha_backward_t {
// Add bias if needed
if constexpr (kUseBias) {
// dSij = dBij
using epilogue_t = group::epilogue_write_back_t<
group::epilogue_policy_default<gpu_arch::Xe>,
tile_shape_BrBc,
mem_desc_dBij_t>;
epilogue_t epilogue;
epilogue(ctx.g_brbc, *matAcc_dSij, ctx.mem_desc_dBij);
if (args.dB_ptr) {
using epilogue_t = group::epilogue_write_back_t<
group::epilogue_policy_default<gpu_arch::Xe>,
tile_shape_BrBc,
mem_desc_dBij_t>;
epilogue_t epilogue;
epilogue(ctx.g_brbc, *matAcc_dSij, ctx.mem_desc_dBij);
}
}
}

Expand Down
43 changes: 43 additions & 0 deletions tests/gpu/examples/test_sdp_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,49 @@ def test_sdp_backward_with_bias(self, dtype=torch.bfloat16):
)
self.assertEqual(bias.grad, bias_xpu.grad.cpu().float(), atol=1e-2, rtol=1e-1)

@pytest.mark.skipif(not torch.xpu.has_xetla(), reason="fallback is required")
def test_sdp_backward_with_bias_no_grad(self, dtype=torch.bfloat16):
query_states = torch.randn((b, n, seq_len, head_size))
key_states = torch.randn((b, n, seq_len, head_size))
value_states = torch.randn((b, n, seq_len, head_size))
bias = torch.randn((b, 1, seq_len, seq_len))

grad = torch.randn((b, n, seq_len, head_size))

query_states_xpu = query_states.bfloat16().xpu()
key_states_xpu = key_states.bfloat16().xpu()
value_states_xpu = value_states.bfloat16().xpu()
bias_xpu = bias.bfloat16().xpu()
grad_xpu = grad.bfloat16().xpu()

query_states.requires_grad_(True)
key_states.requires_grad_(True)
value_states.requires_grad_(True)
bias.requires_grad_(False)

query_states_xpu.requires_grad_(True)
key_states_xpu.requires_grad_(True)
value_states_xpu.requires_grad_(True)
bias_xpu.requires_grad_(False)
r_cpu = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, bias
)
r_xpu = torch.nn.functional.scaled_dot_product_attention(
query_states_xpu, key_states_xpu, value_states_xpu, bias_xpu
)
r_cpu.backward(grad)
r_xpu.backward(grad_xpu)

self.assertEqual(
query_states.grad, query_states_xpu.grad.cpu().float(), atol=1e-2, rtol=1e-2
)
self.assertEqual(
key_states.grad, key_states_xpu.grad.cpu().float(), atol=1e-2, rtol=1e-2
)
self.assertEqual(
value_states.grad, value_states_xpu.grad.cpu().float(), atol=1e-2, rtol=1e-2
)

@pytest.mark.skipif(not torch.xpu.has_xetla(), reason="fallback is required")
def test_sdp_backward_with_bias_512(self, dtype=torch.bfloat16):
head_size = 512
Expand Down

0 comments on commit d015f00

Please sign in to comment.