From ea4de024092b924bd18a60e6092ab9d3a955da5c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 25 Jul 2024 12:39:48 +0000 Subject: [PATCH 01/13] Allow vllm to still work if triton is not installed. Signed-off-by: Thomas Parnell --- requirements-cpu.txt | 1 - vllm/attention/ops/prefix_prefill.py | 1484 ++++++++--------- .../layers/fused_moe/fused_moe.py | 101 +- vllm/model_executor/layers/ops/rand.py | 26 +- vllm/model_executor/layers/ops/sample.py | 32 +- vllm/triton_utils/__init__.py | 5 +- vllm/triton_utils/custom_cache_manager.py | 67 +- vllm/triton_utils/importing.py | 26 + vllm/triton_utils/mock_tl.py | 6 + vllm/triton_utils/mock_triton.py | 13 + 10 files changed, 904 insertions(+), 857 deletions(-) create mode 100644 vllm/triton_utils/importing.py create mode 100644 vllm/triton_utils/mock_tl.py create mode 100644 vllm/triton_utils/mock_triton.py diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 754070df21c0..b3e18b1bdb60 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -4,4 +4,3 @@ # Dependencies for x86_64 CPUs torch == 2.3.1+cpu; platform_machine != "ppc64le" torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch -triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 4577d84db18a..5885aeb0d214 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -2,767 +2,700 @@ # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py import torch -import triton -import triton.language as tl from vllm.platforms import current_platform +from vllm.triton_utils import maybe_import_triton + +triton, tl = maybe_import_triton() + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) # [N] + # [D,N] + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) -if triton.__version__ >= "2.1.0": - - @triton.jit - def _fwd_kernel( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len - - # start position inside of the query - # generally, N goes over kv, while M goes over query_len - block_start_loc = BLOCK_M * start_m - - # initialize offsets - # [N]; starts at 0 - offs_n = tl.arange(0, BLOCK_N) - # [D]; starts at 0 - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - # [M]; starts at current position in query - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # [M,D] - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] - - q = tl.load(Q + off_q, + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, + qk, -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], - dtype=tl.float32) # [M,D] - - # compute query against context (no causal mask here) - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) # [N] - # [D,N] - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - # [N,D] - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - if SLIDING_WINDOW > 0: - # (cur_batch_ctx_len + offs_m[:, None]) are the positions of - # Q entries in sequence - # (start_n + offs_n[None, :]) are the positions of - # KV entries in sequence - # So the condition makes sure each entry in Q only attends - # to KV entries not more than SLIDING_WINDOW away. - # - # We can't use -inf here, because the - # sliding window may lead to the entire row being masked. - # This then makes m_ij contain -inf, which causes NaNs in - # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, - -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) # [M] - p = tl.exp(qk - m_ij[:, None]) # [M,N] - l_ij = tl.sum(p, 1) # [M] - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) # [M] - alpha = tl.exp(m_i - m_i_new) # [M] - beta = tl.exp(m_ij - m_i_new) # [M] - l_i_new = alpha * l_i + beta * l_ij # [M] - - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - # block_mask is 0 when we're already past the current query length - block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) - - # compute query against itself (with causal mask) - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_query_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - if SLIDING_WINDOW > 0: - qk = tl.where( - offs_m[:, None] - - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_query_len), - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len)) - return + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) - @triton.jit - def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) + return + + +@triton.jit +def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) - @triton.jit - def _fwd_kernel_alibi( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - Alibi_slopes, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - ): - # attn_bias[] - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - # cur_batch_seq_len: the length of prompts - # cur_batch_ctx_len: the length of prefix - # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) - - q = tl.load(Q + off_q, + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@triton.jit +def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, +): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) - - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = 0 - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = (bn[None, :] * stride_k_cache_bs + - cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * - stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = ( - bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - # init alibi - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange( - 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = cur_batch_ctx_len - # # init debugger - # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc - # offset_db_k = tl.arange(0, BLOCK_N) - # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, allow_tf32=False) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), - alibi, float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) - return + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, allow_tf32=False) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) - @torch.inference_mode() - def context_attention_fwd(q, - k, - v, - o, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - b_ctx_len, - max_input_len, - alibi_slopes=None, - sliding_window=None): - - cap = current_platform.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 - - # need to reduce num. blocks when using fp32 - # due to increased use of GPU shared memory - if q.dtype is torch.float32: - BLOCK = BLOCK // 2 - - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = triton.next_power_of_2(Lk) - - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - num_queries_per_kv = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - # 0 means "disable" - if sliding_window is None or sliding_window <= 0: - sliding_window = 0 - - num_warps = 8 if Lk <= 64 else 8 - if alibi_slopes is not None: - _fwd_kernel_alibi[grid]( - q, - k, - v, - k_cache, - v_cache, - b_loc, - sm_scale, - b_start_loc, - b_seq_len, - b_ctx_len, - alibi_slopes, - v_cache.shape[3], - k_cache.shape[4], - o, - b_loc.stride(0), - b_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride( - 4 - ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride( - 3), #[num_blocks, num_kv_heads, head_size, block_size] - num_queries_per_kv=num_queries_per_kv, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_DMODEL_PADDED=Lk_padded, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - _fwd_kernel[grid]( + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + +@torch.inference_mode() +def context_attention_fwd(q, + k, + v, + o, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=None, + sliding_window=None): + + cap = current_platform.get_device_capability() + BLOCK = 128 if cap[0] >= 8 else 64 + + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + if q.dtype is torch.float32: + BLOCK = BLOCK // 2 + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + num_warps = 8 if Lk <= 64 else 8 + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( q, k, v, @@ -773,6 +706,7 @@ def context_attention_fwd(q, b_start_loc, b_seq_len, b_ctx_len, + alibi_slopes, v_cache.shape[3], k_cache.shape[4], o, @@ -806,8 +740,56 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, - SLIDING_WINDOW=sliding_window, num_warps=num_warps, num_stages=1, ) return + + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 413c0b6d0924..b3c1917d7f2f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -5,53 +5,54 @@ from typing import Any, Dict, Optional, Tuple import torch -import triton -import triton.language as tl import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.triton_utils import maybe_import_triton + +triton, tl = maybe_import_triton() logger = init_logger(__name__) @triton.jit def fused_moe_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - a_scale_ptr, - b_scale_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - use_fp8: tl.constexpr, + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, # type: ignore + BLOCK_SIZE_N: tl.constexpr, # type: ignore + BLOCK_SIZE_K: tl.constexpr, # type: ignore + GROUP_SIZE_M: tl.constexpr, # type: ignore + MUL_ROUTED_WEIGHT: tl.constexpr, # type: ignore + top_k: tl.constexpr, # type: ignore + compute_type: tl.constexpr, # type: ignore + use_fp8: tl.constexpr, # type: ignore ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -220,16 +221,22 @@ def moe_align_block_size( return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], compute_type: tl.dtype, - use_fp8: bool) -> None: +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, # type: ignore + use_fp8: bool) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py index 4a429e329567..b24f73aedd37 100644 --- a/vllm/model_executor/layers/ops/rand.py +++ b/vllm/model_executor/layers/ops/rand.py @@ -1,8 +1,10 @@ from typing import Optional, Union import torch -import triton -import triton.language as tl + +from vllm.triton_utils import maybe_import_triton + +triton, tl = maybe_import_triton() def seeded_uniform( @@ -93,16 +95,16 @@ def seeded_uniform( @triton.jit def _seeded_uniform_triton( - out_ptr: torch.Tensor, - seed_ptr: torch.Tensor, - out_row_stride: int, - out_3d_stride: int, - seed_row_stride: int, - n_rows: int, - n_3d: int, - n_cols: int, - n_slices: tl.constexpr, - block_size: tl.constexpr, + out_ptr: torch.Tensor, + seed_ptr: torch.Tensor, + out_row_stride: int, + out_3d_stride: int, + seed_row_stride: int, + n_rows: int, + n_3d: int, + n_cols: int, + n_slices: tl.constexpr, # type: ignore + block_size: tl.constexpr, # type: ignore ): """ Generate a random float32 number in [0, 1) for each element in the output diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py index d08ae6064aa2..779017e5659f 100644 --- a/vllm/model_executor/layers/ops/sample.py +++ b/vllm/model_executor/layers/ops/sample.py @@ -2,10 +2,11 @@ from typing import Optional, Tuple import torch -import triton -import triton.language as tl from vllm.model_executor.layers.ops.rand import seeded_uniform +from vllm.triton_utils import maybe_import_triton + +triton, tl = maybe_import_triton() _EPS = 1e-6 @@ -326,16 +327,25 @@ def _uniform_to_exponential(uniform_noise): @triton.jit def _sample_triton( - sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, + sample_indices_ptr: torch.Tensor, + output_ptr: torch.Tensor, output_logprobs_ptr: torch.Tensor, - output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, - logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, - uniform_noise_ptr: torch.Tensor, output_row_stride: int, - probs_row_stride: int, uniform_noise_row_stride: int, - uniform_noise_best_stride: int, n_samples: int, n_cols: int, - n_best: int, block_size: tl.constexpr, - modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, - save_modified_probs: tl.constexpr): + output_modified_probs_ptr: torch.Tensor, + probs_ptr: torch.Tensor, + logprobs_ptr: torch.Tensor, + seeds_ptr: torch.Tensor, + uniform_noise_ptr: torch.Tensor, + output_row_stride: int, + probs_row_stride: int, + uniform_noise_row_stride: int, + uniform_noise_best_stride: int, + n_samples: int, + n_cols: int, + n_best: int, + block_size: tl.constexpr, # type: ignore + modify_greedy_probs: tl.constexpr, # type: ignore + save_logprobs: tl.constexpr, # type: ignore + save_modified_probs: tl.constexpr): # type: ignore # The rows are independent, so we parallelize across those sample_idx = tl.program_id(0) best_idx = tl.program_id(1) diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 09843e5d1f30..509f9a394453 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -1,6 +1,5 @@ from vllm.triton_utils.custom_cache_manager import ( maybe_set_triton_cache_manager) +from vllm.triton_utils.importing import maybe_import_triton -__all__ = [ - "maybe_set_triton_cache_manager", -] +__all__ = ["maybe_import_triton", "maybe_set_triton_cache_manager"] diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py index 17039d7ba24c..8fb3c0411bd8 100644 --- a/vllm/triton_utils/custom_cache_manager.py +++ b/vllm/triton_utils/custom_cache_manager.py @@ -1,12 +1,12 @@ import os -from triton.runtime.cache import (FileCacheManager, default_cache_dir, - default_dump_dir, default_override_dir) - from vllm.logger import init_logger +from vllm.triton_utils import maybe_import_triton logger = init_logger(__name__) +triton, _ = maybe_import_triton() + def maybe_set_triton_cache_manager() -> None: """Set environment variable to tell Triton to use a @@ -18,36 +18,39 @@ def maybe_set_triton_cache_manager() -> None: os.environ["TRITON_CACHE_MANAGER"] = manager -class CustomCacheManager(FileCacheManager): - """Re-implements Triton's cache manager, ensuring that a - unique cache directory is created for each process. This is - needed to avoid collisions when running with tp>1 and - using multi-processing as the distributed backend. - - Note this issue was fixed by triton-lang/triton/pull/4295, - but the fix is not yet included in triton==v3.0.0. However, - it should be included in the subsequent version. - """ - - def __init__(self, key, override=False, dump=False): - self.key = key - self.lock_path = None - if dump: - self.cache_dir = default_dump_dir() - self.cache_dir = os.path.join(self.cache_dir, self.key) - self.lock_path = os.path.join(self.cache_dir, "lock") - os.makedirs(self.cache_dir, exist_ok=True) - elif override: - self.cache_dir = default_override_dir() - self.cache_dir = os.path.join(self.cache_dir, self.key) - else: - # create cache directory if it doesn't exist - self.cache_dir = os.getenv("TRITON_CACHE_DIR", - "").strip() or default_cache_dir() - if self.cache_dir: - self.cache_dir = f"{self.cache_dir}_{os.getpid()}" +if triton.__version__ != "0.0.0": + + class CustomCacheManager(triton.runtime.cache.FileCacheManager): + """Re-implements Triton's cache manager, ensuring that a + unique cache directory is created for each process. This is + needed to avoid collisions when running with tp>1 and + using multi-processing as the distributed backend. + + Note this issue was fixed by triton-lang/triton/pull/4295, + but the fix is not yet included in triton==v3.0.0. However, + it should be included in the subsequent version. + """ + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = triton.default_dump_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = triton.default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) else: - raise RuntimeError("Could not create or locate cache dir") + # create cache directory if it doesn't exist + self.cache_dir = os.getenv( + "TRITON_CACHE_DIR", + "").strip() or triton.default_cache_dir() + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py new file mode 100644 index 000000000000..78c2c2745e8b --- /dev/null +++ b/vllm/triton_utils/importing.py @@ -0,0 +1,26 @@ +from importlib import import_module + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +HAS_WARNED = False + + +def maybe_import_triton(): + + global HAS_WARNED + + try: + triton = import_module("triton") + tl = import_module("triton.language") + return triton, tl + except ImportError: + if not HAS_WARNED: + logger.info("Triton not installed; certain GPU-related functions" + " will be not be available.") + HAS_WARNED = True + + mock_triton = import_module("vllm.triton_utils.mock_triton") + mock_tl = import_module("vllm.triton_utils.mock_tl") + return mock_triton, mock_tl diff --git a/vllm/triton_utils/mock_tl.py b/vllm/triton_utils/mock_tl.py new file mode 100644 index 000000000000..f32f86001bf6 --- /dev/null +++ b/vllm/triton_utils/mock_tl.py @@ -0,0 +1,6 @@ +class constexpr: + pass + + +class dtype: + pass diff --git a/vllm/triton_utils/mock_triton.py b/vllm/triton_utils/mock_triton.py new file mode 100644 index 000000000000..d28971987b79 --- /dev/null +++ b/vllm/triton_utils/mock_triton.py @@ -0,0 +1,13 @@ +__version__ = "0.0.0" + + +def jit(cls): + + def disable_function(func): + + def disabled(*args, **kwargs): + return + + return disabled + + return disable_function From 5e5f6ecf11659ab4990d6c3e23709ab9a9e8fe42 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 25 Jul 2024 13:10:12 +0000 Subject: [PATCH 02/13] Reduce diff Signed-off-by: Thomas Parnell --- vllm/attention/ops/prefix_prefill.py | 1479 +++++++++++++------------- 1 file changed, 749 insertions(+), 730 deletions(-) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 5885aeb0d214..edffe7880558 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -8,694 +8,762 @@ triton, tl = maybe_import_triton() - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len - - # start position inside of the query - # generally, N goes over kv, while M goes over query_len - block_start_loc = BLOCK_M * start_m - - # initialize offsets - # [N]; starts at 0 - offs_n = tl.arange(0, BLOCK_N) - # [D]; starts at 0 - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - # [M]; starts at current position in query - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # [M,D] - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] - - # compute query against context (no causal mask here) - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) # [N] - # [D,N] - off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - # [N,D] - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - if SLIDING_WINDOW > 0: - # (cur_batch_ctx_len + offs_m[:, None]) are the positions of - # Q entries in sequence - # (start_n + offs_n[None, :]) are the positions of - # KV entries in sequence - # So the condition makes sure each entry in Q only attends - # to KV entries not more than SLIDING_WINDOW away. - # - # We can't use -inf here, because the - # sliding window may lead to the entire row being masked. - # This then makes m_ij contain -inf, which causes NaNs in - # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, - -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) # [M] - p = tl.exp(qk - m_ij[:, None]) # [M,N] - l_ij = tl.sum(p, 1) # [M] - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) # [M] - alpha = tl.exp(m_i - m_i_new) # [M] - beta = tl.exp(m_ij - m_i_new) # [M] - l_i_new = alpha * l_i + beta * l_ij # [M] - - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - # block_mask is 0 when we're already past the current query length - block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) - - # compute query against itself (with causal mask) - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_query_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - if SLIDING_WINDOW > 0: - qk = tl.where( - offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, - qk, -10000) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, +if triton.__version__ >= "2.1.0" or triton.__version__ == "0.0.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_query_len), - other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) - return - - -@triton.jit -def _fwd_kernel_flash_attn_v2( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load(Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) - - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) # [N] + # [D,N] + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) + return - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + @triton.jit + def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - # acc /= l_i[:, None] - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) - return - - -@triton.jit -def _fwd_kernel_alibi( - Q, - K, - V, - K_cache, - V_cache, - B_Loc, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_Ctxlen, - Alibi_slopes, - block_size, - x, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: int, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, # head size - BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 - BLOCK_N: tl.constexpr, -): - # attn_bias[] - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // num_queries_per_kv - - # cur_batch_seq_len: the length of prompts - # cur_batch_ctx_len: the length of prefix - # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) - - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = 0 - for start_n in range(0, cur_batch_ctx_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) - off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) - qk *= sm_scale - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, - float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(V_cache + off_v, + @triton.jit + def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + ): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), other=0.0) - p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - k_ptrs = K + off_k - v_ptrs = V + off_v - - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) - - # init alibi - alibi_slope = tl.load(Alibi_slopes + cur_head) - alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len - alibi_start_k = cur_batch_ctx_len - # # init debugger - # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc - # offset_db_k = tl.arange(0, BLOCK_N) - # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, allow_tf32=False) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) - - # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope - alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, - float("-inf")) - qk += alibi - alibi_start_k += BLOCK_N - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - m_i_new = tl.maximum(m_i, m_ij) - p = tl.math.exp(qk - m_i_new[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - - alpha = tl.math.exp(m_i - m_i_new) - l_i_new = alpha * l_i + l_ij - # -- update output accumulator -- - # scale p - # scale acc - acc_scale = alpha - # acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, allow_tf32=False) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return - p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) - return - - -@torch.inference_mode() -def context_attention_fwd(q, - k, - v, - o, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - b_ctx_len, - max_input_len, - alibi_slopes=None, - sliding_window=None): - - cap = current_platform.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 - - # need to reduce num. blocks when using fp32 - # due to increased use of GPU shared memory - if q.dtype is torch.float32: - BLOCK = BLOCK // 2 - - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = triton.next_power_of_2(Lk) - - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - num_queries_per_kv = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - # 0 means "disable" - if sliding_window is None or sliding_window <= 0: - sliding_window = 0 - - num_warps = 8 if Lk <= 64 else 8 - if alibi_slopes is not None: - _fwd_kernel_alibi[grid]( + @torch.inference_mode() + def context_attention_fwd(q, + k, + v, + o, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=None, + sliding_window=None): + + cap = current_platform.get_device_capability() + BLOCK = 128 if cap[0] >= 8 else 64 + + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + if q.dtype is torch.float32: + BLOCK = BLOCK // 2 + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + num_warps = 8 if Lk <= 64 else 8 + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + alibi_slopes, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + _fwd_kernel[grid]( q, k, v, @@ -706,7 +774,6 @@ def context_attention_fwd(q, b_start_loc, b_seq_len, b_ctx_len, - alibi_slopes, v_cache.shape[3], k_cache.shape[4], o, @@ -740,56 +807,8 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window, num_warps=num_warps, num_stages=1, ) return - - _fwd_kernel[grid]( - q, - k, - v, - k_cache, - v_cache, - b_loc, - sm_scale, - b_start_loc, - b_seq_len, - b_ctx_len, - v_cache.shape[3], - k_cache.shape[4], - o, - b_loc.stride(0), - b_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - k_cache.stride( - 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), - v_cache.stride(1), - v_cache.stride(2), - v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] - num_queries_per_kv=num_queries_per_kv, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_DMODEL_PADDED=Lk_padded, - BLOCK_N=BLOCK, - SLIDING_WINDOW=sliding_window, - num_warps=num_warps, - num_stages=1, - ) - return From e80fc3439248684d470e4ba8fbc601db4867a7f2 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 25 Jul 2024 09:18:36 -0400 Subject: [PATCH 03/13] Fix in custom_cache_manager. Signed-off-by: Thomas Parnell --- vllm/triton_utils/custom_cache_manager.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py index 8fb3c0411bd8..1ca2753a2da9 100644 --- a/vllm/triton_utils/custom_cache_manager.py +++ b/vllm/triton_utils/custom_cache_manager.py @@ -1,7 +1,8 @@ import os from vllm.logger import init_logger -from vllm.triton_utils import maybe_import_triton + +from .importing import maybe_import_triton logger = init_logger(__name__) @@ -20,7 +21,10 @@ def maybe_set_triton_cache_manager() -> None: if triton.__version__ != "0.0.0": - class CustomCacheManager(triton.runtime.cache.FileCacheManager): + from triton.runtime.cache import (FileCacheManager, default_cache_dir, + default_dump_dir, default_override_dir) + + class CustomCacheManager(FileCacheManager): """Re-implements Triton's cache manager, ensuring that a unique cache directory is created for each process. This is needed to avoid collisions when running with tp>1 and @@ -35,18 +39,17 @@ def __init__(self, key, override=False, dump=False): self.key = key self.lock_path = None if dump: - self.cache_dir = triton.default_dump_dir() + self.cache_dir = default_dump_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) elif override: - self.cache_dir = triton.default_override_dir() + self.cache_dir = default_override_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) else: # create cache directory if it doesn't exist - self.cache_dir = os.getenv( - "TRITON_CACHE_DIR", - "").strip() or triton.default_cache_dir() + self.cache_dir = os.getenv("TRITON_CACHE_DIR", + "").strip() or default_cache_dir() if self.cache_dir: self.cache_dir = f"{self.cache_dir}_{os.getpid()}" self.cache_dir = os.path.join(self.cache_dir, self.key) From e136740b3748f2d4f81f2171832f5e560d52f7aa Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 07:25:18 +0000 Subject: [PATCH 04/13] Rework without the mock Signed-off-by: Thomas Parnell --- vllm/attention/ops/paged_attn.py | 6 +- vllm/attention/ops/prefix_prefill.py | 7 +- .../layers/fused_moe/fused_moe.py | 101 +++--- vllm/model_executor/layers/ops/rand.py | 26 +- vllm/model_executor/layers/ops/sample.py | 32 +- .../model_executor/layers/quantization/fp8.py | 334 +++++++++--------- vllm/model_executor/layers/sampler.py | 6 +- vllm/model_executor/sampling_metadata.py | 6 +- vllm/triton_utils/__init__.py | 4 +- vllm/triton_utils/custom_cache_manager.py | 7 +- vllm/triton_utils/importing.py | 28 +- vllm/triton_utils/mock_tl.py | 6 - vllm/triton_utils/mock_triton.py | 13 - 13 files changed, 270 insertions(+), 306 deletions(-) delete mode 100644 vllm/triton_utils/mock_tl.py delete mode 100644 vllm/triton_utils/mock_triton.py diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index ce7b4d129779..c06af04d5276 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -4,7 +4,11 @@ import torch from vllm import _custom_ops as ops -from vllm.attention.ops.prefix_prefill import context_attention_fwd + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE = 512 diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index edffe7880558..4577d84db18a 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -2,13 +2,12 @@ # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py import torch +import triton +import triton.language as tl from vllm.platforms import current_platform -from vllm.triton_utils import maybe_import_triton -triton, tl = maybe_import_triton() - -if triton.__version__ >= "2.1.0" or triton.__version__ == "0.0.0": +if triton.__version__ >= "2.1.0": @triton.jit def _fwd_kernel( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b3c1917d7f2f..413c0b6d0924 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -5,54 +5,53 @@ from typing import Any, Dict, Optional, Tuple import torch +import triton +import triton.language as tl import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.triton_utils import maybe_import_triton - -triton, tl = maybe_import_triton() logger = init_logger(__name__) @triton.jit def fused_moe_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - a_scale_ptr, - b_scale_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, # type: ignore - BLOCK_SIZE_N: tl.constexpr, # type: ignore - BLOCK_SIZE_K: tl.constexpr, # type: ignore - GROUP_SIZE_M: tl.constexpr, # type: ignore - MUL_ROUTED_WEIGHT: tl.constexpr, # type: ignore - top_k: tl.constexpr, # type: ignore - compute_type: tl.constexpr, # type: ignore - use_fp8: tl.constexpr, # type: ignore + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -221,22 +220,16 @@ def moe_align_block_size( return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel( - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: Dict[str, Any], - compute_type: tl.dtype, # type: ignore - use_fp8: bool) -> None: +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8: bool) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py index b24f73aedd37..4a429e329567 100644 --- a/vllm/model_executor/layers/ops/rand.py +++ b/vllm/model_executor/layers/ops/rand.py @@ -1,10 +1,8 @@ from typing import Optional, Union import torch - -from vllm.triton_utils import maybe_import_triton - -triton, tl = maybe_import_triton() +import triton +import triton.language as tl def seeded_uniform( @@ -95,16 +93,16 @@ def seeded_uniform( @triton.jit def _seeded_uniform_triton( - out_ptr: torch.Tensor, - seed_ptr: torch.Tensor, - out_row_stride: int, - out_3d_stride: int, - seed_row_stride: int, - n_rows: int, - n_3d: int, - n_cols: int, - n_slices: tl.constexpr, # type: ignore - block_size: tl.constexpr, # type: ignore + out_ptr: torch.Tensor, + seed_ptr: torch.Tensor, + out_row_stride: int, + out_3d_stride: int, + seed_row_stride: int, + n_rows: int, + n_3d: int, + n_cols: int, + n_slices: tl.constexpr, + block_size: tl.constexpr, ): """ Generate a random float32 number in [0, 1) for each element in the output diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py index 779017e5659f..d08ae6064aa2 100644 --- a/vllm/model_executor/layers/ops/sample.py +++ b/vllm/model_executor/layers/ops/sample.py @@ -2,11 +2,10 @@ from typing import Optional, Tuple import torch +import triton +import triton.language as tl from vllm.model_executor.layers.ops.rand import seeded_uniform -from vllm.triton_utils import maybe_import_triton - -triton, tl = maybe_import_triton() _EPS = 1e-6 @@ -327,25 +326,16 @@ def _uniform_to_exponential(uniform_noise): @triton.jit def _sample_triton( - sample_indices_ptr: torch.Tensor, - output_ptr: torch.Tensor, + sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, output_logprobs_ptr: torch.Tensor, - output_modified_probs_ptr: torch.Tensor, - probs_ptr: torch.Tensor, - logprobs_ptr: torch.Tensor, - seeds_ptr: torch.Tensor, - uniform_noise_ptr: torch.Tensor, - output_row_stride: int, - probs_row_stride: int, - uniform_noise_row_stride: int, - uniform_noise_best_stride: int, - n_samples: int, - n_cols: int, - n_best: int, - block_size: tl.constexpr, # type: ignore - modify_greedy_probs: tl.constexpr, # type: ignore - save_logprobs: tl.constexpr, # type: ignore - save_modified_probs: tl.constexpr): # type: ignore + output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, + logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, + uniform_noise_ptr: torch.Tensor, output_row_stride: int, + probs_row_stride: int, uniform_noise_row_stride: int, + uniform_noise_best_stride: int, n_samples: int, n_cols: int, + n_best: int, block_size: tl.constexpr, + modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, + save_modified_probs: tl.constexpr): # The rows are independent, so we parallelize across those sample_idx = tl.program_id(0) best_idx = tl.program_id(1) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3a4f2a49a349..7c021fb30238 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,8 +6,13 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - fused_moe) + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + fused_moe) + from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -227,187 +232,188 @@ def apply(self, cutlass_fp8_supported=self.cutlass_fp8_supported, use_per_token_if_dynamic=False) +if HAS_TRITON: -class Fp8MoEMethod(FusedMoEMethodBase): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. + class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. - Also supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. - Args: - quant_config: The quantization config. - """ + Args: + quant_config: The quantization config. + """ - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, - intermediate_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_scale", w13_scale) - - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_scale", w2_scale) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, extra_weight_attrs) + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn - # INPUT_SCALES - if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8.") + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) - a13_scale = torch.nn.Parameter(torch.ones(num_experts, + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, dtype=torch.float32), requires_grad=False) - layer.register_parameter("a13_scale", a13_scale) - set_weight_attrs(a13_scale, extra_weight_attrs) + layer.register_parameter("w13_scale", w13_scale) - a2_scale = torch.nn.Parameter(torch.ones(num_experts, + w2_scale = torch.nn.Parameter(torch.ones(num_experts, dtype=torch.float32), requires_grad=False) - layer.register_parameter("a2_scale", a2_scale) - set_weight_attrs(a2_scale, extra_weight_attrs) - else: - layer.a13_scale = None - layer.a2_scale = None + layer.register_parameter("w2_scale", w2_scale) - def process_weights_after_loading(self, layer: Module) -> None: + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) - # If checkpoint is fp16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like(layer.w13_weight.data, - dtype=torch.float8_e4m3fn) - w2_weight = torch.empty_like(layer.w2_weight.data, - dtype=torch.float8_e4m3fn) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter(torch.ones( - layer.num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) - for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[ - expert] = ops.scaled_fp8_quant( - layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_scale[ - expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :]) - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - return - - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. - else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. + # INPUT_SCALES if self.quant_config.activation_scheme == "static": - if layer.a13_scale is None or layer.a2_scale is None: + if not self.quant_config.is_checkpoint_fp8_serialized: raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.a13_scale) - or not all_close_1d(layer.a2_scale)): - print_warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") - layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + + a13_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(layer.w2_weight.data, + dtype=torch.float8_e4m3fn) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter(torch.ones( + layer.num_experts, + dtype=torch.float32, + device=w13_weight.device), requires_grad=False) - layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), - requires_grad=False) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_scale.max(dim=1).values - for expert_id in range(layer.num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size - - layer.w13_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) - return + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + return - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - - return fused_moe(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.a13_scale) + or not all_close_1d(layer.a2_scale)): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), + requires_grad=False) + layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + return + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: + + return fused_moe(x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 5c376797a054..e2a891d59c83 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -5,7 +5,11 @@ import torch import torch.nn as nn -from vllm.model_executor.layers.ops.sample import sample as sample_triton +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.layers.ops.sample import sample as sample_triton + from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, SequenceGroupToSample) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 390b5d173ebc..4968272d37e0 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -4,7 +4,11 @@ import torch -from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits + from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.utils import (async_tensor_h2d, is_pin_memory_available, diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 509f9a394453..6a8c71878eaa 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -1,5 +1,5 @@ from vllm.triton_utils.custom_cache_manager import ( maybe_set_triton_cache_manager) -from vllm.triton_utils.importing import maybe_import_triton +from vllm.triton_utils.importing import HAS_TRITON -__all__ = ["maybe_import_triton", "maybe_set_triton_cache_manager"] +__all__ = ["HAS_TRITON", "maybe_set_triton_cache_manager"] diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py index 1ca2753a2da9..eb9a4a1d0815 100644 --- a/vllm/triton_utils/custom_cache_manager.py +++ b/vllm/triton_utils/custom_cache_manager.py @@ -2,11 +2,10 @@ from vllm.logger import init_logger -from .importing import maybe_import_triton -logger = init_logger(__name__) +from .importing import HAS_TRITON -triton, _ = maybe_import_triton() +logger = init_logger(__name__) def maybe_set_triton_cache_manager() -> None: @@ -19,7 +18,7 @@ def maybe_set_triton_cache_manager() -> None: os.environ["TRITON_CACHE_MANAGER"] = manager -if triton.__version__ != "0.0.0": +if HAS_TRITON: from triton.runtime.cache import (FileCacheManager, default_cache_dir, default_dump_dir, default_override_dir) diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index 78c2c2745e8b..d6aa0d21f098 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -1,26 +1,12 @@ -from importlib import import_module from vllm.logger import init_logger logger = init_logger(__name__) -HAS_WARNED = False - - -def maybe_import_triton(): - - global HAS_WARNED - - try: - triton = import_module("triton") - tl = import_module("triton.language") - return triton, tl - except ImportError: - if not HAS_WARNED: - logger.info("Triton not installed; certain GPU-related functions" - " will be not be available.") - HAS_WARNED = True - - mock_triton = import_module("vllm.triton_utils.mock_triton") - mock_tl = import_module("vllm.triton_utils.mock_tl") - return mock_triton, mock_tl +try: + import triton + HAS_TRITON = True +except ImportError: + logger.info("Triton not installed; certain GPU-related functions" + " will be not be available.") + HAS_TRITON = False diff --git a/vllm/triton_utils/mock_tl.py b/vllm/triton_utils/mock_tl.py deleted file mode 100644 index f32f86001bf6..000000000000 --- a/vllm/triton_utils/mock_tl.py +++ /dev/null @@ -1,6 +0,0 @@ -class constexpr: - pass - - -class dtype: - pass diff --git a/vllm/triton_utils/mock_triton.py b/vllm/triton_utils/mock_triton.py deleted file mode 100644 index d28971987b79..000000000000 --- a/vllm/triton_utils/mock_triton.py +++ /dev/null @@ -1,13 +0,0 @@ -__version__ = "0.0.0" - - -def jit(cls): - - def disable_function(func): - - def disabled(*args, **kwargs): - return - - return disabled - - return disable_function From f778898c7357f8317e72b0d9f45c046fd99bef86 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 08:14:25 +0000 Subject: [PATCH 05/13] Separate FP8 FusedMoE into separate module Signed-off-by: Thomas Parnell --- vllm/attention/ops/paged_attn.py | 1 - .../model_executor/layers/quantization/fp8.py | 201 +----------------- .../layers/quantization/fp8_fused_moe.py | 196 +++++++++++++++++ vllm/model_executor/sampling_metadata.py | 3 +- vllm/triton_utils/__init__.py | 10 +- vllm/triton_utils/custom_cache_manager.py | 69 +++--- vllm/triton_utils/importing.py | 9 +- 7 files changed, 250 insertions(+), 239 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/fp8_fused_moe.py diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index c06af04d5276..c3ff706c37e5 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -4,7 +4,6 @@ import torch from vllm import _custom_ops as ops - from vllm.triton_utils import HAS_TRITON if HAS_TRITON: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7c021fb30238..a89612743088 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,13 +6,6 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger - -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - fused_moe) - from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -23,11 +16,16 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, apply_fp8_linear, create_per_tensor_scale_param, - cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale) + apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported, + requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import print_warning_once +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.fp8_fused_moe import ( + Fp8MoEMethod) ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -232,189 +230,6 @@ def apply(self, cutlass_fp8_supported=self.cutlass_fp8_supported, use_per_token_if_dynamic=False) -if HAS_TRITON: - - class Fp8MoEMethod(FusedMoEMethodBase): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. - - Also supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. - - Args: - quant_config: The quantization config. - """ - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, - intermediate_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_scale", w13_scale) - - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_scale", w2_scale) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8.") - - a13_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("a13_scale", a13_scale) - set_weight_attrs(a13_scale, extra_weight_attrs) - - a2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("a2_scale", a2_scale) - set_weight_attrs(a2_scale, extra_weight_attrs) - else: - layer.a13_scale = None - layer.a2_scale = None - - def process_weights_after_loading(self, layer: Module) -> None: - - # If checkpoint is fp16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like(layer.w13_weight.data, - dtype=torch.float8_e4m3fn) - w2_weight = torch.empty_like(layer.w2_weight.data, - dtype=torch.float8_e4m3fn) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter(torch.ones( - layer.num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) - for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[ - expert] = ops.scaled_fp8_quant( - layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_scale[ - expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :]) - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - return - - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. - else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if self.quant_config.activation_scheme == "static": - if layer.a13_scale is None or layer.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.a13_scale) - or not all_close_1d(layer.a2_scale)): - print_warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") - layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), - requires_grad=False) - layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), - requires_grad=False) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_scale.max(dim=1).values - for expert_id in range(layer.num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size - - layer.w13_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) - return - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - - return fused_moe(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) - class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/layers/quantization/fp8_fused_moe.py b/vllm/model_executor/layers/quantization/fp8_fused_moe.py new file mode 100644 index 000000000000..fcc5d884d6fb --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8_fused_moe.py @@ -0,0 +1,196 @@ +from typing import Optional + +import torch +from torch.nn import Module + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, + fused_moe) +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, per_tensor_dequantize) +from vllm.utils import print_warning_once + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, + quant_config: "Fp8Config" # type: ignore + ): + self.quant_config = quant_config + + def create_weights(self, layer: Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_scale", w2_scale) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + + a13_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(layer.w2_weight.data, + dtype=torch.float8_e4m3fn) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter(torch.ones( + layer.num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.a13_scale) + or not all_close_1d(layer.a2_scale)): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), + requires_grad=False) + layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + return + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: + + return fused_moe(x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 4968272d37e0..0f193559e867 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -7,7 +7,8 @@ from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits + from vllm.model_executor.layers.ops.sample import ( + get_num_triton_sampler_splits) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData, SequenceGroupMetadata diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 6a8c71878eaa..4572b45ceaec 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -1,5 +1,11 @@ -from vllm.triton_utils.custom_cache_manager import ( - maybe_set_triton_cache_manager) from vllm.triton_utils.importing import HAS_TRITON +if HAS_TRITON: + from vllm.triton_utils.custom_cache_manager import ( + maybe_set_triton_cache_manager) + __all__ = ["HAS_TRITON", "maybe_set_triton_cache_manager"] + +if not HAS_TRITON: + # need to do this afterwards due to ruff complaining + __all__.pop() diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py index eb9a4a1d0815..17039d7ba24c 100644 --- a/vllm/triton_utils/custom_cache_manager.py +++ b/vllm/triton_utils/custom_cache_manager.py @@ -1,9 +1,9 @@ import os -from vllm.logger import init_logger - +from triton.runtime.cache import (FileCacheManager, default_cache_dir, + default_dump_dir, default_override_dir) -from .importing import HAS_TRITON +from vllm.logger import init_logger logger = init_logger(__name__) @@ -18,41 +18,36 @@ def maybe_set_triton_cache_manager() -> None: os.environ["TRITON_CACHE_MANAGER"] = manager -if HAS_TRITON: - - from triton.runtime.cache import (FileCacheManager, default_cache_dir, - default_dump_dir, default_override_dir) - - class CustomCacheManager(FileCacheManager): - """Re-implements Triton's cache manager, ensuring that a - unique cache directory is created for each process. This is - needed to avoid collisions when running with tp>1 and - using multi-processing as the distributed backend. - - Note this issue was fixed by triton-lang/triton/pull/4295, - but the fix is not yet included in triton==v3.0.0. However, - it should be included in the subsequent version. - """ - - def __init__(self, key, override=False, dump=False): - self.key = key - self.lock_path = None - if dump: - self.cache_dir = default_dump_dir() +class CustomCacheManager(FileCacheManager): + """Re-implements Triton's cache manager, ensuring that a + unique cache directory is created for each process. This is + needed to avoid collisions when running with tp>1 and + using multi-processing as the distributed backend. + + Note this issue was fixed by triton-lang/triton/pull/4295, + but the fix is not yet included in triton==v3.0.0. However, + it should be included in the subsequent version. + """ + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", + "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) - elif override: - self.cache_dir = default_override_dir() - self.cache_dir = os.path.join(self.cache_dir, self.key) else: - # create cache directory if it doesn't exist - self.cache_dir = os.getenv("TRITON_CACHE_DIR", - "").strip() or default_cache_dir() - if self.cache_dir: - self.cache_dir = f"{self.cache_dir}_{os.getpid()}" - self.cache_dir = os.path.join(self.cache_dir, self.key) - self.lock_path = os.path.join(self.cache_dir, "lock") - os.makedirs(self.cache_dir, exist_ok=True) - else: - raise RuntimeError("Could not create or locate cache dir") + raise RuntimeError("Could not create or locate cache dir") diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index d6aa0d21f098..3455036586a9 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -1,12 +1,11 @@ +from importlib.util import find_spec from vllm.logger import init_logger logger = init_logger(__name__) -try: - import triton - HAS_TRITON = True -except ImportError: +HAS_TRITON = find_spec("triton") is not None + +if not HAS_TRITON: logger.info("Triton not installed; certain GPU-related functions" " will be not be available.") - HAS_TRITON = False From 0039ba3568e3bfc0799cfba1b4668fbfc4161c87 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 04:31:27 -0400 Subject: [PATCH 06/13] Fix type checking Signed-off-by: Thomas Parnell --- .../layers/quantization/fp8_fused_moe.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8_fused_moe.py b/vllm/model_executor/layers/quantization/fp8_fused_moe.py index fcc5d884d6fb..2d31f04c702c 100644 --- a/vllm/model_executor/layers/quantization/fp8_fused_moe.py +++ b/vllm/model_executor/layers/quantization/fp8_fused_moe.py @@ -1,16 +1,19 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional import torch from torch.nn import Module from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, - fused_moe) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase, fused_moe from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, per_tensor_dequantize) +from vllm.model_executor.utils import set_weight_attrs from vllm.utils import print_warning_once +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -24,14 +27,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, - quant_config: "Fp8Config" # type: ignore - ): + def __init__(self, quant_config: 'Fp8Config'): self.quant_config = quant_config - def create_weights(self, layer: Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn @@ -160,8 +161,7 @@ def process_weights_after_loading(self, layer: Module) -> None: shard_size, :], layer.w13_scale[expert_id][shard_id]) layer.w13_weight[expert_id][ - start:start + - shard_size, :], _ = ops.scaled_fp8_quant( + start:start + shard_size, :], _ = ops.scaled_fp8_quant( dq_weight, max_w13_scales[expert_id]) start += shard_size From b35adce959c2f09a97974667c902d28f047773df Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 04:54:29 -0400 Subject: [PATCH 07/13] Remove additional redundant Triton deps. Signed-off-by: Thomas Parnell --- requirements-openvino.txt | 2 -- requirements-tpu.txt | 1 - 2 files changed, 3 deletions(-) diff --git a/requirements-openvino.txt b/requirements-openvino.txt index e32c76fb0db2..fabac3c7bbaf 100644 --- a/requirements-openvino.txt +++ b/requirements-openvino.txt @@ -5,5 +5,3 @@ torch >= 2.1.2 openvino ~= 2024.3.0.dev optimum-intel[openvino] >= 1.18.1 - -triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 22487f5524dd..aef488853373 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -4,4 +4,3 @@ # Dependencies for TPU # Currently, the TPU backend uses a nightly version of PyTorch XLA. # You can install the dependencies in Dockerfile.tpu. -triton # To avoid import errors From 3f94f68bf76d76224dfff4e66fd7ddf6e208f334 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 06:33:22 -0400 Subject: [PATCH 08/13] Move get_num_triton_sampler_splits into triton_utils Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/ops/sample.py | 14 +------------- vllm/model_executor/sampling_metadata.py | 7 +------ vllm/triton_utils/sample.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 19 deletions(-) create mode 100644 vllm/triton_utils/sample.py diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py index d08ae6064aa2..bdb577da3172 100644 --- a/vllm/model_executor/layers/ops/sample.py +++ b/vllm/model_executor/layers/ops/sample.py @@ -1,4 +1,3 @@ -import math from typing import Optional, Tuple import torch @@ -6,21 +5,10 @@ import triton.language as tl from vllm.model_executor.layers.ops.rand import seeded_uniform +from vllm.triton_utils.sample import get_num_triton_sampler_splits _EPS = 1e-6 -# This is a hardcoded limit in Triton (max block size). -MAX_TRITON_N_COLS = 131072 - - -def get_num_triton_sampler_splits(n_cols: int) -> int: - """Get the number of splits to use for Triton sampling. - - Triton has a limit on the number of columns it can handle, so we need to - split the tensor and call the kernel multiple times if it's too large. - """ - return math.ceil(n_cols / MAX_TRITON_N_COLS) - def _multi_split_sample( probs: torch.Tensor, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 9a5f45902d2c..1caf9aa01d8c 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -5,14 +5,9 @@ import torch -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.model_executor.layers.ops.sample import ( - get_num_triton_sampler_splits) - from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) diff --git a/vllm/triton_utils/sample.py b/vllm/triton_utils/sample.py new file mode 100644 index 000000000000..844b0003ce5a --- /dev/null +++ b/vllm/triton_utils/sample.py @@ -0,0 +1,12 @@ +import math + +# This is a hardcoded limit in Triton (max block size). +MAX_TRITON_N_COLS = 131072 + +def get_num_triton_sampler_splits(n_cols: int) -> int: + """Get the number of splits to use for Triton sampling. + + Triton has a limit on the number of columns it can handle, so we need to + split the tensor and call the kernel multiple times if it's too large. + """ + return math.ceil(n_cols / MAX_TRITON_N_COLS) From d1a4c517f25496277282c8e46125bc546a9f64f5 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 06:33:59 -0400 Subject: [PATCH 09/13] fmt Signed-off-by: Thomas Parnell --- vllm/triton_utils/sample.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/triton_utils/sample.py b/vllm/triton_utils/sample.py index 844b0003ce5a..401e4d28a3c9 100644 --- a/vllm/triton_utils/sample.py +++ b/vllm/triton_utils/sample.py @@ -3,6 +3,7 @@ # This is a hardcoded limit in Triton (max block size). MAX_TRITON_N_COLS = 131072 + def get_num_triton_sampler_splits(n_cols: int) -> int: """Get the number of splits to use for Triton sampling. From e1cff0a23fc36b106a4e37f1e0dddaab894f993d Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 07:08:49 -0400 Subject: [PATCH 10/13] Fix error in sampler test Signed-off-by: Thomas Parnell --- tests/kernels/test_sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index e28f809309ec..3c53f7decc6e 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -5,11 +5,12 @@ import triton import triton.language as tl -from vllm.model_executor.layers.ops.sample import ( - MAX_TRITON_N_COLS, _uniform_to_exponential, get_num_triton_sampler_splits, - sample) +from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential, + sample) from vllm.model_executor.sampling_metadata import SamplingTensors from vllm.model_executor.utils import set_random_seed +from vllm.triton_utils.sample import (MAX_TRITON_N_COLS, + get_num_triton_sampler_splits) SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 From 9816ac7b781582b304e1cfea488586a320eee72e Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 13:36:55 -0400 Subject: [PATCH 11/13] put fused_moe back in fp8.py Signed-off-by: Thomas Parnell --- .../layers/fused_moe/__init__.py | 16 +- .../model_executor/layers/quantization/fp8.py | 196 +++++++++++++++++- .../layers/quantization/fp8_fused_moe.py | 196 ------------------ 3 files changed, 200 insertions(+), 208 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/fp8_fused_moe.py diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index db837231c6ac..4efe3f28684e 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,14 +1,22 @@ -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEMethodBase) +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) __all__ = [ + "FusedMoE", + "FusedMoEMethodBase", "fused_moe", "fused_topk", "fused_experts", "get_config_file_name", "grouped_topk", - "FusedMoE", - "FusedMoEMethodBase", ] + +if not HAS_TRITON: + # need to do it like this other ruff complains + __all__ = __all__[:2] diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 542629badcbf..c829cb836ee4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,6 +6,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -16,16 +17,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, convert_to_channelwise, create_per_tensor_scale_param, - cutlass_fp8_supported, requantize_with_max_scale) + all_close_1d, apply_fp8_linear, convert_to_channelwise, + create_per_tensor_scale_param, cutlass_fp8_supported, + per_tensor_dequantize, requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.model_executor.layers.fused_moe import FusedMoE - from vllm.model_executor.layers.quantization.fp8_fused_moe import ( - Fp8MoEMethod) +from vllm.utils import print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -241,6 +238,189 @@ def apply(self, use_per_token_if_dynamic=False) +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_scale", w2_scale) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + + a13_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(layer.w2_weight.data, + dtype=torch.float8_e4m3fn) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter(torch.ones( + layer.num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.a13_scale) + or not all_close_1d(layer.a2_scale)): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), + requires_grad=False) + layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + return + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe import fused_moe + return fused_moe(x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group) + + class Fp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. diff --git a/vllm/model_executor/layers/quantization/fp8_fused_moe.py b/vllm/model_executor/layers/quantization/fp8_fused_moe.py deleted file mode 100644 index 2d31f04c702c..000000000000 --- a/vllm/model_executor/layers/quantization/fp8_fused_moe.py +++ /dev/null @@ -1,196 +0,0 @@ -from typing import TYPE_CHECKING, Optional - -import torch -from torch.nn import Module - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase, fused_moe -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, per_tensor_dequantize) -from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import print_warning_once - -if TYPE_CHECKING: - from vllm.model_executor.layers.quantization.fp8 import Fp8Config - - -class Fp8MoEMethod(FusedMoEMethodBase): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. - - Also supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. - - Args: - quant_config: The quantization config. - """ - - def __init__(self, quant_config: 'Fp8Config'): - self.quant_config = quant_config - - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, - intermediate_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_scale", w13_scale) - - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_scale", w2_scale) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8.") - - a13_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("a13_scale", a13_scale) - set_weight_attrs(a13_scale, extra_weight_attrs) - - a2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("a2_scale", a2_scale) - set_weight_attrs(a2_scale, extra_weight_attrs) - else: - layer.a13_scale = None - layer.a2_scale = None - - def process_weights_after_loading(self, layer: Module) -> None: - - # If checkpoint is fp16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like(layer.w13_weight.data, - dtype=torch.float8_e4m3fn) - w2_weight = torch.empty_like(layer.w2_weight.data, - dtype=torch.float8_e4m3fn) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter(torch.ones( - layer.num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) - for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[ - expert] = ops.scaled_fp8_quant( - layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_scale[ - expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :]) - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - return - - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. - else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if self.quant_config.activation_scheme == "static": - if layer.a13_scale is None or layer.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.a13_scale) - or not all_close_1d(layer.a2_scale)): - print_warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") - layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), - requires_grad=False) - layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), - requires_grad=False) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_scale.max(dim=1).values - for expert_id in range(layer.num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size - - layer.w13_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) - return - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - - return fused_moe(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) From b9bb0b4a37369ce316638ae2a7a02e4adf7ac721 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 13:59:58 -0400 Subject: [PATCH 12/13] Improved handling of ruff in __init__.py Signed-off-by: Thomas Parnell --- .../layers/fused_moe/__init__.py | 26 +++++++++---------- vllm/triton_utils/__init__.py | 11 ++++---- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 4efe3f28684e..3e0767c7d266 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -2,21 +2,21 @@ FusedMoEMethodBase) from vllm.triton_utils import HAS_TRITON -if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) - __all__ = [ "FusedMoE", "FusedMoEMethodBase", - "fused_moe", - "fused_topk", - "fused_experts", - "get_config_file_name", - "grouped_topk", ] -if not HAS_TRITON: - # need to do it like this other ruff complains - __all__ = __all__[:2] +if HAS_TRITON: + + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) + + __all__ += [ + "fused_moe", + "fused_topk", + "fused_experts", + "get_config_file_name", + "grouped_topk", + ] diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 4572b45ceaec..4be743ada96e 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -1,11 +1,10 @@ from vllm.triton_utils.importing import HAS_TRITON -if HAS_TRITON: +__all__ = ["HAS_TRITON"] + +if not HAS_TRITON: + from vllm.triton_utils.custom_cache_manager import ( maybe_set_triton_cache_manager) -__all__ = ["HAS_TRITON", "maybe_set_triton_cache_manager"] - -if not HAS_TRITON: - # need to do this afterwards due to ruff complaining - __all__.pop() + __all__ += ["maybe_set_triton_cache_manager"] From 936290d0c79eb55c2599924d9e19c1492b376f1e Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 14:07:01 -0400 Subject: [PATCH 13/13] Fix small bug introduced. Signed-off-by: Thomas Parnell --- vllm/triton_utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 4be743ada96e..568185383aa5 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -2,7 +2,7 @@ __all__ = ["HAS_TRITON"] -if not HAS_TRITON: +if HAS_TRITON: from vllm.triton_utils.custom_cache_manager import ( maybe_set_triton_cache_manager)