From ba39d242fec7b71b5f7d1abff370235e87bc3757 Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Mon, 20 Feb 2023 18:06:24 +0800 Subject: [PATCH 01/14] Add flatten composite rule --- .../composite_ops/test_composite_flatten.py | 144 ++++++++++++ .../test_composite_flatten_grad.py | 219 ++++++++++++++++++ .../incubate/autograd/composite_rules.py | 31 +++ 3 files changed, 394 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py create mode 100644 python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten_grad.py diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py new file mode 100644 index 0000000000000..ad03ef588b872 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py @@ -0,0 +1,144 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import TOLERANCE + +import paddle +from paddle.fluid import core + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = "float32" + self.shape = None + self.start_axi = None + self.stop_axi = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def set_start_axi(self, start_axi) -> None: + self.start_axi = start_axi + return + + def set_stop_axi(self, stop_axi) -> None: + self.stop_axi = stop_axi + return + + def get_rtol(self, flag): + rtol = TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x): + return paddle.flatten( + x, start_axis=attrs.start_axi, stop_axis=attrs.stop_axi + ) + + +def expect_forward(inputs): + return fn(inputs) + + +class TestCompositeFlatten(unittest.TestCase): + def setUp(self): + # self.dtypes = ["float16", "float32", "float64"] + self.dtypes = ["float32", "float64"] + self.shapes = [ + [16, 16, 64, 64, 10], + [2, 3, 4, 6, 8, 2, 3, 4], + [2, 3, 5, 1, 2], + [2, 3, 4, 5, 6, 7], + ] + self.start_axis = [0, 1, 2] + self.stop_axis = [-1, 2, 3, 4] + + def cal_composite(self, inputs): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + y = fn(x) + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that flatten in original block + self.assertTrue('flatten_contiguous_range' in fwd_ops) + + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that flatten is splitted into small ops + self.assertTrue('flatten_contiguous_range' not in fwd_ops_new) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y]) + paddle.disable_static() + core._set_prim_forward_enabled(False) + return res + + def compare_forward(self): + np_data = generate_data(attrs.shape, attrs.dtype) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_forward(tensor_data).numpy() + actual = self.cal_composite(np_data)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("forward"), + atol=attrs.get_atol("forward"), + ) + + def test_forward(self): + for i in self.dtypes: + for j in self.shapes: + for t in self.start_axis: + for k in self.stop_axis: + attrs.set_dtype(i) + attrs.set_shape(j) + attrs.set_start_axi(t) + attrs.set_stop_axi(k) + self.compare_forward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten_grad.py new file mode 100644 index 0000000000000..9ab5721f56726 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten_grad.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import TOLERANCE + +import paddle +from paddle.fluid import core + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class Attr: + def __init__(self) -> None: + self.dtype = "float32" + self.shape = None + self.start_axi = None + self.stop_axi = None + + def set_dtype(self, dtype) -> None: + self.dtype = dtype + return + + def set_shape(self, shape) -> None: + self.shape = shape + return + + def set_start_axi(self, start_axi) -> None: + self.start_axi = start_axi + return + + def set_stop_axi(self, stop_axi) -> None: + self.stop_axi = stop_axi + return + + def get_rtol(self, flag): + rtol = TOLERANCE[self.dtype][flag].get("rtol") + return rtol + + def get_atol(self, flag): + atol = TOLERANCE[self.dtype][flag].get("atol") + return atol + + +attrs = Attr() + + +def fn(x): + return paddle.flatten( + x, start_axis=attrs.start_axi, stop_axis=attrs.stop_axi + ) + + +def expect_grad(inputs): + paddle.disable_static() + inputs.stop_gradient = False + res = fn(inputs) + gradients = paddle.grad(res, inputs) + return gradients + + +class TestCompositeFlatten(unittest.TestCase): + def setUp(self): + self.dtypes = ["float32", "float64"] + self.shapes = [ + [1, 2, 1, 2], + [16, 6, 6, 10], + [2, 4, 6, 8, 3], + [2, 3, 5, 1, 2], + [2, 3, 4, 5, 6, 7], + ] + self.start_axis = [0, 1, 2] + self.stop_axis = [-1, 2, 3] + + def cal_composite_grad(self, inputs): + paddle.enable_static() + core._set_prim_forward_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x) + blocks = main_program.blocks + + fwd_ops = [op.type for op in blocks[0].ops] + # Ensure that flatten_contiguous_range in original block + self.assertTrue('flatten_contiguous_range' in fwd_ops) + + paddle.incubate.autograd.to_prim(blocks) + + fwd_ops_new = [op.type for op in blocks[0].ops] + # Ensure that flatten_contiguous_range is splitted into small ops + self.assertTrue('flatten_contiguous_range' not in fwd_ops_new) + + z = paddle.static.gradients([y], x) + + fwd_ops_grad = [op.type for op in blocks[0].ops] + # Ensure that flatten_contiguous_range_grad not in grad block + self.assertTrue('flatten_contiguous_range_grad' not in fwd_ops_grad) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) + paddle.disable_static() + core._set_prim_forward_enabled(False) + return res + + def compare_backward(self): + np_data = generate_data(attrs.shape, attrs.dtype) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_grad(tensor_data)[0].numpy() + actual = self.cal_composite_grad(np_data)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("backward"), + atol=attrs.get_atol("backward"), + ) + + def test_backward(self): + for i in self.dtypes: + for j in self.shapes: + for t in self.start_axis: + for k in self.stop_axis: + attrs.set_dtype(i) + attrs.set_shape(j) + attrs.set_start_axi(t) + attrs.set_stop_axi(k) + self.compare_backward() + + +class TestCompositeFlattenPrimBackward(unittest.TestCase): + "test composite flatten and prim backward" + + def setUp(self): + self.dtypes = ["float32", "float64"] + self.shapes = [ + [1, 2, 1, 2], + [16, 6, 6, 10], + [2, 4, 6, 8, 3], + [2, 3, 5, 1, 2], + [2, 3, 4, 5, 6, 7], + ] + self.start_axis = [0, 1, 2] + self.stop_axis = [-1, 2, 3] + + def cal_composite_grad(self, inputs): + paddle.enable_static() + core._set_prim_all_enabled(True) + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.static.data( + 'x', shape=inputs.shape, dtype=str(inputs.dtype) + ) + x.stop_gradient = False + y = fn(x) + blocks = main_program.blocks + paddle.incubate.autograd.to_prim(blocks) + z = paddle.static.gradients([y], x) + + exe = paddle.static.Executor() + exe.run(startup_program) + res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) + paddle.disable_static() + core._set_prim_all_enabled(False) + return res + + def compare_backward(self): + np_data = generate_data(attrs.shape, attrs.dtype) + tensor_data = paddle.to_tensor(np_data) + + expect = expect_grad(tensor_data)[0].numpy() + actual = self.cal_composite_grad(np_data)[0] + + assert expect.dtype == actual.dtype + np.testing.assert_allclose( + expect, + actual, + rtol=attrs.get_rtol("prim_backward"), + atol=attrs.get_atol("prim_backward"), + ) + + def test_prim_backward(self): + for i in self.dtypes: + for j in self.shapes: + for t in self.start_axis: + for k in self.stop_axis: + attrs.set_dtype(i) + attrs.set_shape(j) + attrs.set_start_axi(t) + attrs.set_stop_axi(k) + self.compare_backward() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 70bb8f8b80492..591757b762ec5 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -149,3 +149,34 @@ def mean_composite(x, axis, keepdim): dtype=sum_x.dtype, ) return divide(sum_x, norm) + + +def maybe_wrap_dim(dim: int, dim_post_expr: int): + min = -dim_post_expr + max = dim_post_expr - 1 + assert not (dim < min or dim > max) + if dim < 0: + dim += dim_post_expr + return dim + + +@REGISTER_COMPOSITE('flatten_contiguous_range') +def flatten_contiguous_range_composite(x, start_axis, stop_axis): + """define composite rule of op flatten, flatten_contiguous_range -> flatten""" + shape_in = x.shape + start_dim = maybe_wrap_dim(start_axis, len(shape_in)) + end_dim = maybe_wrap_dim(stop_axis, len(shape_in)) + assert start_dim <= end_dim + if len(shape_in) == 0 or start_dim == end_dim: + return x, to_tensor(shape_in, dtype=float32) + slice_numel = 1 + for i in range(start_dim, end_dim + 1): + slice_numel *= shape_in[i] + # slice_numel = multiply_integers(shape_in[start_dim:end_dim - start_dim + 1]) + shape_out: List[int] = [] + for i in range(start_dim): + shape_out.append(shape_in[i]) + shape_out.append(slice_numel) + for i in range(end_dim + 1, len(shape_in)): + shape_out.append(shape_in[i]) + return reshape(x, shape=shape_out), to_tensor(shape_out, dtype='float32') From a80f705789a78c059ae696b2238aef2a216d51f6 Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Tue, 21 Feb 2023 11:14:40 +0800 Subject: [PATCH 02/14] get the right xshape and pass func test --- .../prim/composite_ops/test_composite_flatten.py | 1 - python/paddle/incubate/autograd/composite_rules.py | 9 ++++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py index ad03ef588b872..ebefecfd714aa 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py @@ -73,7 +73,6 @@ def expect_forward(inputs): class TestCompositeFlatten(unittest.TestCase): def setUp(self): - # self.dtypes = ["float16", "float32", "float64"] self.dtypes = ["float32", "float64"] self.shapes = [ [16, 16, 64, 64, 10], diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 591757b762ec5..b5d5c75233968 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -164,19 +164,22 @@ def maybe_wrap_dim(dim: int, dim_post_expr: int): def flatten_contiguous_range_composite(x, start_axis, stop_axis): """define composite rule of op flatten, flatten_contiguous_range -> flatten""" shape_in = x.shape + shape_x_out: List[int] = [0] + shape_x_out.extend(shape_in) + xshape = full(shape=shape_x_out, fill_value=0, dtype=x.dtype) start_dim = maybe_wrap_dim(start_axis, len(shape_in)) end_dim = maybe_wrap_dim(stop_axis, len(shape_in)) assert start_dim <= end_dim if len(shape_in) == 0 or start_dim == end_dim: - return x, to_tensor(shape_in, dtype=float32) + return reshape(x, shape=shape_in), xshape slice_numel = 1 for i in range(start_dim, end_dim + 1): slice_numel *= shape_in[i] # slice_numel = multiply_integers(shape_in[start_dim:end_dim - start_dim + 1]) - shape_out: List[int] = [] + shape_out = [] for i in range(start_dim): shape_out.append(shape_in[i]) shape_out.append(slice_numel) for i in range(end_dim + 1, len(shape_in)): shape_out.append(shape_in[i]) - return reshape(x, shape=shape_out), to_tensor(shape_out, dtype='float32') + return reshape(x, shape=shape_out), xshape From 4943544908ed55c74598bb7b513752bfab7068bd Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Tue, 21 Feb 2023 14:52:47 +0800 Subject: [PATCH 03/14] add cinn unit test --- .../test_cinn_prim_flatten.py | 188 ++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_flatten.py diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_flatten.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_flatten.py new file mode 100644 index 0000000000000..5135f8dfd8ca9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_flatten.py @@ -0,0 +1,188 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +import unittest + +import numpy as np + +import paddle +from paddle.fluid import core + +TOLERANCE = { + "float32": {"rtol": 1e-6, "atol": 1e-6}, + "float64": {"rtol": 1e-15, "atol": 1e-15}, +} + +start_axes = [0, 1, 2] +stop_axes = [-1, 3, 4] + + +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=build_strategy) + + +def generate_data(shape, dtype="float32"): + np_data = np.random.random(shape).astype(dtype) + return np_data + + +class PrimeNet( + paddle.nn.Layer, +): + def __init__(self): + super(PrimeNet, self).__init__() + self.fc = paddle.nn.Linear(4, 4) + + def forward(self, x): + out = paddle.flatten(x) + return out + + +class TestPrimForward(unittest.TestCase): + """ + This case only tests prim_forward + to_static + cinn. Thus we need to + set this flag as False to avoid prim_backward. + core.set_prim_backward(False) + """ + + def setUp(self): + paddle.seed(2022) + self.shapes = [[1, 2, 3, 4, 5], [64, 32, 16, 8, 4]] + self.dtypes = ["float32", "float64"] + + def train(self, use_prim, data): + for start in start_axes: + for stop in stop_axes: + return self._train(use_prim, data, start, stop) + + def _train(self, use_prim, data, start, stop): + paddle.seed(2022) + net = PrimeNet() + sgd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + core._set_prim_forward_enabled(use_prim) + if use_prim: + net = apply_to_static(net, use_prim) + + res = [] + for _ in range(10): + out = net(data) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + res.append(out.numpy()) + + self.check_prim(net, use_prim) + + return res + + def check_prim(self, net, use_prim): + if not use_prim: + return + fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + # Ensure that flatten is splitted into small ops + self.assertTrue('flatten_contiguous_range_grad' not in fwd_ops) + + def test_cinn_prim_forward(self): + for shape in self.shapes: + for dtype in self.dtypes: + data = generate_data(shape, dtype) + data_t = paddle.to_tensor(data) + data_t.stop_gradient = False + dy_res = self.train(use_prim=False, data=data_t) + cinn_res = self.train(use_prim=True, data=data_t) + + np.testing.assert_allclose( + cinn_res, + dy_res, + rtol=TOLERANCE[dtype]['rtol'], + atol=TOLERANCE[dtype]['atol'], + ) + + +class TestPrimForwardAndBackward(unittest.TestCase): + """ + Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph + """ + + def setUp(self): + paddle.seed(2022) + self.shapes = [[1, 2, 3, 4, 5], [64, 32, 16, 8, 4]] + self.dtypes = ["float32", "float64"] + + def train(self, use_prim, data): + for start in start_axes: + for stop in stop_axes: + return self._train(use_prim, data, start, stop) + + def _train(self, use_prim, data, axis, keep_dim): + paddle.seed(2022) + net = PrimeNet() + sgd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + core._set_prim_all_enabled(use_prim) + if use_prim: + net = apply_to_static(net, use_prim) + + res = [] + for _ in range(10): + out = net(data) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + res.append(out.numpy()) + + self.check_prim(net, use_prim) + + return res + + def check_prim(self, net, use_prim): + if not use_prim: + return + fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + # Ensure that flatten is splitted into small ops + self.assertTrue('flatten_contiguous_range_grad' not in fwd_ops) + + def test_cinn_prim(self): + plat = platform.system() + if plat == "Linux": + for shape in self.shapes: + for dtype in self.dtypes: + data = generate_data(shape, dtype) + data_t = paddle.to_tensor(data) + data_t.stop_gradient = False + dy_res = self.train(use_prim=False, data=data_t) + cinn_res = self.train(use_prim=True, data=data_t) + + np.testing.assert_allclose( + cinn_res, + dy_res, + rtol=TOLERANCE[dtype]['rtol'], + atol=TOLERANCE[dtype]['atol'], + ) + else: + pass + + +if __name__ == '__main__': + unittest.main() From d9ffe5e33d34f7fd96852eac397b1a978dab5232 Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Tue, 21 Feb 2023 15:27:32 +0800 Subject: [PATCH 04/14] Remove cinn test, wait for it to be added after repair --- .../test_cinn_prim_flatten.py | 188 ------------------ 1 file changed, 188 deletions(-) delete mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_flatten.py diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_flatten.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_flatten.py deleted file mode 100644 index 5135f8dfd8ca9..0000000000000 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_flatten.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import platform -import unittest - -import numpy as np - -import paddle -from paddle.fluid import core - -TOLERANCE = { - "float32": {"rtol": 1e-6, "atol": 1e-6}, - "float64": {"rtol": 1e-15, "atol": 1e-15}, -} - -start_axes = [0, 1, 2] -stop_axes = [-1, 3, 4] - - -def apply_to_static(net, use_cinn): - build_strategy = paddle.static.BuildStrategy() - build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static(net, build_strategy=build_strategy) - - -def generate_data(shape, dtype="float32"): - np_data = np.random.random(shape).astype(dtype) - return np_data - - -class PrimeNet( - paddle.nn.Layer, -): - def __init__(self): - super(PrimeNet, self).__init__() - self.fc = paddle.nn.Linear(4, 4) - - def forward(self, x): - out = paddle.flatten(x) - return out - - -class TestPrimForward(unittest.TestCase): - """ - This case only tests prim_forward + to_static + cinn. Thus we need to - set this flag as False to avoid prim_backward. - core.set_prim_backward(False) - """ - - def setUp(self): - paddle.seed(2022) - self.shapes = [[1, 2, 3, 4, 5], [64, 32, 16, 8, 4]] - self.dtypes = ["float32", "float64"] - - def train(self, use_prim, data): - for start in start_axes: - for stop in stop_axes: - return self._train(use_prim, data, start, stop) - - def _train(self, use_prim, data, start, stop): - paddle.seed(2022) - net = PrimeNet() - sgd = paddle.optimizer.SGD( - learning_rate=0.1, parameters=net.parameters() - ) - core._set_prim_forward_enabled(use_prim) - if use_prim: - net = apply_to_static(net, use_prim) - - res = [] - for _ in range(10): - out = net(data) - loss = paddle.mean(out) - loss.backward() - sgd.step() - sgd.clear_grad() - - res.append(out.numpy()) - - self.check_prim(net, use_prim) - - return res - - def check_prim(self, net, use_prim): - if not use_prim: - return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] - # Ensure that flatten is splitted into small ops - self.assertTrue('flatten_contiguous_range_grad' not in fwd_ops) - - def test_cinn_prim_forward(self): - for shape in self.shapes: - for dtype in self.dtypes: - data = generate_data(shape, dtype) - data_t = paddle.to_tensor(data) - data_t.stop_gradient = False - dy_res = self.train(use_prim=False, data=data_t) - cinn_res = self.train(use_prim=True, data=data_t) - - np.testing.assert_allclose( - cinn_res, - dy_res, - rtol=TOLERANCE[dtype]['rtol'], - atol=TOLERANCE[dtype]['atol'], - ) - - -class TestPrimForwardAndBackward(unittest.TestCase): - """ - Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph - """ - - def setUp(self): - paddle.seed(2022) - self.shapes = [[1, 2, 3, 4, 5], [64, 32, 16, 8, 4]] - self.dtypes = ["float32", "float64"] - - def train(self, use_prim, data): - for start in start_axes: - for stop in stop_axes: - return self._train(use_prim, data, start, stop) - - def _train(self, use_prim, data, axis, keep_dim): - paddle.seed(2022) - net = PrimeNet() - sgd = paddle.optimizer.SGD( - learning_rate=0.1, parameters=net.parameters() - ) - core._set_prim_all_enabled(use_prim) - if use_prim: - net = apply_to_static(net, use_prim) - - res = [] - for _ in range(10): - out = net(data) - loss = paddle.mean(out) - loss.backward() - sgd.step() - sgd.clear_grad() - - res.append(out.numpy()) - - self.check_prim(net, use_prim) - - return res - - def check_prim(self, net, use_prim): - if not use_prim: - return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] - # Ensure that flatten is splitted into small ops - self.assertTrue('flatten_contiguous_range_grad' not in fwd_ops) - - def test_cinn_prim(self): - plat = platform.system() - if plat == "Linux": - for shape in self.shapes: - for dtype in self.dtypes: - data = generate_data(shape, dtype) - data_t = paddle.to_tensor(data) - data_t.stop_gradient = False - dy_res = self.train(use_prim=False, data=data_t) - cinn_res = self.train(use_prim=True, data=data_t) - - np.testing.assert_allclose( - cinn_res, - dy_res, - rtol=TOLERANCE[dtype]['rtol'], - atol=TOLERANCE[dtype]['atol'], - ) - else: - pass - - -if __name__ == '__main__': - unittest.main() From d3f8af73e2ba63a70544ec13369dbab9ed3c92f7 Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Wed, 22 Feb 2023 11:26:32 +0800 Subject: [PATCH 05/14] add comp test to test_flatten_contiguous_range_op.py --- .../tests/unittests/test_flatten_contiguous_range_op.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index df36af0f5166a..57d3e600c3a8c 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -25,21 +25,25 @@ def setUp(self): self.python_api = paddle.flatten self.python_out_sig = ["Out"] self.op_type = "flatten_contiguous_range" + self.prim_op_type = "comp" self.start_axis = 0 self.stop_axis = -1 self.init_test_case() self.inputs = {"X": np.random.random(self.in_shape).astype("float64")} self.init_attrs() + self.enable_cinn = False self.outputs = { "Out": self.inputs["X"].reshape(self.new_shape), "XShape": np.random.random(self.in_shape).astype("float32"), } def test_check_output(self): - self.check_output(no_check_set=["XShape"], check_eager=True) + self.check_output( + no_check_set=["XShape"], check_eager=True, check_prim=True + ) def test_check_grad(self): - self.check_grad(["X"], "Out", check_eager=True) + self.check_grad(["X"], "Out", check_eager=True, check_prim=True) def init_test_case(self): self.in_shape = (3, 2, 5, 4) From 4e43a7315be3e66f81f2ec94db80e69ee0c4c05f Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Wed, 22 Feb 2023 11:44:26 +0800 Subject: [PATCH 06/14] remove func test on composite_ops --- .../composite_ops/test_composite_flatten.py | 143 ------------ .../test_composite_flatten_grad.py | 219 ------------------ 2 files changed, 362 deletions(-) delete mode 100644 python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py delete mode 100644 python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten_grad.py diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py deleted file mode 100644 index ebefecfd714aa..0000000000000 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from utils import TOLERANCE - -import paddle -from paddle.fluid import core - - -def generate_data(shape, dtype="float32"): - np_data = np.random.random(shape).astype(dtype) - return np_data - - -class Attr: - def __init__(self) -> None: - self.dtype = "float32" - self.shape = None - self.start_axi = None - self.stop_axi = None - - def set_dtype(self, dtype) -> None: - self.dtype = dtype - return - - def set_shape(self, shape) -> None: - self.shape = shape - return - - def set_start_axi(self, start_axi) -> None: - self.start_axi = start_axi - return - - def set_stop_axi(self, stop_axi) -> None: - self.stop_axi = stop_axi - return - - def get_rtol(self, flag): - rtol = TOLERANCE[self.dtype][flag].get("rtol") - return rtol - - def get_atol(self, flag): - atol = TOLERANCE[self.dtype][flag].get("atol") - return atol - - -attrs = Attr() - - -def fn(x): - return paddle.flatten( - x, start_axis=attrs.start_axi, stop_axis=attrs.stop_axi - ) - - -def expect_forward(inputs): - return fn(inputs) - - -class TestCompositeFlatten(unittest.TestCase): - def setUp(self): - self.dtypes = ["float32", "float64"] - self.shapes = [ - [16, 16, 64, 64, 10], - [2, 3, 4, 6, 8, 2, 3, 4], - [2, 3, 5, 1, 2], - [2, 3, 4, 5, 6, 7], - ] - self.start_axis = [0, 1, 2] - self.stop_axis = [-1, 2, 3, 4] - - def cal_composite(self, inputs): - paddle.enable_static() - core._set_prim_forward_enabled(True) - startup_program = paddle.static.Program() - main_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): - x = paddle.static.data( - 'x', shape=inputs.shape, dtype=str(inputs.dtype) - ) - y = fn(x) - blocks = main_program.blocks - - fwd_ops = [op.type for op in blocks[0].ops] - # Ensure that flatten in original block - self.assertTrue('flatten_contiguous_range' in fwd_ops) - - paddle.incubate.autograd.to_prim(blocks) - - fwd_ops_new = [op.type for op in blocks[0].ops] - # Ensure that flatten is splitted into small ops - self.assertTrue('flatten_contiguous_range' not in fwd_ops_new) - - exe = paddle.static.Executor() - exe.run(startup_program) - res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y]) - paddle.disable_static() - core._set_prim_forward_enabled(False) - return res - - def compare_forward(self): - np_data = generate_data(attrs.shape, attrs.dtype) - tensor_data = paddle.to_tensor(np_data) - - expect = expect_forward(tensor_data).numpy() - actual = self.cal_composite(np_data)[0] - - assert expect.dtype == actual.dtype - np.testing.assert_allclose( - expect, - actual, - rtol=attrs.get_rtol("forward"), - atol=attrs.get_atol("forward"), - ) - - def test_forward(self): - for i in self.dtypes: - for j in self.shapes: - for t in self.start_axis: - for k in self.stop_axis: - attrs.set_dtype(i) - attrs.set_shape(j) - attrs.set_start_axi(t) - attrs.set_stop_axi(k) - self.compare_forward() - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten_grad.py deleted file mode 100644 index 9ab5721f56726..0000000000000 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_flatten_grad.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from utils import TOLERANCE - -import paddle -from paddle.fluid import core - - -def generate_data(shape, dtype="float32"): - np_data = np.random.random(shape).astype(dtype) - return np_data - - -class Attr: - def __init__(self) -> None: - self.dtype = "float32" - self.shape = None - self.start_axi = None - self.stop_axi = None - - def set_dtype(self, dtype) -> None: - self.dtype = dtype - return - - def set_shape(self, shape) -> None: - self.shape = shape - return - - def set_start_axi(self, start_axi) -> None: - self.start_axi = start_axi - return - - def set_stop_axi(self, stop_axi) -> None: - self.stop_axi = stop_axi - return - - def get_rtol(self, flag): - rtol = TOLERANCE[self.dtype][flag].get("rtol") - return rtol - - def get_atol(self, flag): - atol = TOLERANCE[self.dtype][flag].get("atol") - return atol - - -attrs = Attr() - - -def fn(x): - return paddle.flatten( - x, start_axis=attrs.start_axi, stop_axis=attrs.stop_axi - ) - - -def expect_grad(inputs): - paddle.disable_static() - inputs.stop_gradient = False - res = fn(inputs) - gradients = paddle.grad(res, inputs) - return gradients - - -class TestCompositeFlatten(unittest.TestCase): - def setUp(self): - self.dtypes = ["float32", "float64"] - self.shapes = [ - [1, 2, 1, 2], - [16, 6, 6, 10], - [2, 4, 6, 8, 3], - [2, 3, 5, 1, 2], - [2, 3, 4, 5, 6, 7], - ] - self.start_axis = [0, 1, 2] - self.stop_axis = [-1, 2, 3] - - def cal_composite_grad(self, inputs): - paddle.enable_static() - core._set_prim_forward_enabled(True) - startup_program = paddle.static.Program() - main_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): - x = paddle.static.data( - 'x', shape=inputs.shape, dtype=str(inputs.dtype) - ) - x.stop_gradient = False - y = fn(x) - blocks = main_program.blocks - - fwd_ops = [op.type for op in blocks[0].ops] - # Ensure that flatten_contiguous_range in original block - self.assertTrue('flatten_contiguous_range' in fwd_ops) - - paddle.incubate.autograd.to_prim(blocks) - - fwd_ops_new = [op.type for op in blocks[0].ops] - # Ensure that flatten_contiguous_range is splitted into small ops - self.assertTrue('flatten_contiguous_range' not in fwd_ops_new) - - z = paddle.static.gradients([y], x) - - fwd_ops_grad = [op.type for op in blocks[0].ops] - # Ensure that flatten_contiguous_range_grad not in grad block - self.assertTrue('flatten_contiguous_range_grad' not in fwd_ops_grad) - - exe = paddle.static.Executor() - exe.run(startup_program) - res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) - paddle.disable_static() - core._set_prim_forward_enabled(False) - return res - - def compare_backward(self): - np_data = generate_data(attrs.shape, attrs.dtype) - tensor_data = paddle.to_tensor(np_data) - - expect = expect_grad(tensor_data)[0].numpy() - actual = self.cal_composite_grad(np_data)[0] - - assert expect.dtype == actual.dtype - np.testing.assert_allclose( - expect, - actual, - rtol=attrs.get_rtol("backward"), - atol=attrs.get_atol("backward"), - ) - - def test_backward(self): - for i in self.dtypes: - for j in self.shapes: - for t in self.start_axis: - for k in self.stop_axis: - attrs.set_dtype(i) - attrs.set_shape(j) - attrs.set_start_axi(t) - attrs.set_stop_axi(k) - self.compare_backward() - - -class TestCompositeFlattenPrimBackward(unittest.TestCase): - "test composite flatten and prim backward" - - def setUp(self): - self.dtypes = ["float32", "float64"] - self.shapes = [ - [1, 2, 1, 2], - [16, 6, 6, 10], - [2, 4, 6, 8, 3], - [2, 3, 5, 1, 2], - [2, 3, 4, 5, 6, 7], - ] - self.start_axis = [0, 1, 2] - self.stop_axis = [-1, 2, 3] - - def cal_composite_grad(self, inputs): - paddle.enable_static() - core._set_prim_all_enabled(True) - startup_program = paddle.static.Program() - main_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): - x = paddle.static.data( - 'x', shape=inputs.shape, dtype=str(inputs.dtype) - ) - x.stop_gradient = False - y = fn(x) - blocks = main_program.blocks - paddle.incubate.autograd.to_prim(blocks) - z = paddle.static.gradients([y], x) - - exe = paddle.static.Executor() - exe.run(startup_program) - res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z]) - paddle.disable_static() - core._set_prim_all_enabled(False) - return res - - def compare_backward(self): - np_data = generate_data(attrs.shape, attrs.dtype) - tensor_data = paddle.to_tensor(np_data) - - expect = expect_grad(tensor_data)[0].numpy() - actual = self.cal_composite_grad(np_data)[0] - - assert expect.dtype == actual.dtype - np.testing.assert_allclose( - expect, - actual, - rtol=attrs.get_rtol("prim_backward"), - atol=attrs.get_atol("prim_backward"), - ) - - def test_prim_backward(self): - for i in self.dtypes: - for j in self.shapes: - for t in self.start_axis: - for k in self.stop_axis: - attrs.set_dtype(i) - attrs.set_shape(j) - attrs.set_start_axi(t) - attrs.set_stop_axi(k) - self.compare_backward() - - -if __name__ == '__main__': - unittest.main() From 3c906bb27eecdbc7884db9f38be8504da20e3497 Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Wed, 22 Feb 2023 11:48:23 +0800 Subject: [PATCH 07/14] Add comments to maybe_wrap_dim func --- python/paddle/incubate/autograd/composite_rules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 1fd5fd2834bf9..9d67f25fb9ca2 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -181,6 +181,7 @@ def mean_composite(x, axis, keepdim): def maybe_wrap_dim(dim: int, dim_post_expr: int): + """get real dim form idx and len of dims""" min = -dim_post_expr max = dim_post_expr - 1 assert not (dim < min or dim > max) From c569f5950236be7dd02b54602caceee826abcadd Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Wed, 22 Feb 2023 11:50:42 +0800 Subject: [PATCH 08/14] remove commented code --- python/paddle/incubate/autograd/composite_rules.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 9d67f25fb9ca2..5ea46de672e95 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -205,7 +205,6 @@ def flatten_contiguous_range_composite(x, start_axis, stop_axis): slice_numel = 1 for i in range(start_dim, end_dim + 1): slice_numel *= shape_in[i] - # slice_numel = multiply_integers(shape_in[start_dim:end_dim - start_dim + 1]) shape_out = [] for i in range(start_dim): shape_out.append(shape_in[i]) From 48547abd13195194da2e03cca11dc0787550b006 Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Wed, 22 Feb 2023 14:36:52 +0800 Subject: [PATCH 09/14] fix the problem with 0D tensor case --- python/paddle/incubate/autograd/composite_rules.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 5ea46de672e95..298edf4e6f659 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -182,6 +182,9 @@ def mean_composite(x, axis, keepdim): def maybe_wrap_dim(dim: int, dim_post_expr: int): """get real dim form idx and len of dims""" + if dim_post_expr == 0: + assert dim == 0 or dim == -1 + return 0 min = -dim_post_expr max = dim_post_expr - 1 assert not (dim < min or dim > max) From d3846981bbee39fb73157a16aead722ff4b25822 Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Wed, 22 Feb 2023 20:03:26 +0800 Subject: [PATCH 10/14] add flatten split rule comment --- python/paddle/incubate/autograd/composite_rules.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 298edf4e6f659..d5d37f2d151b4 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -195,7 +195,12 @@ def maybe_wrap_dim(dim: int, dim_post_expr: int): @REGISTER_COMPOSITE('flatten_contiguous_range') def flatten_contiguous_range_composite(x, start_axis, stop_axis): - """define composite rule of op flatten, flatten_contiguous_range -> flatten""" + """ + define composite rule of op flatten, flatten_contiguous_range -> flatten. + xshape is the dim with 0 added to the front of x, keep the shape information of x to calculate the grad. + shape_out is the parameter of reshape, get from start_axis and stop_axis. + out = reshape(x, shape=shape_out), xshape + """ shape_in = x.shape shape_x_out: List[int] = [0] shape_x_out.extend(shape_in) From e09e5f138c0b7df777efc53da71252f1a6e69cbe Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Thu, 23 Feb 2023 17:13:53 +0800 Subject: [PATCH 11/14] fix syntax issues --- python/paddle/incubate/autograd/composite_rules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 60be201f7566b..ae198066bbed2 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -204,7 +204,7 @@ def flatten_contiguous_range_composite(x, start_axis, stop_axis): out = reshape(x, shape=shape_out), xshape """ shape_in = x.shape - shape_x_out: List[int] = [0] + shape_x_out = [0] shape_x_out.extend(shape_in) xshape = full(shape=shape_x_out, fill_value=0, dtype=x.dtype) start_dim = maybe_wrap_dim(start_axis, len(shape_in)) From 70d74536ca0040611453c864943723ca8f9857f8 Mon Sep 17 00:00:00 2001 From: xuyongsheng Date: Fri, 24 Feb 2023 15:28:59 +0800 Subject: [PATCH 12/14] block flatten on resnet_prim_cinn --- .../fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py b/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py index 1d26926445edc..379ee30fb840c 100644 --- a/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py +++ b/python/paddle/fluid/tests/unittests/prim/model/test_resnet_prim_cinn.py @@ -159,6 +159,7 @@ def test_cinn(self): not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN" ) def test_prim_cinn(self): + core._set_prim_forward_blacklist("flatten_contiguous_range") dy2st_prim_cinn = train( to_static=True, enable_prim=True, enable_cinn=True ) From 95f333b133a58cd523f336b720732721c41d3275 Mon Sep 17 00:00:00 2001 From: xysheng-baidu Date: Fri, 24 Feb 2023 11:26:26 +0000 Subject: [PATCH 13/14] remove maybe_wrap_dim func --- .../paddle/incubate/autograd/composite_rules.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 5efe689ee3c17..5b95edb53495f 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -191,19 +191,6 @@ def mean_composite(x, axis, keepdim): return divide(sum_x, norm) -def maybe_wrap_dim(dim: int, dim_post_expr: int): - """get real dim form idx and len of dims""" - if dim_post_expr == 0: - assert dim == 0 or dim == -1 - return 0 - min = -dim_post_expr - max = dim_post_expr - 1 - assert not (dim < min or dim > max) - if dim < 0: - dim += dim_post_expr - return dim - - @REGISTER_COMPOSITE('flatten_contiguous_range') def flatten_contiguous_range_composite(x, start_axis, stop_axis): """ @@ -216,8 +203,8 @@ def flatten_contiguous_range_composite(x, start_axis, stop_axis): shape_x_out = [0] shape_x_out.extend(shape_in) xshape = full(shape=shape_x_out, fill_value=0, dtype=x.dtype) - start_dim = maybe_wrap_dim(start_axis, len(shape_in)) - end_dim = maybe_wrap_dim(stop_axis, len(shape_in)) + start_dim = start_axis if len(shape_in) != 0 else 0 + end_dim = stop_axis if len(shape_in) != 0 else 0 assert start_dim <= end_dim if len(shape_in) == 0 or start_dim == end_dim: return reshape(x, shape=shape_in), xshape From fe21b3a7a5f06c9d8ddf978be66726af5bef6026 Mon Sep 17 00:00:00 2001 From: xysheng-baidu Date: Mon, 27 Feb 2023 10:50:48 +0000 Subject: [PATCH 14/14] Use none instead od xshape --- python/paddle/incubate/autograd/composite_rules.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 5b95edb53495f..3eadc22fb635e 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -195,19 +195,16 @@ def mean_composite(x, axis, keepdim): def flatten_contiguous_range_composite(x, start_axis, stop_axis): """ define composite rule of op flatten, flatten_contiguous_range -> flatten. - xshape is the dim with 0 added to the front of x, keep the shape information of x to calculate the grad. + CINN doesn't need xshape for backward pass, return none instead of xshape. shape_out is the parameter of reshape, get from start_axis and stop_axis. out = reshape(x, shape=shape_out), xshape """ shape_in = x.shape - shape_x_out = [0] - shape_x_out.extend(shape_in) - xshape = full(shape=shape_x_out, fill_value=0, dtype=x.dtype) start_dim = start_axis if len(shape_in) != 0 else 0 end_dim = stop_axis if len(shape_in) != 0 else 0 assert start_dim <= end_dim if len(shape_in) == 0 or start_dim == end_dim: - return reshape(x, shape=shape_in), xshape + return reshape(x, shape=shape_in), None slice_numel = 1 for i in range(start_dim, end_dim + 1): slice_numel *= shape_in[i] @@ -217,7 +214,7 @@ def flatten_contiguous_range_composite(x, start_axis, stop_axis): shape_out.append(slice_numel) for i in range(end_dim + 1, len(shape_in)): shape_out.append(shape_in[i]) - return reshape(x, shape=shape_out), xshape + return reshape(x, shape=shape_out), None @REGISTER_COMPOSITE('dropout')