Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch sampler #3123

Open
nguyenvannghiem0312 opened this issue Dec 9, 2024 · 3 comments
Open

Batch sampler #3123

nguyenvannghiem0312 opened this issue Dec 9, 2024 · 3 comments

Comments

@nguyenvannghiem0312
Copy link

I want to perform training where I have arranged the data samples in the exact order I desire. If shuffling occurs here, everything will be disrupted. Therefore, I would like to ask: does the Batch Sampler perform any data shuffling in this context?

BATCH_SAMPLER = "batch_sampler"
NO_DUPLICATES = "no_duplicates"
GROUP_BY_LABEL = "group_by_label"
@tomaarsen
Copy link
Collaborator

tomaarsen commented Dec 9, 2024

Hello!

Yes, data shuffling is certainly enabled in a few of those. To be specific, the DataLoader shuffle default is False, and we're not overriding it in the Sentence Transformer Trainer, a.k.a. we're leaving it to the batch sampler.

  1. BATCH_SAMPLER: Shuffles due to a SubsetRandomSampler.
  2. NO_DUPLICATES: Definitely shuffles - it uses torch.randperm.
  3. GROUP_BY_LABEL: Definitely shuffles - it uses torch.randperm.

If you'd like to avoid shuffling, then you can use a custom batch sampler quite easily:

class CustomSentenceTransformerTrainer(SentenceTransformerTrainer):
    def get_batch_sampler(
        self,
        dataset: Dataset,
        batch_size: int,
        drop_last: bool,
        valid_label_columns: list[str] | None = None,
        generator: torch.Generator | None = None,
    ) -> BatchSampler | None:
            return DefaultBatchSampler(
                range(len(dataset)),
                batch_size=batch_size,
                drop_last=drop_last,
            )

I think this will already do it. Instead of using a SubsetRandomSampler to wrap the range(len(dataset)) a.k.a. the list of indices, we just give the iterable of indices directly, which will be yielded in-order.

  • Tom Aarsen

@yusufcakmakk
Copy link

Hi Tom, can you give an example of a custom multi-dataset batch sampler?

@tomaarsen
Copy link
Collaborator

These are 2 examples of the (existing) multi-dataset batch samplers:

class RoundRobinBatchSampler(SetEpochMixin, BatchSampler):
"""
Batch sampler that yields batches in a round-robin fashion from multiple batch samplers, until one is exhausted.
With this sampler, it's unlikely that all samples from each dataset are used, but we do ensure that each dataset
is sampled from equally.
Args:
dataset (ConcatDataset): A concatenation of multiple datasets.
batch_samplers (List[BatchSampler]): A list of batch samplers, one for each dataset in the ConcatDataset.
generator (torch.Generator, optional): A generator for reproducible sampling. Defaults to None.
seed (int, optional): A seed for the generator. Defaults to None.
"""
def __init__(
self,
dataset: ConcatDataset,
batch_samplers: list[BatchSampler],
generator: torch.Generator = None,
seed: int = None,
) -> None:
if len(dataset.datasets) != len(batch_samplers):
raise ValueError("The number of batch samplers must match the number of datasets in the ConcatDataset.")
super().__init__(dataset, batch_samplers[0].batch_size, batch_samplers[0].drop_last)
self.dataset = dataset
self.batch_samplers = batch_samplers
self.generator = generator
self.seed = seed
def __iter__(self) -> Iterator[list[int]]:
if self.generator and self.seed:
self.generator.manual_seed(self.seed + self.epoch)
num_samples = [len(dataset) for dataset in self.dataset.datasets]
sample_offsets = [0] + list(accumulate(num_samples))
batch_samplers = [iter(sampler) for sampler in self.batch_samplers]
for dataset_idx in cycle(range(len(batch_samplers))):
sample_offset = sample_offsets[dataset_idx]
try:
yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])]
except StopIteration:
# current iterator is apparently exhausted
break
def __len__(self) -> int:
return min(len(sampler) for sampler in self.batch_samplers) * len(self.batch_samplers)
class ProportionalBatchSampler(SetEpochMixin, BatchSampler):
def __init__(
self,
dataset: ConcatDataset,
batch_samplers: list[BatchSampler],
generator: torch.Generator,
seed: int,
) -> None:
"""
Batch sampler that samples from each dataset in proportion to its size, until all are exhausted simultaneously.
With this sampler, all samples from each dataset are used and larger datasets are sampled from more frequently.
Args:
dataset (ConcatDataset): A concatenation of multiple datasets.
batch_samplers (List[BatchSampler]): A list of batch samplers, one for each dataset in the ConcatDataset.
generator (torch.Generator, optional): A generator for reproducible sampling. Defaults to None.
seed (int, optional): A seed for the generator. Defaults to None.
"""
super().__init__(dataset, batch_samplers[0].batch_size, batch_samplers[0].drop_last)
self.dataset = dataset
self.batch_samplers = batch_samplers
self.generator = generator
self.seed = seed
def __iter__(self) -> Iterator[list[int]]:
self.generator.manual_seed(self.seed + self.epoch)
num_samples = [len(dataset) for dataset in self.dataset.datasets]
sample_offsets = [0] + list(accumulate(num_samples))
num_batches = [len(sampler) for sampler in self.batch_samplers]
dataset_indices = [idx for idx, length in enumerate(num_batches) for _ in range(length)]
dataset_idx_sampler = SubsetRandomSampler(dataset_indices, generator=self.generator)
batch_samplers = [iter(sampler) for sampler in self.batch_samplers]
for dataset_idx in dataset_idx_sampler:
sample_offset = sample_offsets[dataset_idx]
try:
yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])]
except StopIteration:
continue
def __len__(self) -> int:
return sum([len(sampler) for sampler in self.batch_samplers])

If you want to make your own, then I would recommend following this format (e.g. these __init__ args, add a __iter__ and __len__, and subclass SetEpochMixin and BatchSampler).

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants