Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove mixed precision hook as part of the unwrap_model #860

Merged
merged 4 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,16 +1464,18 @@ def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
"""
return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)

def unwrap_model(self, model):
def unwrap_model(self, model, keep_fp32_wrapper: bool = False):
"""
Unwraps the `model` from the additional layer possible added by [`~Accelerator.prepare`]. Useful before saving
the model.

Args:
model (`torch.nn.Module`):
The model to unwrap.
keep_fp32_wrapper (`bool`, *optional*, defaults to `False`):
Whether to not remove the mixed precision hook if it was added.
"""
return extract_model_from_parallel(model)
return extract_model_from_parallel(model, keep_fp32_wrapper)

def wait_for_everyone(self):
"""
Expand Down Expand Up @@ -1760,7 +1762,7 @@ def get_state_dict(self, model, unwrap=True):
Args:
model (`torch.nn.Module`):
A PyTorch model sent through [`Accelerator.prepare`]
unwrap (`bool`, *optional*, defaults to True):
unwrap (`bool`, *optional*, defaults to `True`):
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
"""
is_zero_3 = False
Expand Down
13 changes: 11 additions & 2 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..state import AcceleratorState
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_tpu_available
from .operations import ConvertOutputsToFp32


if is_deepspeed_available():
Expand All @@ -30,12 +31,15 @@
import torch_xla.core.xla_model as xm


def extract_model_from_parallel(model):
def extract_model_from_parallel(model, keep_fp32_wrapper: bool = False):
"""
Extract a model from its distributed containers.

Args:
model (`torch.nn.Module`): The model to extract.
model (`torch.nn.Module`):
The model to extract.
keep_fp32_wrapper (`bool`, *optional*):
Whether to remove mixed precision hooks from the model.

Returns:
`torch.nn.Module`: The extracted model.
Expand All @@ -46,6 +50,11 @@ def extract_model_from_parallel(model):

while isinstance(model, options):
model = model.module

if not keep_fp32_wrapper:
forward = getattr(model, "forward")
if isinstance(forward, ConvertOutputsToFp32):
setattr(model, "forward", forward.model_forward)
return model


Expand Down