Skip to content

Commit

Permalink
[pipeline] build bloom model and policy , revise the base class of po…
Browse files Browse the repository at this point in the history
…licy (hpcaitech#4161)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining
  • Loading branch information
CjhHa1 authored and ver217 committed Aug 15, 2023
1 parent 35ec13b commit 6728bc4
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 80 deletions.
37 changes: 35 additions & 2 deletions colossalai/pipeline/policy/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any, Dict, List, Optional, Tuple

from colossalai.lazy import LazyTensor
import numpy as np
from torch import Tensor
from torch.nn import Module, Parameter

from colossalai.lazy import LazyTensor
from colossalai.pipeline.stage_manager import PipelineStageManager


class Policy:

def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager

Expand Down Expand Up @@ -93,7 +95,8 @@ def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]:
"""
raise NotImplementedError

def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
def parallelize_model(self,
module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
"""Parallelize model for pipeline parallel
Args:
Expand All @@ -106,3 +109,33 @@ def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[
self.replace_forward(module)
shared_params = self.get_shared_params(module)
return hold_params, hold_buffers, shared_params

@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages

# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

@staticmethod
def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]:
"""
get the start index and end index of layers for each stage.
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

start_idx = num_layers_per_stage_accumulated[stage]
end_idx = num_layers_per_stage_accumulated[stage + 1]

return [start_idx, end_idx]
94 changes: 40 additions & 54 deletions colossalai/pipeline/policy/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,26 @@


def bert_model_forward(
self: BertModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
#labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage
self: BertModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
# labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
# this is from the previous stage
hidden_states: Optional[torch.FloatTensor] = None,
):
#TODO: add explaination of the output here.
# TODO: add explaination of the output here.
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
Expand Down Expand Up @@ -93,6 +94,7 @@ def bert_model_forward(
batch_size, seq_length = input_shape
device = hidden_states.device

# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
Expand Down Expand Up @@ -144,7 +146,7 @@ def bert_model_forward(
else:
encoder_extended_attention_mask = None

#inherit from bert_layer
# inherit from bert_layer
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
Expand All @@ -156,12 +158,12 @@ def bert_model_forward(
use_cache = False
next_decoder_cache = () if use_cache else None

#calculate the num_layers
# calculate the num_layers
num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages
start_layer = stage_manager.stage * num_layers_per_stage
end_layer = (stage_manager.stage + 1) * num_layers_per_stage

#layer_outputs
# layer_outputs
layer_outputs = hidden_states if hidden_states is not None else None
for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer):
if stage_manager.is_first_stage() and idx == 0:
Expand Down Expand Up @@ -206,20 +208,21 @@ def custom_forward(*inputs):
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
all_cross_attentions = all_cross_attentions + \
(layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

#end of a stage loop
# end of a stage loop
sequence_output = layer_outputs[0] if layer_outputs is not None else None

if stage_manager.is_last_stage():
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + layer_outputs[1:]

#output of non-first and non-last stages:
# output of non-first and non-last stages:
if not return_dict:
return tuple(v for v in [
hidden_states,
Expand All @@ -229,7 +232,7 @@ def custom_forward(*inputs):
all_cross_attentions,
] if v is not None)

#return dict is not supported at this moment
# return dict is not supported at this moment
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
Expand All @@ -243,6 +246,7 @@ def custom_forward(*inputs):
class BertModelPolicy(Policy):

def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
super().__init__(stage_manager=stage_manager)
self.stage_manager = stage_manager
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)

Expand All @@ -253,11 +257,8 @@ def get_hold_layers(self, module: BertModel) -> List[Module]:
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.embeddings)
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
num_layers_per_stage_accumulated[self.stage_manager.stage]])

start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
hold_layers.extend(module.encoder.layer[start_idx:end_idx])
if self.stage_manager.is_last_stage():
hold_layers.append(module.pooler)

Expand All @@ -270,23 +271,6 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]:
def replace_forward(self, module: Module) -> None:
module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model)

def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages

# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage


def bert_for_pretraining_forward(
self: BertForPreTraining,
Expand All @@ -306,8 +290,8 @@ def bert_for_pretraining_forward(
) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -320,7 +304,8 @@ def bert_for_pretraining_forward(
)

sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if stage_manager.is_last_stage():
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

total_loss = None
if labels is not None and next_sentence_label is not None:
Expand Down Expand Up @@ -355,11 +340,12 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]:
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.bert.embeddings)
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
num_layers_per_stage_accumulated[self.stage_manager.stage]])

start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx])

if self.stage_manager.is_last_stage():
hold_layers.append(module.bert.pooler)
hold_layers.append(module.cls)

return hold_layers
Expand Down
Loading

0 comments on commit 6728bc4

Please sign in to comment.