Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StatefulDataLoader support #31441

Open
yzhangcs opened this issue Jun 16, 2024 · 14 comments · May be fixed by #34205
Open

Add StatefulDataLoader support #31441

yzhangcs opened this issue Jun 16, 2024 · 14 comments · May be fixed by #34205
Labels
Feature request Request for a new feature

Comments

@yzhangcs
Copy link

yzhangcs commented Jun 16, 2024

Feature request

Add official support for StatefulDataLoader as in torchdata and datasets.

Motivation

The StatefulDataLoader from the torchdata package provides a convenient way to recover a dataset iterator that was interrupted, without having to skip the first batches via a naive for loop, which can be time-consuming for extremely large datasets. The datasets package now officially supports stateful IterableDataset and its combination with StatefulDataLoader in v2.20.0.

Example usage:

from torchdata.stateful_dataloader import StatefulDataLoader
iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train")
dataloader = StatefulDataLoader(iterable_dataset, batch_size=32, num_workers=4)
# checkpoint
state_dict = dataloader.state_dict()  # uses iterable_dataset.state_dict() under the hood
# resume from checkpoint
dataloader.load_state_dict(state_dict)  # uses iterable_dataset.load_state_dict() under the hood

To enhance the usability and efficiency of the Trainer, it would be highly beneficial for the community if official support for StatefulDataLoader could be added.
This would allow users to easily recover from interruptions and resume training from checkpoints without wasting time on re-iterating over already processed batches.
By integrating StatefulDataLoader into the Trainer, users can seamlessly handle large datasets and ensure a smooth training process. This feature would greatly improve the overall user experience and make the Trainer more robust and efficient.
We kindly request the development team to consider adding official support for thoese features in the Trainer, as it would be a valuable addition to the library and benefit the wider community.

@yzhangcs yzhangcs added the Feature request Request for a new feature label Jun 16, 2024
@amyeroberts
Copy link
Collaborator

cc @muellerzr @lhoestq

@byi8220
Copy link
Contributor

byi8220 commented Jun 17, 2024

Hey, just giving my 2 cents since unless I'm missing something, this seems extremely simple to implement since StatefulDataLoader is a drop-in replacement for DataLoader. (I.e. literally just replace DataLoader construction w/ StatefulDataLoader construction in trainer.py)

If it's simple enough I could probably take a shot at implementing it?

The only caveat is it seems torchdata.stateful_dataloader is a beta feature only available in the nightly build. Does it make sense to officially support unreleased features?

@yzhangcs
Copy link
Author

@byi8220 Hi, as I can see, the hf Trainer uses the accelerate library internally to prepare the dataloader. This process involves returning self-defined classes like DataLoaderShard to handle cases involving distributed data dispatch. I think it might be challenging to directly combine the Trainer with StatefulDataLoader without delving into the intricate details of the Trainer implementation.

@byi8220
Copy link
Contributor

byi8220 commented Jun 17, 2024

Hm, maybe I misunderstand the problem. My understanding is that what we are focused on is that when the Trainer is loading from a checkpoint, it calls skip_first_batches to skip past the beginning of the dataset until the DataLoader iterator is pointing to where it was at that checkpoint.

And for an IterableDataset, the way this is done under the hood is that it has to manually loop over the items to iterate. And StatefulDataLoader may solve this problem by allowing one to call load_state_dict somewhere in the Trainer while loading the checkpoint, and writing the StatefulDataLoader's state dict to the checkpoint.

This process involves returning self-defined classes like DataLoaderShard to handle cases involving distributed data dispatch.

Yes, it seems like DataLoaderShard and DataLoaderDispatcher are created in the prepare_data_loader function and skip_first_batches function in the accelerate library. These classes are both subclasses of DataLoader, so likely need to be modified or copied to extend from StatefulDataLoader

So IIUC, it seems maybe the implementation of this feature would involve the following steps?

  1. In the accelerate library, add either refactor DataLoaderShard and DataLoaderDispatcher to compose or add new variants that inherit from a StatefulDataLoader.
  2. In the Trainer class, allow dropping in StatefulDataLoader instead of a regular DataLoader
  3. Also in the Trainer class, support loading and saving the state_dict to and from the checkpoint

Thanks for point this out. I still might not be understanding correctly. Maybe it's a lot more complicated than this.

@muellerzr
Copy link
Contributor

Correct, we need to:

  1. Support the StatefulDataLoader in accelerate and use it as an optional alternative in the DataLoaderConfiguration
  2. Then we can move it to the Trainer!

@byi8220
Copy link
Contributor

byi8220 commented Jun 18, 2024

Makes sense. It also seems like there's a related issue raised in accelerate: huggingface/accelerate#2859

Regarding using it in the trainer, it feels a bit awkward. IIUC, the desired behavior is that if a StatefulDataLoader is being used, and loading from a checkpoint, then it should not call skip_first_batches at all, unless you are passing in the state dict and checkpoints to that function as well. But imo it feels like skip_first_batches and "restore from checkpoint" are two separate concepts.

@yzhangcs
Copy link
Author

yzhangcs commented Jun 18, 2024

Thank you for your responses @byi8220 @muellerzr.

Yes, I agree with you that if we properly manage the states of dataloaders in the Trainer, we no longer need to use the accelerate skip_first_batches option for recovery.

As a workaround, I bypass accelerate to prepare my dataloaders by hacking the Trainer class to support stateful ones:

class Trainer(transformers.Trainer):

    def get_train_dataloader(self) -> DPAwareDataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        logger.info(f"Split the dataset for the node at rank {self.args.process_index} / {self.args.world_size}.")
        train_dataset = HuggingFaceDataset(self.train_dataset,
                                           self.tokenizer,
                                           self.args.context_length,
                                           self.args.process_index,
                                           self.args.world_size)
        loader = DPAwareDataLoader(rank=self.args.process_index,
                                   dataset=train_dataset,
                                   batch_size=self.args.train_batch_size,
                                   collate_fn=self.data_collator,
                                   num_workers=self.args.dataloader_num_workers,
                                   pin_memory=self.args.dataloader_pin_memory,
                                   persistent_workers=self.args.dataloader_persistent_workers)
        data_callback = DataCallback(loader)
        self.add_callback(data_callback)

        return loader

The DPAwareDataLoader is borrowed from torchtitan's impls. This pkg is also developing similar ideas. Then making use of self-defined callbacks to save/load states

class DataCallback(TrainerCallback, ExportableState):

    def __init__(self, loader):

        self.loader = loader

    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        output_dir = None
        if isinstance(args.resume_from_checkpoint, bool):
            if args.resume_from_checkpoint:
                output_dir = get_last_checkpoint(args.output_dir)
        elif args.resume_from_checkpoint is not None:
            output_dir = args.resume_from_checkpoint
        if output_dir is not None:
            if args.world_size <= 1:
                data_state_pth = os.path.join(output_dir, "data_state.json")
            else:
                data_state_pth = os.path.join(output_dir, f"data_state_{args.process_index}.json")
            with open(data_state_pth, "r") as f:
                self.loader.load_state_dict(json.loads(f.read()))

    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
        if args.world_size <= 1:
            data_state_pth = os.path.join(output_dir, "data_state.json")
        else:
            data_state_pth = os.path.join(output_dir, f"data_state_{args.process_index}.json")
        with open(data_state_pth, "w") as f:
            f.write(json.dumps(self.state(), indent=2, sort_keys=True) + "\n")

    def state(self) -> dict:
        return self.loader.state_dict()

skip_first_batches is ignored by --ignore_data_skip.
I performed some minimal unit tests, and the states were successfully recovered without perceiving any delays.

This approach can be extremely useful when performing online tokenization with IterableDataset.
Some people have conducted benchmarks and observed even faster speeds than pre-tokenization in https://github.com/XinDongol/on-the-fly-tokenization-profiling.
I've tried using stateful loaders with the above ugly hacking code to train the mamba model on the subsets of 627B Slimpajama data, reducing the total training time from ~173h to ~170h.
This could also save ~3TB of disk space compared to pre-tokenized map-style data.

So I'm really looking forward to your official impls, very happy to hear about any progress :D

@litagin02
Copy link

litagin02 commented Oct 15, 2024

Any update? I'm really want this feature.
It seems that StatefulDataLoader has been supported on accelerate, and I managed to use this by writing my own train loop for my personal project, but will be very happy if 🤗 Trainer official supports this option, especially for very big dataset.

@byi8220
Copy link
Contributor

byi8220 commented Oct 15, 2024

It seems that StatefulDataLoader has been supported on accelerate

Yes, that was a recent release and it should be possible to integrate that into transformers now that it's in release.

If you are getting it to work with no problem, I think it's stable enough that I can work on getting a PR out for the transformers side now.

@muellerzr
Copy link
Contributor

@byi8220 feel free to have a go at it! Implementation I'd like to see: Add use_stateful_dataloader to the AcceleratorConfig dataclass to be passed in and converted to the DataLoaderConfiguration args :)

(So as a result, we don't add a new training argument for it explicitly, just update the docs and allow a passthrough)

@byi8220
Copy link
Contributor

byi8220 commented Oct 15, 2024

Makes sense, hopefully this should be the much easier part of implementation.

@muellerzr
Copy link
Contributor

only thing that should be a bit tricky is the saving, since we don't use accelerator.save_state. (But at least for the enabling alone that's the core part needed)

@byi8220
Copy link
Contributor

byi8220 commented Oct 15, 2024

only thing that should be a bit tricky is the saving, since we don't use accelerator.save_state.

At a glance, the implementation for this appears to be fairly straightforward:

  1. Passthrough use_stateful_dataloader through AcceleratorConfig
  2. For saving checkpoints, have the prepared StatefulDataLoader's state_dict object be part of the state saved in TrainerState. I'm unsure if this needs to be done for anything besides the train dl, I'm guessing not.
  3. For training, when getting the epoch_iterator, have a prepared StatefulDataLoader load the state dict instead of manually skipping batches.

@byi8220 byi8220 linked a pull request Oct 16, 2024 that will close this issue
5 tasks
@byi8220
Copy link
Contributor

byi8220 commented Oct 16, 2024

Created an initial PR which should fulfill this feature - #34205

This seems complete to me, although I'm unsure if I'm missing any edge cases or unit tests. I'm pretty certain the tests in TrainerIntegrationTest cover multi-GPU setups, and the test_resume_training.* test skeleton would cover the main use case (does resuming from a checkpoint yield the same results as not?)

(PS: there is a very small but well hidden issue with unit tests, which i've separated into a separate PR - #34201)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants