Skip to content

Commit

Permalink
[XPU] update XHPC date and refine FA ut (#60598)
Browse files Browse the repository at this point in the history
* [XPU] update XHPC date

* update comments for ut
  • Loading branch information
houj04 authored Jan 10, 2024
1 parent b8b175c commit 25a7b2b
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions test/xpu/test_flash_attention_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ def run_case(self, dtype, tolerance, tolerance_dv):
float_out = paddle.cast(out, "float32")
float_out_ = paddle.cast(out_, "float32")

np.testing.assert_allclose(
float_out, float_out_, rtol=tolerance, atol=tolerance
)
# TODO(houj04) remove debug codes after correctness check
max_diff_forward = np.max(
np.abs(float_out.numpy() - float_out_.numpy())
Expand All @@ -139,6 +136,10 @@ def run_case(self, dtype, tolerance, tolerance_dv):
print("max_diff_forward:", max_diff_forward)
print("mean_diff_forward:", mean_diff_forward)

np.testing.assert_allclose(
float_out, float_out_, rtol=tolerance, atol=tolerance
)

# backward shape
self.assertEqual(q.grad.shape, q.shape)
self.assertEqual(q_.grad.shape, q.shape)
Expand Down Expand Up @@ -194,17 +195,17 @@ def run_case(self, dtype, tolerance, tolerance_dv):
)


# TODO(houj04) un-comment following DEBUG cases after correctness check
# class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
# def setUp(self):
# self.place = paddle.XPUPlace(0)
# self.shape = (2, 128, 1, 32)
# self.dropout = 0.0
# self.causal = True
# self.return_softmax = False
class TestFlashAttentionAPITest1(TestFlashAttentionAPI):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.shape = (2, 128, 1, 32)
self.dropout = 0.0
self.causal = True
self.return_softmax = False


# The following three REAL unit tests are disabled because they take a VERY LONG time to run, although they all pass under XHPC v20240105.

# TODO(houj04) un-comment following REAL cases after correctness check
# class TestFlashAttentionAPITestEB(TestFlashAttentionAPI):
# def setUp(self):
# self.place = paddle.XPUPlace(0)
Expand All @@ -213,8 +214,6 @@ def run_case(self, dtype, tolerance, tolerance_dv):
# self.causal = True
# self.return_softmax = False


# TODO(houj04) un-comment following REAL cases after correctness check
# class TestFlashAttentionAPITestLlama7B(TestFlashAttentionAPI):
# def setUp(self):
# self.place = paddle.XPUPlace(0)
Expand All @@ -223,8 +222,6 @@ def run_case(self, dtype, tolerance, tolerance_dv):
# self.causal = True
# self.return_softmax = False


# TODO(houj04) un-comment following REAL cases after correctness check
# class TestFlashAttentionAPITestLlama65B(TestFlashAttentionAPI):
# def setUp(self):
# self.place = paddle.XPUPlace(0)
Expand Down

0 comments on commit 25a7b2b

Please sign in to comment.