diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e06cf0581e39..bdc91b51fa4a 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -549,6 +549,13 @@ def step(self, closure=None): working_param = real_working_params[group_id][idx] param_to_gather = master_param.to(device).to(self._dtype) pg = self.param_to_pg[working_param] + if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: + buffer_tensor = torch.empty_like( + torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) + ) + dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) + working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) + continue try: self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) except RuntimeError: