-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Set the state device dependant to Accelerator on multigpu #1220
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmm, if we add a new flag to the load method, we should make it default smartly. Also not sure if that new flag needs to be a string since it has two states (apart from unset): CPU or device. So maybe an optional bool would suffice?
src/accelerate/accelerator.py
Outdated
@@ -2385,8 +2385,17 @@ def load_state(self, input_dir: str, **load_model_func_kwargs): | |||
for hook in self._load_model_state_pre_hook.values(): | |||
hook(models, input_dir) | |||
|
|||
optimizer_map_location = "cpu" if self.num_processes < 2 else self.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need something special for TPUs (but maybe TPUs don't use that path?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're thinking of having this handled with a bool instead, getting the device from PartialState.device
is now what works. The AcceleratorState
/PartialState
already has the right device needed to do the move:
https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L110
src/accelerate/checkpointing.py
Outdated
if optimizer_map_location is None: | ||
optimizer_map_location = "cpu" | ||
elif optimizer_map_location == "on_device": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think we should default to "on_device"
for distributed trainings on GPU, otherwise we require num_processes times the optimizer state available in CPU RAM.
src/accelerate/checkpointing.py
Outdated
if map_location != "cpu": | ||
models[i].to(map_location) | ||
models[i].load_state_dict(torch.load(input_model_file, map_location=map_location), **load_model_func_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch will load the optimizer state based on the mapping to the models parameters, so the model needs to be on the map_location
first if it's not CPU for it to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd let the error pop by itself. If the models are not on the right device, there should be an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating, left a couple more comments.
src/accelerate/checkpointing.py
Outdated
load_model_func_kwargs (`dict`, *optional*): | ||
Additional arguments that can be passed to the model's `load_state_dict` method. | ||
""" | ||
if map_location not in [None, "cpu", "on_device"]: | ||
raise TypeError( | ||
"Unsupported optimizer map location passed, please choose one of `None`, `cpu`, or `on_device`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put the quotes around the strings here please.
src/accelerate/checkpointing.py
Outdated
if map_location != "cpu": | ||
models[i].to(map_location) | ||
models[i].load_state_dict(torch.load(input_model_file, map_location=map_location), **load_model_func_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd let the error pop by itself. If the models are not on the right device, there should be an error.
Solves #1210 by setting the optimizer state to
accelerator.device
whennum_processes > 1
and callingload_state
.