Skip to content

Commit

Permalink
[Bug Fix]Fix bugs in llama (#4601)
Browse files Browse the repository at this point in the history
* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* bug fix: remove rotary_positions_ids

---------

Co-authored-by: cuiqing.li <lixx3527@gmail.com>
  • Loading branch information
isky-cd and tiandiao123 authored Sep 4, 2023
1 parent 91338e3 commit c66849e
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,8 @@ def llama_flash_attn_kvcache_forward(
value_states_transposed = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states_transposed,
seq_len=infer_state.cache_manager.past_key_values_length)

rotary_positions_ids = position_ids
idx = position_ids.shape[0] - 1
if idx >= 1:
rotary_positions_ids = [[idx]]

query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, rotary_positions_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids)

query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
Expand Down

0 comments on commit c66849e

Please sign in to comment.