diff --git a/csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp b/csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp index a7753d2c0..6762b7cc8 100644 --- a/csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp +++ b/csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp @@ -798,7 +798,8 @@ scale_dot_product_for_indirect_access_kv_cache( } } } - flag_access[thread_id][bi][hi] = 1; + if (flag_access[thread_id][bi][hi] == 0) + flag_access[thread_id][bi][hi] = 1; } } } @@ -1102,7 +1103,8 @@ scale_dot_product_for_indirect_access_kv_cache_half( flag_access[thread_id][bi][hi]); } } - flag_access[thread_id][bi][hi] = 1; + if (flag_access[thread_id][bi][hi] == 0) + flag_access[thread_id][bi][hi] = 1; } } }