Skip to content

Commit

Permalink
Eval mode (#1540)
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Jun 7, 2023
1 parent 5f21cde commit 90e9703
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,7 +1199,7 @@ def prepare(self, *args, device_placement=None):

return result if len(result) > 1 else result[0]

def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, mixed_precision_only: bool = False):
def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):
"""
Prepares a PyTorch model for training in any distributed setup. It is recommended to use
[`Accelerator.prepare`] instead.
Expand All @@ -1210,8 +1210,9 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, m
any kind of mixed precision
device_placement (`bool`, *optional*):
Whether or not to place the model on the proper device. Will default to `self.device_placement`.
mixed_precision_only (`bool`, *optional*, defaults to `False`):
Whether or not to *only* apply mixed precision to the model, and not do any other model wrapping.
evaluation_mode (`bool`, *optional*, defaults to `False`):
Whether or not to set the model for evaluation only, by just applying mixed precision and
`torch.compile` (if configured in the `Accelerator` object).
Example:
Expand Down Expand Up @@ -1263,7 +1264,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, m
elif device_placement and not has_hf_device_map:
model = model.to(self.device)

if not mixed_precision_only:
if not evaluation_mode:
if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU):
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
Expand Down Expand Up @@ -1330,14 +1331,14 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, m
"or higher, compute capability of 8.9 or higher). Will use FP16 instead."
)
model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)
if not mixed_precision_only:
if not evaluation_mode:
if self.distributed_type == DistributedType.TPU and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# torch.compile should be called last.
if self.state.dynamo_plugin.backend != DynamoBackend.NO:
if not is_torch_version(">=", "2.0"):
raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
# torch.compile should be called last.
if self.state.dynamo_plugin.backend != DynamoBackend.NO:
if not is_torch_version(">=", "2.0"):
raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model

def _prepare_deepspeed(self, *args):
Expand Down

0 comments on commit 90e9703

Please sign in to comment.