-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add StatefulDataLoader
support
#31441
Comments
Hey, just giving my 2 cents since unless I'm missing something, this seems extremely simple to implement since If it's simple enough I could probably take a shot at implementing it? The only caveat is it seems |
@byi8220 Hi, as I can see, the hf Trainer uses the accelerate library internally to prepare the dataloader. This process involves returning self-defined classes like |
Hm, maybe I misunderstand the problem. My understanding is that what we are focused on is that when the Trainer is loading from a checkpoint, it calls skip_first_batches to skip past the beginning of the dataset until the DataLoader iterator is pointing to where it was at that checkpoint. And for an IterableDataset, the way this is done under the hood is that it has to manually loop over the items to iterate. And StatefulDataLoader may solve this problem by allowing one to call load_state_dict somewhere in the Trainer while loading the checkpoint, and writing the StatefulDataLoader's state dict to the checkpoint.
Yes, it seems like DataLoaderShard and DataLoaderDispatcher are created in the prepare_data_loader function and skip_first_batches function in the accelerate library. These classes are both subclasses of DataLoader, so likely need to be modified or copied to extend from StatefulDataLoader So IIUC, it seems maybe the implementation of this feature would involve the following steps?
Thanks for point this out. I still might not be understanding correctly. Maybe it's a lot more complicated than this. |
Correct, we need to:
|
Makes sense. It also seems like there's a related issue raised in Regarding using it in the |
Thank you for your responses @byi8220 @muellerzr. Yes, I agree with you that if we properly manage the states of dataloaders in the Trainer, we no longer need to use the accelerate As a workaround, I bypass accelerate to prepare my dataloaders by hacking the Trainer class to support stateful ones: class Trainer(transformers.Trainer):
def get_train_dataloader(self) -> DPAwareDataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
logger.info(f"Split the dataset for the node at rank {self.args.process_index} / {self.args.world_size}.")
train_dataset = HuggingFaceDataset(self.train_dataset,
self.tokenizer,
self.args.context_length,
self.args.process_index,
self.args.world_size)
loader = DPAwareDataLoader(rank=self.args.process_index,
dataset=train_dataset,
batch_size=self.args.train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
persistent_workers=self.args.dataloader_persistent_workers)
data_callback = DataCallback(loader)
self.add_callback(data_callback)
return loader The DPAwareDataLoader is borrowed from class DataCallback(TrainerCallback, ExportableState):
def __init__(self, loader):
self.loader = loader
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
output_dir = None
if isinstance(args.resume_from_checkpoint, bool):
if args.resume_from_checkpoint:
output_dir = get_last_checkpoint(args.output_dir)
elif args.resume_from_checkpoint is not None:
output_dir = args.resume_from_checkpoint
if output_dir is not None:
if args.world_size <= 1:
data_state_pth = os.path.join(output_dir, "data_state.json")
else:
data_state_pth = os.path.join(output_dir, f"data_state_{args.process_index}.json")
with open(data_state_pth, "r") as f:
self.loader.load_state_dict(json.loads(f.read()))
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
if args.world_size <= 1:
data_state_pth = os.path.join(output_dir, "data_state.json")
else:
data_state_pth = os.path.join(output_dir, f"data_state_{args.process_index}.json")
with open(data_state_pth, "w") as f:
f.write(json.dumps(self.state(), indent=2, sort_keys=True) + "\n")
def state(self) -> dict:
return self.loader.state_dict()
This approach can be extremely useful when performing online tokenization with IterableDataset. So I'm really looking forward to your official impls, very happy to hear about any progress :D |
Any update? I'm really want this feature. |
Yes, that was a recent release and it should be possible to integrate that into If you are getting it to work with no problem, I think it's stable enough that I can work on getting a PR out for the |
@byi8220 feel free to have a go at it! Implementation I'd like to see: Add (So as a result, we don't add a new training argument for it explicitly, just update the docs and allow a passthrough) |
Makes sense, hopefully this should be the much easier part of implementation. |
only thing that should be a bit tricky is the saving, since we don't use |
At a glance, the implementation for this appears to be fairly straightforward:
|
Created an initial PR which should fulfill this feature - #34205 This seems complete to me, although I'm unsure if I'm missing any edge cases or unit tests. I'm pretty certain the tests in (PS: there is a very small but well hidden issue with unit tests, which i've separated into a separate PR - #34201) |
Feature request
Add official support for
StatefulDataLoader
as in torchdata and datasets.Motivation
The StatefulDataLoader from the torchdata package provides a convenient way to recover a dataset iterator that was interrupted, without having to skip the first batches via a naive for loop, which can be time-consuming for extremely large datasets. The
datasets
package now officially supports statefulIterableDataset
and its combination withStatefulDataLoader
in v2.20.0.Example usage:
To enhance the usability and efficiency of the
Trainer
, it would be highly beneficial for the community if official support forStatefulDataLoader
could be added.This would allow users to easily recover from interruptions and resume training from checkpoints without wasting time on re-iterating over already processed batches.
By integrating
StatefulDataLoader
into theTrainer
, users can seamlessly handle large datasets and ensure a smooth training process. This feature would greatly improve the overall user experience and make the Trainer more robust and efficient.We kindly request the development team to consider adding official support for thoese features in the
Trainer
, as it would be a valuable addition to the library and benefit the wider community.The text was updated successfully, but these errors were encountered: