From f8ed148fb4b13ffdd1801d6bfd604e3aa3844489 Mon Sep 17 00:00:00 2001 From: Han Zhang <623606860@qq.com> Date: Wed, 8 Dec 2021 21:02:41 +0800 Subject: [PATCH] [Fix] Fix dist training infinite waiting issue (#1035) * [#1034] fix dist training infinite waiting issue * print log_vars keys in assertion msg * linting issue --- mmseg/models/segmentors/base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 944da0f2e4..f0f320ffbf 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -188,6 +188,17 @@ def _parse_losses(losses): loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) + # If the loss_vars has different length, raise assertion error + # to prevent GPUs from infinite waiting. + if dist.is_available() and dist.is_initialized(): + log_var_length = torch.tensor(len(log_vars), device=loss.device) + dist.all_reduce(log_var_length) + message = (f'rank {dist.get_rank()}' + + f' len(log_vars): {len(log_vars)}' + ' keys: ' + + ','.join(log_vars.keys()) + '\n') + assert log_var_length == len(log_vars) * dist.get_world_size(), \ + 'loss log variables are different across GPUs!\n' + message + log_vars['loss'] = loss for loss_name, loss_value in log_vars.items(): # reduce loss when distributed training