Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer get_train_dataloader creates wrong batch size when using IterableDataset and multi-gpu training on single machine #21444

Closed
4 tasks
agossard opened this issue Feb 3, 2023 · 5 comments

Comments

@agossard
Copy link
Contributor

agossard commented Feb 3, 2023

System Info

@sgugger

I'm not sure if I'm missing something here or not. But I am doing masked language modeling with RobertaForMaskedLM and working in pytorch on an AWS machine with 8 V100s. I set args.per_device_train_batch_size=32. If I train with a regular Dataset object, the data loader will produce a big batch of 32 * 8 = 256 examples each time, and then they will be split up and sent to each GPU in batches of 32 as expected. But if I switch to an IterableDataset, I end up with the DataLoader producing batches of 32, which get split into batches of 4 being send to each GPU.

This happens because of this code in Trainer.get_train_data_loader. If we have an iterable Dataset, we end up creating a DataLoader based on per_device_train_batch_size (which is 32). But if we have any other type of dataset, we create a DataLoader with self._train_batch_size (which is 256). I confess I don't what the first if self.args.world_size > 1 block is supposed to be doing, but that doesn't get executed in my situation (running on a single machine with multiple GPUs).

Am I doing something wrong, or is this a bug?

Thanks,
Andy

    if isinstance(train_dataset, torch.utils.data.IterableDataset):
        if self.args.world_size > 1:
            train_dataset = IterableDatasetShard(
                train_dataset,
                batch_size=self._train_batch_size,
                drop_last=self.args.dataloader_drop_last,
                num_processes=self.args.world_size,
                process_index=self.args.process_index,
            )

        return DataLoader(
            train_dataset,
            batch_size=self.args.**per_device_train_batch_size**,
            collate_fn=data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
        )

    train_sampler = self._get_train_sampler()

    return DataLoader(
        train_dataset,
        batch_size=self.**_train_batch_size**,
        sampler=train_sampler,
        collate_fn=data_collator,
        drop_last=self.args.dataloader_drop_last,
        num_workers=self.args.dataloader_num_workers,
        pin_memory=self.args.dataloader_pin_memory,
        worker_init_fn=seed_worker,
    )

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Use a pytorch model on a single single machine with multiple GPUs
  2. Set TrainingArguments.per_device_train_batch_size=32
  3. Create a regular dataset in memory from a pandas data frame (or whatever)
  4. Put a breakpoint (or debugging statement) in the forward pass of the model to print out inputs.shape -> Very that first dimension=32
  5. Now create a IterableDataset and run again
  6. See that inputs.shape has first dimension of 4

Expected behavior

The train batch size should be the same whether using regular or IterableDataset

@sgugger
Copy link
Collaborator

sgugger commented Feb 3, 2023

Sounds like the self.args.per_device_train_batch_size should be self._train_batch_size indeed. Do you want to open a PR?

As an aside, using DataParallel is not the recommended way to run a multiple GPUs by PyTorch, you should launch your training script with torchrun

@agossard
Copy link
Contributor Author

agossard commented Feb 3, 2023

Thanks, Sylvain. I issue the pull request. My first time doing so, so hope I did it OK!

@github-actions
Copy link

github-actions bot commented Mar 6, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@edwardpwtsoi
Copy link

Reopening for FSDP use case
IMO, the per_device_batch_size should be used for the FSDP case. As the machines should be treated as a single device. Let me know if my understanding is wrong. I tried to test whether a patch could fix the issue by

    @property
    def train_batch_size(self) -> int:
        """
        The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training).
        """
        return self.per_device_train_batch_size

but the streaming dataset still fetch more item than per_device_train_batch_size.
Does anyone have insight on what would be a possible fix to this use case?

@amyeroberts
Copy link
Collaborator

amyeroberts commented Jul 10, 2024

@edwardpwtsoi Could you open a new issue if it's not covered by the fix in huggingface/datasets#5506? This helps us better track what has and hasn't been resolved

My bad - I thought the above was a merged PR. Regardless, it would be useful to have a new issue with specifics about the FSDP case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants