diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index c410bf900c31..2a24c2920466 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -124,7 +124,7 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]: return None, grad_output -class Replaced_Layer(nn.Module, ABC): +class TensorParallel_Layer(nn.Module, ABC): """ A base class for model layers with tensor parallelism support. This class is designed to be extended by specific layers that require distributed @@ -141,7 +141,7 @@ class Replaced_Layer(nn.Module, ABC): def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any): """ - Initializes the Replaced_Layer with optional model parallelism group and layer name. + Initializes the TensorParallel_Layer with optional model parallelism group and layer name. Args: mp_group (Optional[dist.ProcessGroup]): The process group for model parallelism. @@ -177,7 +177,7 @@ def gather_params(self, params_list): pass @abstractmethod - def partition(self, params_list: List[torch.Tensor]): + def _tp_partition(self, params_list: List[torch.Tensor]): """ Partitions the parameters for tensor parallelism. It is necessary to ensure that this function only involves the logic of params partitioning. @@ -205,7 +205,7 @@ def config_tp_params(self, weight): setattr(weight, DS_TENSOR_MODEL_PARALLEL, True) setattr(weight, DS_IS_REPLACED_MODULE, True) weight.gather_params = self.gather_params - weight.partition = self.partition + weight._tp_partition = self._tp_partition def is_training_mode(self): global DEEPSPEED_AUTOTP_MODE @@ -294,17 +294,17 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: """ #TODO : Check whether there are any missing attributes. if self.enabled: - self.params[0].partition(self.params) + self.params[0]._tp_partition(self.params) -class LinearAllreduce(Replaced_Layer): +class LinearAllreduce(TensorParallel_Layer): def __init__(self, module, mp_group, **kwargs): super(LinearAllreduce, self).__init__(mp_group, **kwargs) self.weight = module.weight self.bias = module.bias - self.partition([self.weight, self.bias]) + self._tp_partition([self.weight, self.bias]) self.support_training = True self.config_tp_params(self.weight) if self.bias is not None: @@ -335,7 +335,7 @@ def gather_params(self, params_list): return @torch.no_grad() - def partition(self, params_list): + def _tp_partition(self, params_list): if not self.is_training_mode(): self.uneven_partition(params_list) @@ -367,21 +367,22 @@ def uneven_partition(self, params_list): #remove kwargs from partition. -class LinearLayer(Replaced_Layer): +class LinearLayer(TensorParallel_Layer): - def __init__(self, module, mp_group, skip_partition=False, **kwargs): + def __init__(self, module, mp_group=None, skip_partition=False, **kwargs): super(LinearLayer, self).__init__(mp_group, **kwargs) self.weight = module.weight self.bias = module.bias if not skip_partition: - self.partition([self.weight, self.bias]) + self._tp_partition([self.weight, self.bias]) self.support_training = True self.config_tp_params(self.weight) if self.bias is not None: self.config_tp_params(self.bias) def forward(self, input): - input = ColumnParallel.apply(self.mp_group, input) + if getattr(self, 'mp_group', None) is not None: + input = ColumnParallel.apply(self.mp_group, input) output = torch.matmul(input, self.weight.transpose(-1, -2)) if self.bias is not None: output += self.bias @@ -401,7 +402,7 @@ def gather_params(self, params_list): params_list[idx].data = output_param.contiguous() @torch.no_grad() - def partition(self, params_list): + def _tp_partition(self, params_list): if not self.is_training_mode(): self.uneven_partition(params_list) @@ -466,7 +467,7 @@ def __init__(self, module, mp_group, skip_partition=False, **kwargs): super().__init__(module, mp_group, skip_partition, **kwargs) @torch.no_grad() - def partition(self, params_list): + def _tp_partition(self, params_list): for idx, param in enumerate(params_list): if param is None: return @@ -481,7 +482,7 @@ def partition(self, params_list): class conv_LinearLayer(LinearLayer): @torch.no_grad() - def partition(self, params_list): + def _tp_partition(self, params_list): weight = None bias = None if len(params_list) == 1: @@ -506,7 +507,7 @@ class Yuan_LinearAllreduce(LinearAllreduce): #Yuan2 @torch.no_grad() - def partition(self, params_list): + def _tp_partition(self, params_list): weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size, False) params_list[0].data = weight @@ -517,7 +518,7 @@ def partition(self, params_list): class Yuan_LinearLayer(LinearLayer): #Yuan2 @torch.no_grad() - def partition(self, params_list): + def _tp_partition(self, params_list): weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size, True) params_list[0].data = move(weight, get_accelerator().current_device_name()).detach() @@ -528,7 +529,7 @@ def partition(self, params_list): class GateUpPack_LinearLayer(LinearLayer): # chatGLM2, chatGLM2 @torch.no_grad() - def partition(self, params_list): + def _tp_partition(self, params_list): weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size) params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach() if bias is not None: @@ -538,7 +539,7 @@ def partition(self, params_list): class Conv_LinearALlreduce(LinearAllreduce): @torch.no_grad() - def partition(self, params_list): + def _tp_partition(self, params_list): for idx, param in enumerate(params_list): if param is None: return diff --git a/deepspeed/runtime/hybrid_engine.py b/deepspeed/runtime/hybrid_engine.py index 8a6311bb6e83..b6e417fd4764 100644 --- a/deepspeed/runtime/hybrid_engine.py +++ b/deepspeed/runtime/hybrid_engine.py @@ -290,8 +290,13 @@ def create_inference_containers(self, module, layer_id=0): layer_id += 1 else: - self._other_layers.append(self.inference_policies[child.__class__][0]( - weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None)) + if self.inference_policies[child.__class__][0] == LinearLayer: + self._other_layers.append(self.inference_policies[child.__class__][0](module=child, + mp_group=None, + skip_partition=True)) + else: + self._other_layers.append(self.inference_policies[child.__class__][0]( + weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None)) self._orig_modules_others.append(child) self._orig_fwds_others.append(child.forward) else: diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py index fc1f0624ec87..73e61b1d3398 100644 --- a/tests/unit/model_parallelism/test_autotp_training.py +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -330,7 +330,7 @@ def test(self, layer_type): assert total_params == params1 for name, param in tp_layer.named_parameters(recurse=False): - param.partition([param]) + param._tp_partition([param]) tp_params2 = sum(p.numel() for p in tp_layer.parameters())