Skip to content

Commit

Permalink
Merge branch 'master' into olruwase/zero_multi_models
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Feb 11, 2025
2 parents 02477ce + 22d7fdc commit 2f84032
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
39 changes: 20 additions & 19 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]:
return None, grad_output


class Replaced_Layer(nn.Module, ABC):
class TensorParallel_Layer(nn.Module, ABC):
"""
A base class for model layers with tensor parallelism support.
This class is designed to be extended by specific layers that require distributed
Expand All @@ -141,7 +141,7 @@ class Replaced_Layer(nn.Module, ABC):

def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
"""
Initializes the Replaced_Layer with optional model parallelism group and layer name.
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
Args:
mp_group (Optional[dist.ProcessGroup]): The process group for model parallelism.
Expand Down Expand Up @@ -177,7 +177,7 @@ def gather_params(self, params_list):
pass

@abstractmethod
def partition(self, params_list: List[torch.Tensor]):
def _tp_partition(self, params_list: List[torch.Tensor]):
"""
Partitions the parameters for tensor parallelism.
It is necessary to ensure that this function only involves the logic of params partitioning.
Expand Down Expand Up @@ -205,7 +205,7 @@ def config_tp_params(self, weight):
setattr(weight, DS_TENSOR_MODEL_PARALLEL, True)
setattr(weight, DS_IS_REPLACED_MODULE, True)
weight.gather_params = self.gather_params
weight.partition = self.partition
weight._tp_partition = self._tp_partition

def is_training_mode(self):
global DEEPSPEED_AUTOTP_MODE
Expand Down Expand Up @@ -294,17 +294,17 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
#TODO : Check whether there are any missing attributes.
if self.enabled:
self.params[0].partition(self.params)
self.params[0]._tp_partition(self.params)


class LinearAllreduce(Replaced_Layer):
class LinearAllreduce(TensorParallel_Layer):

def __init__(self, module, mp_group, **kwargs):
super(LinearAllreduce, self).__init__(mp_group, **kwargs)
self.weight = module.weight
self.bias = module.bias

self.partition([self.weight, self.bias])
self._tp_partition([self.weight, self.bias])
self.support_training = True
self.config_tp_params(self.weight)
if self.bias is not None:
Expand Down Expand Up @@ -335,7 +335,7 @@ def gather_params(self, params_list):
return

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):

if not self.is_training_mode():
self.uneven_partition(params_list)
Expand Down Expand Up @@ -367,21 +367,22 @@ def uneven_partition(self, params_list):


#remove kwargs from partition.
class LinearLayer(Replaced_Layer):
class LinearLayer(TensorParallel_Layer):

def __init__(self, module, mp_group, skip_partition=False, **kwargs):
def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
super(LinearLayer, self).__init__(mp_group, **kwargs)
self.weight = module.weight
self.bias = module.bias
if not skip_partition:
self.partition([self.weight, self.bias])
self._tp_partition([self.weight, self.bias])
self.support_training = True
self.config_tp_params(self.weight)
if self.bias is not None:
self.config_tp_params(self.bias)

def forward(self, input):
input = ColumnParallel.apply(self.mp_group, input)
if getattr(self, 'mp_group', None) is not None:
input = ColumnParallel.apply(self.mp_group, input)
output = torch.matmul(input, self.weight.transpose(-1, -2))
if self.bias is not None:
output += self.bias
Expand All @@ -401,7 +402,7 @@ def gather_params(self, params_list):
params_list[idx].data = output_param.contiguous()

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):

if not self.is_training_mode():
self.uneven_partition(params_list)
Expand Down Expand Up @@ -466,7 +467,7 @@ def __init__(self, module, mp_group, skip_partition=False, **kwargs):
super().__init__(module, mp_group, skip_partition, **kwargs)

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None:
return
Expand All @@ -481,7 +482,7 @@ def partition(self, params_list):
class conv_LinearLayer(LinearLayer):

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight = None
bias = None
if len(params_list) == 1:
Expand All @@ -506,7 +507,7 @@ class Yuan_LinearAllreduce(LinearAllreduce):

#Yuan2
@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, False)
params_list[0].data = weight
Expand All @@ -517,7 +518,7 @@ def partition(self, params_list):
class Yuan_LinearLayer(LinearLayer):
#Yuan2
@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, True)
params_list[0].data = move(weight, get_accelerator().current_device_name()).detach()
Expand All @@ -528,7 +529,7 @@ def partition(self, params_list):
class GateUpPack_LinearLayer(LinearLayer):
# chatGLM2, chatGLM2
@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size)
params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach()
if bias is not None:
Expand All @@ -538,7 +539,7 @@ def partition(self, params_list):
class Conv_LinearALlreduce(LinearAllreduce):

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None:
return
Expand Down
9 changes: 7 additions & 2 deletions deepspeed/runtime/hybrid_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,13 @@ def create_inference_containers(self, module, layer_id=0):

layer_id += 1
else:
self._other_layers.append(self.inference_policies[child.__class__][0](
weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None))
if self.inference_policies[child.__class__][0] == LinearLayer:
self._other_layers.append(self.inference_policies[child.__class__][0](module=child,
mp_group=None,
skip_partition=True))
else:
self._other_layers.append(self.inference_policies[child.__class__][0](
weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None))
self._orig_modules_others.append(child)
self._orig_fwds_others.append(child.forward)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/model_parallelism/test_autotp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test(self, layer_type):
assert total_params == params1

for name, param in tp_layer.named_parameters(recurse=False):
param.partition([param])
param._tp_partition([param])

tp_params2 = sum(p.numel() for p in tp_layer.parameters())

Expand Down

0 comments on commit 2f84032

Please sign in to comment.