Skip to content

Commit

Permalink
[Core] Cross-attention KV caching and memory-management (towards even…
Browse files Browse the repository at this point in the history
…tual encoder/decoder model support) (vllm-project#4837)
  • Loading branch information
afeldman-nm authored and blinkbear committed Jun 6, 2024
1 parent baa43bd commit c73cb52
Showing 1 changed file with 0 additions and 51 deletions.
51 changes: 0 additions & 51 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,56 +179,5 @@ def create_seq_group_encoder_decoder(
encoder_seq=encoder_seq)


def create_seq_group_encoder_decoder(
seq_prompt_len: int = 1024,
seq_output_lens: Iterable[int] = (128, ),
request_id: str = '0',
seq_id_start: int = 0,
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:

assert len(seq_output_lens) > 0

if sampling_params is None:
sampling_params = SamplingParams()

prompt_token_ids = [0] * seq_prompt_len

seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)

for i in range(output_len):
seq.append_token_id(
token_id=i,
logprobs={i: Logprob(0.0)},
)
seqs.append(seq)

# Encoder sequence
encoder_seq = Sequence(
seq_id=seq_id_start + len(seq_output_lens),
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)

return SequenceGroup(request_id=request_id,
seqs=seqs,
sampling_params=sampling_params,
arrival_time=time.time(),
encoder_seq=encoder_seq)


def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size

0 comments on commit c73cb52

Please sign in to comment.