-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheckpointing_.py
918 lines (737 loc) · 34.5 KB
/
checkpointing_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Use to partition the activations stored for backward propagation
Therefore reduces the memory consumption
Also implements CPU checkpointing and contiguous memory checkpointing
Reduces memory consumption and memory fragmentation
Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
b886b7bb972afe72bac0f5de4f42a4a7bae8ebef
"""
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import copy
import torch
import contextlib
from deepspeed import comm as dist
import mmap
from torch import _C
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers
from deepspeed.accelerator import get_accelerator
# DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False
import mpu as mpu_lib
# MP parameters
mpu = None
mp_rank = None
mp_size = None
mp_group = None
# Model Parameters
num_layers = None
# Checkpointing buffers
contiguous_data_buffers = []
data_offsets = []
contiguous_size_buffers = []
size_offsets = []
timers = None
# optimization flags
PARTITION_ACTIVATIONS = True
CPU_CHECKPOINT = False
CONTIGUOUS_CHECKPOINTING = False
SYNCHRONIZE = False
PROFILE_TIME = False
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
transport_stream = None
cuda_device = None
def detach_variable(inputs, device=None):
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
continue
requires_grad = inp.requires_grad
if device is not None:
x = inp.to(device=device)
else:
x = inp
x = x.detach()
x.requires_grad = requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError("Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with get_accelerator().device(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device(get_accelerator().device_name())
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(get_accelerator().device_name(), device)
def cb():
idx = device.index
if idx is None:
idx = get_accelerator().current_device()
default_generator = get_accelerator().default_generator(idx)
default_generator.set_state(new_state)
get_accelerator().lazy_call(cb)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
return copy.copy(self.states_)
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = get_accelerator().get_rng_state()
# Set the new state and store it.
get_accelerator().manual_seed(seed)
self.states_[name] = get_accelerator().get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = get_accelerator().get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = get_accelerator().get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no get_accelerator().manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model parallel groups. This is used for
example for dropout in the non-model-parallel regions.
model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
global mpu
tp_rank = bwc_tensor_model_parallel_rank(mpu)
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + tp_rank
# Data parallel gets the original seed.
data_parallel_seed = seed
if dist.get_rank() == 0:
logger.info(
'> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(dist.get_rank(), tp_rank,
mpu.get_data_parallel_rank(),
model_parallel_seed, data_parallel_seed), )
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
get_accelerator().manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
def get_partition_start(item):
global mp_rank, mp_size, mp_group
size = item.numel()
partition_size = size / mp_size
start = partition_size * mp_rank
logger.info(f"LOG start for rank {mp_rank} is {start}")
return int(start)
def get_partition_size(item):
global mp_rank, mp_size, mp_group
size = item.numel()
assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size"
partition_size = size / mp_size
logger.info(f"LOG partition_size for rank {mp_rank} is {partition_size} / {size}")
return int(partition_size)
def gather_partitioned_activations(tensors, device=None):
global mp_rank, mp_size, mp_group
assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}'
inputs = []
num_args = int(len(tensors) / 2)
for i in range(num_args):
item = tensors[2 * i]
size = tensors[2 * i + 1]
if not is_activation_to_checkpoint(item):
inputs.append(item)
continue
# don't need to do all_gather if model parallel is not enabled
if mp_size == 1:
item = item.view(list(size.numpy()))
inputs.append(item)
continue
partition_size = item.numel()
tensor_size = partition_size * mp_size
if device is not None:
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device)
else:
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
partitions = []
for i in range(mp_size):
part_i = flat_tensor.narrow(0, partition_size * i, partition_size)
if i == mp_rank:
part_i.copy_(item)
partitions.append(part_i)
#
for part_i in range(mp_size):
logger.info(f"LOG partitions cksm rank {mp_rank} on part {part_i} before: {partitions[part_i].mean()}")
dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
for part_i in range(mp_size):
logger.info(f"LOG partitions cksm rank {mp_rank} on part {part_i} after: {partitions[part_i].mean()}")
input_tensor = flat_tensor.view(list(size.numpy()))
item.data = input_tensor.data
inputs.append(item)
return tuple(inputs)
def extract_tensors(all_objects):
"""
Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation.
The order of tensors and non-tensors is preserved in their respective output groups.
Parameters:
all_objects (list/tuple): Objects containing tensors and non-tensors to be split.
Returns:
tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor.
"""
tensor_objects = [v for v in all_objects if torch.is_tensor(v)]
non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)]
tensor_flags = [torch.is_tensor(v) for v in all_objects]
if type(all_objects) is tuple:
return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags)
return tensor_objects, non_tensor_objects, tensor_flags
def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags):
"""
Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple).
Parameters:
tensor_objects (list/tuple): Tensors to merge.
non_tensor_objects (list/tuple): Non-tensors to merge.
tensor_flags (list/tuple): Indicates whether each position in output is a tensor.
Returns:
tuple: Merge of tensors and non-tensors
"""
merged_objects = []
tensor_idx = 0
non_tensor_idx = 0
real_tensor_flags = None
# remove the flags that are assigned to the size of the flattened tensors
if PARTITION_ACTIVATIONS:
real_tensor_flags = []
previous_flag = False
for flag in tensor_flags:
if previous_flag:
previous_flag = False
continue
previous_flag = flag
real_tensor_flags.append(flag)
else:
real_tensor_flags = tensor_flags
for is_tensor in real_tensor_flags:
if is_tensor:
merged_objects.append(tensor_objects[tensor_idx])
tensor_idx += 1
else:
merged_objects.append(non_tensor_objects[non_tensor_idx])
non_tensor_idx += 1
return tuple(merged_objects)
def is_activation_to_checkpoint(item):
"""
Is an activation to be checkpointed
"""
global mp_size
return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size
def partition_activations(args, cpu_checkpoint, contiguous_checkpoint):
global contiguous_data_buffers, data_offsets
inputs = []
num_non_fp_tensors = 0
for arg_index, item in enumerate(args):
if not is_activation_to_checkpoint(item):
inputs.append(item)
num_non_fp_tensors += 1
continue
i = arg_index - num_non_fp_tensors
partition_size = get_partition_size(item)
partition = item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), partition_size).clone()
buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device
if contiguous_checkpoint:
if i >= len(contiguous_data_buffers):
tensor_list = [
torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
for _ in range(num_layers)
]
contiguous_data_buffers.append(tensor_list)
data_offsets.append(0)
elif contiguous_data_buffers[i] is None:
tensor_list = [
torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device)
for _ in range(num_layers)
]
contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0
# Because the 'new_empty' returns uninitialized pages,
# the pages need to be populated during the cudaMemcpy time
# which increases the data copy time. To avoid this, we
# pre-populate these pages by simply writing 0 ahead of
# the actual cudaMemcpy operation time. Due to the
# previously launched GPU kernels, there is a small
# window of time here for CPUs to populate pages asynchronously.
contiguous_data_buffers[i][data_offsets[i]].data[range(
0, contiguous_data_buffers[i][data_offsets[i]].data.shape[0],
int(mmap.PAGESIZE / contiguous_data_buffers[i][data_offsets[i]].data.element_size()))] = 0
contiguous_partition = contiguous_data_buffers[i][data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1
inputs.append(contiguous_partition)
else:
partition = partition.cpu() if CPU_CHECKPOINT else partition
inputs.append(partition)
logger.info(f"LOG before partition size {item.numel()} after {partition.shape}")
return inputs
def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint):
global contiguous_size_buffers, size_offsets
new_args = []
num_non_fp_tensors = 0
for arg_index, (arg, inp) in enumerate(zip(args, inputs)):
size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None
if not is_activation_to_checkpoint(arg):
new_args.append(arg)
new_args.append(size)
num_non_fp_tensors += 1
continue
arg.data = inp.data
logger.info(f"LOG get_partitioned_activations_for_backward inp.data.numel(): {inp.data.numel()}")
new_args.append(arg)
i = arg_index - num_non_fp_tensors
if contiguous_checkpoint:
numel = size.numel()
if i >= len(contiguous_size_buffers):
tmp = torch.tensor(())
contiguous_size_buffers.append(
tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device))
size_offsets.append(0)
elif contiguous_size_buffers[i] is None:
tmp = torch.tensor(())
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)
size_offsets[i] = 0
contiguous_size = contiguous_size_buffers[i].narrow(0, size_offsets[i], numel).data.copy_(size.data)
contiguous_size = contiguous_size.view_as(size)
size_offsets[i] = size_offsets[i] + numel
new_args.append(contiguous_size)
else:
new_args.append(size)
return new_args
def get_cpu_activations_for_backward(args, inputs):
new_args = []
for i, (arg, inp) in enumerate(zip(args, inputs)):
if not is_activation_to_checkpoint(arg):
new_args.append(arg)
continue
arg.data = inp.data
new_args.append(arg)
return new_args
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` #ignore-cuda
2) the states in the model parallel tracker are also properly
tracked/set/reset.
3) Performance activation partitioning, contiguous memory optimization
4) CPU Checkpointing
5) Profile forward and backward functions
"""
@staticmethod
def forward(ctx, run_function, all_outputs, *args):
global mpu, timers, SYNCHRONIZE, PROFILE_TIME
def save_args_for_backward(*all_args):
tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args)
ctx.deepspeed_saved_tensors = tensor_args
logger.info(f"LOG ctx.deepspeed_saved_tensors {ctx.deepspeed_saved_tensors[0].numel()}")
ctx.non_tensor_args = non_tensor_args
ctx.tensor_flags = tensor_flags
if SYNCHRONIZE:
get_accelerator().synchronize()
if timers is None and PROFILE_TIME:
timers = Timers()
if PROFILE_TIME:
timers('forward').start()
ctx.run_function = run_function
global num_layers
global mp_rank, mp_size, mp_group
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
mpu_ = mpu_lib
logger.info(f"----CUSTOM fw Partition Activations")
if mpu_ is not None:
if mpu_.model_parallel_is_initialized():
logger.info('CUSTOM fw model parallel is already initialized')
else:
logger.info('CUSTOM fw mpu initialize_model_parallel')
mpu_.initialize_model_parallel(tensor_model_parallel_size_ = torch.distributed.get_world_size() ,
pipeline_model_parallel_size_=1,)
mp_rank = mpu_.get_tensor_model_parallel_rank()
mp_size = mpu_.get_tensor_model_parallel_world_size()
mp_group = mpu_.get_tensor_model_parallel_group()
logger.info(f'CUSTOM fw Initializing mp as:'
f'LOG mp_rank {mp_rank} '
f'LOG mp_size {mp_size} '
f'LOG mp_group {mp_group}')
# mp_rank = torch.distributed.get_rank() #mpu.get_tensor_model_parallel_rank()
# mp_size = torch.distributed.get_world_size() #mpu.get_tensor_model_parallel_world_size()
# mp_group = None #torch.distributed.get_world_group() #mpu.get_tensor_model_parallel_group()
#if mp_rank is None:
# if mpu is not None:
# if hasattr(mpu, 'get_tensor_model_parallel_rank'):
# mp_rank = mpu.get_tensor_model_parallel_rank()
# mp_size = mpu.get_tensor_model_parallel_world_size()
# mp_group = mpu.get_tensor_model_parallel_group()
# else:
# mp_rank = mpu.get_model_parallel_rank()
# mp_size = mpu.get_model_parallel_world_size()
# mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
mp_group = None
global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset
if cuda_device is None:
see_memory_usage("First Forward Beginning", force=True)
if dist.get_rank() == 0:
logger.info(f"Activation Checkpointing Information")
logger.info(f"----CUSTOM Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}")
logger.info(
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers")
logger.info(f"----Synchronization {SYNCHRONIZE}")
logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}")
cuda_device = get_accelerator().current_device_name()
transport_stream = get_accelerator().Stream(device=cuda_device)
if PARTITION_ACTIVATIONS:
inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING)
logger.info("PARTITIONING ACTIVATIONS FW")
elif CPU_CHECKPOINT:
inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint)
# just in case something funky is happening such as reuse of inputs
inputs_cuda = copy_to_device(args, device=cuda_device, criterion_func=is_activation_to_checkpoint)
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = get_accelerator().get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
see_memory_usage("Before running forward on the layer", force=True)
# ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*inputs_cuda)
see_memory_usage("After running forward on the layer", force=True)
del inputs_cuda
if PARTITION_ACTIVATIONS:
#this part should be responsible for saving activations til the bw pass ??? - here we have to reduce memory be partitioned checkpointing
new_args = get_partitioned_activations_for_backward(args, inputs, CONTIGUOUS_CHECKPOINTING)
logger.info(f"LOG BW before get_partitioned_activations_for_backward - partition size after {inputs[0].shape}")
logger.info(f"LOG BW after get_partitioned_activations_for_backward - partition size after {new_args[0].shape}")
logger.info("PARTITIONING ACTIVATIONS BW")
assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
save_args_for_backward(*new_args)
elif CPU_CHECKPOINT:
new_args = get_cpu_activations_for_backward(args, inputs)
save_args_for_backward(*new_args)
else:
save_args_for_backward(*args)
if PROFILE_TIME:
timers('forward').stop()
timers.log(['forward'])
if SYNCHRONIZE:
get_accelerator().synchronize()
# Tensors returned from forward() may not be differentiable.
if torch.is_tensor(outputs):
non_grad_outputs = [outputs] if not outputs.is_floating_point() else []
else:
non_grad_outputs = [o for o in outputs if torch.is_tensor(o) and not o.is_floating_point()]
ctx.mark_non_differentiable(*non_grad_outputs)
if torch.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
outputs, _, _ = extract_tensors(all_objects=outputs)
return tuple(outputs)
@staticmethod
def backward(ctx, *grads):
global timers
see_memory_usage("In backward", force=True)
# removing pointers to the contiguous buffer memory
# so that they can be garbage collected once the checkpoints
# have been used
if SYNCHRONIZE:
get_accelerator().synchronize()
if PROFILE_TIME:
timers('backward').start()
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
see_memory_usage("In backward checkpointing code", force=True)
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
if PARTITION_ACTIVATIONS:
# with get_accelerator().stream(transport_stream):
# ctx.deepspeed_saved_tensors is the key???
inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors,
device=cuda_device if CPU_CHECKPOINT else None)
detached_inputs = detach_variable(inputs)
elif CPU_CHECKPOINT:
inputs = move_to_device(ctx.deepspeed_saved_tensors, cuda_device, is_activation_to_checkpoint)
detached_inputs = detach_variable(inputs)
else:
inputs = ctx.deepspeed_saved_tensors
detached_inputs = detach_variable(inputs)
# Add non tensor input args
detached_inputs = merge_tensors(tensor_objects=detached_inputs,
non_tensor_objects=ctx.non_tensor_args,
tensor_flags=ctx.tensor_flags)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = get_accelerator().get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# if PARTITION_ACTIVATIONS:
# current_stream=get_accelerator().current_stream()
# current_stream.wait_stream(transport_stream)
see_memory_usage("In backward checkpointing code before forward", force=True)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
see_memory_usage("In backward checkpointing code after forward", force=True)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs, )
# Filter out non tensor outputs
outputs, _, _ = extract_tensors(all_objects=outputs)
# Construct arguments to autograd.backward().
# This is usually just outputs and grads, but forward() can return tensors that
# are not differentiable.
output_tensors = []
grad_tensors = []
for out, grad in zip(outputs, grads):
if out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)
see_memory_usage("In backward checkpointing code before backward", force=True)
torch.autograd.backward(output_tensors, grad_tensors)
# Force clear our stashed tensors to prevent a memory leak in certain scenarios
ctx.deepspeed_saved_tensors = None
ctx.non_tensor_args = None
ctx.tensor_flags = None
see_memory_usage("After backward checkpointing code after backward", force=True)
if PROFILE_TIME:
timers('backward').stop()
timers.log(['backward'])
if SYNCHRONIZE:
get_accelerator().synchronize()
ret_list = [None, None] # first None for ctx
for inp in detached_inputs:
if torch.is_tensor(inp):
ret_list.append(inp.grad)
else:
ret_list.append(None)
return tuple(ret_list)
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint. """
all_outputs = []
CheckpointFunction.apply(function, all_outputs, *args)
if len(all_outputs) == 1:
return all_outputs[0]
else:
return tuple(all_outputs)
def partition_activations_in_checkpoint(partition_activation):
global PARTITION_ACTIVATIONS
PARTITION_ACTIVATIONS = partition_activation
if dist.get_rank() == 0:
logger.info(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
def set_num_layers(nlayers):
global num_layers
num_layers = nlayers
def reset():
"""Resets memory buffers related to contiguous memory optimizations.
Should be called during eval when multiple forward propagations are
computed without any backward propagation that usually clears these
buffers.
Arguments:
None
Return:
None
"""
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contiguous_data_buffers, contiguous_size_buffers
for buffers in contiguous_data_buffers:
buffers = []
# frees up all the pointers to the checkpoints except for the ones
# stored by save for backward
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []
def _configure_using_config_file(config, mpu=None):
global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config
if dist.get_rank() == 0:
logger.info(config.repr())
PARTITION_ACTIVATIONS = config.partition_activations
CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
num_layers = config.number_checkpoints
CPU_CHECKPOINT = config.cpu_checkpointing
SYNCHRONIZE = config.synchronize_checkpoint_boundary
PROFILE_TIME = config.profile
def _configure_defaults():
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
PARTITION_ACTIVATIONS = True
CONTIGUOUS_CHECKPOINTING = False
num_layers = False
CPU_CHECKPOINT = False
SYNCHRONIZE = False
PROFILE_TIME = False
deepspeed_checkpointing_enabled = True
def configure(
mpu_,
deepspeed_config=None,
partition_activations=None,
contiguous_checkpointing=None,
num_checkpoints=None,
checkpoint_in_cpu=None,
synchronize=None,
profile=None,
):
"""Configure DeepSpeed Activation Checkpointing.
Arguments:
mpu_: Optional: An object that implements the following methods
get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size
deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to
configure DeepSpeed Activation Checkpointing
partition_activations: Optional: Partitions activation checkpoint across model parallel
GPUs when enabled. By default False. Will overwrite deepspeed_config if provided
contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory
buffer. Works only with homogeneous checkpoints when partition_activations is enabled.
Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if
provided
num_checkpoints: Optional: Number of activation checkpoints stored during the forward
propagation of the model. Used to calculate the buffer size for contiguous_checkpointing
Will overwrite deepspeed_config if provided
checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with
partition_activation. Default is false. Will overwrite deepspeed_config if provided
synchronize: Optional: Performs get_accelerator().synchronize() at the beginning and end of
each call to deepspeed.checkpointing.checkpoint for both forward and backward pass.
By default false. Will overwrite deepspeed_config if provided
profile: Optional: Logs the forward and backward time for each
deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config
if provided
Returns:
None
"""
global mpu, num_layers, deepspeed_checkpointing_enabled
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME
_configure_defaults()
if mpu_ is not None:
mpu = mpu_
if deepspeed_config is not None:
_configure_using_config_file(deepspeed_config, mpu=mpu)
if partition_activations is not None:
PARTITION_ACTIVATIONS = partition_activations
if contiguous_checkpointing is not None:
CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing
if num_checkpoints is not None:
num_layers = num_checkpoints
if checkpoint_in_cpu is not None:
CPU_CHECKPOINT = checkpoint_in_cpu
if synchronize is not None:
SYNCHRONIZE = synchronize
if profile is not None:
PROFILE_TIME = profile
if CONTIGUOUS_CHECKPOINTING:
assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only available with partitioned activations. Set partitioned activations to true in deepspeed config"
if CONTIGUOUS_CHECKPOINTING:
assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"
def is_configured():
"""True if deepspeed activation checkpointing has been configured
by calling deepspeed.checkpointing.configure, else returns false
Arguments:
None
Return:
True of configured, else False
"""
return deepspeed_checkpointing_enabled