Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hotfix] Fix accuracy and align attention method api with Triton kernel #5229

Merged
merged 7 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 116 additions & 83 deletions colossalai/inference/modeling/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb


def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
"""
Func: copy key/value into key/value cache.

Args: key/value(source): shape [bsz,seq_len,num_heads,head_size]
cache: shape [num_blocks, num_heads, head_size, block_size]
cache: shape [num_blocks, num_kv_heads, head_size, block_size]
lengths: key/value lengths
block_tables
"""
num_blocks, num_heads, head_size, block_size = cache.shape
bsz, max_seq_len = block_tables.shape
bsz, max_blocks_per_seq = block_tables.shape
needed_blocks = (lengths + block_size - 1) // block_size

if type == "prefill":
Expand All @@ -29,7 +27,9 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
for block_idx in range(block_num - 1):
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
token_id += block_size
cache[block_tables[i][block_num - 1]] = source[i][token_id:seq_len].permute(1, 2, 0)
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
1, 2, 0
)
elif type == "decoding":
assert len(source[0]) == 1, "seq_len should be equal to 1 when decoding."
source = source.squeeze(1)
Expand All @@ -40,56 +40,49 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
return cache


def convert_kvcache(source, cache, lengths, block_tables):
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation

Args: key/value(source): shape [bsz, 1, num_heads, head_size]
cache: shape [num_blocks, num_heads, head_size, block_size]
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
lengths: key/value length
block_tables
pad_id: padded_id
"""
num_blocks, num_heads, head_size, block_size = cache.shape

needed_blocks = (lengths + block_size - 1) // block_size
num_remaing_tokens = (lengths - 1) % block_size
num_remaing_tokens = lengths % block_size
num_remaing_tokens[num_remaing_tokens == 0] += block_size
bsz = block_tables.shape[0]
seq_len = max(lengths)
padded_cache = []
for i in range(bsz):
_cache = torch.cat(
(
cache[block_tables[i][: needed_blocks[i] - 1]].permute((3, 0, 1, 2)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 1, 0),
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1),
),
dim=0,
)
concat_cache = torch.cat((_cache, source[i]), dim=0)
padding = seq_len - concat_cache.size(0)
padding = seq_len - _cache.size(0)
if padding > 0:
concat_cache = F.pad(concat_cache, (0, 0, 0, 0, 0, 1))
padded_cache.append(concat_cache)

_cache = F.pad(_cache, (0, 0, 0, 0, 0, 1), value=pad_id)
padded_cache.append(_cache)
return torch.stack(padded_cache, dim=0)


class PagedAttention(nn.Module):
class PagedAttention:
"""
Pure Torch implementation version of paged_attention.
Holds different types of forward function and useful components.
"""

def __init__(self, num_heads: int, head_size: int, scale: float = 1.0, sliding_window: Optional[int] = None):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.sliding_window = sliding_window
self._init_rope()

def _init_rope(self):
self.rotary_emb = LlamaRotaryEmbedding(self.head_size)

def pad_and_reshape(self, tensor, seq_lengths, max_seq_len, num_heads, head_size):
@staticmethod
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
"""
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
"""
bsz = len(seq_lengths)
padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size)

Expand All @@ -100,22 +93,49 @@ def pad_and_reshape(self, tensor, seq_lengths, max_seq_len, num_heads, head_size
token_idx += seq_len
return padded_tensor

def generate_padding_mask(self, lengths, max_seq_len):
@staticmethod
def generate_padding_mask(lengths, max_seq_len):
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
padding_mask = range_tensor < lengths.unsqueeze(1)
return padding_mask

@staticmethod
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
"""
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
Args: hidden_states(batch, num_key_value_heads, seqlen, head_dim)
n_rep: times of repeatition.
Output: hidden_states (batch, num_attention_heads, seqlen, head_dim)
"""
if n_rep == 1:
return hidden_states

batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape
num_attention_heads = n_rep * num_key_value_heads
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)

return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)

@staticmethod
def nopad_context_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor,
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
):
"""
NOTE: q,k,v are projected and applied rotary embedding, all aligned with triton version.
"""
# Fisrt, do shape verification
num_tokens, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]

assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads

block_size = k_cache.shape[-1]
bsz, max_blocks_per_sequence = block_tables.shape
max_seq_len = max_blocks_per_sequence * block_size
Expand All @@ -124,143 +144,156 @@ def nopad_context_forward(
assert context_lengths.shape[0] == block_tables.shape[0]
shape = (bsz, max_seq_len, num_heads, head_size)
input_shape = shape[:2]
query = self.pad_and_reshape(q, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)
key = self.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)
value = self.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size).transpose(1, 2)

attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
self.generate_padding_mask(context_lengths, max_seq_len)
q = PagedAttention.pad_and_reshape(
q, context_lengths, max_seq_len, num_heads, head_size
) # bsz,seqlen,num_heads,head_size
k = PagedAttention.pad_and_reshape(k, context_lengths, max_seq_len, num_heads, head_size)
v = PagedAttention.pad_and_reshape(v, context_lengths, max_seq_len, num_heads, head_size)

position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device)
position_ids = position_ids.unsqueeze(0)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)

cos, sin = self.rotary_emb(value, max_seq_len)
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, max_seq_len)

copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(value.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables)
q = q.transpose(1, 2)
k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)

attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
# position_ids = torch.arange(0, max_seq_len, dtype=torch.long, device=query.device)
# position_ids = position_ids.unsqueeze(0)
# cos, sin = self.rotary_emb(value, max_seq_len)
# query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)

attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
if attn_weights.size() != (bsz, num_heads, max_seq_len, max_seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,max_seq_len,max_seq_len)}.")

if attn_mask is not None:
attn_weights += attn_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
attn_output = torch.matmul(attn_weights, value)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)

if attn_output.size() != (bsz, num_heads, max_seq_len, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,max_seq_len,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, max_seq_len, -1)

del attn_weights

return attn_output

@staticmethod
def pad_context_forward(
self,
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor,
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
):
# Firt, do shape verification
bsz, seq_len, num_heads, head_size = q.shape
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size
shape = (bsz, seq_len, num_heads, head_size)
input_shape = shape[:2]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

position_ids = torch.arange(0, seq_len, dtype=torch.long, device=q.device)
position_ids = position_ids.unsqueeze(0)
cos, sin = self.rotary_emb(v, seq_len)
query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
# Copy kv to memory(rotary embedded)
copy_to_cache(k, k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v, v_cache, lengths=context_lengths, block_tables=block_tables)

copy_to_cache(key.transpose(1, 2), k_cache, lengths=context_lengths, block_tables=block_tables)
copy_to_cache(v.transpose(1, 2), v_cache, lengths=context_lengths, block_tables=block_tables)
q = q.transpose(1, 2)
k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)

attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
attn_mask = AttentionMaskConverter._make_causal_mask(input_shape, q.dtype, q.device, past_key_values_length=0)
self.generate_padding_mask(context_lengths, seq_len)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(context_lengths, seq_len)

if attn_weights.size() != (bsz, num_heads, seq_len, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,seq_len,seq_len)}.")
if attn_mask is not None:
attn_weights += attn_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)

# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)

if attn_output.size() != (bsz, num_heads, seq_len, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,seq_len,head_size)}.")

attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)

del attn_weights

return attn_output

@staticmethod
def pad_decoding_forward(
self,
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor,
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
):
# Firt, do shape verification.
bsz, _, num_heads, head_size = q.shape

num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
seq_len = max(lengths)

assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
max_seq_len = block_tables.shape[-1] * block_size
block_tables.shape[-1] * block_size

attn_mask = AttentionMaskConverter._make_causal_mask(
q.shape[:2], q.dtype, q.device, past_key_values_length=seq_len - 1
)
self.generate_padding_mask(lengths, max_seq_len)
cos, sin = self.rotary_emb(v, max_seq_len)

position_ids = lengths - 1
position_ids = position_ids.unsqueeze(1)

query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2)
attn_mask = attn_mask + PagedAttention.generate_padding_mask(lengths, seq_len).unsqueeze(1).unsqueeze(2)
# cos, sin = self.rotary_emb(v, max_seq_len)
# position_ids = lengths - 1
# position_ids = position_ids.unsqueeze(1)
# query, key = apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=2)

copy_to_cache(key, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
copy_to_cache(k, k_cache, lengths=lengths, block_tables=block_tables, type="decoding")
copy_to_cache(v, v_cache, lengths=lengths, block_tables=block_tables, type="decoding")

key = convert_kvcache(key, k_cache, lengths, block_tables) # bsz, seqlen,
value = convert_kvcache(v, v_cache, lengths, block_tables)
k = convert_kvcache(k_cache, lengths, block_tables) # bsz, seqlen,
v = convert_kvcache(v_cache, lengths, block_tables)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
q = q.transpose(1, 2)
k = PagedAttention.repeat_kv(k.transpose(1, 2), num_kv_groups)
v = PagedAttention.repeat_kv(v.transpose(1, 2), num_kv_groups)

attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_size)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")

if attn_mask is not None:
attn_weights += attn_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# attn_weights = nn.functional.dropout(attn_weights,p=self.attention_dropout,training=False) maybe useless
attn_output = torch.matmul(attn_weights, value)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)

if attn_output.size() != (bsz, num_heads, 1, head_size):
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)

del attn_weights

return attn_output

@staticmethod
def no_pad_decoding_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size]
Expand Down
3 changes: 2 additions & 1 deletion tests/test_infer/test_config_and_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import BatchInfo, Sequence
from colossalai.testing import spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn


def check_config_and_inference():
Expand Down Expand Up @@ -73,6 +73,7 @@ def run_dist(rank, world_size, port):


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_config_and_inference():
spawn(run_dist, 1)

Expand Down
Loading
Loading