-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathdataset_utils.py
1417 lines (1271 loc) · 55.3 KB
/
dataset_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import csv
import logging
import math
import os
import random
import time
from copy import deepcopy
from typing import Any, Callable, List, Optional, Set, Union
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torch.utils.data as data
from transformers import AutoTokenizer # type: ignore
from fms_fsdp.utils.checkpointing_utils import get_latest
"""
The following distributed dataloaders are designed around 3 main principles:
1. Efficient, asynchronous operation. Workers on different devices do not communicate.
2. Modularity. Data loading pipeline is composed of wrapped iterators, the base iterator
loading from disk and additional layers adding levels of post-processing (shuffling,
packing, padding, etc.).
3. Seamless resumption from checkpoint. Each stage of the pipeline maintains an internal
state that can be written/read on disk via implemented recursive `state_dict()` and
`load_state_dict()` calls.
4. Rescalability. Users can save and load checkpoints to/from different numbers of workers
without losing the global state. This is accomplished by splitting state fields for each
layer into `state_params`, which are typically scalar-valued and can be discarded when
rescaling (i.e. counters, RNG states), and `reshard_params`, which are lists that can be
re-distributed over workers (i.e. buffers).
Our loaders obey the following type hierarchy:
torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset.
`_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a
single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times,
then applying some sort of post-processing and yielding the result. Users build data processing
pipelines by wrapping a base `_StatefulDataset` in any number of `_WrapperDataset` layers,
which is then passed to the torch DataLoader.
"""
def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]:
"""
Partition itemlist into worldsize chunks, grab chunk corresponding to rank and return.
"""
return itemlist[
(rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize
]
def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]:
"""
In cases where len(itemlist) % worldsize != 0, allow for fractional ownership of items,
and return the span including all owned items, fractional or otherwise.
"""
start = math.floor(len(itemlist) * rank / worldsize)
end = math.ceil(len(itemlist) * (rank + 1) / worldsize)
return itemlist[start:end]
class _StatefulDataset(data.IterableDataset):
"""
Stub for stateful datasets, extends data.IterableDataset with state_dict methods.
All subclasses should specify the params to be considered stateful or reshardable in the
self.state_params and self.reshard_params lists.
"""
def __init__(
self,
datapath: str,
rank: int,
worldsize: int,
):
assert rank >= 0, f"Rank {rank} must be a positive integer"
assert (
worldsize > rank
), f"Worldsize {worldsize} must be greater than rank {rank}"
assert datapath is None or (
os.path.isdir(datapath) and len(os.listdir(datapath)) > 0
), f"Data path {datapath} must be a non-empty folder or None"
self.state_params: List[str] = []
self.reshard_params: List[str] = []
# Default fields
self.datapath = datapath
self.rank = rank
self.worldsize = worldsize
self.local_worldsize = -1
# Setup / loading flags
self.load_worldsize = worldsize
self.is_setup = False
def setup(self):
"""
This method should contain all setup depending on datapath or rank.
It is called after init, but immediately before any other operation.
Certain operations higher up in the pipeline may change rank or datapath
after init (for example, wrapping in a subdataset sampler layer, or copying
to worker processes), so all rank- and datapth- dependent ops are deferred to
this function.
Currently, this function simply adjusts rank/worldsize to account for
multiprocess dataloaders.
"""
if not self.is_setup:
self.is_setup = True
# Perform adjustment only if not already adjusted (i.e. via _WrapperDataset)
if self.local_worldsize == -1:
info = data.get_worker_info()
if info is None or info.num_workers == 1:
# No multi-worker rank adjustment needed
self.local_worldsize = 1
else:
self.local_worldsize = info.num_workers
self.worldsize = self.worldsize * self.local_worldsize
self.rank = self.local_worldsize * self.rank + info.id
def statename(self, x: str):
# Note that this naming convention implicitly disallows repeated layers in the dataset pipeline
return self.__class__.__name__ + "." + x
def state_dict(self):
"""
Retrieve all state and reshard flags (each worker/process saves its own state dict shard).
On the off chance that you're saving a checkpoint with zero steps, run setup first.
"""
self.setup()
return {
self.statename(flag): getattr(self, flag)
for flag in self.state_params + self.reshard_params
}
def _reshard(self, sharded_list):
"""
Sharded_list is a list of lists, where each "shard" sublist must have the same length.
These shards should tightly span only the partition of data owned by this worker.
(i.e. if global_list is the list of all entries, sharded_list = _shard_inclusive(global_list) ).
Determine fractional ownership of shards, and get the flattened partition owned by this worker.
"""
# How many shards did _shard_inclusive() drop to the left of sharded_list?
shard_offset = math.floor(self.load_worldsize * self.rank / self.worldsize)
# How long are the list shards?
shard_len = len(sharded_list[0])
for i, shard in enumerate(sharded_list):
assert (
len(shard) == shard_len
), f"Shard {i} with length {len(shard)} does not match expected {shard_len}"
# How many list items did _shard_inclusive() drop to the left of the flattened sharded_list?
item_offset = shard_len * shard_offset
# How many list items are there in total?
n_items = self.load_worldsize * shard_len
# The indices of the flattened sharded_list that this worker owns
my_items = range(
int(n_items * self.rank / self.worldsize) - item_offset,
int(n_items * (self.rank + 1) / self.worldsize) - item_offset,
)
# Pull out owned items
return [sharded_list[i // shard_len][i % shard_len] for i in my_items]
def load_state_dict(self, state_dicts, sharded_input=False):
"""
Input state_dicts is a list of state_dicts. If sharded_input=False, this is expected to be the
global list of states across all checkpoint shard files. If sharded_input=True, this expects
_shard_inclusive(global_state_list). Handling reduced inputs allows for much more efficient loading.
Workflow:
1. Run setup to prepare dataset
2. if sharded_inputs is false, shard the inputs.
3. If worldsize matches checkpoint, pull state and reshard params from the given checkpoint
shard (state_dicts is a singleton list).
4. If worldsize does not match checkpoint, toss state params and assemble reshard params from
across given state_dicts. In this case state_dicts may be singleton (for fractional ownership)
or multi-element (for multiple/partitioned ownership).
5. Return reduced input for use by downstream loading functions
"""
self.setup()
if not sharded_input:
self.load_worldsize = len(state_dicts)
state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize)
if self.load_worldsize == self.worldsize:
[
setattr(self, flag, state_dicts[0][self.statename(flag)])
for flag in self.state_params + self.reshard_params
]
else:
for flag in self.reshard_params:
reshard = self._reshard(
[sd[self.statename(flag)] for sd in state_dicts]
)
setattr(self, flag, reshard)
return state_dicts
def load_from_path(self, path: str):
"""
Count shard files in the specified checkpoint folder and determine overlap with current
rank and worldsize partition. Load only matching shardfile(s) and pass to load_state_dict.
This is more efficient than sharding the full loaded state.
"""
assert os.path.exists(path), "Specified checkpoint does not exist"
assert not os.path.isfile(path), "Checkpoint should be a folder of shard states"
fileshards = [x for x in os.listdir(path) if "loader" in x]
fileshards = sorted(fileshards, key=lambda x: int(x.split("_")[2][:-4]))
assert (
len(fileshards) > 0
), "Checkpoint directory must contain checkpoint files with 'loader' in the name"
self.load_worldsize = len(fileshards)
# Grab only the shard files holding data we currently own
my_fileshards = _shard_inclusive(fileshards, self.rank, self.worldsize)
states = [torch.load(os.path.join(path, x)) for x in my_fileshards]
self.load_state_dict(states, True)
def save_to_path(self, path: str):
"""
Grab recursive shard states and save all shard states to the specified checkpoint folder
"""
os.makedirs(path, exist_ok=True)
state = self.state_dict()
torch.save(state, os.path.join(path, f"loader_state_{self.rank}.pth"))
class _WrapperDataset(_StatefulDataset):
"""
Stub for nested wrappers of _StatefulDatasets. Extends state fns with recursion.
Requires a single instantiated sub-dataset (which may be replicated during setup fn).
"""
def __init__(
self,
dataset: _StatefulDataset,
):
self.dataset = dataset
# Inherit default flags from sub-dataset
super().__init__(
self.dataset.datapath, self.dataset.rank, self.dataset.worldsize
)
def setup(self):
"""
Datapath/rank/worldsize percolate upwards recursively during initialization, so
now we project any desired changes downward, also recursively.
We also project local_worldsize downward to prevent subsequent layers from
further inflating the rank/worldsize - we only need to account for multiprocessing once!
Any code overriding this function should still include this functionality.
"""
if not self.is_setup:
super().setup()
self.dataset.datapath = self.datapath
self.dataset.rank = self.rank
self.dataset.worldsize = self.worldsize
self.dataset.local_worldsize = self.local_worldsize
self.dataset.setup()
def load_state_dict(self, state_dicts, sharded_input=False):
"""
Sets all specified flags at the current level, then recurses into wrapped dataset.
"""
self.setup()
sharded_dicts = super().load_state_dict(state_dicts, sharded_input)
self.dataset.load_worldsize = self.load_worldsize
self.dataset.load_state_dict(sharded_dicts, True)
return sharded_dicts
def state_dict(self):
"""
Fetches state dict recursively from wrapped layers, then adds specified flags.
Overlapping flags are overwritten with a warning.
"""
self.setup()
out = self.dataset.state_dict()
state = super().state_dict()
for flag in self.state_params + self.reshard_params:
if flag in out:
logging.warning(
f"Loader {self.rank}: flag {flag} already present in state_dict with value {out[flag]}. "
+ f"Overwriting with value {state[flag]}"
)
out.update(state)
return out
#### ------------------------- FILE READERS ------------------------- ####
class _ShardFileHandler:
"""
Stub for shard file readers of different formats.
Must implement open, length, indexing, and slicing functions.
"""
def is_legal(self, filepath: str):
"""
Given a file path, determine if it qualifies for this handler.
Ideally does not involve opening the file.
"""
return os.path.isfile(filepath)
def open(self, path: str):
"""
Open the file, to be indexed via self.get() method.
Avoid reading entire multi-Gb files when possible!
"""
raise NotImplementedError
def length(self, path: str):
"""
Calculate the number of documents in the given file.
Avoid reading entire multi-Gb files when possible!
"""
raise NotImplementedError
def get(self, reader, index: int, drop_tokens: Set):
"""
Given the output of self.open() and an index, return the document at that index.
Then, remove the first and/or last items if they appear in drop_tokens.
Try to avoid reading entire documents at a time in case of long documents,
but this is less important than avoiding reading entire files as above.
Output must support len() method.
"""
raise NotImplementedError
def slice(self, doc, index: int, n_pull: int) -> List:
"""
Given a long document, retrieve n_pull consecutive items starting from index.
Again, try to be memory-efficient when doing so, but efficiency in self.get()
and self.open() is far more important.
Must return a python list.
"""
raise NotImplementedError
class ArrowHandler(_ShardFileHandler):
"""
Reader for indexable, pre-tokenized PyArrow shard files.
Pyarrow shard files are expected to hold multiple RecordBatches,
where each RecordBatch has a "tokens" field consisting of
a single token list (i.e. each document is a single sequence
under a "token" field, and the file is a list of such sequences).
A preferred format as we can load document chunks without having to ever pull
the entire document or shard file, allowing for graceful handling of large documents.
Non-standard data format, though.
"""
def __init__(self, col_name: str = "tokens"):
self.col_name = col_name
def is_legal(self, filepath: str):
return "arrow" in os.path.splitext(filepath)[1]
def open(self, path: str):
return pa.ipc.open_file(pa.memory_map(path))
def length(self, path: str):
return self.open(path).num_record_batches
def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
doc = reader.get_batch(index)[self.col_name]
if len(doc) > 0 and doc[0].as_py() in drop_tokens:
doc = doc.slice(1, len(doc) - 1)
# Recheck len for edge case where doc=[eos]
if len(doc) > 0 and doc[-1].as_py() in drop_tokens:
doc = doc.slice(0, len(doc) - 1)
return doc
def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List:
return doc.slice(index, n_pull).to_pylist()
class ParquetHandler(_ShardFileHandler):
"""
Reader for indexable parquet shard files, common in HF datasets.
Here we assume reasonably small shard files (<5Gb) and documents (<100k tokens),
as we rely on parquet/pandas for efficient file reading, and tokenize entire documents
before getting/slicing. However, this is a standard and widely-used data format.
"""
def __init__(self, tokenizer_path: str, col_name: str = "text"):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.col_name = col_name
def is_legal(self, filepath: str):
return "parquet" in os.path.splitext(filepath)[1]
def open(self, path: str):
return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[
self.col_name
]
def length(self, path: str):
return pq.read_metadata(path).num_rows
def get(self, reader, index: int, drop_tokens: Set):
doc = self.tokenizer(str(reader[index]))["input_ids"]
if len(doc) > 0 and doc[0] in drop_tokens:
doc = doc[1:]
# Recheck len for edge case where doc=[eos]
if len(doc) > 0 and doc[-1] in drop_tokens:
doc = doc[:-1]
return doc
def slice(self, doc: List, index: int, n_pull: int) -> List:
return doc[index : index + n_pull]
class AutoHandler(_ShardFileHandler):
def __init__(self, tokenizer_path: str, col_name: str = "text"):
self.PHandler = ParquetHandler(tokenizer_path, col_name)
self.AHandler = ArrowHandler()
self.current = _ShardFileHandler()
def is_legal(self, filepath: str):
return (
"parquet" in os.path.splitext(filepath)[1]
or "arrow" in os.path.splitext(filepath)[1]
)
def open(self, path: str):
"""
Open the file, to be indexed via self.get() method.
Avoid reading entire multi-Gb files when possible!
"""
if "arrow" in os.path.splitext(path)[1]:
self.current = self.AHandler
else:
self.current = self.PHandler
return self.current.open(path)
def length(self, path: str):
"""
Calculate the number of documents in the given file.
Avoid reading entire multi-Gb files when possible!
"""
if "arrow" in os.path.splitext(path)[1]:
return self.AHandler.length(path)
else:
return self.PHandler.length(path)
def get(self, reader, index: int, drop_tokens: Set):
"""
Given the output of self.open() and an index, return the document at that index.
Then, remove the first and/or last items if they appear in drop_tokens.
Try to avoid reading entire documents at a time in case of long documents,
but this is less important than avoiding reading entire files as above.
Output must support len().
"""
return self.current.get(reader, index, drop_tokens)
def slice(self, doc, index: int, n_pull: int) -> List:
"""
Given a long document, retrieve n_pull consecutive items starting from index.
Again, try to be memory-efficient when doing so, but efficiency in self.get()
and self.open() is far more important.
Must return a python list.
"""
return self.current.slice(doc, index, n_pull)
#### ------------------------- PIPELINE LAYERS ------------------------- ####
class PreprocessDataset(_WrapperDataset):
"""
Wrapper for a _StatefulDataset that applies a specified preprocessing
or augmentation function to dataset outputs.
...
Args
----
dataset : _StatefulDataset
Fully instantiated dataset
aug_fn : function (any -> any)
The augmentation function to apply to each dataset item.
"""
def __init__(
self,
dataset: _StatefulDataset,
aug_fn: Callable,
):
super().__init__(dataset)
self.aug_fn = aug_fn
def __iter__(self):
dataset = iter(self.dataset)
while True:
out = next(dataset)
yield self.aug_fn(out)
class CheckpointDataset(_WrapperDataset):
"""
Wrapper for a _StatefulDataset that implements auto-checkpoint saving every n steps.
Useful for setting n_workers > 0, so that workers do not rely on the master process
for state saving (inter-process communication unsupported in PyTorch datasets).
...
Args
----
dataset : _StatefulDataset
Fully instantiated dataset
load_path : str
Absolute path to checkpoint load directory. If a checkpoint exists, loads it.
interval : int
Saves a new checkpoint every interval.
steps_per_batch : optional[int]
Number of steps required to fill a single batch. Increments interval only
when a full batch is formed. Defaults to 1.
save_path : optional[str]
Absolute path to checkpoint save directory. Defaults to load_path.
"""
def __init__(
self,
dataset: _StatefulDataset,
load_path: str,
interval: int,
steps_per_batch: int = 1,
save_path: str = "",
):
super().__init__(dataset)
self.interval = interval
self.spb = steps_per_batch
load_path = os.path.join(load_path, "checkpoints")
if len(save_path) == 0:
save_path = load_path
else:
save_path = os.path.join(save_path, "checkpoints")
self.load_path = load_path
self.path = save_path
self.step = 0
self.ministep = 0
def setup(self):
if not self.is_setup:
super().setup()
self.load_from_path(self.load_path)
def __iter__(self):
self.setup()
dataset = iter(self.dataset)
while True:
yield next(dataset)
self.ministep += 1
if self.ministep == self.spb:
self.ministep = 0
self.step += 1
if self.step % self.interval == 0:
newpath = os.path.join(self.path, "step_" + str(self.step) + "_ckp")
self.save_to_path(newpath)
def report(self, msg):
if self.rank == 0:
print(msg)
def _validate_ckp_path(self, path: str, verbose: bool = False):
"""
Interpret path to appropriate checkpoint.
If found, return modified path.
If not found, return empty string.
"""
# Does path exists, and if it exists, is it non-empty?
if not os.path.exists(path) or len(os.listdir(path)) == 0:
if verbose:
self.report(
f" Dataset: No valid checkpoint detected at {path}, dataset starting from scratch."
)
return ""
# Check latest path, using ckp naming syntax
latest = get_latest(path, key=lambda path: int(path.split("_")[-2]))
if verbose:
self.report(f"Checkpoint detected at {latest}")
# If item is not a folder, exit early
if os.path.isfile(latest):
if verbose:
self.report(
f" Dataset: Detected checkpoint {latest} is a single file with no dataset info."
+ " Dataset starting from scratch."
)
return ""
# If item is a folder, check that it contains shard files
if len([x for x in os.listdir(latest) if "loader" in x]) == 0:
if verbose:
self.report(
f" Dataset: Detected checkpoint {latest} exists but contains no dataset checkpoints."
+ " Dataset starting from scratch."
)
return ""
# If item is a folder, get the step count
self.step = int(latest.split("_")[-2])
return latest
def save_to_path(self, path: str):
self.report(f"Saving dataset to {path}")
start = time.time()
super().save_to_path(path)
self.report(
f"Dataset successfully saved to {path}! Save time: {time.time() - start}"
)
def load_from_path(self, path: str):
save_path = self._validate_ckp_path(self.path, False)
if len(save_path) > 0:
self.report(
f" Dataset: Detected a checkpoint in the save directory {save_path}. Restoring from this checkpoint."
)
path = save_path
else:
load_path = self._validate_ckp_path(self.load_path, True)
if len(load_path) == 0:
return
else:
path = load_path
# When loading from external ckp, always reset step count
self.step = 0
# Proceed
start = time.time()
self.dataset.load_from_path(path)
self.report(f"Dataset checkpoint loaded! Load time: {time.time() - start}")
class PreloadBufferDataset(_WrapperDataset):
"""
Wrapper for a StatefulDataset that implements data shuffling via a single in/out buffer.
Fills buffer two at a time, up to desired size, then switches to one at a time to maintain size.
Passes randomly sampled outputs one by one.
Ensures local mixing of data without relying on sliding windows or shuffling of large buffers.
Any two consecutive inputs will be separated by window_size steps in expectation.
Rescaling-enabled: buffers that shrink will re-grow to window_size,
buffers that expand will shrink back down to window_size.
...
Args
----
dataset : _StatefulDataset
Fully instantiated dataset
window_size : int
Max size of input/output buffer
"""
def __init__(self, dataset: _StatefulDataset, window_size: int):
super().__init__(dataset)
assert (
window_size > 1
), f"Window size {window_size} must be greater than 1 for shuffling to occur"
self.window_size = window_size
self.g_state = None
self.generator = torch.Generator().manual_seed(self.rank)
self.buffer: List[List[Any]] = []
self.buffer_size = 0
self.state_params = ["g_state"]
self.reshard_params = ["buffer"]
def __iter__(self):
dataset = iter(self.dataset)
while True:
# Pad out buffer if needed
self._pad_buffer()
# If buffer is undersized, add a datapoint
if self.buffer_size < self.window_size:
self.buffer[self.buffer_size] = next(dataset)
self.buffer_size += 1
# Swap out randomly sampled value from buffer.
# If buffer is small, add new item.
# If buffer is large, pop last item into that slot.
i = torch.randint(self.buffer_size, (1,), generator=self.generator).item()
out = self.buffer[i]
if self.buffer_size > self.window_size:
self.buffer[i] = self.buffer[self.buffer_size - 1]
self.buffer_size -= 1
else:
self.buffer[i] = next(dataset)
yield out
def _pad_buffer(self):
if self.buffer_size < self.window_size:
self.buffer += [
[],
] * (self.window_size - self.buffer_size)
def state_dict(self):
# Write generator state manually
self.g_state = self.generator.get_state()
# Prune buffer so it can be resharded in future
self.buffer = self.buffer[: self.buffer_size]
out = super().state_dict()
return out
def load_state_dict(self, state_dicts, sharded_input=False):
sharded_dicts = super().load_state_dict(state_dicts, sharded_input)
# Manually set generator state if it exists
if self.g_state is not None:
self.generator.set_state(self.g_state)
# Manually set buffer size
self.buffer_size = len(self.buffer)
return sharded_dicts
class BufferDataset(_WrapperDataset):
"""
Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them
into sequences of desired length. Input sequences are packed greedily until the buffer would
otherwise overrun, then remaining values are filled depending on initialization flags.
Also injects BOS/EOS into the packed output sequence if desired, and if BOS/EOS tokens are
not already in those positions. Implements rescaling by simply dropping (buffer) state.
...
Args
----
dataset : _StatefulDataset
Fully instantiated dataset
seq_len : int
The desired sequence length
pack_hard : bool
Split input sequences to fill output buffer, or use pad tokens to fill remaining space?
bos_token : any | None
Token to prepend to every output sequence. If None, no token is added. Type should match data type.
eos_token : any | None
Token to append to every output sequence. If None, no token is added. Type should match data type.
pad_token : any | None
Token used to fill out output sequence. Type should match data type.
"""
def __init__(
self,
dataset: _StatefulDataset,
seq_len: int,
pack_hard: bool,
bos_token=None,
eos_token=None,
pad_token=None,
):
super().__init__(dataset)
self.len = seq_len
# Buffer args
self.buffer: List[str] = []
self.bos = bos_token
self.eos = eos_token
self.pad = pad_token
self.pack_hard = pack_hard
if not pack_hard:
assert (
pad_token is not None
), "Error: if using pads, you must supply a pad_token"
self.state_params = ["buffer"]
def _get_buffer(self, iterable, length, buffer):
# Pull data until buffer is about to overrun, return exactly proper length
new = []
while len(buffer) + len(new) < length:
buffer += new
new = next(iterable)
# Add bos if needed
if self.bos is not None and (len(buffer) == 0 or buffer[0] != self.bos):
buffer = [self.bos] + buffer
# Handle buffer splitting
if len(buffer) >= length:
# If buffer is too long, force split
out = buffer[:length]
buffer = buffer[length:]
if self.eos is not None and out[-1] != self.eos:
buffer = [out[-1]] + buffer
out[-1] = self.eos
buffer = buffer + new
else:
if self.pack_hard:
# Pack in as much of new sequence as will fit
buffer = buffer + new
out = buffer[:length]
buffer = buffer[length:]
if self.eos is not None and out[-1] != self.eos:
buffer = [out[-1]] + buffer
out[-1] = self.eos
else:
# Fill out with pads as needed
if self.eos is not None and buffer[-1] != self.eos:
buffer.append(self.eos)
if self.pad is not None:
out = buffer + [self.pad] * (length - len(buffer))
else:
out = buffer
buffer = new
return out, buffer
# Fill buffer line by line, delimiters and packing/splitting as appropriate
def __iter__(self):
dataset = iter(self.dataset)
while True:
out, buffer = self._get_buffer(dataset, self.len, self.buffer)
self.buffer = buffer
yield out
class StreamingDocDataset(_StatefulDataset):
"""
The base distributed dataset for loading sequences/documents from file shards.
For a single dataset directory, splits shard files into x=worldsize fragments and grabs a 1/n contiguous
span of shard fragments (contiguous to limit file reads from cloud/disk).
Logs the number of documents owned from each shardfile, and relies on LCG random bijection to
map contiguous range of indices to shuffled, noncontiguous set of documents from each shard file.
Shuffles the file list deterministically to hop from file to file.
At runtime, iterates through documents in each shuffled shard file, pulling each shard on demand.
Shards are thus pulled no more than once per epoch.
Returns documents in chunks up to size max_chunksize, and handles delimiter token placement between documents.
StreamingDocDataset grabs files from a directory representing a single dataset.
This directory need not be flat.
For percentage-based sampling over multiple such subdatasets, see SamplingDataset.
When available in the parent directory, relies on a compiled metadata file to fetch shardfile lengths.
Expects csv file (first row "dataset/filename,documents,tokens", subsequent rows these values) under a 'meta' directory.
This can be removed in the future.
...
Args
----
datapath : str
Absolute path to the dataset directory. Expects directory containing shardfiles.
Directory need not be flat.
rank : int
Current worker index
worldsize : int
Total number of workers
filereader : _ShardFileReader
A file reader handling specific data shard file formats
delimiter_token : Any
Token used to indicate sequence/document breaks. Type should match data type. Required for downstream
sampling logic (can be removed later via PreProcessDataset if needed).
bos_token : Any | None
Optional token used to indicate sequence/document start. Type should match data type.
strip_tokens : set[Any]
Token values that should be removed if detected at beginning or end of document
(i.e. any eos/bos tokens already present in the data). Type should match data type.
seed : int
The random seed for deterministic shuffling/sharding
min_length : int
Documents below this length are skipped
max_chunksize : int
Maximum sequence length to return. Break long docs into chunks of this size or shorter.
verbose : bool
Track setup progress?
shuffle : bool
Shuffle shard file and document orders? (Disable for simple testing)
"""
def __init__(
self,
datapath: str,
rank: int,
worldsize: int,
filehandler: _ShardFileHandler,
delimiter_token: Any,
bos_token: Optional[Any] = None,
strip_tokens: Optional[Set[Any]] = set(),
seed: int = 42,
min_length: int = 1,
max_chunksize: int = 1024,
verbose: bool = False,
):
super().__init__(datapath, rank, worldsize)
self.seed = seed
self.datapath = datapath
self.filehandler = filehandler
self.min_length = min_length
assert max_chunksize > 0, f"Max chunksize must be a nonzero positive integer"
self.chunksize = max_chunksize
self.eos = delimiter_token
self.bos = bos_token
self.drop = strip_tokens
self.verbose = verbose
self.docset: List[
Any
] = [] # map of doc indices to (shardid, min docid, max docid)
# Position
self.docset_index = 0
self.chunk_index = -1
# Stats
self.epochs_seen = -1
self.tokens_seen = 0
self.docs_seen = 0
self.percent_seen = 0
self.state_params = [
"dataset",
"docset_index",
"chunk_index",
"epochs_seen",
"tokens_seen",
"docs_seen",
"percent_seen",
"lcg_state",
]
# Setup flags
self.is_setup = False
self._len = 0
self.dataset = ""
self.lcg_state = 0
def setup(self):
"""
All rank-dependent setup, which must occur after init
(rank assignment, data partitioning, shuffling)
"""
if not self.is_setup:
super().setup()
datapath = self.datapath
pathsplit = (datapath, "")
# May take an extra round to account for any trailing slashes
while len(pathsplit[1]) == 0:
pathsplit = os.path.split(pathsplit[0])
pardir, dataset = pathsplit
self.dataset = dataset
# Assemble document set owned by this worker:
# listdir, assemble shardfraglist (ind -> shard, frag)
shards = [
os.path.join(root, name)[len(datapath) + 1 :]
for root, dirs, files in os.walk(datapath, topdown=False)
for name in files
if self.filehandler.is_legal(os.path.join(root, name))
]
shards.sort() # Ensure consistent sharding across machines
start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize
end_frag = (
(self.rank + 1) * self.worldsize * len(shards)
) // self.worldsize
shardfrags = [
(shards[i // self.worldsize], i % self.worldsize)
for i in range(start_frag, end_frag)
]
# Assemble length of each owned shard file
countfiles = []
if os.path.exists(os.path.join(pardir, "meta")):
countfiles = [
x
for x in os.listdir(os.path.join(pardir, "meta"))
if "counts" in x and "csv" in x
]
doc_counts = {}
if len(countfiles) > 0:
# Count file exists, use it
countpath = os.path.join(pardir, "meta", countfiles[0])
with open(countpath, "r") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
fullpath = row["dataset/filename"]
prefix = fullpath.find("/" + dataset) + 1
if prefix > 0:
key = fullpath[prefix + len(dataset) + 1 :]
doc_counts[key] = int(row["documents"])
else:
# Count file does not exist, touch every owned file for length
unique_shardfiles = set(shard for shard, frag in shardfrags)
doc_counts = {
shard: self.filehandler.length(os.path.join(datapath, shard))
for shard in unique_shardfiles
}
# Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
ndocs = -1
docset = {} # shardid -> (min docid, max docid)
for i, (shard, frag) in enumerate(shardfrags):
ndocs = doc_counts[shard]
doc_start = (ndocs * frag) // self.worldsize
doc_end = (
ndocs * frag + ndocs
) // self.worldsize - 1 # Inclusive upper bound
if shard not in docset:
docset[shard] = [doc_start, doc_end]
min_d, max_d = docset[shard]
if doc_start < min_d:
docset[shard][0] = doc_start
if doc_end > max_d:
docset[shard][1] = doc_end
# Add shard entries to self.docset
doccount = 0
for shardid in docset:
min_d = docset[shardid][0]
max_d = docset[shardid][1]
self.docset.append((shardid, min_d, max_d))
doccount += max_d - min_d + 1
self._len = doccount
if self.verbose:
logging.info(
f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}"
)
# Shuffle shard files - guaranteed inconsistent across workers
seed = self.seed + self.rank