Skip to content

Commit 74aa7d9

Browse files
authored
initial commit: add colossal llama 2 (#4784)
1 parent 4146f1c commit 74aa7d9

19 files changed

+2162
-2
lines changed

applications/Colossal-LLaMA-2/README.md

+377
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
import numpy as np
5+
import os
6+
import random
7+
from dataclasses import dataclass
8+
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
9+
10+
import torch
11+
from datasets import dataset_dict, load_from_disk
12+
from datasets import Dataset as HFDataset
13+
from torch.distributed import ProcessGroup
14+
from torch.distributed.distributed_c10d import _get_default_group
15+
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
16+
from transformers.tokenization_utils import PreTrainedTokenizer
17+
import torch.nn.functional as F
18+
19+
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
20+
PathType = Union[str, os.PathLike]
21+
22+
23+
def load_tokenized_dataset(
24+
dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
25+
) -> Optional[DatasetType]:
26+
"""
27+
Load pre-tokenized dataset.
28+
Each instance of dataset is a dictionary with
29+
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
30+
"""
31+
mode_map = {"train": "train", "dev": "validation", "test": "test"}
32+
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
33+
34+
if isinstance(dataset_paths, (str, os.PathLike)):
35+
dataset_paths = [dataset_paths]
36+
37+
datasets = [] # `List[datasets.dataset_dict.Dataset]`
38+
for ds_path in dataset_paths:
39+
ds_path = os.path.abspath(ds_path)
40+
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
41+
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
42+
if isinstance(ds_dict, HFDataset):
43+
datasets.append(ds_dict)
44+
else:
45+
if mode_map[mode] in ds_dict:
46+
datasets.append(ds_dict[mode_map[mode]])
47+
if len(datasets) == 0:
48+
return None
49+
if len(datasets) == 1:
50+
return datasets.pop()
51+
return ConcatDataset(datasets=datasets)
52+
53+
54+
@dataclass
55+
class DataCollatorForSupervisedDataset(object):
56+
"""
57+
Collate instances for supervised dataset.
58+
Each instance is a tokenized dictionary with fields
59+
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
60+
"""
61+
62+
tokenizer: PreTrainedTokenizer
63+
max_length: int = 4096
64+
ignore_index: int = -100
65+
66+
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
67+
"""
68+
69+
Args:
70+
instances (`Sequence[Dict[str, List[int]]]`):
71+
Mini-batch samples, each sample is stored in an individual dictionary.
72+
73+
Returns:
74+
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
75+
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
76+
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
77+
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
78+
"""
79+
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
80+
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
81+
f"but now `{self.tokenizer.pad_token_id}`"
82+
)
83+
84+
# `List[torch.Tensor]`
85+
batch_input_ids = [
86+
torch.LongTensor(instance["input_ids"][: self.max_length])
87+
if len(instance["input_ids"]) > self.max_length
88+
else torch.LongTensor(instance["input_ids"])
89+
for instance in instances
90+
]
91+
batch_labels = [
92+
torch.LongTensor(instance["labels"][: self.max_length])
93+
if len(instance["labels"]) > self.max_length
94+
else torch.LongTensor(instance["labels"])
95+
for instance in instances
96+
]
97+
98+
if self.tokenizer.padding_side == "right":
99+
input_ids = torch.nn.utils.rnn.pad_sequence(
100+
sequences=batch_input_ids,
101+
batch_first=True,
102+
padding_value=self.tokenizer.pad_token_id,
103+
) # (bsz, max_len)
104+
labels = torch.nn.utils.rnn.pad_sequence(
105+
sequences=batch_labels,
106+
batch_first=True,
107+
padding_value=self.ignore_index,
108+
) # (bsz, max_len)
109+
# pad to max
110+
to_pad = self.max_length - input_ids.size(1)
111+
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
112+
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
113+
elif self.tokenizer.padding_side == "left":
114+
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
115+
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
116+
sequences=reversed_input_ids,
117+
batch_first=True,
118+
padding_value=self.tokenizer.pad_token_id,
119+
) # (bsz, max_len)
120+
input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
121+
reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
122+
reversed_labels = torch.nn.utils.rnn.pad_sequence(
123+
sequences=reversed_labels,
124+
batch_first=True,
125+
padding_value=self.ignore_index,
126+
) # (bsz, max_len)
127+
labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
128+
else:
129+
raise RuntimeError(
130+
f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
131+
f"but now `{self.tokenizer.padding_side}`"
132+
)
133+
134+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
135+
136+
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
137+
138+
139+
class StatefulDistributedSampler(DistributedSampler):
140+
"""
141+
Stateful distributed sampler for multi-stage training.
142+
"""
143+
144+
def __init__(
145+
self,
146+
dataset: DatasetType,
147+
num_replicas: Optional[int] = None,
148+
rank: Optional[int] = None,
149+
shuffle: bool = True,
150+
seed: int = 0,
151+
drop_last: bool = False,
152+
) -> None:
153+
super().__init__(
154+
dataset=dataset,
155+
num_replicas=num_replicas,
156+
rank=rank,
157+
shuffle=shuffle,
158+
seed=seed,
159+
drop_last=drop_last,
160+
)
161+
self.start_index = 0
162+
163+
def __iter__(self) -> Iterator:
164+
iterator = super().__iter__()
165+
indices = list(iterator)
166+
indices = indices[self.start_index :]
167+
return iter(indices)
168+
169+
def __len__(self) -> int:
170+
return self.num_samples - self.start_index
171+
172+
def set_start_index(self, start_index: int) -> None:
173+
self.start_index = start_index
174+
175+
176+
def setup_distributed_dataloader(
177+
dataset: DatasetType,
178+
batch_size: int = 1,
179+
shuffle: bool = False,
180+
seed: int = 1024,
181+
drop_last: bool = False,
182+
pin_memory: bool = False,
183+
num_workers: int = 0,
184+
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
185+
process_group: Optional[ProcessGroup] = None,
186+
**kwargs,
187+
) -> DataLoader:
188+
"""
189+
Setup dataloader for distributed training.
190+
"""
191+
_kwargs = kwargs.copy()
192+
process_group = process_group or _get_default_group()
193+
sampler = StatefulDistributedSampler(
194+
dataset=dataset,
195+
num_replicas=process_group.size(),
196+
rank=process_group.rank(),
197+
shuffle=shuffle,
198+
seed=seed,
199+
drop_last=drop_last,
200+
)
201+
202+
# Deterministic dataloader
203+
def seed_worker(worker_id: int) -> None:
204+
worker_seed = seed
205+
np.random.seed(worker_seed)
206+
torch.manual_seed(worker_seed)
207+
random.seed(worker_seed)
208+
209+
return DataLoader(
210+
dataset=dataset,
211+
batch_size=batch_size,
212+
sampler=sampler,
213+
num_workers=num_workers,
214+
collate_fn=collate_fn,
215+
pin_memory=pin_memory,
216+
drop_last=drop_last,
217+
worker_init_fn=seed_worker,
218+
**_kwargs,
219+
)

0 commit comments

Comments
 (0)