From ee871c1cb505b4f7ebabdc61192098a02b8ea0e5 Mon Sep 17 00:00:00 2001 From: JiabinYang <360788950@qq.com> Date: Thu, 16 Mar 2023 05:59:35 +0000 Subject: [PATCH] support relue custom vjp --- paddle/fluid/prim/api/api.yaml | 1 + .../composite_backward_api.h | 11 ++ paddle/phi/api/yaml/backward.yaml | 1 + .../test_composite_relu_custom_vjp.py | 122 ++++++++++++++++++ .../incubate/autograd/composite_rules.py | 1 + 5 files changed, 136 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_relu_custom_vjp.py diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index 529d024b8b8a08..cd83e227ea3d25 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -33,3 +33,4 @@ - put_along_axis - greater_than - less_equal +- where diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index bcd6f459b8dc37..d7ea52a4dad02f 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -31,6 +31,17 @@ using IntArray = paddle::experimental::IntArrayBase; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h template +void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto condition = greater_than( + out, full(phi::vectorize(out.dims()), 0.0, out.dtype())); + auto res = where(condition, + out_grad, + full(phi::vectorize(out.dims()), 0.0, out.dtype())); + set_output(res, x_grad); + } +} +template void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) { if (x_grad) { auto res = cast(out_grad, dtype); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 9b2f48aac6c395..2f884799e77143 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1142,6 +1142,7 @@ kernel : func : relu_grad backward: relu_double_grad + composite: relu_grad(out, out_grad, x_grad) inplace : (out_grad -> x_grad) - backward_op : renorm_grad diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_relu_custom_vjp.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_relu_custom_vjp.py new file mode 100644 index 00000000000000..8113ddee89a29c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_relu_custom_vjp.py @@ -0,0 +1,122 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import TOLERANCE + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = None + self.shape = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def get_rtol(self, flag): + rtol = TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x): + return F.relu(x) + + +def expect_grad(inputs): + paddle.disable_static() + inputs.stop_gradient = False + res = fn(inputs) + + gradients = paddle.grad(res, inputs) + return gradients + + +class TestCompositeSoftmaxPrimBackward(unittest.TestCase): + "test composite softmax and prim backward" + + def setUp(self): + core._set_prim_backward_enabled(True) + self.dtypes = ["float16", "float32", "float64"] + self.shapes = [[2, 3, 4], [2, 3]] + + def cal_composite_grad(self, inputs): + paddle.enable_static() + core._set_prim_all_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x) + blocks = main_program.blocks + z = paddle.static.gradients([y], x) + paddle.incubate.autograd.primapi.to_prim(blocks) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) + paddle.disable_static() + core._set_prim_all_enabled(False) + return res + + def compare_backward(self): + np_data = generate_data(attrs.shape) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_grad(tensor_data)[0].numpy() + actual = self.cal_composite_grad(np_data)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("prim_backward"), + atol=attrs.get_rtol("prim_backward"), + ) + + def test_prim_backward(self): + for j in self.dtypes: + for t in self.shapes: + attrs.set_dtype(j) + attrs.set_shape(t) + self.compare_backward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 28c87609ae133b..e79d86ef627fef 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -97,6 +97,7 @@ def composite_batchnorm( batch_mean = zeros(run_mean.shape, run_mean.dtype) batch_var = zeros(run_var.shape, run_var.dtype) if not use_run_stat: + batch_mean = mean(x, reduce_axes, keepdim=True) temp = mean(x * x, reduce_axes, keepdim=True) batch_var = temp - batch_mean * batch_mean