Skip to content

Commit

Permalink
[CINN] Fix TestGelu unittest of CINN
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 committed May 16, 2023
1 parent 69161a9 commit 3e97d2d
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,7 +2083,6 @@ def setUp(self):
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = gelu(x, approximate)
self.enable_cinn = False

self.inputs = {'X': x}
self.outputs = {'Out': out}
Expand All @@ -2093,6 +2092,9 @@ def setUp(self):
# cpu device, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
# Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8
self.cinn_rtol = 1e-8
self.cinn_atol = 1e-8

def test_check_output(self):
self.check_output(check_prim=True)
Expand Down Expand Up @@ -2125,9 +2127,12 @@ def setUp(self):
# cpu, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
# Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8
self.cinn_rtol = 1e-8
self.cinn_atol = 1e-8

def if_enable_cinn(self):
self.enable_cinn = False
self.enable_cinn = True

def test_check_output(self):
self.check_output(check_prim=True)
Expand Down Expand Up @@ -4028,9 +4033,11 @@ def test_check_grad(self):
create_test_act_fp16_class(
TestGelu,
check_prim=True,
enable_cinn=False,
enable_cinn=True,
rev_comp_rtol=1e-3,
rev_comp_atol=1e-3,
cinn_rtol=1e-3,
cinn_atol=1e-3,
)
create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6)
Expand Down Expand Up @@ -4141,9 +4148,11 @@ def test_check_grad(self):
create_test_act_bf16_class(
TestGelu,
check_prim=True,
enable_cinn=False,
enable_cinn=True,
rev_comp_rtol=1e-2,
rev_comp_atol=1e-2,
cinn_rtol=1e-2,
cinn_atol=1e-2,
)
create_test_act_bf16_class(TestBRelu)
create_test_act_bf16_class(TestRelu6)
Expand Down

0 comments on commit 3e97d2d

Please sign in to comment.