Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2st]Fix error when set buffer in forward #38540

Merged
merged 8 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in self._parameters:
if in_declarative_mode() and not framework.in_dygraph_mode():
if in_declarative_mode():
return _convert_into_variable(self._parameters[name])
return self._parameters[name]
if '_sub_layers' in self.__dict__:
Expand All @@ -1104,7 +1104,7 @@ def __getattr__(self, name):
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
if in_declarative_mode() and not framework.in_dygraph_mode():
if in_declarative_mode():
return _convert_into_variable(_buffers[name])
return _buffers[name]
return object.__getattribute__(self, name)
Expand Down Expand Up @@ -1176,11 +1176,16 @@ def _remove_if_exist(*dicts):
# but should all non-Variable _buffers[name] be re-assign? We
# should consider it in the future. I current wrote this as
# conservative code.
if _buffers[name] is None or type(_buffers[
name]) == core.VarBase:
if in_declarative_mode() and _buffers[name] is None:
raise RuntimeError(
'In Dy2stat, self.{0} is a buffer and self.{0} is '
'not allowed to be set to Variable when self.{0} is None.'.
format(name))
elif _buffers[name] is None or type(
getattr(self, name)) == core.VarBase:
_buffers[name] = assign(value)
else:
assign(value, _buffers[name])
assign(value, getattr(self, name))
elif value is not None:
raise TypeError(
"assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(self, hidden_dim=16):
self.alpha = 10.
self.constant_vars = {}

@paddle.jit.to_static
def forward(self, input):
hidden_dim = input.shape[-1]
if hidden_dim != self.hidden_dim:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,5 +408,45 @@ def test_call_non_forward(self):
paddle.enable_static()


class SetBuffersNet1(paddle.nn.Layer):
def __init__(self):
super(SetBuffersNet1, self).__init__()
self.a = paddle.to_tensor([1])

@paddle.jit.to_static
def forward(self):
self.a = self.a + 1
return self.a


class SetBuffersNet2(paddle.nn.Layer):
def __init__(self):
super(SetBuffersNet2, self).__init__()
self.b = paddle.to_tensor([2])

@paddle.jit.to_static
def forward(self):
self.b = None
self.b = paddle.to_tensor([3])
return self.b


class TestSetBuffers(unittest.TestCase):
def test_set_buffers1(self):
paddle.disable_static()
net = SetBuffersNet1()
out = net()
self.assertEqual(out.numpy().tolist(), [2])
paddle.jit.save(net, './SetBuffersNet1')
paddle.enable_static()

def test_set_buffers2(self):
paddle.disable_static()
net = SetBuffersNet2()
with self.assertRaises(RuntimeError):
out = net()
paddle.enable_static()


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,17 @@ def setUp(self):
self.dyfunc = dyfunc_ifelse_ret_int4

def test_ast_to_func(self):
ProgramTranslator().enable(True)
with self.assertRaises(TypeError):
ProgramTranslator().enable(True)
static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x)

def __del__(self):
# Why need set `_in_declarative_mode_` here?
# In Dy2St we use `with _switch_declarative_mode_guard_()` to indicate
# that the code block is under @to_static, but in this UT
# an exception is thrown during Dy2St, making the `_in_declarative_mode_`
# a wrong value. So We need set `_in_declarative_mode_` to False manually.
paddle.fluid.dygraph.base._in_declarative_mode_ = False
ProgramTranslator().enable(False)
super(TestDy2StIfElseRetInt4, self).__del__()


if __name__ == '__main__':
Expand Down