Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instead go with evaluation_mode #1540

Merged
merged 1 commit into from
Jun 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ def prepare(self, *args, device_placement=None):

return result if len(result) > 1 else result[0]

def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, mixed_precision_only: bool = False):
def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):
"""
Prepares a PyTorch model for training in any distributed setup. It is recommended to use
[`Accelerator.prepare`] instead.
Expand All @@ -1231,8 +1231,9 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, m
any kind of mixed precision
device_placement (`bool`, *optional*):
Whether or not to place the model on the proper device. Will default to `self.device_placement`.
mixed_precision_only (`bool`, *optional*, defaults to `False`):
Whether or not to *only* apply mixed precision to the model, and not do any other model wrapping.
evaluation_mode (`bool`, *optional*, defaults to `False`):
Whether or not to set the model for evaluation only, by just applying mixed precision and
`torch.compile` (if configured in the `Accelerator` object).

Example:

Expand Down Expand Up @@ -1284,7 +1285,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, m
elif device_placement and not has_hf_device_map:
model = model.to(self.device)

if not mixed_precision_only:
if not evaluation_mode:
if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU):
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
Expand Down Expand Up @@ -1351,14 +1352,14 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, m
"or higher, compute capability of 8.9 or higher). Will use FP16 instead."
)
model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)
if not mixed_precision_only:
if not evaluation_mode:
if self.distributed_type == DistributedType.TPU and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# torch.compile should be called last.
if self.state.dynamo_plugin.backend != DynamoBackend.NO:
if not is_torch_version(">=", "2.0"):
raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
# torch.compile should be called last.
if self.state.dynamo_plugin.backend != DynamoBackend.NO:
if not is_torch_version(">=", "2.0"):
raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model

def _prepare_deepspeed(self, *args):
Expand Down