Skip to content

Commit

Permalink
[PIR] add all paddle.dot pir test (PaddlePaddle#58081)
Browse files Browse the repository at this point in the history
* add dot new ir test

* supply full dot op test

* replace check_new_ir
  • Loading branch information
MarioLulab authored Oct 17, 2023
1 parent 7f61beb commit 4607791
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions test/legacy_test/test_dot_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,28 +230,30 @@ def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=0.125)
self.check_output_with_place(place, atol=0.125, check_pir=True)

def test_check_grad_normal(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', check_pir=True
)

def test_check_grad_ingore_x(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['Y'], 'Out', no_grad_set=set("X")
place, ['Y'], 'Out', no_grad_set=set("X"), check_pir=True
)

def test_check_grad_ingore_y(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['X'], 'Out', no_grad_set=set("Y")
place, ['X'], 'Out', no_grad_set=set("Y"), check_pir=True
)

def init_input_output(self):
Expand Down Expand Up @@ -302,7 +304,7 @@ def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_bfloat16_supported(place):
self.check_output_with_place(place, atol=0.5)
self.check_output_with_place(place, atol=0.5, check_pir=True)

def test_check_grad_normal(self):
if core.is_compiled_with_cuda():
Expand All @@ -313,6 +315,7 @@ def test_check_grad_normal(self):
['X', 'Y'],
'Out',
user_defined_grads=[self.inputs['Y'], self.inputs['X']],
check_pir=True,
)

def test_check_grad_ingore_x(self):
Expand All @@ -325,6 +328,7 @@ def test_check_grad_ingore_x(self):
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.inputs['X']],
check_pir=True,
)

def test_check_grad_ingore_y(self):
Expand All @@ -337,6 +341,7 @@ def test_check_grad_ingore_y(self):
'Out',
no_grad_set=set("Y"),
user_defined_grads=[self.inputs['Y']],
check_pir=True,
)

def init_input_output(self):
Expand Down Expand Up @@ -374,6 +379,7 @@ def test_check_grad_normal(self):
self.y / self.y.shape[0],
self.x / self.x.shape[0],
],
check_pir=True,
)

def test_check_grad_ingore_x(self):
Expand All @@ -386,6 +392,7 @@ def test_check_grad_ingore_x(self):
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.x / self.x.shape[0]],
check_pir=True,
)

def test_check_grad_ingore_y(self):
Expand All @@ -398,6 +405,7 @@ def test_check_grad_ingore_y(self):
'Out',
no_grad_set=set("Y"),
user_defined_grads=[self.y / self.y.shape[0]],
check_pir=True,
)


Expand Down

0 comments on commit 4607791

Please sign in to comment.