Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero] support all-gather overlap #5898

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ def __init__(
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=False,
)

def sync_dp_grads(self):
Expand Down
50 changes: 46 additions & 4 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import warnings
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from types import MethodType
Expand Down Expand Up @@ -34,7 +35,10 @@
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle

from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
Expand All @@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
def __init__(self, module: nn.Module, precision: str, overlap_communication: bool = False) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
Expand All @@ -72,12 +76,25 @@ def __init__(self, module: nn.Module, precision: str) -> None:
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_communication = overlap_communication
if overlap_communication:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)

def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_communication else nullcontext()
with ctx:
return super().forward(*args, **kwargs)

def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
Expand Down Expand Up @@ -209,6 +226,7 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params()

Expand All @@ -221,16 +239,38 @@ def load_sharded_model(
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)

def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
Expand Down Expand Up @@ -290,6 +330,7 @@ def __init__(
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
overlap_allgather: bool = False,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
Expand All @@ -316,6 +357,7 @@ def __init__(
cpu_offload=cpu_offload,
master_weights=master_weights,
)
self.overlap_allgather = overlap_allgather
self.lora_enabled = False
self.verbose = verbose

Expand Down Expand Up @@ -431,11 +473,11 @@ def configure(
self.add_lora_params_to_optimizer(model, optimizer)

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
model = LowLevelZeroModel(model, self.precision, overlap_communication=self.overlap_allgather)

# TODO: Support Galore + ZeRO
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs}
zero_optim_kwargs = {**self.zero_optim_kwargs, "overlap_allgather": self.overlap_allgather}
dp_size = dist.get_world_size()

# Replace with the distributed implementation if exists
Expand Down
50 changes: 31 additions & 19 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket
from .zero_hook import set_all_gather_handle, wait_all_gather_handle


class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)

Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(

# communication params
self._overlap_communication = overlap_communication
self._overlap_allgather = overlap_allgather
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype

Expand All @@ -145,6 +148,8 @@ def __init__(

# record the padding size of each param
self._padding_map = dict()
# padded working param is all-gather buffer and it shares the same memory with working param
self._working_param_to_padded_working_param = dict()

# mapping working param and master param
self.master_to_working_param = dict()
Expand Down Expand Up @@ -245,11 +250,12 @@ def _create_master_param_current_rank(self, param_list):
with torch.no_grad():
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
# reset working params' ptr when no master weights
if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
# # reset working params' ptr when no master weights
# if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
self._working_param_to_padded_working_param[param] = padding_param

splited_params = padding_param.split(
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
Expand All @@ -258,7 +264,7 @@ def _create_master_param_current_rank(self, param_list):

# use fp32 when master_weights is True
if self._master_weights is True:
splited_param_current_rank = splited_params.detach().float().to(device)
splited_param_current_rank = splited_params.detach().clone().float().to(device)
else:
splited_param_current_rank = splited_params

Expand Down Expand Up @@ -549,22 +555,24 @@ def step(self, closure=None):
working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param]
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather:
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
set_all_gather_handle(working_param, handle)
else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
if not self._overlap_allgather:
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)

def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Expand Down Expand Up @@ -892,3 +900,7 @@ def get_working_grad_by_param_id(self, param_id: int) -> Tensor:
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
grad_store = self.pid_to_grad_store[param_id]
return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)

def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param)
33 changes: 33 additions & 0 deletions colossalai/zero/low_level/zero_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import List

from torch._tensor import Tensor

from colossalai.tensor.param_op_hook import ColoParamOpHook

_ALL_GATHER_HANDLE = "_all_gather_handle"


def wait_all_gather_handle(p):
if hasattr(p, _ALL_GATHER_HANDLE):
handle = getattr(p, _ALL_GATHER_HANDLE)
handle.wait()
delattr(p, _ALL_GATHER_HANDLE)


def set_all_gather_handle(p, handle):
setattr(p, _ALL_GATHER_HANDLE, handle)


class ZeroOpHook(ColoParamOpHook):
def pre_forward(self, params: List[Tensor]) -> None:
for p in params:
wait_all_gather_handle(p)

def post_forward(self, params: List[Tensor]) -> None:
pass

def pre_backward(self, params: List[Tensor]) -> None:
pass

def post_backward(self, params: List[Tensor]) -> None:
pass
4 changes: 2 additions & 2 deletions examples/language/performance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ def on_step_start(self, step: int) -> None:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
get_accelerator().synchronize()
# get_accelerator().synchronize()
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
self.timer.start()

def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
get_accelerator().synchronize()
# get_accelerator().synchronize()
self.timer.end()

batch_size, seq_len = input_ids.shape
Expand Down
4 changes: 4 additions & 0 deletions tests/test_zero/test_low_level/test_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ def fwd_bwd_func(number, cur_data, check_flag):
zero1_optimizer.step()
zero2_optimizer.step()

zero1_optimizer._force_wait_all_gather()
zero2_optimizer._force_wait_all_gather()

# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert not hasattr(z1p, "_all_gather_handle")
assert torch.equal(z1p.data, z2p.data)


Expand Down
2 changes: 2 additions & 0 deletions tests/test_zero/test_low_level/test_zero1_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
# torch ddp step
torch_optimizer.step()

zero_optimizer._force_wait_all_gather()

# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p, z1p, dtype=dtype)
Expand Down
Loading