Skip to content

Commit

Permalink
[Fix] Fix dist training infinite waiting issue (#1035)
Browse files Browse the repository at this point in the history
* [#1034] fix dist training infinite waiting issue

* print log_vars keys in assertion msg

* linting issue
  • Loading branch information
fingertap authored Dec 8, 2021
1 parent a357419 commit f8ed148
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions mmseg/models/segmentors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,17 @@ def _parse_losses(losses):
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)

# If the loss_vars has different length, raise assertion error
# to prevent GPUs from infinite waiting.
if dist.is_available() and dist.is_initialized():
log_var_length = torch.tensor(len(log_vars), device=loss.device)
dist.all_reduce(log_var_length)
message = (f'rank {dist.get_rank()}' +
f' len(log_vars): {len(log_vars)}' + ' keys: ' +
','.join(log_vars.keys()) + '\n')
assert log_var_length == len(log_vars) * dist.get_world_size(), \
'loss log variables are different across GPUs!\n' + message

log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
Expand Down

0 comments on commit f8ed148

Please sign in to comment.