Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. #5365

Merged
merged 9 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ def _shardformer(
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.

Returns:
nn.Module: _description_
isky-cd marked this conversation as resolved.
Show resolved Hide resolved
nn.Module: The model optimized by Shardformer.
"""

shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
Expand Down Expand Up @@ -149,25 +150,25 @@ def generate(
Returns:
List[str]: Inference result returned by one generation.
"""
with torch.no_grad():
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)

self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)

output_seqs_list = []
output_tokens_list = []
output_seqs_list = []
output_tokens_list = []

while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()

output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))

for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id)
for seq in output_seqs_list:
output_tokens_list.append(seq.input_token_id + seq.output_token_id)

output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)

return output_str
return output_str

def add_request(
self,
Expand Down
9 changes: 0 additions & 9 deletions colossalai/inference/modeling/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from transformers.modeling_attn_mask_utils import AttentionMaskConverter


@torch.no_grad
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
"""
Func: copy key/value into key/value cache.
Expand Down Expand Up @@ -41,7 +40,6 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
return cache


@torch.no_grad
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation
Expand Down Expand Up @@ -81,7 +79,6 @@ class PagedAttention:
"""

@staticmethod
@torch.no_grad
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
"""
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
Expand All @@ -97,14 +94,12 @@ def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
return padded_tensor

@staticmethod
@torch.no_grad
def generate_padding_mask(lengths, max_seq_len):
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
padding_mask = range_tensor < lengths.unsqueeze(1)
return padding_mask

@staticmethod
@torch.no_grad
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
"""
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
Expand All @@ -122,7 +117,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)

@staticmethod
@torch.no_grad
def nopad_context_forward(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
Expand Down Expand Up @@ -191,7 +185,6 @@ def nopad_context_forward(
return attn_output

@staticmethod
@torch.no_grad
def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
Expand Down Expand Up @@ -249,7 +242,6 @@ def pad_context_forward(
return attn_output

@staticmethod
@torch.no_grad
def pad_decoding_forward(
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
Expand Down Expand Up @@ -306,7 +298,6 @@ def pad_decoding_forward(
return attn_output

@staticmethod
@torch.no_grad
def no_pad_decoding_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size]
Expand Down
36 changes: 15 additions & 21 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")


@torch.no_grad()
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchInfo = None,
Expand All @@ -60,7 +59,6 @@ def llama_causal_lm_forward(
return logits


@torch.no_grad()
def llama_model_forward(
self: LlamaModel,
batch: BatchInfo = None,
Expand Down Expand Up @@ -117,12 +115,11 @@ def llama_model_forward(
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
norm_output = torch.empty_like(hidden_states)
hidden_states = self.norm(hidden_states, norm_output)
hidden_states, _ = self.norm(hidden_states, norm_output)

return hidden_states


@torch.no_grad()
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
Expand All @@ -141,7 +138,7 @@ def llama_decoder_layer_forward(
"""This function will replace the forward function of LlamaDecoderLayer.

Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
Expand All @@ -156,9 +153,8 @@ def llama_decoder_layer_forward(
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states, norm_output)
hidden_states, residual = self.input_layernorm(hidden_states, norm_output)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
Expand All @@ -176,8 +172,7 @@ def llama_decoder_layer_forward(
)

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states, norm_output)
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output)
hidden_states = self.mlp(hidden_states, residual)

return hidden_states
Expand Down Expand Up @@ -242,7 +237,6 @@ def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttentio
return attn_layer

# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -260,8 +254,8 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
Expand Down Expand Up @@ -347,9 +341,10 @@ def __init__(
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__(config)
self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False)
self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False)
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False)
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
self.gate_proj = None
self.up_proj = None

@staticmethod
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
Expand All @@ -373,15 +368,14 @@ def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:

return mlp_layer

@torch.no_grad()
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj.
"""
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight)
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight)
tmp_out = act_out * up_proj_out
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
tmp_out = act_out * gate_up_proj_out[1]
return torch.addmm(residual, tmp_out, self.down_proj.weight)
14 changes: 4 additions & 10 deletions colossalai/inference/modeling/models/padding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed


@torch.no_grad()
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchInfo = None,
Expand All @@ -78,7 +77,6 @@ def llama_causal_lm_forward(
return logits


@torch.no_grad()
def llama_model_forward(
self: LlamaModel,
batch: BatchInfo = None,
Expand Down Expand Up @@ -163,7 +161,6 @@ def llama_model_forward(
return hidden_states


@torch.no_grad()
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
Expand All @@ -184,7 +181,7 @@ def llama_decoder_layer_forward(
"""This function will replace the forward function of LlamaDecoderLayer.

Args:
hidden_states (torch.Tensor): _description_
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
position_ids (torch.LongTensor), The position ids of input sequences.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
Expand Down Expand Up @@ -282,7 +279,6 @@ def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttentio

return attn_layer

@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -301,15 +297,15 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]
position_ids (torch.LongTensor), The position ids of input sequences.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)`
attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size [batch_size, seq_len]
where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
Expand Down Expand Up @@ -418,12 +414,11 @@ def forward(
return attn_output


@torch.no_grad()
def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
"""Generate padding position_id through attention mask.

Args:
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
attention_mask (`torch.Tensor` of shape [batch_size, sequence_length]:
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

Returns:
Expand All @@ -434,7 +429,6 @@ def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor:
return position_ids


@torch.no_grad()
def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor):
"""Convert padding input to nopad input.

Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def greedy_sample(
"""
Sample tokens greedyly.
"""
results = torch.argmax(logprobs, dim=-1).cpu()
results = torch.argmax(logprobs, dim=-1)
return results


Expand Down
9 changes: 4 additions & 5 deletions colossalai/kernel/triton/flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def flash_decoding_attention(
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
grid_1 = (triton.next_power_of_2(bsz), num_heads)
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

_flash_decoding_fwd_kernel[grid](
q,
k_cache,
Expand Down Expand Up @@ -293,11 +296,7 @@ def flash_decoding_attention(
HEAD_DIM=head_dim,
)

output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output

grid = (triton.next_power_of_2(bsz), num_heads)

_flash_decoding_fwd_reduce_kernel[grid](
_flash_decoding_fwd_reduce_kernel[grid_1](
mid_output,
mid_output_lse,
output,
Expand Down
1 change: 0 additions & 1 deletion colossalai/kernel/triton/fused_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def fused_rotary_emb(
)


@torch.no_grad()
def fused_rotary_embedding(
q: torch.Tensor,
k: torch.Tensor,
Expand Down
1 change: 0 additions & 1 deletion colossalai/kernel/triton/no_pad_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def fused_rotary_embedding_kernel(
)


@torch.no_grad()
def rotary_embedding(
q: torch.Tensor,
k: torch.Tensor,
Expand Down
3 changes: 1 addition & 2 deletions colossalai/kernel/triton/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def _rmsnorm_kernel(
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)

@torch.no_grad()
def rms_layernorm(x, weight, eps, norm_output=None):
# allocate output
y = torch.empty_like(x) if norm_output is None else norm_output
Expand All @@ -66,4 +65,4 @@ def rms_layernorm(x, weight, eps, norm_output=None):

# enqueue kernel
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y
return y, x
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
1 change: 0 additions & 1 deletion colossalai/kernel/triton/rotary_cache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def decoding_cache_kernel(
)


@torch.no_grad()
def get_xine_cache(lengths: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, is_prompts: bool = False):
"""
Transform cos/sin cache into no pad sequence, with two different modes.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_infer/test_ops/triton/test_rmsnorm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_layer_norm(M, N):
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")

y_triton = rms_layernorm(x, weight, eps=eps)
y_triton, _ = rms_layernorm(x, weight, eps=eps)
y_llama = rms_norm.forward(x).to(dtype)

assert y_triton.shape == y_llama.shape
Expand Down
Loading