-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy pathdata.py
413 lines (317 loc) · 13.9 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
from dataclasses import dataclass, asdict
import random
from typing import Any, Generator, Optional, List, Dict, TypedDict, Union
import functools
from zeroband.utils.logger import get_logger
from zeroband.config import DataConfig
import torch
from torch.utils.data import IterableDataset, Dataset
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.distributed.checkpoint.stateful import Stateful
from datasets import load_dataset_builder, BuilderConfig
from pyarrow import parquet as pq
from transformers import PreTrainedTokenizer
TEST_VOCAB_SIZE = 1024
class FakeTokenizedDataset(IterableDataset):
"""This is a dummy dataset that generates random sequences of length seq_len and vocab_size"""
def __init__(self, seq_len: int, vocab_size: int):
self.seq_len = seq_len
self.vocab_size = vocab_size
assert vocab_size > 3, "Vocab size must be greater than 3"
self.step = 0
def __iter__(self) -> Generator[dict[str, Any], Any, None]:
while True:
len_ = random.randint(1, self.seq_len)
input_ids = torch.randint(3, self.vocab_size, (len_,)).tolist()
self.step += 1
yield {"input_ids": input_ids}
def state_dict(self):
return {"step": self.step}
def load_state_dict(self, state_dict):
self.step = state_dict["step"]
itera = iter(self)
for _ in range(self.step):
next(itera)
class BatchOutput(TypedDict):
input_ids: torch.IntTensor
labels: torch.IntTensor
seqlens: list[int]
@dataclass
class SequencePackingDataSetState:
inputs_ids: list[int]
labels: list[int]
seqlens: list[int]
class SequencePackingDataSet(IterableDataset, Stateful):
"""
This class wrap a dataset and wrap it into an iterable that return sequence of max_seq_length
packed
"""
def __init__(self, dataset: Dataset, max_seq_length: int, eos_token: int):
self.dataset = dataset
self.max_seq_length = max_seq_length
self.eos_token = eos_token
self.state = SequencePackingDataSetState(inputs_ids=[], labels=[], seqlens=[])
def __iter__(self) -> Generator[BatchOutput, Any, None]:
for og_sample in self.dataset:
og_sample: list[int] = og_sample["input_ids"]
og_sample = og_sample + [self.eos_token]
sample_inputs_ids = og_sample[:-1]
sample_labels = og_sample[1:]
token_remaining = self.max_seq_length - len(self.state.inputs_ids)
if len(sample_inputs_ids) < token_remaining:
self.state.inputs_ids.extend(sample_inputs_ids)
self.state.labels.extend(sample_labels)
self.state.seqlens.append(len(sample_inputs_ids))
else:
self.state.inputs_ids.extend(sample_inputs_ids[:token_remaining])
self.state.labels.extend(sample_labels[:token_remaining])
self.state.seqlens.append(token_remaining)
data = {
"input_ids": torch.Tensor(self.state.inputs_ids).to(dtype=torch.long),
"labels": torch.Tensor(self.state.labels).to(dtype=torch.long),
"seqlens": self.state.seqlens,
}
self.state.inputs_ids = []
self.state.labels = []
self.state.seqlens = []
yield data
def state_dict(self):
return {"dataset": self.dataset.state_dict(), "state": asdict(self.state)}
def load_state_dict(self, state_dict):
self.dataset.load_state_dict(state_dict["dataset"])
self.state = SequencePackingDataSetState(**state_dict["state"])
def collate_fn(samples: list[dict[str, torch.LongTensor]]) -> dict[str, torch.LongTensor | list[torch.LongTensor]]:
assert samples[0].keys() == {"input_ids", "labels", "seqlens"}
inputs_ids = []
labels = []
seqlens = []
for sample in samples:
inputs_ids.append(sample["input_ids"])
labels.append(sample["labels"])
seqlens.append(torch.Tensor(sample["seqlens"]).long())
return {
"input_ids": torch.stack(inputs_ids, dim=0),
"labels": torch.stack(labels, dim=0),
"seqlens": seqlens,
}
@dataclass
class PQDatasetState:
files: List[str]
file_index: int
row_index: int
increment: int
init_row_index: int
class ParquetDataset(IterableDataset, Stateful):
"""
this class is a wrapper around a parquet dataset compatible with datasets and statefull compatible. The dataset is infinite and will restart from the last state if the iterator is exhausted.
TODO:
* [ ] handle mutli proc dataloader pytorch
"""
def __init__(self, files: List[str], tokenizer: PreTrainedTokenizer):
self.arg_files = files
self.tokenizer = tokenizer
self.state = None
def _lazy_init(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
if worker_info.num_workers > len(self.arg_files):
get_logger().warning(
f"dataloader rank {worker_info.id} Number of workers {worker_info.num_workers} is greater than the number of files {len(self.arg_files)}"
)
self.state = PQDatasetState(
files=self.arg_files,
file_index=0,
row_index=worker_info.id,
increment=worker_info.num_workers,
init_row_index=worker_info.id,
)
return
files = self.arg_files[worker_info.id :: worker_info.num_workers]
else:
files = self.arg_files
self.state = PQDatasetState(files=files, file_index=0, row_index=0, increment=1, init_row_index=0)
def __iter__(self):
# we lazy init the parquet dataset to get the worker info from dataloader multi process
if self.state is None:
self._lazy_init()
while True:
file = self.state.files[self.state.file_index]
parquet_file = pq.ParquetFile(file)
table = parquet_file.read()["text"]
while True:
row = table[self.state.row_index]
self.state.row_index += self.state.increment
if self.state.row_index >= len(table):
self.state.row_index = self.state.init_row_index
self.state.file_index += 1
if self.state.file_index >= len(self.state.files): # infinite datasets
self.state.file_index = 0
yield {"input_ids": self.tokenizer.encode(str(row))}
@property
def is_empty(self):
return len(self.arg_files) == 0
def state_dict(self) -> dict[str, Any]:
return asdict(self.state) if self.state is not None else {}
def load_state_dict(self, state_dict):
self.state = PQDatasetState(**state_dict)
@dataclass
class InterleaveDatasetState:
current_index: int
seed: int
class InterleaveDataset(IterableDataset, Stateful):
"""This class take a list of datasets and interleave them. It is stateful and can be used with pytorch dataloader.
It draw a sample from each dataset with a probability given by the probabilities list.
The state can be saved and restored. Under the hood we just fast forward the random generator to the current position.
"""
def __init__(self, datasets: List[ParquetDataset], probabilities: List[float], seed: int = 42):
assert len(datasets) > 0, "At least one dataset is required"
assert len(datasets) == len(probabilities), "The number of datasets and probabilities must be the same"
self.probabilities = []
self.datasets = []
for dataset, prob in zip(datasets, probabilities):
if not dataset.is_empty:
self.datasets.append(dataset)
self.probabilities.append(prob)
else:
get_logger().warning(f"Dataset {dataset} is empty. Skipping.")
self.state = InterleaveDatasetState(current_index=0, seed=seed)
self._init_random_state()
def _init_random_state(self):
"""Initialize random generator and advance to current position"""
...
self.random_generator = random.Random(self.state.seed)
# Advance the RNG to the current position
for _ in range(self.state.current_index):
self._get_dataset_to_yield_from()
def _get_dataset_to_yield_from(self) -> int:
return self.random_generator.choices(range(len(self.datasets)), weights=self.probabilities, k=1)[0]
def __iter__(self):
data_iters = [iter(dataset) for dataset in self.datasets]
while True:
dataset_to_yield_from = self._get_dataset_to_yield_from()
sample = next(data_iters[dataset_to_yield_from])
self.state.current_index += 1
yield sample
def state_dict(self):
state = {"interleave_state": asdict(self.state)}
for i, dataset in enumerate(self.datasets):
state[f"dataset_{i}"] = dataset.state_dict()
return state
def load_state_dict(self, state_dict):
self.state = InterleaveDatasetState(**state_dict["interleave_state"])
for i, dataset in enumerate(self.datasets):
dataset.load_state_dict(state_dict[f"dataset_{i}"])
self._init_random_state()
def get_dataloader(
tokenizer,
world_size: int,
rank: int,
batch_size: int,
data_config: DataConfig,
) -> StatefulDataLoader:
if data_config.fake:
train_dataset = FakeTokenizedDataset(data_config.seq_length, TEST_VOCAB_SIZE)
else:
train_dataset = load_all_datasets(
data_config=data_config, split="train", tokenizer=tokenizer, rank=rank, world_size=world_size
)
dataset = SequencePackingDataSet(train_dataset, data_config.seq_length, eos_token=tokenizer.eos_token_id)
return StatefulDataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=data_config.num_workers,
)
@functools.lru_cache(maxsize=None)
def _get_ds_config_dict(path: str, name: Optional[str] = None) -> Dict[str, BuilderConfig]:
ds_builder = load_dataset_builder(path=path, name=name)
return ds_builder.builder_configs
def _get_datafiles(path: str, name: Optional[str] = None, split: str = "train") -> List[str]:
builder_config = _get_ds_config_dict(path=path, name=name)
if name is None or len(name) == 0:
if "default" not in builder_config:
get_logger().warning(f"Default config not found for {path}. Using first config.")
name = next(iter(builder_config.keys()))
else:
name = "default"
return builder_config[name].data_files[split]
def _nice_print(kwargs: Dict[str, Union[str, List[str]]]) -> str:
def _foo(a):
if isinstance(a, list):
return str(a[:5]) + "..." + str(a[-5:]) if len(a) > 10 else str(a)
return str(a)
return str({k: _foo(v) for k, v in kwargs.items()})
def _load_datasets(
dataset_names: str,
split: str,
tokenizer: PreTrainedTokenizer,
data_rank: Optional[int] = None,
data_world_size: Optional[int] = None,
streaming: bool = True,
probabilities: Optional[List[float]] = None,
reverse_data_files: bool = False,
) -> InterleaveDataset:
get_logger().debug(dataset_names)
ds_args = []
for _ds in dataset_names.split(","):
_ds_name, _, _ds_config = _ds.partition(":")
_ds_args: dict[str, Any] = {"path": _ds_name}
if _ds_config:
_ds_args["name"] = _ds_config
_data_files = _get_datafiles(_ds_name, _ds_config, split)
if reverse_data_files:
_data_files = _data_files[::-1]
_ds_args["data_files"] = _data_files
if data_rank is not None and data_world_size is not None:
_ds_args["data_files"] = _data_files[data_rank::data_world_size]
ds_args.append(_ds_args)
# logger.debug(f"Datasets ({split}):\n" + "\n".join(map(_nice_print, ds_args)))
# logger.debug(f"Probabilities: {probabilities}")
get_logger().debug(f"Loading datasets{' in streaming mode' if streaming else ''}")
datasets = []
for ds_arg in ds_args:
# logger.debug(f"Loading dataset: {ds_arg['data_files']}")
_ds = ParquetDataset(files=ds_arg["data_files"], tokenizer=tokenizer)
datasets.append(_ds)
if len(datasets) > 1:
ds = InterleaveDataset(datasets=datasets, probabilities=probabilities)
else:
ds = datasets[0]
get_logger().info(f"Loaded datasets ({split})")
return ds
def _get_probabilities(data_config: DataConfig) -> Optional[List[float]]:
if data_config.dataset_ratio is None:
return None
if len(data_config.dataset_name_or_paths.split(",")) != len(data_config.dataset_ratio.split(":")):
raise ValueError("Number of datasets and dataset ratios must be the same")
nums = [float(i) for i in data_config.dataset_ratio.split(":")]
denom = sum(nums)
return [i / denom for i in nums]
def load_all_datasets(
data_config: DataConfig,
split: str,
tokenizer: PreTrainedTokenizer,
rank: int,
world_size: int,
) -> InterleaveDataset:
"""Load all datasets and interleave them"""
if data_config.split_by_data_rank and (
data_config.data_rank is not None and data_config.data_world_size is not None
):
split_rank = data_config.data_rank * world_size + rank
split_world_size = data_config.data_world_size * world_size
else:
split_rank = rank
split_world_size = world_size
get_logger().info("Loading Train dataset(s)")
ds = _load_datasets(
dataset_names=data_config.dataset_name_or_paths,
split=split,
data_rank=split_rank,
data_world_size=split_world_size,
probabilities=_get_probabilities(data_config),
reverse_data_files=data_config.reverse_data_files,
tokenizer=tokenizer,
)
get_logger().info(f"Train dataset: {ds}")
return ds