Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #56 from hpcaitech/feature/accumulate
Browse files Browse the repository at this point in the history
combine two pipeline wrapper
  • Loading branch information
MaruyamaAya authored May 11, 2022
2 parents 0d7caf1 + 5201e01 commit b493ee3
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 32 deletions.
1 change: 0 additions & 1 deletion energon/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from energon.utils import ensure_directory_exists
from energon.logging import get_dist_logger
from energon.nn import PipelineCommWrapper


class InferenceEngine(Module):
Expand Down
23 changes: 16 additions & 7 deletions energon/engine/gpt_pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _init_tensor_meta(self, sample):
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
output = self.model(hidden_states=None, input_ids=input_tensor, attention_mask=sample['attention_mask'])
output = self.model(hidden_states=input_tensor, input_ids=input_tensor, attention_mask=sample['attention_mask'])
send_tensor_meta(output)
send_forward(output)

Expand Down Expand Up @@ -92,9 +92,15 @@ def run_without_pp(self, key, inputs):
'''

def fill_meta_tensor(self, inputs, pipe_meta):
pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1]
if 'seq_lens' in inputs:
pipe_meta.get_meta_tensor()[0] = 1
pipe_meta.get_meta_tensor()[1] = 1
pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens'])
else:
pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1]

pipe_meta.get_meta_tensor()[3] = self.hidden_size
pipe_meta.update_meta()

Expand All @@ -113,7 +119,8 @@ def run_with_pp(self, key, inputs):
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(hidden_states=None,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'])
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)

send_forward(output)
self.lock.release()
Expand All @@ -125,7 +132,8 @@ def run_with_pp(self, key, inputs):
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(hidden_states=input_tensor,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'])
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
self.lock.release()
return output, cur_key

Expand All @@ -134,7 +142,8 @@ def run_with_pp(self, key, inputs):
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(hidden_states=input_tensor,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'])
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
send_forward(output)
self.lock.release()
return None
149 changes: 149 additions & 0 deletions energon/engine/pipeline_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import inspect
import threading

import torch
import torch.nn as nn
import torch.distributed as dist
from typing import List, Tuple, Union

from energon.communication import send_forward, recv_forward, send_tensor_meta, recv_tensor_meta
from energon.context import ParallelMode
from energon.core import global_context as gpc

from .pipeline_meta import PipelineMeta
from .pipeline_msg_dict import PipelineMsgDict, CircleInt # PipelineMsgPriorityQueue


# The Wrapper is only for Transformer Model.
class PipelineCommWrapper:
def __init__(self,
model: nn.Module,
max_batch_size: int = 1,
dtype=torch.float) -> None:
# TODO (dujiangsu): to make sample capability for different types. Iteration, Tensor, and others.
self.model = model
self.dtype = dtype

self.tensor_dim = 0
self.hidden_size = 0
self.max_batch_size = max_batch_size

if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda()
attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda()
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
self._init_tensor_meta(sample)

self.pipe_msg_queue = PipelineMsgDict()
self.lock = threading.Lock()
self.key = CircleInt()

def _init_tensor_meta(self, sample):

with torch.inference_mode():
recv_tensor_shape = None
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(hidden_states=None, input_ids=sample['input_ids'],
attention_mask=sample['attention_mask']) # ([32, 512, 1600])
send_tensor_meta(output)
send_forward(output)
self.tensor_dim = output.dim()
self.hidden_size = output.size()[-1]
elif gpc.is_last_rank(ParallelMode.PIPELINE):
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
else:
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
output = self.model(hidden_states=input_tensor, input_ids=input_tensor, attention_mask=sample['attention_mask'])
send_tensor_meta(output)
send_forward(output)

def run(self, key, inputs):
if gpc.is_initialized(ParallelMode.PIPELINE):
return self.run_with_pp(key, inputs)
else:
return self.run_without_pp(key, inputs)

def run_without_pp(self, key, inputs):
pipe_meta = None
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()

cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()
output = self.model(hidden_states=None,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'])
self.lock.release()

return output, cur_key

'''
hidden_size : ([32, 512, 1600])
For different model type, fill_meta_tensor is different
'''

def fill_meta_tensor(self, inputs, pipe_meta):
if 'seq_lens' in inputs:
pipe_meta.get_meta_tensor()[0] = 1
pipe_meta.get_meta_tensor()[1] = 1
pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens'])
else:
pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1]

pipe_meta.get_meta_tensor()[3] = self.hidden_size
pipe_meta.update_meta()

def run_with_pp(self, key, inputs):
pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size)
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()
cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()

with torch.inference_mode():

if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(hidden_states=None,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)

send_forward(output)
self.lock.release()
return None

if gpc.is_last_rank(ParallelMode.PIPELINE):

# print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}')
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(hidden_states=input_tensor,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
self.lock.release()
return output, cur_key

else:

input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(hidden_states=input_tensor,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
send_forward(output)
self.lock.release()
return None
12 changes: 2 additions & 10 deletions energon/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
from energon.core import global_context as gpc
from energon.context import ParallelMode
from .rpc_utils import remote_cls_method, sync_cls_method, async_cls_method
from .gpt_pipeline_wrapper import GPTPipelineCommWrapper
from .bert_pipeline_wrapper import BertPipelineCommWrapper

WRAPPER_TYPES = {
"gpt": GPTPipelineCommWrapper,
"bert": BertPipelineCommWrapper,
}

from .pipeline_wrapper import PipelineCommWrapper

class ReturnDict:
def __init__(self):
Expand All @@ -41,7 +34,6 @@ def __init__(self,
self.model_config = model_config
self.dtype = dtype
self.max_batch_size = max_batch_size
self.pipe_wrapper = WRAPPER_TYPES[model_type]

self.WORKER_NAME = "wok{}"
self.model = None # call the model
Expand All @@ -62,7 +54,7 @@ def _init_self(self):
# print("Pass")
self.model.eval()

self.model = self.pipe_wrapper(model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)
self.model = PipelineCommWrapper(model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)

def run(self, key, inputs):
# print("key: {}".format(key), flush=True)
Expand Down
25 changes: 11 additions & 14 deletions examples/gpt/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self,
self.softmax = nn.Softmax(dim=-1)
self.dense = Linear1D_Row(dim, dim, bias=True, dtype=dtype, parallel_input=True)

def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None):
def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None, valid_word_num=None):
qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3
num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads
Expand All @@ -115,14 +115,15 @@ def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None
causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
device=get_current_device())).view(1, 1, q_len, k_len).bool()
x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))

if attention_mask is not None:
x = x + attention_mask
x = self.softmax(x)

x = torch.matmul(x, v)

if seq_lens is not None:
x = transpose_depad(x, batch_size, valid_word_num[0].item(), max_padding_size, seq_lens, num_attention_heads, self.attention_head_size)
x = transpose_depad(x, batch_size, valid_word_num, max_padding_size, seq_lens, num_attention_heads, self.attention_head_size)
else:
x = x.transpose(1, 2)

Expand All @@ -142,6 +143,7 @@ def __init__(self,
dtype: dtype = None,
bias: bool = True):
super().__init__()

intermediate_dim = int(dim * mlp_ratio)
self.dense_1 = Linear1D_Col(dim, intermediate_dim, bias=bias, dtype=dtype, gather_output=False)
self.activation = activation
Expand Down Expand Up @@ -179,13 +181,13 @@ def __init__(self,
self.norm2 = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon)
self.mlp = GPTMLP1D(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dtype=dtype, bias=bias)

def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None):
def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None, valid_word_num=None):
if not self.apply_post_layernorm:
residual = x
x = self.norm1(x)
if self.apply_post_layernorm:
residual = x
x = residual + self.attn(x, attention_mask, batch_size, max_padding_size, seq_lens)
x = residual + self.attn(x, attention_mask, batch_size, max_padding_size, seq_lens, valid_word_num)

if not self.apply_post_layernorm:
residual = x
Expand Down Expand Up @@ -295,30 +297,25 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_l
if seq_lens is not None:
hidden_states = ft_remove_padding(hidden_states, self.tmp_mask_offset,
self.mask_offset, self.valid_word_num[0].item(), self.dim)
elif seq_lens is not None:
ft_remove_padding(hidden_states, self.tmp_mask_offset,
self.mask_offset, self.valid_word_num[0].item(), self.dim)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface

if attention_mask is not None:
if self.first:
batch_size = input_ids.shape[0]
else:
batch_size = hidden_states.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0


for block in self.blocks:
hidden_states = block(hidden_states, attention_mask, batch_size, max_padding_size, seq_lens)
hidden_states = block(hidden_states, attention_mask, batch_size, max_padding_size, seq_lens, self.valid_word_num[0].item())

if self.last:
if seq_lens is not None:
hidden_states = ft_rebuild_padding(hidden_states, self.mask_offset, self.valid_word_num[0].item(), self.dim, batch_size, max_padding_size)
hidden_states = ft_rebuild_padding(hidden_states, self.tmp_mask_offset[0:self.valid_word_num[0].item()], self.valid_word_num[0].item(), self.dim, batch_size, max_padding_size)
hidden_states = self.head(self.norm(hidden_states))
# res = []
# for i in range(hidden_states.shape[0]):
Expand Down Expand Up @@ -420,4 +417,4 @@ def gpt2_8B(**kwargs):

def gpt3(**kwargs):
model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
return _create_gpt_pipeline_model(**model_kwargs)
return _create_gpt_pipeline_model(**model_kwargs)

0 comments on commit b493ee3

Please sign in to comment.