Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Pytorch BatchSampler does not work well with pytorch lightning #20326

Closed
dadwadw233 opened this issue Oct 9, 2024 · 0 comments · Fixed by #20327
Closed

Custom Pytorch BatchSampler does not work well with pytorch lightning #20326

dadwadw233 opened this issue Oct 9, 2024 · 0 comments · Fixed by #20327
Labels
bug Something isn't working priority: 1 Medium priority task ver: 2.4.x

Comments

@dadwadw233
Copy link
Contributor

dadwadw233 commented Oct 9, 2024

Bug description

where is the bug❓

When I use a custom BatchSampler to initialize the DataLoader and use it with pytorch lightning's datamodule, I find that the shuffle settings don't take effect correctly, as evidenced by the fact that no matter how I set the sampler used to initialize the BatchSampler ( random or sequential), pytorch lightning sets the wrapped distributedsampler to the default option, i.e., shuffle by default for the training stage, and by the dataloader's sampler type for the other stages.

Analyse

The problem arises in the _is_dataloader_shuffled function (in pytorch_lightining.utilities.data) in the pytorch lightning code, where the selection of the shuffle state is based on the sampler state of the dataloader, which may seem like No problem, but in fact pytorch ignores the sampler when setting the BatchSampler (its sampler is set to the default sequential sampler), so pytorch lightning's behavior here will always get a sequential sampler which results in shuffle not working as I expected.

In fact I think the pytroch implementation is equally problematic, in the latest version of the pytorch code the dataloader property Sampler is kept mutually exclusive with BatchSampler, Shuffle etc. That is, when I use a custom BatchSampler, pytorch will only initialize a default SequentialSmapler, which is a bit counter-intuitive, but you don't get the wrong result by doing that, because pytorch chooses to use the batchsampler for data loading when it exists, and the sampler is only used when the batchsize is 1.

key code:
image

Suggestions

Since the problems with the pytorch code do not trigger the mentioned problem when pytorch lighting is not used, I would suggest a change to the pytorch lighting code:

before:

def _is_dataloader_shuffled(dataloader: object) -> bool:
    if hasattr(dataloader, "__pl_saved_kwargs"):
        # this attribute is not part of PyTorch's DataLoader, but could have been set by
        # our `_replace_init_method` context manager
        if "shuffle" in dataloader.__pl_saved_kwargs:
            return dataloader.__pl_saved_kwargs["shuffle"]
        if "shuffle" in dataloader.__pl_saved_arg_names:
            return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
    if hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset):
        # shuffling is useless with iterable datasets
        return False
    if not hasattr(dataloader, "sampler"):
        # shuffling is enabled via a sampler. No sampler, no shuffling
        return False
   
    sampler = dataloder.sampler
    if isinstance(sampler, SequentialSampler):
        return False
    return isinstance(sampler, RandomSampler)

after:

def _is_dataloader_shuffled(dataloader: object) -> bool:
    if hasattr(dataloader, "__pl_saved_kwargs"):
        # this attribute is not part of PyTorch's DataLoader, but could have been set by
        # our `_replace_init_method` context manager
        if "shuffle" in dataloader.__pl_saved_kwargs:
            return dataloader.__pl_saved_kwargs["shuffle"]
        if "shuffle" in dataloader.__pl_saved_arg_names:
            return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
    if hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset):
        # shuffling is useless with iterable datasets
        return False
    if not hasattr(dataloader, "sampler"):
        # shuffling is enabled via a sampler. No sampler, no shuffling
        return False
    
    batch_sampler = dataloader.batch_sampler
    if batch_sampler is not None:
        sampler = batch_sampler.sampler
    else:
        sampler = dataloder.sampler
        
    sampler_cls = type(sampler)
    if sampler_cls not in (RandomSampler, SequentialSampler):
        # custom sampler case:
        if hasattr(sampler, "generator"):
            # maybe custom random sampler
            return True
        else:
            # we don't know
            return False
        
    if isinstance(sampler, SequentialSampler):
        return False
    return isinstance(sampler, RandomSampler)

What version are you seeing the problem on?

master

How to reproduce the bug

Firstly, define some customized BatchSampler like(or just use default BatchSampler):

class DynamicBatchSampler(BatchSampler):
    def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool, dataset):
        super().__init__(sampler, batch_size, drop_last)
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.dataset = dataset

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            if not batch and getattr(self.dataset, 'dynamic_length', False):
                min_len, max_len = self.dataset.min_length, self.dataset.max_length
                length = random.randint(min_len, max_len)
            else:
                length = None
            batch.append((idx, length))
            if len(batch) == self.batch_size:
                print(batch)
                yield batch
                batch = []
        if batch and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return math.ceil(len(self.sampler) / self.batch_size)

Secondly, init the dataloader by BatchSampler like:

if cfg.shuffle:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
        
    bsampler = DynamicBatchSampler(sampler, cfg.batch_size, cfg.drop_last, dataset)
    dl = DataLoader(dataset, batch_sampler=bsampler, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory)

    return dl

If you use the dl to init datamodule, the bug will occurred

cc @tchaton

@dadwadw233 dadwadw233 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 9, 2024
@lantiga lantiga added priority: 1 Medium priority task and removed needs triage Waiting to be triaged by maintainers labels Nov 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: 1 Medium priority task ver: 2.4.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants