Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero] support multiple (partial) backward passes #5596

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/zero/low_level/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self.reset_all()

def reset_all(self) -> None:
# init
self.current_group_id = 0
self._num_elements_in_bucket = 0
Expand Down
67 changes: 53 additions & 14 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def __init__(
max_scale: float = 2**32,
) -> None:
super().__init__(
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
)
self.num_working_param_groups = num_working_param_groups
self.grad_store = grad_store
Expand Down Expand Up @@ -273,11 +279,10 @@ def _create_master_param_current_rank(self, param_list):
# Backward Reduction Hook #
###########################

def _grad_handler(self, param, group_id, grad):
def _grad_handler(self, group_id, param):
# if run with no_sync context, would not sync grad when backward
if self.require_grad_sync:
self._add_to_bucket(param, group_id)
return grad

def _attach_reduction_hook(self):
# we iterate over the working params
Expand All @@ -286,7 +291,7 @@ def _attach_reduction_hook(self):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
param.register_hook(partial(self._grad_handler, param, group_id))
param.register_post_accumulate_grad_hook(partial(self._grad_handler, group_id))

#######################
# Reduction Functions #
Expand Down Expand Up @@ -415,15 +420,22 @@ def _run_reduction(self):
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
self._update_partitoned_grad(
non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
non_moe_grad_in_bucket_current_rank,
recieved_grad,
group_id,
1,
)

if len(moe_grad_list) > 0:
flat_grads_list = list(
moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
dist.reduce_scatter(
recieved_grad,
flat_grads_list,
group=self.moe_extra_dp_pg,
)
param_slice = self._world_size // self.moe_extra_dp_pg_size
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
for split_recieved_grad in recieved_grad:
Expand All @@ -444,14 +456,25 @@ def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List
self._add_grad(grad, self._world_size, group_id, param_id, rank)

def _update_partitoned_grad(
self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
self,
origin_grad_list: List,
flat_grad: torch.Tensor,
group_id: int,
partition_num: int,
) -> None:
sync_tensor(flat_grad, origin_grad_list)
for grad in origin_grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._add_grad(grad, partition_num, group_id, param_id)

def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
def _add_grad(
self,
grad: torch.Tensor,
partition_num: int,
group_id: int,
param_id: int,
rank: int = 0,
) -> None:
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
Expand Down Expand Up @@ -534,6 +557,7 @@ def zero_grad(self, set_to_none=True):
if param.grad is not None:
param.grad.detach()
param.grad.zero_()
self._bucket_store.reset_all()

####################
# Update Parameter #
Expand Down Expand Up @@ -655,14 +679,20 @@ def step(self, closure=None):
for _ in range(self.moe_extra_dp_pg_size)
]
dist.all_gather(
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
all_splited_param,
splited_param.to(device).to(self._dtype),
group=self.moe_extra_dp_pg,
)
else:
all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
dist.all_gather(
all_splited_param,
splited_param.to(device).to(self._dtype),
group=self.dp_pg,
)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]

Expand All @@ -685,7 +715,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
[float(total_norm)],
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
Expand All @@ -698,10 +730,14 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo

# Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
[float(total_norm_exponentiated)],
device=get_accelerator().get_current_device(),
dtype=torch.float,
)
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
total_norm_exponentiated_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_pg,
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)

Expand Down Expand Up @@ -920,5 +956,8 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:

def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
if hasattr(self, "moe_master_to_working_map"):
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
return {
**self._param_store.master_to_working_param,
**self.moe_master_to_working_map,
}
return self._param_store.master_to_working_param
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ click
fabric
contexttimer
ninja
torch>=1.12
torch>=2.1.0
safetensors
einops
pydantic
Expand Down
Loading