Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use stateful dataloader to checkpoint data iteration order and token buffer #279

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch >= 2.2.0.dev
datasets
datasets >= 2.19.0
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
sentencepiece
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/unit_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ jobs:
pip config --user set global.progress_bar off

python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need this in .github/workflows/integration_test_periodic.yaml as well.

Let's create an issue, tracking that we need to put torchdata in requirements.txt and pyproject.toml after the needed changes ship in an official release.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l Should pyproject.toml datasets dependency also enforce >= 2.19.0 version requirement?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes I think so

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #351 to remove torchdata nightly pip install

mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded
1 change: 1 addition & 0 deletions .github/workflows/unit_test_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ jobs:
pip config --user set global.progress_bar off

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
pytest test --cov=. --cov-report=xml --durations=20 -vv
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
```

### Downloading a tokenizer
Expand Down
5 changes: 5 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
5 changes: 5 additions & 0 deletions test/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
54 changes: 54 additions & 0 deletions test/datasets/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import create_tokenizer


class TestCheckpoint:
def test_c4_resumption(self):
dataset_name = "c4_mini"
dataset_path = "./torchtitan/datasets/c4_mini"
batch_size = 1
seq_len = 1024
world_size = 4
rank = 0

dl = self._build_dataloader(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)

it = iter(dl)
for _ in range(250):
next(it)
state = dl.state_dict()
expected_input_ids, expected_labels = next(it)

# Create new dataloader, restore checkpoint, and check if next data yielded is the same as above
dl = self._build_dataloader(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)
dl.load_state_dict(state)
input_ids, labels = next(iter(dl))

assert torch.equal(input_ids, expected_input_ids)
assert torch.equal(labels, expected_labels)

def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer_type = "tiktoken"
tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
return build_hf_data_loader(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
batch_size=1,
seq_len=1024,
world_size=4,
rank=0,
)
3 changes: 3 additions & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import init_logger, logger

Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
model: nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
dataloader: DataLoader,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand All @@ -118,6 +120,7 @@ def __init__(
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
"lr_scheduler": lr_scheduler,
"dataloader": dataloader,
}
)

Expand Down
81 changes: 71 additions & 10 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional
import pickle
from typing import Any, Dict, List, Optional

import torch
from torch.utils.data import DataLoader, IterableDataset
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
Expand All @@ -23,7 +26,7 @@
}


class HuggingFaceDataset(IterableDataset):
class HuggingFaceDataset(IterableDataset, Stateful):
"""PyTorch Representation of the HuggingFace Dataset.

Args:
Expand Down Expand Up @@ -99,32 +102,90 @@ def __init__(
self.seq_len = seq_len
self.infinite = infinite

# variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

while True:
for sample in iter(self._data):
for sample in self._get_data_iter():
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
all_tokens = all_tokens[max_buffer_token_len:]
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield input, label

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data.")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(
f"Dataset {self.dataset_name} is being re-looped. "
"Loss related metrics might be misleading."
)

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# Skip samples
if isinstance(self._data, IterableDataset):
tianyu-l marked this conversation as resolved.
Show resolved Hide resolved
it = iter(self._data)
# Naively iterate through the samples as skip may not be supported
tianyu-l marked this conversation as resolved.
Show resolved Hide resolved
for _ in range(self._sample_idx):
next(it)
return it

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if self._sample_idx == len(self._data):
return iter([])
return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]

def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved


class DpAwareDataLoader(StatefulDataLoader, Stateful):
gokulavasan marked this conversation as resolved.
Show resolved Hide resolved
"""
A wrapper around the StatefulDataLoader that ensures that the state is stored only once for DP ranks.
"""

def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int):
super().__init__(hf_ds, batch_size)
self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {self._rank_id: pickle.dumps(super().state_dict())}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid, don't log a warning
if not state_dict:
return

if self._rank_id not in state_dict:
logger.warning(
f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}."
)
return
super().load_state_dict(pickle.loads(state_dict[self._rank_id]))


def build_hf_data_loader(
dataset_name: str,
Expand All @@ -140,4 +201,4 @@ def build_hf_data_loader(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return DataLoader(hf_ds, batch_size=batch_size)
return DpAwareDataLoader(rank, hf_ds, batch_size=batch_size)
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def loss_fn(pred, labels):
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
dataloader=data_loader,
states={"train_state": train_state},
job_config=job_config,
)
Expand Down
Loading