Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SPMDLoadPlanner to enable distributed checkpoint loading #5130

Merged
merged 2 commits into from
Jun 13, 2023

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Jun 7, 2023

This implements the LoadPlanner interface from torch.distributed.checkpoint. This implementation only directly handles sharded tensors and relies on the default planner's logic for everything else.

A high-level overview of each of the Planner interface methods:

  • set_up_planner: Called with the state_dict to be restored and metadata from the checkpoint. Our implementation will split the state_dict into a sharded and unsharded portion so that we can defer to the default planner logic for the unsharded part.
  • create_local_plan: ReadItems are generated for every item in the state_dict. The default planner is used for the unsharded objects, and we generate a ReadItem for each shard of each XLAShardedTensor with non-REPLICATED sharding type.
  • create_global_plan: The coordinator process makes any global decisions for the restoration. There is no custom logic here.
  • finish_plan: The process can adjust its plan after global coordination. Again, no custom logic here.
  • load_bytes: This is how non-tensor data is restored. We defer to the default planner's logic.
  • resolve_tensor: This function returns a tensor to store the read result of the associated ReadItem. If the ReadItem doesn't correspond to a sharded tensor, we defer to the default planner logic. Otherwise, we return the unpadded data of the local shard associated with the ReadItem.
  • commit_tensor: This is called after the data has been loaded into the tensor. This is a no-op in the default planner, but for sharded tensors we track each shard that has been committed. Once all shards are committed for a tensor, they are loaded into the XLAShardedTensor.

This PR depends on #5128 for some utility functions.

@jonb377
Copy link
Collaborator Author

jonb377 commented Jun 7, 2023

cc @yashs97

@jonb377 jonb377 force-pushed the jonbolin-checkpoint-restore branch from d27834a to 318074b Compare June 7, 2023 02:10
@jonb377 jonb377 force-pushed the jonbolin-load-planner branch from 75e89bf to f043082 Compare June 7, 2023 02:11
@jonb377 jonb377 force-pushed the jonbolin-checkpoint-restore branch from 318074b to cfcd622 Compare June 8, 2023 00:24
@jonb377 jonb377 force-pushed the jonbolin-load-planner branch from f043082 to 0fd0174 Compare June 8, 2023 00:25
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jonb377. This is great. Left a few comments.

# Flatten the state_dict to allow separating sharded XLA tensors from
# types that can be handled by the default planner, and ensure all sharded
# tensors are wrapped in XLAShardedTensor
state_dict, self.mappings = flatten_state_dict(state_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the upstream DefaultLoadPlanner doesn't always flatten the state_dict? Do you know why and what are the downsides of flattening the state_dict?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I actually don't know why it's optional in the DefaultLoadPlanner, I'll follow up on that. The flattened state_dict is easier for us to work with since we need to split the input into sharded/unsharded parts, which is difficult if the state_dict is nested.

One downside of flattening is that we aren't operating directly on the input state_dict, so non-tensor items need to be mapped back to the original state_dict.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that down side is handled by the default planner helpers, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's optional on DefaultLoadPlaner mostly to help with testing. Its default value is True, so most users will always have it enabled.

self.assertFalse(self._same_shard_data(xtensor.local_shards, old_shards))

def test_load_state_dict(self):
dist.init_process_group(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the process_group?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.distributed.checkpoint requires a process group to coordinate merging the local plans into a global plan. Even though this is just a single instance, we still need to initialize the PG to use the save_state_dict and load_state_dict APIs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, is this because we are saving the unsharded CPU tensor? If it's required all the time, I feel like it's not a good UX.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A PG is always required, but I'll need to check if we can use the XLA backend - my first attempt to replace gloo with an xla group errored out when creating the global plan.

If we build a higher-level interface, this could be hidden this from the user.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of our own distributed API does't require PG. So, maybe we should try keep the consistency. Do you know why they require PG?

Copy link
Collaborator Author

@jonb377 jonb377 Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used to centrally coordinate the global plans, e.g. all worker plans are sent to the coordinator using collectives in the process group.

Actually, I see we have the option to run without a PG by setting no_dist=True in the (save|load)_state_dict functions, but we lose out on the global coordination. This shouldn't be an issue for the LoadPlanner, and I suspect we can make the SavePlanner work without global planning as well - we'll just need to track replica ranks and dedupe tensors locally.

Thanks for bringing this up Jiewen, I'll experiment with it and update the PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline - when taking a checkpoint, we would lose the consistency provided by the process group (the coordinator waits until all workers are finished before writing metadata, which depends on a PG).

I'll use no_dist=True in the tests to simplify the test code, but we will still expect users to have a CPU process group when taking a checkpoint.

tmpdir = tempfile.mkdtemp()

# Save an unsharded model using the default planner in dist_cp.save_state_dict
model = self.SimpleLinear().to(xm.xla_device())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you illustrate more on how a saved unsharded model can be loaded into a sharded model?

Copy link
Collaborator Author

@jonb377 jonb377 Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a feature from the upstream - resharding is handled transparently for us when we use the create_read_items_for_chunk_list API to create our ReadItems. Even sharded checkpoints should be able to be restored into different device meshes or unsharded models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense now.

@jonb377 jonb377 force-pushed the jonbolin-load-planner branch 2 times, most recently from cd3d043 to 363ad31 Compare June 8, 2023 22:36
Base automatically changed from jonbolin-checkpoint-restore to master June 9, 2023 21:37
@jonb377 jonb377 force-pushed the jonbolin-load-planner branch 6 times, most recently from 3476c8d to ad1647f Compare June 12, 2023 17:56
@jonb377 jonb377 changed the title Implement XLALoadPlanner to enable distributed checkpoint loading Implement SPMDLoadPlanner to enable distributed checkpoint loading Jun 12, 2023
@jonb377 jonb377 force-pushed the jonbolin-load-planner branch from ad1647f to 9bcd879 Compare June 12, 2023 18:12
@jonb377 jonb377 marked this pull request as ready for review June 12, 2023 18:22
@alanwaketan
Copy link
Collaborator

@jonb377 One small thing I would suggest you to do is to use git commit instead of git commit --amend such that I can just review your diffs instead of the whole PR again after your updates.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally speaking, it LGTM! Thanks for getting this complicated work being done so quick. I left a bunch of comments just for my education to learn more in-depth on how this planner is supposed to be used by load_state_dict API.

Some of the missing pieces that I imagine the load_state_dict will do to fill some of the gaps of planners are:

  1. Loading the actual metadata and storage.
  2. Loading the storage and slice it into the proper size of tensors for the ReadItems.

It will be great if you can elaborate more on the E2E flow on how the high level APIs are interacting with our low level implementation. I also commented this on your design doc.

def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
offsets = read_item.dest_offsets
index = read_item.dest_index
if index.fqn in self.sharded_state_dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate this a little bit more?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I'm confused about why we want to do the narrowing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment - the storage layer expects that the tensor returned from resolve_tensor matches the shape of the ReadItem's lengths field, so the tensor must be narrowed here.

x, XLAShardedTensor) and x.sharding_type != ShardingType.REPLICATED


def _unwrap_sharded_tensor(x: Any) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is only used for the replicated tensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, that's a good point - the name is confusing. This is used to ensure the default planner is operating on torch.Tensor instead of XLAShardedTensor. Maybe _unwrap_xla_sharded_tensor?

# Load the checkpoint using the provided load planner
for p1, p2 in zip(model_in.parameters(), model_out.parameters()):
self.assertFalse(torch.allclose(p1, p2))
dist_cp.load_state_dict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the LoadPlanner should be able to handle the resharding of the consolidated checkpoints. I wonder how does it handle the case where the save checkpoint has a different sharding spec than the loading model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it's handled by create_read_items_for_chunk_list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct - we just need to generate ChunkStorageMetadata for each local shard, and the upstream API create_read_items_for_chunk_list will convert that chunk list into a list of ReadItem. The ReadItems generated will depend on the device mesh used when taking the checkpoint.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then is there a simple formula to calculate the number of ReadItems? Or I have to refer to the create_read_items_for_chunk_list algorithm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added some discussion in another comment, but adding here as well:

The algorithm for create_read_items_for_chunk_list just iterates across all chunks in the storage metadata, identifies overlap with the ChunkStorageMetadata describing our XLAShard, and generates a ReadItem if there is overlap.

plan = planner.create_local_plan()
parameter_count = len(list(model.parameters()))
if self.n_devices > 1:
# When the model is sharded across devices, fc1.weight will result in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you suggest the number of ReadItem will match the number of XLAShard for a particular XLAShardedTensor. Does this hold true for the following resharding example?

  1. stored shards <= loading shard, i.e., the models is sharded in 2 ways but need to loaded into 4 ways.
  2. stored shards > loading shard, i.e., the models is sharded in 4 ways but need to loaded into 2 ways.

I can see that holds true for case 1 but it's hard for me to imagine for case 2. Can you elaborate? Commented in your design as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true - I'll clarify the comment that this comes from using the unsharded checkpoint metadata used when creating the Planner in _get_load_planner.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still curious about case 2 if you can explain to me?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are loading into a coarser mesh, e.g. loading a checkpoint taken on a (4, 4) mesh into a model sharded across (2, 2), each local shard in the (2, 2) mesh will require multiple ReadItems. On our side, we will still generate a single ChunkStorageMetadata for each shard, and the utility function create_read_items_for_chunk_list translates these into the ReadItems necessary by finding the overlap of ChunkStorageMetadata with the chunks in the storage metadata.

The algorithm for create_read_items_for_chunk_list just iterates across all chunks in the storage metadata, identifies overlap with the ChunkStorageMetadata describing our XLAShard, and generates a ReadItem if there is overlap.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level speaking, the number of ReadItems will then be determined by the max(number of shards in loading model, number of shards in saving model)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can contrive an example where this won't hold, e.g. saving from a (4, 1) mesh and loading into a (2, 2) mesh would require 8 total ReadItems based on a quick test. It depends a lot on how the storage layer handles the checkpoint, but the worst case number of ReadItems would be O(global shards when saving the model) for each shard being loaded.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect. Thanks, Jon!

@jonb377 jonb377 merged commit 50821da into master Jun 13, 2023
@jonb377 jonb377 deleted the jonbolin-load-planner branch June 13, 2023 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants