From 9b8ca293b0eb72a8e1c7c39a356ffc8effed6a00 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 12 May 2021 14:43:04 +0800 Subject: [PATCH] Add param_guard in ParameterList to support @to_static --- python/paddle/fluid/dygraph/container.py | 7 +- .../dygraph_to_static/test_param_guard.py | 95 +++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_param_guard.py diff --git a/python/paddle/fluid/dygraph/container.py b/python/paddle/fluid/dygraph/container.py index c7ea412fec1b77..2938516e5bc442 100644 --- a/python/paddle/fluid/dygraph/container.py +++ b/python/paddle/fluid/dygraph/container.py @@ -15,6 +15,7 @@ from collections import OrderedDict from ..framework import Parameter from .layers import Layer +from .base import param_guard __all__ = [ 'Sequential', @@ -159,7 +160,8 @@ def __init__(self, parameters=None): self.add_parameter(str(idx), param) def __getitem__(self, idx): - return self._parameters[str(idx)] + with param_guard(self._parameters): + return self._parameters[str(idx)] def __setitem__(self, idx, param): assert isinstance(param, Parameter) @@ -169,7 +171,8 @@ def __len__(self): return len(self._parameters) def __iter__(self): - return iter(self._parameters.values()) + with param_guard(self._parameters): + return iter(self._parameters.values()) def append(self, parameter): """Appends a given parameter at the end of the list. diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_param_guard.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_param_guard.py new file mode 100644 index 00000000000000..afae480a926765 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_param_guard.py @@ -0,0 +1,95 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest + +from paddle.jit import to_static, ProgramTranslator + + +class NetWithParameterList(paddle.nn.Layer): + def __init__(self, in_size, out_size): + super(NetWithParameterList, self).__init__() + weight = self.create_parameter([in_size, out_size]) + bias = self.create_parameter([out_size], is_bias=True) + self.params = paddle.nn.ParameterList([weight, bias]) + + @to_static + def forward(self, x): + out = paddle.matmul(x, self.params[0]) + out = paddle.add(out, self.params[1]) + out = paddle.tanh(out) + return out + + +class NetWithParameterListIter(NetWithParameterList): + def __init__(self, in_size, out_size): + super(NetWithParameterListIter, self).__init__(in_size, out_size) + + @to_static + def forward(self, x): + # NOTE: manually trigger `__iter__` logic. + params = list(self.params.__iter__()) + out = paddle.matmul(x, params[0]) + out = paddle.add(out, params[1]) + out = paddle.tanh(out) + return out + + +class TestParameterList(unittest.TestCase): + def setUp(self): + self.seed = 2021 + self.iter_num = 5 + self.prog_trans = ProgramTranslator() + + def train(self, is_iter, to_static): + paddle.seed(self.seed) + np.random.seed(self.seed) + self.prog_trans.enable(to_static) + if is_iter: + net = NetWithParameterList(10, 3) + else: + net = NetWithParameterListIter(10, 3) + sgd = paddle.optimizer.SGD(0.1, parameters=net.parameters()) + + for batch_id in range(self.iter_num): + x = paddle.rand([4, 10], dtype='float32') + out = net(x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + + return loss + + def test_parameter_list(self): + static_loss = self.train(False, to_static=True) + dygraph_loss = self.train(False, to_static=False) + self.assertTrue( + np.allclose(dygraph_loss, static_loss), + msg='dygraph result is {}\nstatic result is {}'.format(dygraph_loss, + static_loss)) + + def test_parameter_list_iter(self): + static_loss = self.train(True, to_static=True) + dygraph_loss = self.train(True, to_static=False) + self.assertTrue( + np.allclose(dygraph_loss, static_loss), + msg='dygraph result is {}\nstatic result is {}'.format(dygraph_loss, + static_loss)) + + +if __name__ == '__main__': + unittest.main()