Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix]add _obtain_parameters_buffers in layers #38838

Merged
merged 1 commit into from
Jan 10, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,8 +1279,36 @@ def register_state_dict_hook(self, hook):
def _obtain_parameters_buffers(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
structured_name_prefix=""):
"""
The difference from state_dict() is that state_dict_hook will not be called,
but the original types of parameters and buffers will be maintained.
"""
if destination is None:
destination = collections.OrderedDict()
for name, data in self._parameters.items():
if data is not None:
destination[structured_name_prefix + name] = data
for name, buffer in self._buffers.items():
if buffer is not None and name not in self._non_persistable_buffer_names_set:
destination[structured_name_prefix + name] = buffer

if include_sublayers:
for layer_name, layer_item in self._sub_layers.items():
if layer_item is not None:
destination_temp = destination.copy()
destination_temp.update(
layer_item._obtain_parameters_buffers(
destination_temp, include_sublayers,
structured_name_prefix + layer_name + "."))
destination = destination_temp
return destination

def _state_dict_impl(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
"""
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
Expand Down Expand Up @@ -1313,16 +1341,6 @@ def _obtain_parameters_buffers(self,
structured_name_prefix + layer_name + ".",
include_non_persistable_buffer))
destination = destination_temp
return destination

def _state_dict_impl(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
destination = self._obtain_parameters_buffers(
destination, include_sublayers, structured_name_prefix,
include_non_persistable_buffer)
for state_dict_hook in self._state_dict_hooks.values():
hook_result = state_dict_hook(destination)
if hook_result is not None:
Expand Down