Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kangguangli committed Aug 24, 2023
1 parent 971f4fb commit 33df9c9
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions test/legacy_test/test_i0_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from scipy import special

import paddle
from paddle.fluid import core

np.random.seed(100)
paddle.seed(100)
Expand All @@ -33,6 +34,78 @@ def ref_i0_grad(x, dout):
return dout * gradx


class TestI0API(unittest.TestCase):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5]

def setUp(self):
self.x = np.array(self.DATA).astype(self.DTYPE)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def test_api_static(self):
def run(place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.DTYPE
)
out = paddle.i0(x)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x},
fetch_list=[out],
)
out_ref = output_i0(self.x)
np.testing.assert_allclose(res[0], out_ref, rtol=1e-5)
paddle.disable_static()

for place in self.place:
run(place)

def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
out = paddle.i0(x)

out_ref = output_i0(self.x)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-5)
paddle.enable_static()

for place in self.place:
run(place)

def test_empty_input_error(self):
for place in self.place:
paddle.disable_static(place)
x = None
self.assertRaises(ValueError, paddle.i0, x)
paddle.enable_static()


class TestI0Float32Zero2EightCase(TestI0API):
DTYPE = "float32"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]


class TestI0Float32OverEightCase(TestI0API):
DTYPE = "float32"
DATA = [9, 10, 11, 12]


class TestI0Float64Zero2EightCase(TestI0API):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]


class TestI0Float64OverEightCase(TestI0API):
DTYPE = "float64"
DATA = [9, 10, 11, 12]


class TestI0Op(OpTest):
def setUp(self) -> None:
self.op_type = "i0"
Expand Down

0 comments on commit 33df9c9

Please sign in to comment.