Skip to content

Commit

Permalink
Merge branch 'PaddlePaddle:develop' into fix_fusegemm
Browse files Browse the repository at this point in the history
  • Loading branch information
gongshaotian authored Nov 10, 2023
2 parents a21f824 + ec729e2 commit 2db49f0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def mode(x, axis=-1, keepdim=False, name=None):
[2, 1]]))
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.mode(x, axis, keepdim)
else:
helper = LayerHelper("mode", **locals())
Expand Down
10 changes: 6 additions & 4 deletions test/legacy_test/test_mode_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def _mode1D(a):
Expand Down Expand Up @@ -112,12 +113,12 @@ def init_numeric_grads(self):

def test_check_output(self):
paddle.enable_static()
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
paddle.enable_static()
grad = self.init_numeric_grads()
self.check_grad({'X'}, 'Out', user_defined_grads=[grad])
self.check_grad({'X'}, 'Out', user_defined_grads=[grad], check_pir=True)


@unittest.skipIf(
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_check_output(self):
place = core.CUDAPlace(0)
paddle.enable_static()
if core.is_bfloat16_supported(place):
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad(self):
place = core.CUDAPlace(0)
Expand All @@ -157,7 +158,7 @@ def test_check_grad(self):

if core.is_bfloat16_supported(place):
self.check_grad_with_place(
place, {'X'}, 'Out', user_defined_grads=[grad]
place, {'X'}, 'Out', user_defined_grads=[grad], check_pir=True
)


Expand Down Expand Up @@ -243,6 +244,7 @@ def setUp(self):
np.random.random((2, 10, 10)) * 1000, dtype=np.float64
)

@test_with_pir_api
def test_run_static(self):
paddle.enable_static()
with paddle.static.program_guard(
Expand Down

0 comments on commit 2db49f0

Please sign in to comment.