Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling #3103

Merged
merged 73 commits into from
Mar 9, 2024
Merged
Changes from 1 commit
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
d74ff5c
first test passes
cadedaniel Feb 24, 2024
74b5c48
test
cadedaniel Feb 24, 2024
f6a730b
test
cadedaniel Feb 24, 2024
aafebd0
test
cadedaniel Feb 29, 2024
415db01
test
cadedaniel Feb 29, 2024
76dfe1a
test
cadedaniel Feb 29, 2024
069b564
test
cadedaniel Feb 29, 2024
7f13ccd
test for metrics, separate out metrics functionality
cadedaniel Feb 29, 2024
c91a55b
metrics test
cadedaniel Feb 29, 2024
3a69d54
clean
cadedaniel Feb 29, 2024
73212b7
test
cadedaniel Feb 29, 2024
f796ab0
fixes
cadedaniel Feb 29, 2024
384dc9d
nvtx_range
cadedaniel Feb 29, 2024
2d1a192
profile and cache tests
cadedaniel Feb 29, 2024
bec5cba
test
cadedaniel Feb 29, 2024
273baea
Merge remote-tracking branch 'upstream/main' into draft-target-worker
cadedaniel Feb 29, 2024
7a18f37
lint
cadedaniel Feb 29, 2024
4eb8e04
attempt add tests to ci
cadedaniel Feb 29, 2024
e0ec4b4
refactor outline
cadedaniel Mar 4, 2024
b7e580b
wip
cadedaniel Mar 4, 2024
665ed8e
WIP
cadedaniel Mar 5, 2024
8fcb257
sampler mock raw tensors
cadedaniel Mar 5, 2024
79a1f6c
wip
cadedaniel Mar 5, 2024
7a42183
asd
cadedaniel Mar 5, 2024
c86b44e
asd
cadedaniel Mar 5, 2024
f5e5d76
wip
cadedaniel Mar 5, 2024
68284ed
bugfix
cadedaniel Mar 5, 2024
87cc31a
wip
cadedaniel Mar 5, 2024
c142006
wip
cadedaniel Mar 5, 2024
c026de9
wip
cadedaniel Mar 5, 2024
1486a84
wip
cadedaniel Mar 5, 2024
264b5cb
wip
cadedaniel Mar 5, 2024
8cc8caf
wip
cadedaniel Mar 5, 2024
ee1efff
wip
cadedaniel Mar 5, 2024
cadac54
wip
cadedaniel Mar 5, 2024
2c2dd86
wip
cadedaniel Mar 5, 2024
b112463
wip
cadedaniel Mar 5, 2024
e09c666
wip
cadedaniel Mar 5, 2024
807aa02
fix
cadedaniel Mar 5, 2024
5545e73
clean
cadedaniel Mar 5, 2024
49c9798
wip
cadedaniel Mar 6, 2024
978a711
remove
cadedaniel Mar 6, 2024
d548db4
clean
cadedaniel Mar 6, 2024
a277fe0
wip
cadedaniel Mar 6, 2024
c1357d6
wip
cadedaniel Mar 6, 2024
136a59b
wip
cadedaniel Mar 6, 2024
5657feb
wip
cadedaniel Mar 6, 2024
059beba
rename
cadedaniel Mar 6, 2024
4ce7119
wip
cadedaniel Mar 6, 2024
db52bee
clean
cadedaniel Mar 6, 2024
20297f2
clean
cadedaniel Mar 6, 2024
13aebbf
wip
cadedaniel Mar 6, 2024
524ada4
clean
cadedaniel Mar 6, 2024
e3f57b0
wip
cadedaniel Mar 6, 2024
aff7a34
clean
cadedaniel Mar 6, 2024
c2beb94
rename
cadedaniel Mar 6, 2024
2c835de
first autoformat
cadedaniel Mar 6, 2024
292c34f
wip
cadedaniel Mar 6, 2024
c387d56
wip
cadedaniel Mar 6, 2024
39382b7
move
cadedaniel Mar 6, 2024
f78325c
move
cadedaniel Mar 6, 2024
5cafc12
move
cadedaniel Mar 6, 2024
3a5dcfb
name
cadedaniel Mar 6, 2024
8e7ee97
sequence test and docs
cadedaniel Mar 6, 2024
e5e334f
docs
cadedaniel Mar 6, 2024
2d27c57
lint
cadedaniel Mar 6, 2024
3c61a52
typo
cadedaniel Mar 6, 2024
be66076
Merge remote-tracking branch 'upstream/main' into draft-target-worker
cadedaniel Mar 6, 2024
b165a73
lint
cadedaniel Mar 6, 2024
17725fb
fix test
cadedaniel Mar 6, 2024
364a415
Merge remote-tracking branch 'upstream/main' into draft-target-worker
cadedaniel Mar 6, 2024
11b6f39
pr feedback
cadedaniel Mar 9, 2024
ab00fcf
better comment
cadedaniel Mar 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pr feedback
  • Loading branch information
cadedaniel committed Mar 9, 2024
commit 11b6f393601c177149c0fee42802eac133ce9af1
83 changes: 83 additions & 0 deletions tests/spec_decode/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from vllm.spec_decode.util import get_all_seq_ids
from vllm.sequence import SequenceGroupMetadata
from vllm.spec_decode.util import split_batch_by_proposal_len

import pytest
from unittest.mock import MagicMock


@@ -26,3 +28,84 @@ def test_get_all_seq_ids():

actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
assert actual_seq_ids == expected_seq_ids


@pytest.fixture
def fake_sequence_group_metadata():
seq_ids = list(range(3))
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=True,
seq_data={
i: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
i: MagicMock(),
},
lora_request=None,
) for i in seq_ids
]


def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)

expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
]
expected_indices = [0, 2]

assert filtered_groups == expected_groups
assert indices == expected_indices


def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)

expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
]
expected_indices = [1, 2]

assert filtered_groups == expected_groups
assert indices == expected_indices


def test_empty_inputs():
filtered_groups, indices = split_batch_by_proposal_len(
[], [], select_proposal_len_zero=True)

assert filtered_groups == []
assert indices == []


def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=False)

assert filtered_groups == []
assert indices == []


def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1]
filtered_groups, indices = split_batch_by_proposal_len(
fake_sequence_group_metadata,
proposal_lens,
select_proposal_len_zero=True)

assert filtered_groups == []
assert indices == []
48 changes: 23 additions & 25 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@

from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData)
from vllm.worker.worker import Worker
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores

SeqId = int
@@ -107,27 +107,18 @@ def _expand_batch(
query token.
"""

spec_seqs = [
seq_group for seq_group, proposal_len in zip(
seq_group_metadata_list, proposal_lens_list)
if proposal_len != 0
]
spec_indices = [
i for i, (_, proposal_len) in enumerate(
zip(seq_group_metadata_list, proposal_lens_list))
if proposal_len != 0
]

non_spec_seqs = [
seq_group for seq_group, proposal_len in zip(
seq_group_metadata_list, proposal_lens_list)
if proposal_len == 0
]
non_spec_indices = [
i for i, (_, proposal_len) in enumerate(
zip(seq_group_metadata_list, proposal_lens_list))
if proposal_len == 0
]
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
spec_seqs, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)

target_seq_group_metadata_list = self._create_scoring_model_input(
spec_seqs, proposal_token_ids_list)
@@ -159,8 +150,10 @@ def _contract_batch(self, original_bs: int,
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
self._vocab_size)

all_tokens = torch.ones(
original_bs, k + 1, device=self._device, dtype=torch.long) * -1
all_tokens = torch.full(size=(original_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
all_probs = torch.zeros(original_bs,
k + 1,
self._vocab_size,
@@ -169,7 +162,7 @@ def _contract_batch(self, original_bs: int,

if non_spec_indices:
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
all_probs[non_spec_indices, 1:, :] = non_spec_target_probs
all_probs[non_spec_indices, :1, :] = non_spec_target_probs

if spec_indices:
all_tokens[spec_indices] = target_token_ids
@@ -288,6 +281,11 @@ def _split_scoring_output(
output.
"""

# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
#
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
split_sizes = [
69 changes: 42 additions & 27 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
@@ -206,6 +206,18 @@ def _raise_if_unsupported(
class DraftModelTop1Proposer(SpeculativeProposer):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.

This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.

We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.

Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""

def __init__(
@@ -270,6 +282,32 @@ def get_proposals(

return proposals

def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
max_proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length.
"""

proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()

# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
if seq_len + max_proposal_len < self._max_model_len:
proposal_lens.append(max_proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)

return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices

def _merge_outputs(
self,
batch_size: int,
@@ -306,10 +344,11 @@ def _merge_outputs(
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]

entire_proposal_tokens = torch.ones(batch_size,
*proposal_tokens.shape[1:],
entire_proposal_tokens = torch.full(size=(batch_size,
*proposal_tokens.shape[1:]),
fill_value=-1,
dtype=torch.long,
device=self._device) * -1
device=self._device)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = torch.zeros(batch_size,
*proposal_probs.shape[1:],
@@ -325,27 +364,3 @@ def _merge_outputs(
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len

return proposal_tokens, proposal_probs, proposal_lens

def _split_by_max_model_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
max_proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length.
"""

proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()

if seq_len + max_proposal_len < self._max_model_len:
proposal_lens.append(max_proposal_len)
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)

return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices
25 changes: 14 additions & 11 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.config import CacheConfig
from vllm.spec_decode.util import nvtx_range, get_all_seq_ids
from vllm.spec_decode.util import nvtx_range, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeScorer
@@ -220,16 +220,19 @@ def _verify_tokens(
probabilities of each token according to the proposer and scorer models.
"""
proposal_lens_list = proposals.proposal_lens.tolist()
spec_indices = [
i for i, (_, proposal_len) in enumerate(
zip(seq_group_metadata_list, proposal_lens_list))
if proposal_len != 0
]
non_spec_indices = [
i for i, (_, proposal_len) in enumerate(
zip(seq_group_metadata_list, proposal_lens_list))
if proposal_len == 0
]

# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
_, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
_, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
original_indices = spec_indices + non_spec_indices

proposal_probs = proposal_scores.probs[spec_indices, :-1]
27 changes: 27 additions & 0 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,33 @@ def get_all_seq_ids(
]))


def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""

if select_proposal_len_zero:
predicate = lambda proposal_len: proposal_len == 0
else:
predicate = lambda proposal_len: proposal_len != 0

indices = [
i for i, (_, proposal_len
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
if predicate(proposal_len)
]
seq_groups = [
seq_group for seq_group, proposal_len in zip(
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
]

return seq_groups, indices


def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput],
) -> Tuple[torch.Tensor, torch.Tensor]:
Loading