From fd926dbb4fa047f3579e72c670a9880b2a1d59f7 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Mon, 30 Oct 2023 12:39:02 +0000 Subject: [PATCH 1/2] relu forward sink --- .../decomp_interface_gen_op_list.py | 4 +- paddle/fluid/primitive/composite/composite.h | 5 +++ test/legacy_test/test_activation_op.py | 2 +- test/prim/pir_prim/test_sink_decomp.py | 40 +++++++++++++++++++ 4 files changed, 48 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 334d410e7dab31..5ef9bbac186f27 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -18,8 +18,8 @@ # come into effect in generated file pd_op.h -decomp_interface_declare_gen_op_list = ['mean'] +decomp_interface_declare_gen_op_list = ['mean', 'relu'] # come into effect in generated file op_decomp.cc # manual decomp interface implementation are located in manual_op_decomp.cc -decomp_interface_implementation_gen_op_list = ["mean"] +decomp_interface_implementation_gen_op_list = ["mean", "relu"] diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index e0da626ef4c938..acca71fbf5c700 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -62,6 +62,11 @@ Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) { } } +template +Tensor relu_decomp(const Tensor& x) { + return maximum(x, full(phi::vectorize(x.dims()), 0.0, x.dtype())); +} + } // namespace details } // namespace primitive diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index f616828e0e69c5..7cbe4c2f2679fd 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -2361,7 +2361,7 @@ def test_check_grad(self): self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def test_check_output(self): - self.check_output(check_prim=True, check_pir=True) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def if_enable_cinn(self): pass diff --git a/test/prim/pir_prim/test_sink_decomp.py b/test/prim/pir_prim/test_sink_decomp.py index d1a14987123ee9..e9154eba60976f 100644 --- a/test/prim/pir_prim/test_sink_decomp.py +++ b/test/prim/pir_prim/test_sink_decomp.py @@ -17,6 +17,7 @@ import numpy as np import paddle +import paddle.nn.functional as F from paddle.autograd.ir_backward import grad from paddle.base import core from paddle.decomposition import decompose @@ -109,5 +110,44 @@ def test_has_decomp(self): self.assertEqual(core.has_decomp(op), True) +class TestReluSink(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [8, 16, 32, 64] + self.x = np.random.random(self.shape_x).astype("float32") + self.prog = None + + def base_net(self, flag=None): + if flag == "forward": + core._set_prim_forward_enabled(True) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data('x', self.shape_x, dtype='float32') + x.stop_gradient = False + sum_out = F.relu(x) + [new_out] = decompose(main_program, [sum_out]) + gradients = grad(new_out, x) + + exe = paddle.static.Executor() + [fwd, dx] = exe.run( + feed={'x': self.x}, fetch_list=[new_out, gradients] + ) + + whole_ops = [op.name() for op in main_program.global_block().ops] + self.prog = main_program + if flag == "forward": + core._set_prim_forward_enabled(False) + assert 'pd_op.relu' not in whole_ops + else: + assert 'pd_op.relu' in whole_ops + return fwd, dx + + def test_relu_forward(self): + res_ref = self.base_net() + res = self.base_net("forward") + for ref, actual in zip(res_ref, res): + np.testing.assert_equal(ref, actual) + + if __name__ == "__main__": unittest.main() From 738c0fee0fe29c8006190f5df4c675d5308c94a0 Mon Sep 17 00:00:00 2001 From: kevincheng2 Date: Wed, 1 Nov 2023 11:50:40 +0000 Subject: [PATCH 2/2] conf test relu --- test/legacy_test/test_activation_op.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 7cbe4c2f2679fd..efcdc088b528f9 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -4726,7 +4726,11 @@ def test_check_grad(self): create_test_act_fp16_class(TestAtanh) create_test_act_fp16_class(TestRound, grad_check=False, check_pir=True) create_test_act_fp16_class( - TestRelu, check_prim=True, enable_cinn=True, check_pir=True + TestRelu, + check_prim=True, + enable_cinn=True, + check_pir=True, + check_prim_pir=True, ) create_test_act_fp16_class( TestGelu, @@ -4793,6 +4797,7 @@ def create_test_act_bf16_class( check_prim=False, enable_cinn=False, check_pir=False, + check_prim_pir=False, grad_atol=1e-2, **kwargs ): @@ -4825,6 +4830,7 @@ def test_check_output(self): atol=atol, check_prim=check_prim, check_pir=check_pir, + check_prim_pir=check_prim_pir, ) def test_check_grad(self): @@ -4837,6 +4843,7 @@ def test_check_grad(self): max_relative_error=grad_atol, check_prim=check_prim, check_pir=check_pir, + check_prim_pir=check_prim_pir, ) cls_name = "{}_{}".format(parent.__name__, "BF16OP") @@ -4879,7 +4886,9 @@ def test_check_grad(self): create_test_act_bf16_class(TestAsinh) create_test_act_bf16_class(TestAtanh) create_test_act_bf16_class(TestRound, grad_check=False, check_pir=True) -create_test_act_bf16_class(TestRelu, check_prim=True, check_pir=True) +create_test_act_bf16_class( + TestRelu, check_prim=True, check_pir=True, check_prim_pir=True +) create_test_act_bf16_class( TestGelu, check_prim=True,