Skip to content

Commit

Permalink
Use stateful dataloader to checkpoint data iteration order and token …
Browse files Browse the repository at this point in the history
…buffer (#279)

Summary: 

Use the stateful_dataloader from torchdata
(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader)
for storing the token buffer and iteration data order. It requires a
dependency on the nightly build of torchdata >= 20240426.

Also make sure the dataloader state has a different key per rank.

Test Plan:

Tested locally by first running 30 steps (checkpointing every 5 steps)
and capturing all the loss values. Then deleting the last 3 checkpoints
and then re-run the training and the loss values from step 16-30 match
with what we had earlier in the first run. Note that this requires
changes in the train.py to enable a deterministic run.

Reviewers: @tianyu-l 

Subscribers: @andrewkho 

Tasks:

Tags:
  • Loading branch information
gokulavasan authored May 21, 2024
1 parent 3bd14ec commit 99a73dd
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch >= 2.2.0.dev
datasets
datasets >= 2.19.0

This comment has been minimized.

Copy link
@wconstab

wconstab May 22, 2024

Contributor

should torchdata be added to requirements too? or how would users know to install it other than looking at our CI yaml?

tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
sentencepiece
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/integration_test_periodic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
python -m pip install -r requirements.txt
python -m pip install -r dev-requirements.txt
- name: Run test_runner.py
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/unit_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded
1 change: 1 addition & 0 deletions .github/workflows/unit_test_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ jobs:
pip config --user set global.progress_bar off
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
pytest test --cov=. --cov-report=xml --durations=20 -vv
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
```

### Downloading a tokenizer
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ authors = [
keywords = ["pytorch", "training", "llm"]
dependencies = [
# Hugging Face integrations
"datasets",
"datasets>=2.19.0",

# Tokenization
"blobfile",
Expand Down
5 changes: 5 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
5 changes: 5 additions & 0 deletions test/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
54 changes: 54 additions & 0 deletions test/datasets/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import create_tokenizer


class TestCheckpoint:
def test_c4_resumption(self):
dataset_name = "c4_mini"
dataset_path = "./torchtitan/datasets/c4_mini"
batch_size = 1
seq_len = 1024
world_size = 4
rank = 0

dl = self._build_dataloader(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)

it = iter(dl)
for _ in range(250):
next(it)
state = dl.state_dict()
expected_input_ids, expected_labels = next(it)

# Create new dataloader, restore checkpoint, and check if next data yielded is the same as above
dl = self._build_dataloader(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)
dl.load_state_dict(state)
input_ids, labels = next(iter(dl))

assert torch.equal(input_ids, expected_input_ids)
assert torch.equal(labels, expected_labels)

def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer_type = "tiktoken"
tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
return build_hf_data_loader(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
batch_size=1,
seq_len=1024,
world_size=4,
rank=0,
)
3 changes: 3 additions & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import init_logger, logger

Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
model: nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
dataloader: DataLoader,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand All @@ -118,6 +120,7 @@ def __init__(
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
"lr_scheduler": lr_scheduler,
"dataloader": dataloader,
}
)

Expand Down
81 changes: 71 additions & 10 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional
import pickle
from typing import Any, Dict, List, Optional

import torch
from torch.utils.data import DataLoader, IterableDataset
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
Expand All @@ -23,7 +26,7 @@
}


class HuggingFaceDataset(IterableDataset):
class HuggingFaceDataset(IterableDataset, Stateful):
"""PyTorch Representation of the HuggingFace Dataset.
Args:
Expand Down Expand Up @@ -99,32 +102,90 @@ def __init__(
self.seq_len = seq_len
self.infinite = infinite

# variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

while True:
for sample in iter(self._data):
for sample in self._get_data_iter():
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
all_tokens = all_tokens[max_buffer_token_len:]
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield input, label

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data.")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(
f"Dataset {self.dataset_name} is being re-looped. "
"Loss related metrics might be misleading."
)

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# Skip samples
if isinstance(self._data, IterableDataset):
it = iter(self._data)
# Naively iterate through the samples as skip may not be supported
for _ in range(self._sample_idx):
next(it)
return it

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if self._sample_idx == len(self._data):
return iter([])
return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]

def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}


class DPAwareDataLoader(StatefulDataLoader, Stateful):
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""

def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int):
super().__init__(hf_ds, batch_size)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {self._rank_id: pickle.dumps(super().state_dict())}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid, don't log a warning
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}."
)
return
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


def build_hf_data_loader(
dataset_name: str,
Expand All @@ -140,4 +201,4 @@ def build_hf_data_loader(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return DataLoader(hf_ds, batch_size=batch_size)
return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size)
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def loss_fn(pred, labels):
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
dataloader=data_loader,
states={"train_state": train_state},
job_config=job_config,
)
Expand Down

0 comments on commit 99a73dd

Please sign in to comment.