Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

combine two pipeline wrapper #56

Merged
merged 1 commit into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion energon/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from energon.utils import ensure_directory_exists
from energon.logging import get_dist_logger
from energon.nn import PipelineCommWrapper


class InferenceEngine(Module):
Expand Down
23 changes: 16 additions & 7 deletions energon/engine/gpt_pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _init_tensor_meta(self, sample):
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
output = self.model(hidden_states=None, input_ids=input_tensor, attention_mask=sample['attention_mask'])
output = self.model(hidden_states=input_tensor, input_ids=input_tensor, attention_mask=sample['attention_mask'])
send_tensor_meta(output)
send_forward(output)

Expand Down Expand Up @@ -92,9 +92,15 @@ def run_without_pp(self, key, inputs):
'''

def fill_meta_tensor(self, inputs, pipe_meta):
pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1]
if 'seq_lens' in inputs:
pipe_meta.get_meta_tensor()[0] = 1
pipe_meta.get_meta_tensor()[1] = 1
pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens'])
else:
pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1]

pipe_meta.get_meta_tensor()[3] = self.hidden_size
pipe_meta.update_meta()

Expand All @@ -113,7 +119,8 @@ def run_with_pp(self, key, inputs):
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(hidden_states=None,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'])
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)

send_forward(output)
self.lock.release()
Expand All @@ -125,7 +132,8 @@ def run_with_pp(self, key, inputs):
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(hidden_states=input_tensor,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'])
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
self.lock.release()
return output, cur_key

Expand All @@ -134,7 +142,8 @@ def run_with_pp(self, key, inputs):
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(hidden_states=input_tensor,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'])
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
send_forward(output)
self.lock.release()
return None
149 changes: 149 additions & 0 deletions energon/engine/pipeline_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import inspect
import threading

import torch
import torch.nn as nn
import torch.distributed as dist
from typing import List, Tuple, Union

from energon.communication import send_forward, recv_forward, send_tensor_meta, recv_tensor_meta
from energon.context import ParallelMode
from energon.core import global_context as gpc

from .pipeline_meta import PipelineMeta
from .pipeline_msg_dict import PipelineMsgDict, CircleInt # PipelineMsgPriorityQueue


# The Wrapper is only for Transformer Model.
class PipelineCommWrapper:
def __init__(self,
model: nn.Module,
max_batch_size: int = 1,
dtype=torch.float) -> None:
# TODO (dujiangsu): to make sample capability for different types. Iteration, Tensor, and others.
self.model = model
self.dtype = dtype

self.tensor_dim = 0
self.hidden_size = 0
self.max_batch_size = max_batch_size

if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda()
attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda()
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
self._init_tensor_meta(sample)

self.pipe_msg_queue = PipelineMsgDict()
self.lock = threading.Lock()
self.key = CircleInt()

def _init_tensor_meta(self, sample):

with torch.inference_mode():
recv_tensor_shape = None
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(hidden_states=None, input_ids=sample['input_ids'],
attention_mask=sample['attention_mask']) # ([32, 512, 1600])
send_tensor_meta(output)
send_forward(output)
self.tensor_dim = output.dim()
self.hidden_size = output.size()[-1]
elif gpc.is_last_rank(ParallelMode.PIPELINE):
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
else:
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
output = self.model(hidden_states=input_tensor, input_ids=input_tensor, attention_mask=sample['attention_mask'])
send_tensor_meta(output)
send_forward(output)

def run(self, key, inputs):
if gpc.is_initialized(ParallelMode.PIPELINE):
return self.run_with_pp(key, inputs)
else:
return self.run_without_pp(key, inputs)

def run_without_pp(self, key, inputs):
pipe_meta = None
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()

cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()
output = self.model(hidden_states=None,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'])
self.lock.release()

return output, cur_key

'''
hidden_size : ([32, 512, 1600])
For different model type, fill_meta_tensor is different
'''

def fill_meta_tensor(self, inputs, pipe_meta):
if 'seq_lens' in inputs:
pipe_meta.get_meta_tensor()[0] = 1
pipe_meta.get_meta_tensor()[1] = 1
pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens'])
else:
pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1]

pipe_meta.get_meta_tensor()[3] = self.hidden_size
pipe_meta.update_meta()

def run_with_pp(self, key, inputs):
pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size)
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()
cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()

with torch.inference_mode():

if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(hidden_states=None,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)

send_forward(output)
self.lock.release()
return None

if gpc.is_last_rank(ParallelMode.PIPELINE):

# print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}')
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(hidden_states=input_tensor,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
self.lock.release()
return output, cur_key

else:

input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(hidden_states=input_tensor,
input_ids=sample['input_ids'],
attention_mask=sample['attention_mask'],
seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
send_forward(output)
self.lock.release()
return None
12 changes: 2 additions & 10 deletions energon/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
from energon.core import global_context as gpc
from energon.context import ParallelMode
from .rpc_utils import remote_cls_method, sync_cls_method, async_cls_method
from .gpt_pipeline_wrapper import GPTPipelineCommWrapper
from .bert_pipeline_wrapper import BertPipelineCommWrapper

WRAPPER_TYPES = {
"gpt": GPTPipelineCommWrapper,
"bert": BertPipelineCommWrapper,
}

from .pipeline_wrapper import PipelineCommWrapper

class ReturnDict:
def __init__(self):
Expand All @@ -41,7 +34,6 @@ def __init__(self,
self.model_config = model_config
self.dtype = dtype
self.max_batch_size = max_batch_size
self.pipe_wrapper = WRAPPER_TYPES[model_type]

self.WORKER_NAME = "wok{}"
self.model = None # call the model
Expand All @@ -62,7 +54,7 @@ def _init_self(self):
# print("Pass")
self.model.eval()

self.model = self.pipe_wrapper(model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)
self.model = PipelineCommWrapper(model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)

def run(self, key, inputs):
# print("key: {}".format(key), flush=True)
Expand Down
25 changes: 11 additions & 14 deletions examples/gpt/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self,
self.softmax = nn.Softmax(dim=-1)
self.dense = Linear1D_Row(dim, dim, bias=True, dtype=dtype, parallel_input=True)

def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None):
def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None, valid_word_num=None):
qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3
num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads
Expand All @@ -115,14 +115,15 @@ def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None
causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
device=get_current_device())).view(1, 1, q_len, k_len).bool()
x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))

if attention_mask is not None:
x = x + attention_mask
x = self.softmax(x)

x = torch.matmul(x, v)

if seq_lens is not None:
x = transpose_depad(x, batch_size, valid_word_num[0].item(), max_padding_size, seq_lens, num_attention_heads, self.attention_head_size)
x = transpose_depad(x, batch_size, valid_word_num, max_padding_size, seq_lens, num_attention_heads, self.attention_head_size)
else:
x = x.transpose(1, 2)

Expand All @@ -142,6 +143,7 @@ def __init__(self,
dtype: dtype = None,
bias: bool = True):
super().__init__()

intermediate_dim = int(dim * mlp_ratio)
self.dense_1 = Linear1D_Col(dim, intermediate_dim, bias=bias, dtype=dtype, gather_output=False)
self.activation = activation
Expand Down Expand Up @@ -179,13 +181,13 @@ def __init__(self,
self.norm2 = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon)
self.mlp = GPTMLP1D(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dtype=dtype, bias=bias)

def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None):
def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None, valid_word_num=None):
if not self.apply_post_layernorm:
residual = x
x = self.norm1(x)
if self.apply_post_layernorm:
residual = x
x = residual + self.attn(x, attention_mask, batch_size, max_padding_size, seq_lens)
x = residual + self.attn(x, attention_mask, batch_size, max_padding_size, seq_lens, valid_word_num)

if not self.apply_post_layernorm:
residual = x
Expand Down Expand Up @@ -295,30 +297,25 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_l
if seq_lens is not None:
hidden_states = ft_remove_padding(hidden_states, self.tmp_mask_offset,
self.mask_offset, self.valid_word_num[0].item(), self.dim)
elif seq_lens is not None:
ft_remove_padding(hidden_states, self.tmp_mask_offset,
self.mask_offset, self.valid_word_num[0].item(), self.dim)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface

if attention_mask is not None:
if self.first:
batch_size = input_ids.shape[0]
else:
batch_size = hidden_states.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0


for block in self.blocks:
hidden_states = block(hidden_states, attention_mask, batch_size, max_padding_size, seq_lens)
hidden_states = block(hidden_states, attention_mask, batch_size, max_padding_size, seq_lens, self.valid_word_num[0].item())

if self.last:
if seq_lens is not None:
hidden_states = ft_rebuild_padding(hidden_states, self.mask_offset, self.valid_word_num[0].item(), self.dim, batch_size, max_padding_size)
hidden_states = ft_rebuild_padding(hidden_states, self.tmp_mask_offset[0:self.valid_word_num[0].item()], self.valid_word_num[0].item(), self.dim, batch_size, max_padding_size)
hidden_states = self.head(self.norm(hidden_states))
# res = []
# for i in range(hidden_states.shape[0]):
Expand Down Expand Up @@ -420,4 +417,4 @@ def gpt2_8B(**kwargs):

def gpt3(**kwargs):
model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
return _create_gpt_pipeline_model(**model_kwargs)
return _create_gpt_pipeline_model(**model_kwargs)