-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
Copy pathdistrib_parts.py
560 lines (439 loc) · 18.5 KB
/
distrib_parts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
"""
Root module for all distributed operations in Lightning.
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
"""
from contextlib import ExitStack
import os
from abc import ABC, abstractmethod
import time
import random
import torch
from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
try:
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True
class TrainerDPMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
on_gpu: bool
use_dp: bool
use_ddp2: bool
use_ddp: bool
testing: bool
single_gpu: bool
root_gpu: ...
amp_level: str
precision: ...
global_rank: int
tpu_local_core_rank: int
tpu_global_core_rank: int
use_tpu: bool
use_native_amp: bool
data_parallel_device_ids: ...
progress_bar_callback: ...
tpu_id: Optional[int]
on_colab_kaggle: str
save_spawn_weights: Callable
@property
@abstractmethod
def use_amp(self) -> bool:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def reinit_scheduler_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def setup(self, *args) -> None:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def is_function_implemented(self, *args) -> bool:
"""Warning: this is just empty shell for code implemented in other class."""
def copy_trainer_model_properties(self, model):
if isinstance(model, LightningDataParallel):
ref_model = model.module
elif isinstance(model, LightningDistributedDataParallel):
ref_model = model.module
else:
ref_model = model
for m in [model, ref_model]:
m.trainer = self
m.use_dp = self.use_dp
m.use_ddp2 = self.use_ddp2
m.use_ddp = self.use_ddp
m.use_amp = self.use_amp
m.testing = self.testing
m.single_gpu = self.single_gpu
m.use_tpu = self.use_tpu
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank
def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None):
"""
Transfers the data to the TPU.
Args:
batch: A tensor or collection of tensors.
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
Return:
the tensor on the TPU device.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not XLA_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'
)
device = xm.xla_device(tpu_id)
return self.__transfer_batch_to_device(batch, device)
def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
"""
Transfers the data to the GPU.
Args:
batch: A tensor or collection of tensors.
gpu_id: The id of the GPU device. If omitted, the first available GPU is chosen.
Return:
the tensor on the GPU device.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
device = torch.device('cuda', gpu_id)
return self.__transfer_batch_to_device(batch, device)
def __transfer_batch_to_device(self, batch: Any, device: torch.device):
model = self.get_model()
if model is not None:
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)
def single_gpu_train(self, model):
# call setup
self.setup('fit')
if self.is_function_implemented('setup', model):
model.setup('fit')
model.cuda(self.root_gpu)
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
# TODO: remove with dropping NVIDIA AMP support
if self.use_amp and not self.use_native_amp:
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
self.run_pretrain_routine(model)
def tpu_train(self, tpu_core_idx, model):
# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
model.setup('fit')
# put model on tpu
self._device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
model.to(self._device)
# get the appropriate tpu ranks
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
# avoid duplicating progress bar
if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()
self.global_rank = self.tpu_local_core_rank
rank_zero_only.rank = self.global_rank
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
# init 16 bit for TPU
if self.precision == 16:
os.environ['XLA_USE_BF16'] = str(1)
log.info(f'INIT TPU local core: {self.tpu_local_core_rank},'
f' global rank: {self.tpu_global_core_rank}')
# continue training routine
self.run_pretrain_routine(model)
# when training ends on these platforms dump weights to get out of the main process
if self.on_colab_kaggle:
self.save_spawn_weights(model)
def dp_train(self, model):
# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
model.setup('fit')
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
model.cuda(self.root_gpu)
# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
if self.use_amp and self.use_native_amp:
# wrap the user's forward in autocast and give it back at the end
model.forward = torch.cuda.amp.autocast()(model.forward)
# TODO: remove with dropping NVIDIA AMP support
# check for this bug (amp + dp + !01 doesn't work)
# https://github.com/NVIDIA/apex/issues/227
if self.use_dp and self.use_amp and not self.use_native_amp:
if self.amp_level == 'O2':
raise MisconfigurationException(
f'Amp level {self.amp_level} with DataParallel is not supported.'
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
f' We recommend you switch to ddp if you want to use amp')
else:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.reinit_scheduler_properties(optimizers, self.lr_schedulers)
# create list of device ids
device_ids = self.data_parallel_device_ids
if isinstance(device_ids, int):
device_ids = list(range(device_ids))
# set dp device
torch.cuda.set_device(self.root_gpu)
model = LightningDataParallel(model, device_ids=device_ids)
self.run_pretrain_routine(model)
model.forward = model_autocast_original_forward
def horovod_train(self, model):
# call setup after the ddp process has connected
self.setup('fit')
if self.is_function_implemented('setup', model):
model.setup('fit')
if torch.cuda.is_available() and self.on_gpu:
# Horovod: pin GPU to local rank
assert self.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)
# avoid duplicating progress bar
if hvd.rank() != 0 and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
# Horovod: scale the learning rate by the number of workers to account for
# increased total batch size
for optimizer in self.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] *= hvd.size()
if self.use_amp:
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
for optimizer in self.optimizers:
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
def filter_named_parameters(model, optimizer):
opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])])
return [(name, p) for name, p in model.named_parameters() if p in opt_params]
# Horovod: wrap optimizers to perform gradient aggregation via allreduce
self.optimizers = [
hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters(model, optimizer))
for optimizer in self.optimizers
]
# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.global_rank = hvd.rank()
rank_zero_only.rank = self.global_rank
with ExitStack() as stack:
for optimizer in self.optimizers:
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())
self.run_pretrain_routine(model)
# Make sure all workers have finished training before returning to the user
hvd.join()
def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
if isinstance(s, str):
if s == '-1':
return -1
else:
return [int(x.strip()) for x in s.split(',') if len(x) > 0]
else:
return s
def get_all_available_gpus() -> List[int]:
"""
Returns:
a list of all available gpus
"""
return list(range(torch.cuda.device_count()))
def _check_data_type(device_ids: Any) -> None:
"""
Checks that the device_ids argument is one of: None, Int, String or List.
Raises a MisconfigurationException otherwise.
Args:
device_ids: gpus/tpu_cores parameter as passed to the Trainer
"""
if device_ids is not None and (not isinstance(device_ids, (int, str, MutableSequence)) or isinstance(device_ids, bool)):
raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.")
def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int]]) -> Optional[List[int]]:
assert gpus is not None
if isinstance(gpus, MutableSequence):
return list(gpus)
# must be an int
if not gpus: # gpus==0
return None
if gpus == -1:
return get_all_available_gpus()
return list(range(gpus))
def sanitize_gpu_ids(gpus: List[int]) -> List[int]:
"""
Checks that each of the GPUs in the list is actually available.
Raises a MisconfigurationException if any of the GPUs is not available.
Args:
gpus: list of ints corresponding to GPU indices
Returns:
unmodified gpus variable
"""
all_available_gpus = get_all_available_gpus()
misconfig = False
for gpu in gpus:
if gpu not in all_available_gpus:
misconfig = True
if misconfig:
# sometimes auto ddp might have different flags
# but this is not what the user intended
# correct for the user
if len(gpus) == len(all_available_gpus):
gpus = all_available_gpus
else:
raise MisconfigurationException(f"""
You requested GPUs: {gpus}
But your machine only has: {all_available_gpus}
""")
return gpus
def _parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]:
"""
Parses the GPU ids given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.
Args:
gpus: An int -1 or string '-1' indicate that all available GPUs should be used.
A list of ints or a string containing list of comma separated integers
indicates specific GPUs to use.
An int 0 means that no GPUs should be used.
Any int N > 0 indicates that GPUs [0..N) should be used.
Returns:
a list of gpus to be used or ``None`` if no GPUs were requested
If no GPUs are available but the value of gpus variable indicates request for GPUs
then a MisconfigurationException is raised.
"""
# nothing was passed into the GPUs argument
if callable(gpus):
return None
# Check that gpus param is None, Int, String or List
_check_data_type(gpus)
# Handle the case when no gpus are requested
if gpus is None or isinstance(gpus, int) and gpus == 0:
return None
# We know user requested GPUs therefore if some of the
# requested GPUs are not available an exception is thrown.
gpus = _normalize_parse_gpu_string_input(gpus)
gpus = _normalize_parse_gpu_input_to_list(gpus)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")
gpus = sanitize_gpu_ids(gpus)
return gpus
def determine_root_gpu_device(gpus: List[int]) -> Optional[int]:
"""
Args:
gpus: non-empty list of ints representing which gpus to use
Returns:
designated root GPU device id
"""
if gpus is None:
return None
assert isinstance(gpus, list), "gpus should be a list"
assert len(gpus) > 0, "gpus should be a non empty list"
# set root gpu
root_gpu = gpus[0]
return root_gpu
def retry_jittered_backoff(func: Callable, num_retries: int = 5, cap_delay: float = 1.0, base_delay: float = 0.01):
"""Retry jittered backoff.
Based on:
https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
Args:
func: tested function
num_retries: number of tries
cap_delay: max sleep time
base_delay: initial sleep time is 10ms
"""
sleep_delay = base_delay # initial sleep time is 10ms
for i in range(num_retries):
try:
return func()
except RuntimeError as err:
if i == num_retries - 1:
raise err
else:
continue
time.sleep(sleep_delay)
sleep_delay = min(cap_delay, random.uniform(base_delay, sleep_delay * 3))
def _parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int], int]]:
"""
Parses the tpu_cores given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.
Args:
tpu_cores: An int 1 or string '1' indicate that 1 core with multi-processing should be used
An int 8 or string '8' indicate that all 8 cores with multi-processing should be used
A list of int or a string containing list of comma separated integer
indicates specific TPU core to use.
Returns:
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
"""
if callable(tpu_cores):
return None
_check_data_type(tpu_cores)
if isinstance(tpu_cores, str):
tpu_cores = _parse_tpu_cores_str(tpu_cores.strip())
if not _tpu_cores_valid(tpu_cores):
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")
return tpu_cores
def _tpu_cores_valid(tpu_cores):
return tpu_cores in (1, 8, None) or (
isinstance(tpu_cores, (list, tuple, set)) and
len(tpu_cores) == 1 and
tpu_cores[0] in range(1, 9)
)
def _parse_tpu_cores_str(tpu_cores):
if tpu_cores in ('1', '8'):
tpu_cores = int(tpu_cores)
else:
tpu_cores = [int(x.strip()) for x in tpu_cores.split(',') if len(x) > 0]
return tpu_cores
def pick_single_gpu(exclude_gpus: list):
for i in range(torch.cuda.device_count()):
if i in exclude_gpus:
continue
# Try to allocate on device:
device = torch.device(f"cuda:{i}")
try:
torch.ones(1).to(device)
except RuntimeError:
continue
return i
raise RuntimeError("No GPUs available.")
def pick_multiple_gpus(nb):
picked = []
for _ in range(nb):
picked.append(pick_single_gpu(exclude_gpus=picked))
return picked