Skip to content

Commit

Permalink
[PIR]Migrate maximum into pir (#57929)
Browse files Browse the repository at this point in the history
* [PIR]Migrate maximum into pir

* Polish code
  • Loading branch information
0x45f authored Oct 10, 2023
1 parent 1e3212f commit 24701ef
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
12 changes: 12 additions & 0 deletions python/paddle/pir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from functools import wraps

import paddle


Expand Down Expand Up @@ -95,3 +97,13 @@ def _switch_to_old_ir(self):
"IrGuard._switch_to_old_ir only work when paddle.framework.in_pir_mode() is false, \
please set FLAGS_enable_pir_api = false"
)


def test_with_pir_api(func):
@wraps(func)
def impl(*args, **kwargs):
func(*args, **kwargs)
with IrGuard():
func(*args, **kwargs)

return impl
2 changes: 1 addition & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def maximum(x, y, name=None):
Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
[5. , 3. , inf.])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.maximum(x, y)
else:
return _elementwise_op(LayerHelper('elementwise_max', **locals()))
Expand Down
6 changes: 6 additions & 0 deletions test/legacy_test/test_maximum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


class ApiMaximumTest(unittest.TestCase):
Expand All @@ -39,6 +40,7 @@ def setUp(self):
self.np_expected3 = np.maximum(self.input_a, self.input_c)
self.np_expected4 = np.maximum(self.input_b, self.input_c)

@test_with_pir_api
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(
Expand Down Expand Up @@ -119,3 +121,7 @@ def test_dynamic_api(self):
res = paddle.maximum(b, c)
res = res.numpy()
np.testing.assert_allclose(res, self.np_expected4, rtol=1e-05)


if __name__ == '__main__':
unittest.main()

0 comments on commit 24701ef

Please sign in to comment.