@@ -45,6 +45,7 @@ def _fwd_kernel(
45
45
stride_v_cache_h ,
46
46
stride_v_cache_d ,
47
47
stride_v_cache_bl ,
48
+ num_queries_per_kv : int ,
48
49
BLOCK_M : tl .constexpr ,
49
50
BLOCK_DMODEL : tl .constexpr ,
50
51
BLOCK_N : tl .constexpr ,
@@ -53,6 +54,8 @@ def _fwd_kernel(
53
54
cur_head = tl .program_id (1 )
54
55
start_m = tl .program_id (2 )
55
56
57
+ cur_kv_head = cur_head // num_queries_per_kv
58
+
56
59
cur_batch_ctx_len = tl .load (B_Ctxlen + cur_batch )
57
60
cur_batch_seq_len = tl .load (B_Seqlen + cur_batch )
58
61
cur_batch_in_all_start_index = tl .load (B_Start_Loc + cur_batch )
@@ -85,13 +88,14 @@ def _fwd_kernel(
85
88
mask = (start_n + offs_n ) < cur_batch_ctx_len ,
86
89
other = 0 )
87
90
off_k = (bn [None , :] * stride_k_cache_bs +
88
- cur_head * stride_k_cache_h +
91
+ cur_kv_head * stride_k_cache_h +
89
92
(offs_d [:, None ] // x ) * stride_k_cache_d +
90
93
((start_n + offs_n [None , :]) % block_size ) *
91
94
stride_k_cache_bl +
92
95
(offs_d [:, None ] % x ) * stride_k_cache_x )
93
96
off_v = (
94
- bn [:, None ] * stride_v_cache_bs + cur_head * stride_v_cache_h +
97
+ bn [:, None ] * stride_v_cache_bs +
98
+ cur_kv_head * stride_v_cache_h +
95
99
offs_d [None , :] * stride_v_cache_d +
96
100
(start_n + offs_n [:, None ]) % block_size * stride_v_cache_bl )
97
101
k = tl .load (K_cache + off_k ,
@@ -131,9 +135,9 @@ def _fwd_kernel(
131
135
l_i = l_i_new
132
136
m_i = m_i_new
133
137
134
- off_k = (offs_n [None , :] * stride_kbs + cur_head * stride_kh +
138
+ off_k = (offs_n [None , :] * stride_kbs + cur_kv_head * stride_kh +
135
139
offs_d [:, None ] * stride_kd )
136
- off_v = (offs_n [:, None ] * stride_vbs + cur_head * stride_vh +
140
+ off_v = (offs_n [:, None ] * stride_vbs + cur_kv_head * stride_vh +
137
141
offs_d [None , :] * stride_vd )
138
142
k_ptrs = K + off_k
139
143
v_ptrs = V + off_v
@@ -232,6 +236,7 @@ def _fwd_kernel_flash_attn_v2(
232
236
stride_v_cache_h ,
233
237
stride_v_cache_d ,
234
238
stride_v_cache_bl ,
239
+ num_queries_per_kv : int ,
235
240
BLOCK_M : tl .constexpr ,
236
241
BLOCK_DMODEL : tl .constexpr ,
237
242
BLOCK_N : tl .constexpr ,
@@ -240,6 +245,8 @@ def _fwd_kernel_flash_attn_v2(
240
245
cur_head = tl .program_id (1 )
241
246
start_m = tl .program_id (2 )
242
247
248
+ cur_kv_head = cur_head // num_queries_per_kv
249
+
243
250
cur_batch_ctx_len = tl .load (B_Ctxlen + cur_batch )
244
251
cur_batch_seq_len = tl .load (B_Seqlen + cur_batch )
245
252
cur_batch_in_all_start_index = tl .load (B_Start_Loc + cur_batch )
@@ -272,13 +279,14 @@ def _fwd_kernel_flash_attn_v2(
272
279
mask = (start_n + offs_n ) < cur_batch_ctx_len ,
273
280
other = 0 )
274
281
off_k = (bn [None , :] * stride_k_cache_bs +
275
- cur_head * stride_k_cache_h +
282
+ cur_kv_head * stride_k_cache_h +
276
283
(offs_d [:, None ] // x ) * stride_k_cache_d +
277
284
((start_n + offs_n [None , :]) % block_size ) *
278
285
stride_k_cache_bl +
279
286
(offs_d [:, None ] % x ) * stride_k_cache_x )
280
287
off_v = (
281
- bn [:, None ] * stride_v_cache_bs + cur_head * stride_v_cache_h +
288
+ bn [:, None ] * stride_v_cache_bs +
289
+ cur_kv_head * stride_v_cache_h +
282
290
offs_d [None , :] * stride_v_cache_d +
283
291
(start_n + offs_n [:, None ]) % block_size * stride_v_cache_bl )
284
292
k = tl .load (K_cache + off_k ,
@@ -317,9 +325,9 @@ def _fwd_kernel_flash_attn_v2(
317
325
l_i = l_i_new
318
326
m_i = m_i_new
319
327
320
- off_k = (offs_n [None , :] * stride_kbs + cur_head * stride_kh +
328
+ off_k = (offs_n [None , :] * stride_kbs + cur_kv_head * stride_kh +
321
329
offs_d [:, None ] * stride_kd )
322
- off_v = (offs_n [:, None ] * stride_vbs + cur_head * stride_vh +
330
+ off_v = (offs_n [:, None ] * stride_vbs + cur_kv_head * stride_vh +
323
331
offs_d [None , :] * stride_vd )
324
332
k_ptrs = K + off_k
325
333
v_ptrs = V + off_v
@@ -420,6 +428,7 @@ def _fwd_kernel_alibi(
420
428
stride_v_cache_h ,
421
429
stride_v_cache_d ,
422
430
stride_v_cache_bl ,
431
+ num_queries_per_kv : int ,
423
432
BLOCK_M : tl .constexpr ,
424
433
BLOCK_DMODEL : tl .constexpr ,
425
434
BLOCK_N : tl .constexpr ,
@@ -429,6 +438,8 @@ def _fwd_kernel_alibi(
429
438
cur_head = tl .program_id (1 )
430
439
start_m = tl .program_id (2 )
431
440
441
+ cur_kv_head = cur_head // num_queries_per_kv
442
+
432
443
# cur_batch_seq_len: the length of prompts
433
444
# cur_batch_ctx_len: the length of prefix
434
445
# cur_batch_in_all_start_index: the start id of the dim=0
@@ -468,13 +479,14 @@ def _fwd_kernel_alibi(
468
479
mask = (start_n + offs_n ) < cur_batch_ctx_len ,
469
480
other = 0 )
470
481
off_k = (bn [None , :] * stride_k_cache_bs +
471
- cur_head * stride_k_cache_h +
482
+ cur_kv_head * stride_k_cache_h +
472
483
(offs_d [:, None ] // x ) * stride_k_cache_d +
473
484
((start_n + offs_n [None , :]) % block_size ) *
474
485
stride_k_cache_bl +
475
486
(offs_d [:, None ] % x ) * stride_k_cache_x )
476
487
off_v = (
477
- bn [:, None ] * stride_v_cache_bs + cur_head * stride_v_cache_h +
488
+ bn [:, None ] * stride_v_cache_bs +
489
+ cur_kv_head * stride_v_cache_h +
478
490
offs_d [None , :] * stride_v_cache_d +
479
491
(start_n + offs_n [:, None ]) % block_size * stride_v_cache_bl )
480
492
k = tl .load (K_cache + off_k ,
@@ -522,9 +534,9 @@ def _fwd_kernel_alibi(
522
534
l_i = l_i_new
523
535
m_i = m_i_new
524
536
525
- off_k = (offs_n [None , :] * stride_kbs + cur_head * stride_kh +
537
+ off_k = (offs_n [None , :] * stride_kbs + cur_kv_head * stride_kh +
526
538
offs_d [:, None ] * stride_kd )
527
- off_v = (offs_n [:, None ] * stride_vbs + cur_head * stride_vh +
539
+ off_v = (offs_n [:, None ] * stride_vbs + cur_kv_head * stride_vh +
528
540
offs_d [None , :] * stride_vd )
529
541
k_ptrs = K + off_k
530
542
v_ptrs = V + off_v
@@ -628,6 +640,7 @@ def context_attention_fwd(q,
628
640
629
641
sm_scale = 1.0 / (Lq ** 0.5 )
630
642
batch , head = b_seq_len .shape [0 ], q .shape [1 ]
643
+ num_queries_per_kv = q .shape [1 ] // k .shape [1 ]
631
644
632
645
grid = (batch , head , triton .cdiv (max_input_len , BLOCK )) # batch, head,
633
646
@@ -674,6 +687,7 @@ def context_attention_fwd(q,
674
687
v_cache .stride (2 ),
675
688
v_cache .stride (
676
689
3 ), #[num_blocks, num_kv_heads, head_size, block_size]
690
+ num_queries_per_kv = num_queries_per_kv ,
677
691
BLOCK_M = BLOCK ,
678
692
BLOCK_DMODEL = Lk ,
679
693
BLOCK_N = BLOCK ,
@@ -721,6 +735,7 @@ def context_attention_fwd(q,
721
735
v_cache .stride (2 ),
722
736
v_cache .stride (
723
737
3 ), #[num_blocks, num_kv_heads, head_size, block_size]
738
+ num_queries_per_kv = num_queries_per_kv ,
724
739
BLOCK_M = BLOCK ,
725
740
BLOCK_DMODEL = Lk ,
726
741
BLOCK_N = BLOCK ,
0 commit comments