Skip to content

Commit

Permalink
fix bug of fp16 (#38838)
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes authored Jan 10, 2022
1 parent 3a23c1a commit 7d4ce5b
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,8 +1279,36 @@ def register_state_dict_hook(self, hook):
def _obtain_parameters_buffers(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
structured_name_prefix=""):
"""
The difference from state_dict() is that state_dict_hook will not be called,
but the original types of parameters and buffers will be maintained.
"""
if destination is None:
destination = collections.OrderedDict()
for name, data in self._parameters.items():
if data is not None:
destination[structured_name_prefix + name] = data
for name, buffer in self._buffers.items():
if buffer is not None and name not in self._non_persistable_buffer_names_set:
destination[structured_name_prefix + name] = buffer

if include_sublayers:
for layer_name, layer_item in self._sub_layers.items():
if layer_item is not None:
destination_temp = destination.copy()
destination_temp.update(
layer_item._obtain_parameters_buffers(
destination_temp, include_sublayers,
structured_name_prefix + layer_name + "."))
destination = destination_temp
return destination

def _state_dict_impl(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
"""
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
Expand Down Expand Up @@ -1313,16 +1341,6 @@ def _obtain_parameters_buffers(self,
structured_name_prefix + layer_name + ".",
include_non_persistable_buffer))
destination = destination_temp
return destination

def _state_dict_impl(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
destination = self._obtain_parameters_buffers(
destination, include_sublayers, structured_name_prefix,
include_non_persistable_buffer)
for state_dict_hook in self._state_dict_hooks.values():
hook_result = state_dict_hook(destination)
if hook_result is not None:
Expand Down

0 comments on commit 7d4ce5b

Please sign in to comment.