Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.225、227、197、187、152】 Migrate tanhshrink/thresholded_relu/Selu/RRelu/maxout into pir #58429

Merged
merged 11 commits into from
Nov 13, 2023
10 changes: 5 additions & 5 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None):

is_test = not training

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.rrelu(x, lower, upper, is_test)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -889,7 +889,7 @@ def maxout(x, groups, axis=1, name=None):
[0.42400089, 0.40641287, 0.97020894, 0.74437362],
[0.51785129, 0.73292869, 0.97786582, 0.92382854]]]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.maxout(x, groups, axis)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def selu(
f"The alpha must be no less than zero. Received: {alpha}."
)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.selu(x, scale, alpha)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -1536,7 +1536,7 @@ def tanhshrink(x, name=None):
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
[-0.02005100, -0.00262472, 0.00033201, 0.00868741])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.tanh_shrink(x)
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -1586,7 +1586,7 @@ def thresholded_relu(x, threshold=1.0, name=None):
[2., 0., 0.])
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.thresholded_relu(x, threshold)
else:
check_variable_and_dtype(
Expand Down
6 changes: 5 additions & 1 deletion test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def setUp(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)

DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved

class TestTanhshrink_ZeroDim(TestTanhshrink):
Expand All @@ -1154,6 +1154,7 @@ def setUp(self):
else paddle.CPUPlace()
)

@test_with_pir_api
def test_static_api(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand All @@ -1177,6 +1178,7 @@ def test_dygraph_api(self):
for r in [out1, out2]:
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)

@test_with_pir_api
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
def test_errors(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -4227,6 +4229,7 @@ def setUp(self):
else paddle.CPUPlace()
)

@test_with_pir_api
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
def test_static_api(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand All @@ -4250,6 +4253,7 @@ def test_dygraph_api(self):
for r in [out1, out2]:
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)

@test_with_pir_api
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
def test_errors(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand Down
8 changes: 6 additions & 2 deletions test/legacy_test/test_maxout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle
import paddle.nn.functional as F
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()
np.random.seed(1)
Expand Down Expand Up @@ -57,10 +58,10 @@ def set_attrs(self):
pass

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestMaxOutOpAxis0(TestMaxOutOp):
Expand Down Expand Up @@ -95,6 +96,7 @@ def setUp(self):
else paddle.CPUPlace()
)

@test_with_pir_api
def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
Expand Down Expand Up @@ -122,6 +124,7 @@ def test_dygraph_api(self):
np.testing.assert_allclose(out3_ref, out3.numpy(), rtol=1e-05)
paddle.enable_static()

@test_with_pir_api
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
Expand Down Expand Up @@ -161,6 +164,7 @@ def setUp(self):
self.axis = 1
self.place = paddle.CUDAPlace(0)

@test_with_pir_api
def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
Expand Down
21 changes: 15 additions & 6 deletions test/legacy_test/test_rrelu_op.py
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle.nn.functional as F
from paddle import base
from paddle.base import core, dygraph
from paddle.pir_utils import test_with_pir_api

paddle.seed(102)
np.random.seed(102)
Expand Down Expand Up @@ -87,10 +88,12 @@ def check_static_result(self, place):
)
np.testing.assert_allclose(fetches[0], res_np2, rtol=1e-05)

@test_with_pir_api
def test_static(self):
for place in self.places:
self.check_static_result(place=place)

@test_with_pir_api
def test_static_graph_functional(self):
'''test_static_graph_functional'''

Expand Down Expand Up @@ -134,6 +137,7 @@ def test_static_graph_functional(self):
check_output(self.x_np, res_3[0], self.lower_1, self.upper_1)
)

@test_with_pir_api
def test_static_graph_layer(self):
'''test_static_graph_layer'''

Expand Down Expand Up @@ -214,6 +218,7 @@ def test_dygraph(self):
)
paddle.enable_static()

@test_with_pir_api
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
def test_error_functional(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -351,10 +356,10 @@ def convert_input_output(self):
pass

def test_check_output(self):
self.check_output(no_check_set=['Noise'])
self.check_output(no_check_set=['Noise'], check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class RReluTrainingTest(RReluTest):
Expand Down Expand Up @@ -394,11 +399,13 @@ def convert_input_output(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, no_check_set=['Noise'])
self.check_output_with_place(
place, no_check_set=['Noise'], check_pir=True
)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)


class RReluTrainingTestFP16OP(RReluTrainingTest):
Expand All @@ -425,11 +432,13 @@ def convert_input_output(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, no_check_set=['Noise'])
self.check_output_with_place(
place, no_check_set=['Noise'], check_pir=True
)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)


if __name__ == "__main__":
Expand Down
14 changes: 10 additions & 4 deletions test/legacy_test/test_selu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle.nn.functional as F
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def ref_selu(
Expand Down Expand Up @@ -79,10 +80,10 @@ def init_dtype(self):
self.dtype = np.float64

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class SeluTestFP16OP(SeluTest):
Expand All @@ -100,10 +101,12 @@ def init_dtype(self):
self.dtype = np.uint16

def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
self.check_output_with_place(core.CUDAPlace(0), check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out')
self.check_grad_with_place(
core.CUDAPlace(0), ['X'], 'Out', check_pir=True
)


class TestSeluAPI(unittest.TestCase):
Expand All @@ -121,6 +124,7 @@ def setUp(self):
else paddle.CPUPlace()
)

@test_with_pir_api
def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
Expand All @@ -144,6 +148,7 @@ def test_dygraph_api(self):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
paddle.enable_static()

@test_with_pir_api
def test_base_api(self):
with base.program_guard(base.Program()):
x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype)
Expand All @@ -153,6 +158,7 @@ def test_base_api(self):
out_ref = ref_selu(self.x_np, self.scale, self.alpha)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)

@test_with_pir_api
DrRyanHuang marked this conversation as resolved.
Show resolved Hide resolved
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
Expand Down