-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement SPMDLoadPlanner to enable distributed checkpoint loading #5130
Conversation
cc @yashs97 |
d27834a
to
318074b
Compare
75e89bf
to
f043082
Compare
318074b
to
cfcd622
Compare
f043082
to
0fd0174
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jonb377. This is great. Left a few comments.
torch_xla/experimental/checkpoint.py
Outdated
# Flatten the state_dict to allow separating sharded XLA tensors from | ||
# types that can be handled by the default planner, and ensure all sharded | ||
# tensors are wrapped in XLAShardedTensor | ||
state_dict, self.mappings = flatten_state_dict(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the upstream DefaultLoadPlanner doesn't always flatten the state_dict? Do you know why and what are the downsides of flattening the state_dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I actually don't know why it's optional in the DefaultLoadPlanner, I'll follow up on that. The flattened state_dict is easier for us to work with since we need to split the input into sharded/unsharded parts, which is difficult if the state_dict is nested.
One downside of flattening is that we aren't operating directly on the input state_dict, so non-tensor items need to be mapped back to the original state_dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that down side is handled by the default planner helpers, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's optional on DefaultLoadPlaner mostly to help with testing. Its default value is True, so most users will always have it enabled.
self.assertFalse(self._same_shard_data(xtensor.local_shards, old_shards)) | ||
|
||
def test_load_state_dict(self): | ||
dist.init_process_group( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the process_group?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.distributed.checkpoint
requires a process group to coordinate merging the local plans into a global plan. Even though this is just a single instance, we still need to initialize the PG to use the save_state_dict
and load_state_dict
APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, is this because we are saving the unsharded CPU tensor? If it's required all the time, I feel like it's not a good UX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A PG is always required, but I'll need to check if we can use the XLA backend - my first attempt to replace gloo
with an xla
group errored out when creating the global plan.
If we build a higher-level interface, this could be hidden this from the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of our own distributed API does't require PG. So, maybe we should try keep the consistency. Do you know why they require PG?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used to centrally coordinate the global plans, e.g. all worker plans are sent to the coordinator using collectives in the process group.
Actually, I see we have the option to run without a PG by setting no_dist=True in the (save|load)_state_dict
functions, but we lose out on the global coordination. This shouldn't be an issue for the LoadPlanner, and I suspect we can make the SavePlanner work without global planning as well - we'll just need to track replica ranks and dedupe tensors locally.
Thanks for bringing this up Jiewen, I'll experiment with it and update the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced offline - when taking a checkpoint, we would lose the consistency provided by the process group (the coordinator waits until all workers are finished before writing metadata, which depends on a PG).
I'll use no_dist=True
in the tests to simplify the test code, but we will still expect users to have a CPU process group when taking a checkpoint.
tmpdir = tempfile.mkdtemp() | ||
|
||
# Save an unsharded model using the default planner in dist_cp.save_state_dict | ||
model = self.SimpleLinear().to(xm.xla_device()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you illustrate more on how a saved unsharded model can be loaded into a sharded model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a feature from the upstream - resharding is handled transparently for us when we use the create_read_items_for_chunk_list API to create our ReadItems. Even sharded checkpoints should be able to be restored into different device meshes or unsharded models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense now.
cd3d043
to
363ad31
Compare
3476c8d
to
ad1647f
Compare
ad1647f
to
9bcd879
Compare
@jonb377 One small thing I would suggest you to do is to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally speaking, it LGTM! Thanks for getting this complicated work being done so quick. I left a bunch of comments just for my education to learn more in-depth on how this planner is supposed to be used by load_state_dict API.
Some of the missing pieces that I imagine the load_state_dict will do to fill some of the gaps of planners are:
- Loading the actual metadata and storage.
- Loading the storage and slice it into the proper size of tensors for the ReadItems.
It will be great if you can elaborate more on the E2E flow on how the high level APIs are interacting with our low level implementation. I also commented this on your design doc.
def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): | ||
offsets = read_item.dest_offsets | ||
index = read_item.dest_index | ||
if index.fqn in self.sharded_state_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate this a little bit more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I'm confused about why we want to do the narrowing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a comment - the storage layer expects that the tensor returned from resolve_tensor
matches the shape of the ReadItem's lengths
field, so the tensor must be narrowed here.
x, XLAShardedTensor) and x.sharding_type != ShardingType.REPLICATED | ||
|
||
|
||
def _unwrap_sharded_tensor(x: Any) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is only used for the replicated tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, that's a good point - the name is confusing. This is used to ensure the default planner is operating on torch.Tensor
instead of XLAShardedTensor
. Maybe _unwrap_xla_sharded_tensor
?
# Load the checkpoint using the provided load planner | ||
for p1, p2 in zip(model_in.parameters(), model_out.parameters()): | ||
self.assertFalse(torch.allclose(p1, p2)) | ||
dist_cp.load_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the LoadPlanner should be able to handle the resharding of the consolidated checkpoints. I wonder how does it handle the case where the save checkpoint has a different sharding spec than the loading model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like it's handled by create_read_items_for_chunk_list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct - we just need to generate ChunkStorageMetadata
for each local shard, and the upstream API create_read_items_for_chunk_list
will convert that chunk list into a list of ReadItem. The ReadItems generated will depend on the device mesh used when taking the checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then is there a simple formula to calculate the number of ReadItems? Or I have to refer to the create_read_items_for_chunk_list algorithm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added some discussion in another comment, but adding here as well:
The algorithm for create_read_items_for_chunk_list just iterates across all chunks in the storage metadata, identifies overlap with the ChunkStorageMetadata describing our XLAShard, and generates a ReadItem if there is overlap.
plan = planner.create_local_plan() | ||
parameter_count = len(list(model.parameters())) | ||
if self.n_devices > 1: | ||
# When the model is sharded across devices, fc1.weight will result in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you suggest the number of ReadItem will match the number of XLAShard for a particular XLAShardedTensor. Does this hold true for the following resharding example?
- stored shards <= loading shard, i.e., the models is sharded in 2 ways but need to loaded into 4 ways.
- stored shards > loading shard, i.e., the models is sharded in 4 ways but need to loaded into 2 ways.
I can see that holds true for case 1 but it's hard for me to imagine for case 2. Can you elaborate? Commented in your design as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true - I'll clarify the comment that this comes from using the unsharded checkpoint metadata used when creating the Planner in _get_load_planner
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still curious about case 2 if you can explain to me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are loading into a coarser mesh, e.g. loading a checkpoint taken on a (4, 4) mesh into a model sharded across (2, 2), each local shard in the (2, 2) mesh will require multiple ReadItems. On our side, we will still generate a single ChunkStorageMetadata for each shard, and the utility function create_read_items_for_chunk_list
translates these into the ReadItems necessary by finding the overlap of ChunkStorageMetadata
with the chunks in the storage metadata.
The algorithm for create_read_items_for_chunk_list
just iterates across all chunks in the storage metadata, identifies overlap with the ChunkStorageMetadata
describing our XLAShard, and generates a ReadItem if there is overlap.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level speaking, the number of ReadItems will then be determined by the max(number of shards in loading model, number of shards in saving model)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can contrive an example where this won't hold, e.g. saving from a (4, 1) mesh and loading into a (2, 2) mesh would require 8 total ReadItems based on a quick test. It depends a lot on how the storage layer handles the checkpoint, but the worst case number of ReadItems would be O(global shards when saving the model) for each shard being loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect. Thanks, Jon!
This implements the LoadPlanner interface from torch.distributed.checkpoint. This implementation only directly handles sharded tensors and relies on the default planner's logic for everything else.
A high-level overview of each of the Planner interface methods:
set_up_planner
: Called with the state_dict to be restored and metadata from the checkpoint. Our implementation will split the state_dict into a sharded and unsharded portion so that we can defer to the default planner logic for the unsharded part.create_local_plan
: ReadItems are generated for every item in the state_dict. The default planner is used for the unsharded objects, and we generate a ReadItem for each shard of each XLAShardedTensor with non-REPLICATED sharding type.create_global_plan
: The coordinator process makes any global decisions for the restoration. There is no custom logic here.finish_plan
: The process can adjust its plan after global coordination. Again, no custom logic here.load_bytes
: This is how non-tensor data is restored. We defer to the default planner's logic.resolve_tensor
: This function returns a tensor to store the read result of the associated ReadItem. If the ReadItem doesn't correspond to a sharded tensor, we defer to the default planner logic. Otherwise, we return the unpadded data of the local shard associated with the ReadItem.commit_tensor
: This is called after the data has been loaded into the tensor. This is a no-op in the default planner, but for sharded tensors we track each shard that has been committed. Once all shards are committed for a tensor, they are loaded into the XLAShardedTensor.This PR depends on #5128 for some utility functions.