-
-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] [Bugfix] Refactor block manager subsystem for better testability #3492
Merged
youkaichao
merged 96 commits into
vllm-project:main
from
cadedaniel:block-manager-tests
Mar 28, 2024
Merged
Changes from 95 commits
Commits
Show all changes
96 commits
Select commit
Hold shift + click to select a range
85fb179
logical block test
cadedaniel 0f19984
sequence
cadedaniel 0306a8c
notes
cadedaniel de14e54
wip
cadedaniel 7d66c4a
prefix caching bug when prompt len < block size
cadedaniel e03e057
wip
cadedaniel c162283
refcount
cadedaniel 99a5b59
wip
cadedaniel 1fe4cbb
wip
cadedaniel 5e70924
wip
cadedaniel ea94ecc
wip
cadedaniel 376cdb6
wip
cadedaniel d7e122e
wip
cadedaniel 2b821dc
wip
cadedaniel 0a6fbd2
wip
cadedaniel 658b4c5
wip
cadedaniel e976541
wip
cadedaniel 085f419
content hash
cadedaniel 6fc22ef
wip
cadedaniel cbea543
unused cached blocks
cadedaniel 029d39a
wip
cadedaniel d2ca90b
wip
cadedaniel 1eee08c
break files
cadedaniel ebe6ccf
wip
cadedaniel 9dfc821
wip
cadedaniel 619fb0d
wip
cadedaniel ea49f23
device aware
cadedaniel 1252223
wip
cadedaniel c1e1b2f
wip
cadedaniel d0b4f20
wip0
cadedaniel a3cffb9
wip
cadedaniel cd75992
wip
cadedaniel 1d25cf2
wip
cadedaniel 335a218
wip
cadedaniel 960da58
wip
cadedaniel 63f5dd5
fork
cadedaniel d5ebfd2
fork
cadedaniel c127343
wip
cadedaniel a20051a
remove
cadedaniel 02e4154
simple generation works
cadedaniel 6f88528
interfaces
cadedaniel 70c3fff
wip
cadedaniel 5867272
wip
cadedaniel 65cfac8
wip
cadedaniel 7d059a6
lint
cadedaniel c286632
lint2
cadedaniel 46bbd14
lint3
cadedaniel 2e794de
lint4
cadedaniel 2416c22
lint5
cadedaniel 558ad36
v2 config
cadedaniel 3fa5b2b
lint
cadedaniel de2a5c9
Merge remote-tracking branch 'upstream/main' into block-manager-tests
cadedaniel 0464d48
clean
cadedaniel 6ac0318
wip
cadedaniel 9fb053c
wip
cadedaniel 7f33d2f
wip
cadedaniel 9455a46
cow in naive
cadedaniel 2f9ebac
wip
cadedaniel 26b6ce7
fix cow bug
cadedaniel 548aec8
cow test
cadedaniel f0025ab
wip
cadedaniel 3be4040
wip
cadedaniel 62a616b
wip
cadedaniel 9fd6c08
wip prefix cow
cadedaniel 6ded181
wip
cadedaniel 95b65f1
wip
cadedaniel b03693c
wip
cadedaniel d582cb6
wip
cadedaniel 70b1f60
lint
cadedaniel 3ce9347
lint2
cadedaniel ed6c2e6
Merge remote-tracking branch 'upstream/main' into block-manager-tests
cadedaniel 640d7e5
isort
cadedaniel 0f0daf8
fix
cadedaniel ba8acbd
wip
cadedaniel b51287c
adding to entrypoint tests
cadedaniel 1f3483f
try
cadedaniel 4ebc0c0
docstrings!
cadedaniel 1f09fd0
wip
cadedaniel 80cdc3c
more docstring / format
cadedaniel 36bd93f
entrypoints
cadedaniel 79dac79
model correctness test
cadedaniel b392a5d
remove
cadedaniel 8d42bd7
lint
cadedaniel 1b3fe9f
note
cadedaniel 9680dc8
remove
cadedaniel dd4bcee
clean
cadedaniel 9000b41
name
cadedaniel a2897b0
rename
cadedaniel bead69a
wip
cadedaniel 132e7a3
clean
cadedaniel 0d75e12
comment
cadedaniel 70d1812
lint
cadedaniel 321dc16
comment
cadedaniel 887496b
fix test
cadedaniel f0b1bf1
empty
cadedaniel 5b86297
pr feedback
cadedaniel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import contextlib | ||
import gc | ||
|
||
import pytest | ||
import ray | ||
import torch | ||
|
||
from vllm import LLM | ||
from vllm.model_executor.parallel_utils.parallel_state import ( | ||
destroy_model_parallel) | ||
from vllm.model_executor.utils import set_random_seed | ||
|
||
|
||
def cleanup(): | ||
destroy_model_parallel() | ||
with contextlib.suppress(AssertionError): | ||
torch.distributed.destroy_process_group() | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
ray.shutdown() | ||
|
||
|
||
@pytest.fixture | ||
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, | ||
baseline_llm_kwargs, seed): | ||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, | ||
baseline_llm_kwargs, seed) | ||
|
||
|
||
@pytest.fixture | ||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, | ||
test_llm_kwargs, seed): | ||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, | ||
test_llm_kwargs, seed) | ||
|
||
|
||
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, | ||
distinct_llm_kwargs, seed): | ||
kwargs = { | ||
**common_llm_kwargs, | ||
**per_test_common_llm_kwargs, | ||
**distinct_llm_kwargs, | ||
} | ||
|
||
def generator_inner(): | ||
llm = LLM(**kwargs) | ||
|
||
set_random_seed(seed) | ||
|
||
yield llm | ||
del llm | ||
cleanup() | ||
|
||
def generator_outer(): | ||
for llm in generator_inner(): | ||
yield llm | ||
del llm | ||
|
||
return generator_outer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from itertools import cycle | ||
|
||
import pytest | ||
|
||
from vllm import SamplingParams | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"common_llm_kwargs", | ||
[{ | ||
# Use a small model for a fast test. | ||
"model": "facebook/opt-125m", | ||
|
||
# skip cuda graph creation for fast test. | ||
"enforce_eager": True, | ||
|
||
# Allow only 5 sequences of ~1024 tokens in worst case. | ||
"block_size": 16, | ||
"forced_num_gpu_blocks": 5 * (64 + 1), | ||
}]) | ||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) | ||
@pytest.mark.parametrize("baseline_llm_kwargs", [{ | ||
"use_v2_block_manager": False | ||
}]) | ||
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) | ||
@pytest.mark.parametrize("batch_size", [10]) | ||
@pytest.mark.parametrize("seed", [1]) | ||
def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, | ||
test_llm_generator, batch_size): | ||
"""Verify block manager v2 produces same outputs as block manager v1, even | ||
when there is preemption. | ||
|
||
This constructs two LLM, each with limited number of GPU blocks. The limit | ||
is decided such that as the sequences in the batch grow, sequences must be | ||
preempted and removed from cache. | ||
|
||
If the output token ids are equivalent, then we have confidence that the KV | ||
cache is not corrupted in the v2 block manager. | ||
|
||
NOTE: We want a significant number of generated tokens so that any incorrect | ||
KV mapping has time to build up error. | ||
""" | ||
output_len = 1024 | ||
temperature = 0.0 | ||
|
||
# We want to ensure equality even with preemption. | ||
# We force the total block size to be 1 + cdiv(output_len, block_size) | ||
# so that only one sequence can fit at a time (once the sequences grow). | ||
|
||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
|
||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] | ||
|
||
sampling_params = SamplingParams( | ||
max_tokens=output_len, | ||
ignore_eos=True, | ||
temperature=temperature, | ||
) | ||
|
||
print('Getting token ids from block manager v1') | ||
baseline_token_ids = get_token_ids_from_llm_generator( | ||
baseline_llm_generator, prompts, sampling_params) | ||
|
||
print('Getting token ids from block manager v2') | ||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, | ||
prompts, sampling_params) | ||
|
||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids, | ||
test_token_ids): | ||
assert expected_token_ids == actual_token_ids | ||
|
||
assert baseline_token_ids == test_token_ids | ||
|
||
|
||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): | ||
for llm in llm_generator: | ||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True) | ||
token_ids = [output.outputs[0].token_ids for output in outputs] | ||
del llm | ||
|
||
return token_ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import pytest | ||
|
||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2 | ||
from vllm.core.interfaces import AllocStatus | ||
|
||
from ..utils import create_seq_group | ||
|
||
|
||
@pytest.mark.parametrize("block_size", [16]) | ||
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80]) | ||
@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) | ||
@pytest.mark.parametrize("watermark", [0.0, 0.5]) | ||
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, | ||
num_gpu_blocks: int, watermark: float): | ||
block_manager = BlockSpaceManagerV2( | ||
block_size=block_size, | ||
num_gpu_blocks=num_gpu_blocks, | ||
num_cpu_blocks=1024, | ||
watermark=watermark, | ||
) | ||
num_watermark_blocks = int(watermark * num_gpu_blocks) | ||
|
||
num_output_blocks_per_seq = 1 | ||
|
||
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but | ||
# the current implementation assumes all seqs are new prompts / don't have | ||
# different output lens. | ||
num_output_blocks = num_output_blocks_per_seq | ||
|
||
for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks): | ||
seq_group = create_seq_group( | ||
seq_prompt_lens=block_size * num_prompt_blocks, | ||
seq_output_lens=[ | ||
block_size * num_output_blocks_per_seq | ||
for _ in range(num_seqs_per_group) | ||
], | ||
) | ||
|
||
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks | ||
|
||
can_allocate_result = block_manager.can_allocate(seq_group) | ||
|
||
num_required_blocks = num_prompt_blocks + num_output_blocks | ||
|
||
if num_gpu_blocks - num_required_blocks < num_watermark_blocks: | ||
assert can_allocate_result == AllocStatus.NEVER | ||
elif num_gpu_blocks >= num_required_blocks: | ||
assert can_allocate_result == AllocStatus.OK | ||
else: | ||
assert can_allocate_result == AllocStatus.LATER |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need another level of wrapper? would it work without it? if not please comment why
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the usage of the llm generator below but still confused since we are only yielding one llm instance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh good catch, not necessary