Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipeline] build bloom model and policy , revise the base class of policy #4161

Merged
merged 14 commits into from
Jul 5, 2023
37 changes: 35 additions & 2 deletions colossalai/pipeline/policy/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any, Dict, List, Optional, Tuple

from colossalai.lazy import LazyTensor
import numpy as np
from torch import Tensor
from torch.nn import Module, Parameter

from colossalai.lazy import LazyTensor
from colossalai.pipeline.stage_manager import PipelineStageManager


class Policy:

def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager

Expand Down Expand Up @@ -93,7 +95,8 @@ def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]:
"""
raise NotImplementedError

def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
def parallelize_model(self,
module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
"""Parallelize model for pipeline parallel
Args:
Expand All @@ -106,3 +109,33 @@ def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[
self.replace_forward(module)
shared_params = self.get_shared_params(module)
return hold_params, hold_buffers, shared_params

@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages

# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

@staticmethod
def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]:
"""
get the start index and end index of layers for each stage.
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

start_idx = num_layers_per_stage_accumulated[stage]
end_idx = num_layers_per_stage_accumulated[stage + 1]

return [start_idx, end_idx]
94 changes: 40 additions & 54 deletions colossalai/pipeline/policy/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,26 @@


def bert_model_forward(
self: BertModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
#labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage
self: BertModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
# labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
# this is from the previous stage
hidden_states: Optional[torch.FloatTensor] = None,
):
#TODO: add explaination of the output here.
# TODO: add explaination of the output here.
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
Expand Down Expand Up @@ -93,6 +94,7 @@ def bert_model_forward(
batch_size, seq_length = input_shape
device = hidden_states.device

# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
Expand Down Expand Up @@ -144,7 +146,7 @@ def bert_model_forward(
else:
encoder_extended_attention_mask = None

#inherit from bert_layer
# inherit from bert_layer
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
Expand All @@ -156,12 +158,12 @@ def bert_model_forward(
use_cache = False
next_decoder_cache = () if use_cache else None

#calculate the num_layers
# calculate the num_layers
num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages
start_layer = stage_manager.stage * num_layers_per_stage
end_layer = (stage_manager.stage + 1) * num_layers_per_stage

#layer_outputs
# layer_outputs
layer_outputs = hidden_states if hidden_states is not None else None
for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer):
if stage_manager.is_first_stage() and idx == 0:
Expand Down Expand Up @@ -206,20 +208,21 @@ def custom_forward(*inputs):
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
all_cross_attentions = all_cross_attentions + \
(layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

#end of a stage loop
# end of a stage loop
sequence_output = layer_outputs[0] if layer_outputs is not None else None

if stage_manager.is_last_stage():
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + layer_outputs[1:]

#output of non-first and non-last stages:
# output of non-first and non-last stages:
if not return_dict:
return tuple(v for v in [
hidden_states,
Expand All @@ -229,7 +232,7 @@ def custom_forward(*inputs):
all_cross_attentions,
] if v is not None)

#return dict is not supported at this moment
# return dict is not supported at this moment
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
Expand All @@ -243,6 +246,7 @@ def custom_forward(*inputs):
class BertModelPolicy(Policy):

def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
super().__init__(stage_manager=stage_manager)
self.stage_manager = stage_manager
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)

Expand All @@ -253,11 +257,8 @@ def get_hold_layers(self, module: BertModel) -> List[Module]:
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.embeddings)
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
num_layers_per_stage_accumulated[self.stage_manager.stage]])

start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
hold_layers.extend(module.encoder.layer[start_idx:end_idx])
if self.stage_manager.is_last_stage():
hold_layers.append(module.pooler)

Expand All @@ -270,23 +271,6 @@ def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]:
def replace_forward(self, module: Module) -> None:
module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model)

def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages

# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage


def bert_for_pretraining_forward(
self: BertForPreTraining,
Expand All @@ -306,8 +290,8 @@ def bert_for_pretraining_forward(
) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -320,7 +304,8 @@ def bert_for_pretraining_forward(
)

sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if stage_manager.is_last_stage():
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

total_loss = None
if labels is not None and next_sentence_label is not None:
Expand Down Expand Up @@ -355,11 +340,12 @@ def get_hold_layers(self, module: BertForPreTraining) -> List[Module]:
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.bert.embeddings)
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
num_layers_per_stage_accumulated[self.stage_manager.stage]])

start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx])

if self.stage_manager.is_last_stage():
hold_layers.append(module.bert.pooler)
hold_layers.append(module.cls)

return hold_layers
Expand Down
Loading
Loading