Skip to content

Commit

Permalink
【PIR API adaptor No.21】paddle.bernoulli (PaddlePaddle#58877)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liyulingyue authored and SecretXV committed Nov 28, 2023
1 parent 5e0dadb commit 9319cdd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def bernoulli(x, name=None):
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.bernoulli(x)
else:
check_variable_and_dtype(
Expand Down
9 changes: 5 additions & 4 deletions test/legacy_test/test_bernoulli_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def output_hist(out):
Expand All @@ -31,6 +32,7 @@ def output_hist(out):

class TestBernoulliOp(OpTest):
def setUp(self):
self.python_api = paddle.bernoulli
self.op_type = "bernoulli"
self.init_dtype()
self.init_test_case()
Expand All @@ -46,7 +48,7 @@ def init_test_case(self):
self.out = np.zeros((1000, 784)).astype(self.dtype)

def test_check_output(self):
self.check_output_customized(self.verify_output)
self.check_output_customized(self.verify_output, check_pir=True)

def verify_output(self, outs):
hist, prob = output_hist(np.array(outs[0]))
Expand All @@ -62,13 +64,12 @@ def test_dygraph(self):
hist, prob = output_hist(out.numpy())
np.testing.assert_allclose(hist, prob, rtol=0, atol=0.01)

@test_with_pir_api
def test_static(self):
x = paddle.rand([1024, 1024])
out = paddle.bernoulli(x)
exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(
paddle.static.default_main_program(), fetch_list=[out.name]
)
out = exe.run(paddle.static.default_main_program(), fetch_list=[out])
hist, prob = output_hist(out[0])
np.testing.assert_allclose(hist, prob, rtol=0, atol=0.01)

Expand Down

0 comments on commit 9319cdd

Please sign in to comment.