From 9b07bb9bb8e3f02f1c84ef62b1f5544aa8f2002a Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Tue, 21 May 2024 10:45:20 -0700 Subject: [PATCH 1/5] Integrate stateful dataloader to torchtitan Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .ci/docker/requirements.txt | 2 +- .github/workflows/unit_test_4gpu.yaml | 1 + .github/workflows/unit_test_cpu.yaml | 1 + README.md | 1 + test/__init__.py | 5 +++ test/datasets/__init__.py | 5 +++ test/datasets/test_dataset_checkpoint.py | 56 ++++++++++++++++++++++++ torchtitan/checkpoint.py | 32 ++++++++++++++ torchtitan/datasets/hf_datasets.py | 51 +++++++++++++++++---- train.py | 1 + 10 files changed, 145 insertions(+), 10 deletions(-) create mode 100644 test/datasets/__init__.py create mode 100644 test/datasets/test_dataset_checkpoint.py diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index b82120a6..bb21293b 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -1,5 +1,5 @@ torch >= 2.2.0.dev -datasets +datasets >= 2.19.0 tomli >= 1.1.0 ; python_version < "3.11" tensorboard sentencepiece diff --git a/.github/workflows/unit_test_4gpu.yaml b/.github/workflows/unit_test_4gpu.yaml index 5759349d..6f052868 100644 --- a/.github/workflows/unit_test_4gpu.yaml +++ b/.github/workflows/unit_test_4gpu.yaml @@ -31,5 +31,6 @@ jobs: pip config --user set global.progress_bar off python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded diff --git a/.github/workflows/unit_test_cpu.yaml b/.github/workflows/unit_test_cpu.yaml index dd318dbb..2482bd51 100644 --- a/.github/workflows/unit_test_cpu.yaml +++ b/.github/workflows/unit_test_cpu.yaml @@ -25,4 +25,5 @@ jobs: pip config --user set global.progress_bar off pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly pytest test --cov=. --cov-report=xml --durations=20 -vv diff --git a/README.md b/README.md index 21634d0b..a8d1fcc4 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ git clone https://github.com/pytorch/torchtitan cd torchtitan pip install -r requirements.txt pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 +pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly ``` ### Downloading a tokenizer diff --git a/test/__init__.py b/test/__init__.py index e69de29b..2e41cd71 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/datasets/__init__.py b/test/datasets/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/test/datasets/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/datasets/test_dataset_checkpoint.py b/test/datasets/test_dataset_checkpoint.py new file mode 100644 index 00000000..e4be71d6 --- /dev/null +++ b/test/datasets/test_dataset_checkpoint.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchtitan.checkpoint import DataLoaderWrapper +from torchtitan.datasets.hf_datasets import build_hf_data_loader +from torchtitan.datasets.tokenizer import create_tokenizer + + +class TestDatasetCheckpoint: + def test_c4_resumption(self): + dataset_name = "c4_mini" + dataset_path = "./torchtitan/datasets/c4_mini" + batch_size = 1 + seq_len = 1024 + world_size = 4 + rank = 0 + + dl_wrapper = self._create_dataloader_wrapper( + dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ) + + it = iter(dl_wrapper.dataloader) + for _ in range(250): + next(it) + state = dl_wrapper.state_dict() + expected_input_ids, expected_labels = next(it) + + # Create new dataloader, restore checkpoint, and check if next data yielded is the same as above + dl_wrapper = self._create_dataloader_wrapper( + dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ) + dl_wrapper.load_state_dict(state) + input_ids, labels = next(iter(dl_wrapper.dataloader)) + + assert torch.equal(input_ids, expected_input_ids) + assert torch.equal(labels, expected_labels) + + def _create_dataloader_wrapper( + self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank + ): + tokenizer_type = "tiktoken" + tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") + dataloader = build_hf_data_loader( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + batch_size=1, + seq_len=1024, + world_size=4, + rank=0, + ) + return DataLoaderWrapper(dataloader) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 81bdf592..61dce0c2 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -6,6 +6,7 @@ import enum import os +import pickle import re import time from multiprocessing import get_context @@ -22,6 +23,8 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import DataLoader +from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import init_logger, logger @@ -60,6 +63,33 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict) +class DataLoaderWrapper(Stateful): + def __init__(self, dataloader: DataLoader) -> None: + self.dataloader = dataloader + # Use global rank for now even though dataloader state could be same across dp groups + self.rank_id = str( + dist.get_rank() if (dist.is_available() and dist.is_initialized()) else 0 + ) + + def state_dict(self) -> Dict[str, Any]: + if isinstance(self.dataloader, StatefulDataLoader): + return {self.rank_id: pickle.dumps(self.dataloader.state_dict())} + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if isinstance(self.dataloader, StatefulDataLoader): + # State is empty + if not state_dict: + return + + if self.rank_id not in state_dict: + logger.warning(f"DataLoader state is empty for rank {self.rank_id}. ") + return + + # Load state for the current rank + self.dataloader.load_state_dict(pickle.loads(state_dict[self.rank_id])) + + class Terminate: pass @@ -103,6 +133,7 @@ def __init__( model: nn.Module, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + dataloader: DataLoader, states: Dict[str, Any], job_config: JobConfig, ) -> None: @@ -118,6 +149,7 @@ def __init__( "model": ModelWrapper(model), "optimizer": OptimizerWrapper(model, optimizer), "lr_scheduler": lr_scheduler, + "dataloader": DataLoaderWrapper(dataloader), } ) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index f6d09faa..a98c9467 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -7,7 +7,9 @@ from typing import List, Optional import torch -from torch.utils.data import DataLoader, IterableDataset +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset +from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.datasets.tokenizer import Tokenizer from torchtitan.logging_utils import logger @@ -23,7 +25,7 @@ } -class HuggingFaceDataset(IterableDataset): +class HuggingFaceDataset(IterableDataset, Stateful): """PyTorch Representation of the HuggingFace Dataset. Args: @@ -99,32 +101,63 @@ def __init__( self.seq_len = seq_len self.infinite = infinite + # variables for checkpointing + self._sample_idx = 0 + self._all_tokens: List[int] = [] + def __iter__(self): max_buffer_token_len = 1 + self.seq_len - all_tokens: List[int] = [] while True: - for sample in iter(self._data): + for sample in self._get_data_iter(): sample_text = sample["text"] sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) - all_tokens.extend(sample_tokens) + self._all_tokens.extend(sample_tokens) + self._sample_idx += 1 - while len(all_tokens) >= max_buffer_token_len: - x = torch.LongTensor(all_tokens[:max_buffer_token_len]) + while len(self._all_tokens) >= max_buffer_token_len: + x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) # update tokens to the remaining tokens - all_tokens = all_tokens[max_buffer_token_len:] + self._all_tokens = self._all_tokens[max_buffer_token_len:] input = x[:-1] label = x[1:] yield input, label + if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data.") break else: + # Reset offset for the next iteration + self._sample_idx = 0 logger.warning( f"Dataset {self.dataset_name} is being re-looped. " "Loss related metrics might be misleading." ) + def _get_data_iter(self): + if self._sample_idx == 0: + return iter(self._data) + + # Skip samples + if isinstance(self._data, IterableDataset): + it = iter(self._data) + # Naively iterate through the samples as skip may not be supported + for _ in range(self._sample_idx): + next(it) + return it + + # As skipping to the end throws an error in case of map-style dataset, return an empty iterator + if self._sample_idx == len(self._data): + return iter([]) + return iter(self._data.skip(self._sample_idx)) + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + self._all_tokens = state_dict["token_buffer"] + + def state_dict(self): + return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} + def build_hf_data_loader( dataset_name: str, @@ -140,4 +173,4 @@ def build_hf_data_loader( dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) - return DataLoader(hf_ds, batch_size=batch_size) + return StatefulDataLoader(hf_ds, batch_size=batch_size) diff --git a/train.py b/train.py index 318c7174..a0bb337e 100644 --- a/train.py +++ b/train.py @@ -245,6 +245,7 @@ def loss_fn(pred, labels): model=model, optimizer=optimizer, lr_scheduler=scheduler, + dataloader=data_loader, states={"train_state": train_state}, job_config=job_config, ) From c1a49fbe06f93fcd1fa9533ca0a746890cdb8572 Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Tue, 21 May 2024 12:59:57 -0700 Subject: [PATCH 2/5] Store state only for dp rank Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- ...taset_checkpoint.py => test_checkpoint.py} | 20 +++++------ torchtitan/checkpoint.py | 33 ++----------------- torchtitan/datasets/hf_datasets.py | 32 ++++++++++++++++-- 3 files changed, 41 insertions(+), 44 deletions(-) rename test/datasets/{test_dataset_checkpoint.py => test_checkpoint.py} (74%) diff --git a/test/datasets/test_dataset_checkpoint.py b/test/datasets/test_checkpoint.py similarity index 74% rename from test/datasets/test_dataset_checkpoint.py rename to test/datasets/test_checkpoint.py index e4be71d6..6f04dd23 100644 --- a/test/datasets/test_dataset_checkpoint.py +++ b/test/datasets/test_checkpoint.py @@ -5,12 +5,11 @@ # LICENSE file in the root directory of this source tree. import torch -from torchtitan.checkpoint import DataLoaderWrapper from torchtitan.datasets.hf_datasets import build_hf_data_loader from torchtitan.datasets.tokenizer import create_tokenizer -class TestDatasetCheckpoint: +class TestCheckpoint: def test_c4_resumption(self): dataset_name = "c4_mini" dataset_path = "./torchtitan/datasets/c4_mini" @@ -19,32 +18,32 @@ def test_c4_resumption(self): world_size = 4 rank = 0 - dl_wrapper = self._create_dataloader_wrapper( + dl = self._build_dataloader( dataset_name, dataset_path, batch_size, seq_len, world_size, rank ) - it = iter(dl_wrapper.dataloader) + it = iter(dl) for _ in range(250): next(it) - state = dl_wrapper.state_dict() + state = dl.state_dict() expected_input_ids, expected_labels = next(it) # Create new dataloader, restore checkpoint, and check if next data yielded is the same as above - dl_wrapper = self._create_dataloader_wrapper( + dl = self._build_dataloader( dataset_name, dataset_path, batch_size, seq_len, world_size, rank ) - dl_wrapper.load_state_dict(state) - input_ids, labels = next(iter(dl_wrapper.dataloader)) + dl.load_state_dict(state) + input_ids, labels = next(iter(dl)) assert torch.equal(input_ids, expected_input_ids) assert torch.equal(labels, expected_labels) - def _create_dataloader_wrapper( + def _build_dataloader( self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank ): tokenizer_type = "tiktoken" tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model") - dataloader = build_hf_data_loader( + return build_hf_data_loader( dataset_name=dataset_name, dataset_path=dataset_path, tokenizer=tokenizer, @@ -53,4 +52,3 @@ def _create_dataloader_wrapper( world_size=4, rank=0, ) - return DataLoaderWrapper(dataloader) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 61dce0c2..692a3183 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -6,7 +6,6 @@ import enum import os -import pickle import re import time from multiprocessing import get_context @@ -23,9 +22,8 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful -from torch.utils.data import DataLoader -from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torch.utils.data import DataLoader from torchtitan.logging_utils import init_logger, logger @@ -63,33 +61,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict) -class DataLoaderWrapper(Stateful): - def __init__(self, dataloader: DataLoader) -> None: - self.dataloader = dataloader - # Use global rank for now even though dataloader state could be same across dp groups - self.rank_id = str( - dist.get_rank() if (dist.is_available() and dist.is_initialized()) else 0 - ) - - def state_dict(self) -> Dict[str, Any]: - if isinstance(self.dataloader, StatefulDataLoader): - return {self.rank_id: pickle.dumps(self.dataloader.state_dict())} - return {} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - if isinstance(self.dataloader, StatefulDataLoader): - # State is empty - if not state_dict: - return - - if self.rank_id not in state_dict: - logger.warning(f"DataLoader state is empty for rank {self.rank_id}. ") - return - - # Load state for the current rank - self.dataloader.load_state_dict(pickle.loads(state_dict[self.rank_id])) - - class Terminate: pass @@ -149,7 +120,7 @@ def __init__( "model": ModelWrapper(model), "optimizer": OptimizerWrapper(model, optimizer), "lr_scheduler": lr_scheduler, - "dataloader": DataLoaderWrapper(dataloader), + "dataloader": dataloader, } ) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index a98c9467..3710b1fd 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +import pickle +from typing import Any, Dict, List, Optional import torch from torch.distributed.checkpoint.stateful import Stateful @@ -159,6 +160,33 @@ def state_dict(self): return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} +class DpAwareDataLoader(StatefulDataLoader, Stateful): + """ + A wrapper around the StatefulDataLoader that ensures that the state is stored only once for DP ranks. + """ + + def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int): + super().__init__(hf_ds, batch_size) + self._dp_rank = dp_rank + self._rank_id = f"dp_rank_{dp_rank}" + + def state_dict(self) -> Dict[str, Any]: + # Store state only for dp rank to avoid replicating the same state across other dimensions + return {self._rank_id: pickle.dumps(super().state_dict())} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # State being empty is valid, don't log a warning + if not state_dict: + return + + if self._rank_id not in state_dict: + logger.warning( + f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}." + ) + return + super().load_state_dict(pickle.loads(state_dict[self._rank_id])) + + def build_hf_data_loader( dataset_name: str, dataset_path: Optional[str], @@ -173,4 +201,4 @@ def build_hf_data_loader( dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) - return StatefulDataLoader(hf_ds, batch_size=batch_size) + return DpAwareDataLoader(rank, hf_ds, batch_size=batch_size) From 63bd006c02308659696fd97995f2435c82d93bdf Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Tue, 21 May 2024 13:10:10 -0700 Subject: [PATCH 3/5] lint --- torchtitan/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 692a3183..fb7c41c8 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -22,8 +22,8 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful -from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torch.utils.data import DataLoader +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import init_logger, logger From a62f3173c63d861158a67eb75dbf20d12dd2963b Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Tue, 21 May 2024 14:45:23 -0700 Subject: [PATCH 4/5] Address PR comments --- .github/workflows/integration_test_periodic.yaml | 1 + pyproject.toml | 2 +- torchtitan/datasets/hf_datasets.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/integration_test_periodic.yaml b/.github/workflows/integration_test_periodic.yaml index bc717cd1..488fc4da 100644 --- a/.github/workflows/integration_test_periodic.yaml +++ b/.github/workflows/integration_test_periodic.yaml @@ -34,6 +34,7 @@ jobs: - name: Install dependencies run: | pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly python -m pip install -r requirements.txt python -m pip install -r dev-requirements.txt - name: Run test_runner.py diff --git a/pyproject.toml b/pyproject.toml index 2a8f9557..a5c1b72f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ keywords = ["pytorch", "training", "llm"] dependencies = [ # Hugging Face integrations - "datasets", + "datasets>=2.19.0", # Tokenization "blobfile", diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 3710b1fd..52c13697 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -160,7 +160,7 @@ def state_dict(self): return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} -class DpAwareDataLoader(StatefulDataLoader, Stateful): +class DPAwareDataLoader(StatefulDataLoader, Stateful): """ A wrapper around the StatefulDataLoader that ensures that the state is stored only once for DP ranks. """ @@ -201,4 +201,4 @@ def build_hf_data_loader( dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) - return DpAwareDataLoader(rank, hf_ds, batch_size=batch_size) + return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size) From f22a97712e7122fc5310a3b7ff3011c6232b841a Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Tue, 21 May 2024 15:45:37 -0700 Subject: [PATCH 5/5] minor doc change --- torchtitan/datasets/hf_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 52c13697..d0306663 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -162,7 +162,7 @@ def state_dict(self): class DPAwareDataLoader(StatefulDataLoader, Stateful): """ - A wrapper around the StatefulDataLoader that ensures that the state is stored only once for DP ranks. + A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank. """ def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int):