From 1c2bb3ac912ff45bc091cd044b30b72e695dac61 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 16 Feb 2024 20:10:39 -0800 Subject: [PATCH] modify data split to use HF api ghstack-source-id: 489d666dd77ddcae80b139147ad82f4b1e6888da Pull Request resolved: https://github.com/pytorch-labs/torchtrain/pull/65 --- torchtrain/datasets/alpaca.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/torchtrain/datasets/alpaca.py b/torchtrain/datasets/alpaca.py index f52d21121..779dee5f5 100644 --- a/torchtrain/datasets/alpaca.py +++ b/torchtrain/datasets/alpaca.py @@ -9,6 +9,7 @@ from torchtrain.datasets.tokenizer import TokenizerIf from datasets import load_dataset +from datasets.distributed import split_dataset_by_node class AlpacaDataset(IterableDataset): @@ -44,27 +45,19 @@ def __init__( rank: int = 0, **kwargs ) -> None: - self._data = load_dataset("tatsu-lab/alpaca", split="train") + # TODO: This is a temporary solution for small datasets like Alpaca. + # For larger datasets we need to use a more scalable approach. + # Setting `streaming=True` works for large dataset, but the speed is slow. + ds = load_dataset("tatsu-lab/alpaca", split="train") + self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size)) self._tokenizer = tokenizer - self.data_iterator = iter(self._data) self.seq_len = seq_len - self.world_size = world_size - self.rank = rank - self.response_tag = "\n\n### Response:\n" - - def __len__(self): - return len(self._data) def __iter__(self): max_buffer_token_len = 1 + self.seq_len all_tokens: List[int] = [] - for idx, sample in enumerate(self.data_iterator): - # select samples to pack in a round-robin fashion - # TODO: This is a temporary solution for small datasets like Alpaca. - # For larger datasets we need to use a more scalable approach. - if idx % self.world_size != self.rank: - continue + for sample in self.data_iterator: sample_text = sample["text"] sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) all_tokens.extend(sample_tokens)