Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling #3103

Merged
merged 73 commits into from
Mar 9, 2024
Merged
Changes from 1 commit
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
d74ff5c
first test passes
cadedaniel Feb 24, 2024
74b5c48
test
cadedaniel Feb 24, 2024
f6a730b
test
cadedaniel Feb 24, 2024
aafebd0
test
cadedaniel Feb 29, 2024
415db01
test
cadedaniel Feb 29, 2024
76dfe1a
test
cadedaniel Feb 29, 2024
069b564
test
cadedaniel Feb 29, 2024
7f13ccd
test for metrics, separate out metrics functionality
cadedaniel Feb 29, 2024
c91a55b
metrics test
cadedaniel Feb 29, 2024
3a69d54
clean
cadedaniel Feb 29, 2024
73212b7
test
cadedaniel Feb 29, 2024
f796ab0
fixes
cadedaniel Feb 29, 2024
384dc9d
nvtx_range
cadedaniel Feb 29, 2024
2d1a192
profile and cache tests
cadedaniel Feb 29, 2024
bec5cba
test
cadedaniel Feb 29, 2024
273baea
Merge remote-tracking branch 'upstream/main' into draft-target-worker
cadedaniel Feb 29, 2024
7a18f37
lint
cadedaniel Feb 29, 2024
4eb8e04
attempt add tests to ci
cadedaniel Feb 29, 2024
e0ec4b4
refactor outline
cadedaniel Mar 4, 2024
b7e580b
wip
cadedaniel Mar 4, 2024
665ed8e
WIP
cadedaniel Mar 5, 2024
8fcb257
sampler mock raw tensors
cadedaniel Mar 5, 2024
79a1f6c
wip
cadedaniel Mar 5, 2024
7a42183
asd
cadedaniel Mar 5, 2024
c86b44e
asd
cadedaniel Mar 5, 2024
f5e5d76
wip
cadedaniel Mar 5, 2024
68284ed
bugfix
cadedaniel Mar 5, 2024
87cc31a
wip
cadedaniel Mar 5, 2024
c142006
wip
cadedaniel Mar 5, 2024
c026de9
wip
cadedaniel Mar 5, 2024
1486a84
wip
cadedaniel Mar 5, 2024
264b5cb
wip
cadedaniel Mar 5, 2024
8cc8caf
wip
cadedaniel Mar 5, 2024
ee1efff
wip
cadedaniel Mar 5, 2024
cadac54
wip
cadedaniel Mar 5, 2024
2c2dd86
wip
cadedaniel Mar 5, 2024
b112463
wip
cadedaniel Mar 5, 2024
e09c666
wip
cadedaniel Mar 5, 2024
807aa02
fix
cadedaniel Mar 5, 2024
5545e73
clean
cadedaniel Mar 5, 2024
49c9798
wip
cadedaniel Mar 6, 2024
978a711
remove
cadedaniel Mar 6, 2024
d548db4
clean
cadedaniel Mar 6, 2024
a277fe0
wip
cadedaniel Mar 6, 2024
c1357d6
wip
cadedaniel Mar 6, 2024
136a59b
wip
cadedaniel Mar 6, 2024
5657feb
wip
cadedaniel Mar 6, 2024
059beba
rename
cadedaniel Mar 6, 2024
4ce7119
wip
cadedaniel Mar 6, 2024
db52bee
clean
cadedaniel Mar 6, 2024
20297f2
clean
cadedaniel Mar 6, 2024
13aebbf
wip
cadedaniel Mar 6, 2024
524ada4
clean
cadedaniel Mar 6, 2024
e3f57b0
wip
cadedaniel Mar 6, 2024
aff7a34
clean
cadedaniel Mar 6, 2024
c2beb94
rename
cadedaniel Mar 6, 2024
2c835de
first autoformat
cadedaniel Mar 6, 2024
292c34f
wip
cadedaniel Mar 6, 2024
c387d56
wip
cadedaniel Mar 6, 2024
39382b7
move
cadedaniel Mar 6, 2024
f78325c
move
cadedaniel Mar 6, 2024
5cafc12
move
cadedaniel Mar 6, 2024
3a5dcfb
name
cadedaniel Mar 6, 2024
8e7ee97
sequence test and docs
cadedaniel Mar 6, 2024
e5e334f
docs
cadedaniel Mar 6, 2024
2d27c57
lint
cadedaniel Mar 6, 2024
3c61a52
typo
cadedaniel Mar 6, 2024
be66076
Merge remote-tracking branch 'upstream/main' into draft-target-worker
cadedaniel Mar 6, 2024
b165a73
lint
cadedaniel Mar 6, 2024
17725fb
fix test
cadedaniel Mar 6, 2024
364a415
Merge remote-tracking branch 'upstream/main' into draft-target-worker
cadedaniel Mar 6, 2024
11b6f39
pr feedback
cadedaniel Mar 9, 2024
ab00fcf
better comment
cadedaniel Mar 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor outline
cadedaniel committed Mar 4, 2024
commit e0ec4b418460bb6d07d47b03d9ba62e282f366fe
95 changes: 95 additions & 0 deletions vllm/worker/spec_decode/draft_target_worker.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from functools import cached_property
import logging
import time
from dataclasses import dataclass

#import msgspec
import torch
@@ -34,6 +35,100 @@

logger = logging.getLogger(__name__)

"""
Run spec step (ExecuteModelData -> AcceptedTokens)
- get proposals (ExecuteModelData -> Top1Proposals)
- score proposals ((ExecuteModelData, Top1Proposals) -> Top1Scores)
- verify proposals ((ExecuteModelData, Top1Proposals, Top1Scores) -> AcceptedTokens)
Missing pieces:
- get_proposals in DraftTargetWorker --> currently all mocked
- Tests cover vertical functionality.
I can start by writing get_proposals without batch expansion, write tests.
Then abstract out score proposals batch expansion. Move tests over to batch expander.
Then abstract out Top1Scores and rejection sampler interface. Fix tests.
"""

@dataclass
class Top1Proposals:
proposal_token_ids: torch.Tensor
proposal_token_probs: torch.Tensor

# not sure
proposal_lens: torch.Tensor

@dataclass
class Top1Scores:
# TODO: how to represent bonus token ?
# current thinking ->
# always have bonus token such that num_scored_tokens == k+1.
token_probs: torch.Tensor

@dataclass
class AcceptedTokens:
accepted_tokens: torch.Tensor
logprobs: torch.Tensor

class Top1ProposerWorker:

def init_model(self):
pass
# etc.

def get_proposals(
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_spec_tokens: int,
) -> Top1Proposals:
"""
- determine k
- drop k=0
- fwd pass
- create output
"""
pass

class Top1ScorerWorker:
def init_model(self):
pass
# etc.

def score_proposals(
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_spec_tokens: int,
proposals: Top1Proposals,
) -> Top1Scores:
"""
- extract k=0
- batch expand k>0
- combine
- fwd pass
- batch contract
"""
pass

class Top1Verifier:
def init_gpu_tensors(self):
pass

def verify_proposals(
#seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
#blocks_to_swap_in: Optional[Dict[int, int]],
#blocks_to_swap_out: Optional[Dict[int, int]],
#blocks_to_copy: Optional[Dict[int, List[int]]],
#num_spec_tokens: int,
proposals: Top1Proposals,
scores: Top1Scores,
) -> AcceptedTokens:
pass

class DraftTargetWorker: