diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 3c67299bbe79..8abaf8fc6b3f 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -371,7 +371,7 @@ def _run_reduction(self): for i, sz in enumerate(bucket_store.sizes): grp = bucket_store.torch_pg if len(bucket_store.sizes) == 1 else bucket_store.torch_pg[i] flat_grads_list = list(cur_flat_grads.split(len(cur_flat_grads) // sz)) - received_grad = torch.zeros_like(flat_grads_list[0]) + received_grad = torch.empty_like(flat_grads_list[0]) if self._fp8_communication: reduce_scatter_fp8( received_grad,