Skip to content

Commit

Permalink
[XPU] refine flash attention ut (PaddlePaddle#60474)
Browse files Browse the repository at this point in the history
* [XPU] refine flash attention ut

* refine tolerance
  • Loading branch information
houj04 authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent fc5bc0a commit 6612d36
Showing 1 changed file with 69 additions and 44 deletions.
113 changes: 69 additions & 44 deletions test/xpu/test_flash_attention_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,43 +72,45 @@ class TestFlashAttentionAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.shape = (1, 128, 2, 32)
self.dtype = 'float32'
self.dropout = 0.0
self.causal = True
self.return_softmax = False
self.rtol = 1e-3
self.atol = 1e-3

def test_all(self):
self.run_case(dtype="float32", tolerance=5e-4, tolerance_dv=5e-4)
self.run_case(dtype="float16", tolerance=5e-4, tolerance_dv=1e-3)
self.run_case(dtype="bfloat16", tolerance=5e-3, tolerance_dv=1e-2)

def run_case(self, dtype, tolerance, tolerance_dv):
# TODO(houj04) remove debug codes after correctness check
print(f"Test case shape {self.shape} dtype {self.dtype}")
print(f"Test case shape {self.shape} dtype {dtype}")

# test dynamic
paddle.disable_static()

np.random.seed(2023)
query = np.random.random(self.shape)
key = np.random.random(self.shape)
value = np.random.random(self.shape)
query = np.random.uniform(-1.0, 1.0, self.shape)
key = np.random.uniform(-1.0, 1.0, self.shape)
value = np.random.uniform(-1.0, 1.0, self.shape)

q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
query, place=self.place, dtype=dtype, stop_gradient=False
)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
key, place=self.place, dtype=dtype, stop_gradient=False
)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
value, place=self.place, dtype=dtype, stop_gradient=False
)

q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
query, place=self.place, dtype=dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
key, place=self.place, dtype=dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
value, place=self.place, dtype=dtype, stop_gradient=False
)

out, _ = flash_attention(
Expand All @@ -125,8 +127,17 @@ def test_all(self):
float_out_ = paddle.cast(out_, "float32")

np.testing.assert_allclose(
float_out, float_out_, rtol=self.rtol, atol=self.atol
float_out, float_out_, rtol=tolerance, atol=tolerance
)
# TODO(houj04) remove debug codes after correctness check
max_diff_forward = np.max(
np.abs(float_out.numpy() - float_out_.numpy())
)
mean_diff_forward = np.mean(
np.abs(float_out.numpy() - float_out_.numpy())
)
print("max_diff_forward:", max_diff_forward)
print("mean_diff_forward:", mean_diff_forward)

# backward shape
self.assertEqual(q.grad.shape, q.shape)
Expand Down Expand Up @@ -173,40 +184,54 @@ def test_all(self):
print("mean_diff_v_grad:", mean_diff_v_grad)

np.testing.assert_allclose(
float_q_grad, float_q_grad_, rtol=self.rtol, atol=self.atol
float_q_grad, float_q_grad_, rtol=tolerance, atol=tolerance
)
np.testing.assert_allclose(
float_k_grad, float_k_grad_, rtol=self.rtol, atol=self.atol
float_k_grad, float_k_grad_, rtol=tolerance, atol=tolerance
)
np.testing.assert_allclose(
float_v_grad, float_v_grad_, rtol=self.rtol, atol=self.atol
)


class TestFlashAttentionAPITestFP16(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.shape = (1, 128, 2, 32)
self.dtype = 'float16'
self.dropout = 0.0
self.causal = True
self.return_softmax = False
# TODO(houj04) fix ut threshold after correctness check
self.rtol = 5e-3
self.atol = 5e-3


class TestFlashAttentionAPITestBF16(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.shape = (1, 128, 2, 32)
self.dtype = 'bfloat16'
self.dropout = 0.0
self.causal = True
self.return_softmax = False
# TODO(houj04) fix ut threshold after correctness check
self.rtol = 1e-1
self.atol = 1e-1
float_v_grad, float_v_grad_, rtol=tolerance_dv, atol=tolerance_dv
)


# TODO(houj04) un-comment following DEBUG cases after correctness check
# class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
# def setUp(self):
# self.place = paddle.XPUPlace(0)
# self.shape = (2, 128, 1, 32)
# self.dropout = 0.0
# self.causal = True
# self.return_softmax = False


# TODO(houj04) un-comment following REAL cases after correctness check
# class TestFlashAttentionAPITestEB(TestFlashAttentionAPI):
# def setUp(self):
# self.place = paddle.XPUPlace(0)
# self.shape = (4, 4096, 4, 128)
# self.dropout = 0.0
# self.causal = True
# self.return_softmax = False


# TODO(houj04) un-comment following REAL cases after correctness check
# class TestFlashAttentionAPITestLlama7B(TestFlashAttentionAPI):
# def setUp(self):
# self.place = paddle.XPUPlace(0)
# self.shape = (2, 2048, 16, 128)
# self.dropout = 0.0
# self.causal = True
# self.return_softmax = False


# TODO(houj04) un-comment following REAL cases after correctness check
# class TestFlashAttentionAPITestLlama65B(TestFlashAttentionAPI):
# def setUp(self):
# self.place = paddle.XPUPlace(0)
# self.shape = (2, 8192, 8, 128)
# self.dropout = 0.0
# self.causal = True
# self.return_softmax = False


if __name__ == '__main__':
Expand Down

0 comments on commit 6612d36

Please sign in to comment.