Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Enable prefix caching with block manager v2 enabled #4142

Merged
merged 7 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):


def main(args):
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat",
llm = LLM(model=args.model,
tokenizer_mode='auto',
trust_remote_code=True,
enforce_eager=True,
use_v2_block_manager=args.use_v2_block_manager,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching)

num_prompts = 100
prompts = [PROMPT] * num_prompts
sampling_params = SamplingParams(temperature=0, max_tokens=100)
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

print("------warm up------")
test_prefix(
llm=llm,
prompts=prompts[:1],
prompts=prompts,
sampling_params=sampling_params,
)

Expand All @@ -45,8 +47,16 @@ def main(args):
parser = argparse.ArgumentParser(
description='Benchmark the performance with or without automatic '
'prefix caching.')
parser.add_argument('--model',
type=str,
default='baichuan-inc/Baichuan2-13B-Chat')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)
146 changes: 146 additions & 0 deletions tests/core/block/e2e/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,152 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",

# skip cuda graph creation for fast test.
"enforce_eager": True,

# Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16,
"num_gpu_blocks_override": 5 * (64 + 1),

# Enable prefill cache
"enable_prefix_caching": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
baseline_llm_generator, test_llm_generator, batch_size):
"""Verify block manager v2 produces same outputs as block manager v1, even
when there is preemption.

This constructs two LLM, each with limited number of GPU blocks. The limit
is decided such that as the sequences in the batch grow, sequences must be
preempted and removed from cache.

If the output token ids are equivalent, then we have confidence that the KV
cache is not corrupted in the v2 block manager.

NOTE: We want a significant number of generated tokens so that any incorrect
KV mapping has time to build up error.
"""
output_len = 1024
temperature = 0.0

# We want to ensure equality even with preemption.
# We force the total block size to be 1 + cdiv(output_len, block_size)
# so that only one sequence can fit at a time (once the sequences grow).

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)

print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)

for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids

assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",

# skip cuda graph creation for fast test.
"enforce_eager": True,

# Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16,
"num_gpu_blocks_override": 5 * (64 + 1),

# Test APC in v2 block
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"enable_prefix_caching": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify block manager v2 with auto prefix caching enabled produces same
outputs as auto prefix caching disabled, even when there is preemption.

This constructs two LLM, each with limited number of GPU blocks. The limit
is decided such that as the sequences in the batch grow, sequences must be
preempted and removed from cache.

If the output token ids are equivalent, then we have confidence that auto
prefix caching itself at least don't cause result error.
"""
output_len = 1024
temperature = 0.0

# We want to ensure equality even with preemption.
# We force the total block size to be 1 + cdiv(output_len, block_size)
# so that only one sequence can fit at a time (once the sequences grow).
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

print('Getting token ids with APC disabled')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)

print('Getting token ids with APC enabled')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)

for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids

assert baseline_token_ids == test_token_ids


def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
Expand Down
125 changes: 125 additions & 0 deletions tests/core/block/test_prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,131 @@ def test_get_num_free_blocks_shared(num_blocks: int, block_size: int,
i)
allocator.free(block)

@staticmethod
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("seed", list(range(20)))
def test_get_common_computed_block_ids(num_blocks: int, block_size: int,
seed: int):
"""Verify get_common_computed_block_ids could get correct result
by create two immutable chain sharing prefix at specified pos,
and compare whether we also could get right result
from get_common_computed_block_ids.
"""
random.seed(seed)
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2,
block_size=block_size)
num_blocks_to_consume = random.randint(1, num_blocks - 1)

# Create token ids that will exhaust all blocks.
token_ids = list(range(num_blocks_to_consume * block_size))
blocks = list(range(num_blocks_to_consume))

first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)

# mark all blocks in first chain as computed
allocator.mark_blocks_as_computed(blocks)

# After zero_point, second_chain's token_ids would be set -1, which
# make it different from here comparing with first_chain
zero_point = random.randint(1, len(token_ids) - 1)
zero_point_blocks = zero_point // block_size
token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point)

second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)

first_computed_ids = [
first_chain[i].block_id for i in range(num_blocks_to_consume)
]
second_computed_ids = [
second_chain[i].block_id for i in range(num_blocks_to_consume)
]
res = allocator.get_common_computed_block_ids(
[first_computed_ids, second_computed_ids])

assert (len(res) == zero_point_blocks)

# Test case where two last accessed times are equal
@staticmethod
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("seed", list(range(20)))
def test_eviction_order(num_blocks: int, block_size: int, seed: int):
"""This test case simulate the two chain created and free in order,
and together they would exhaust the initial freed blocks.

So the next block created after those two chain shall use the block
from the first chain as that block has long access time.
While first chain has two blocks, it shall pick up the last one, as
it has larger token number.
"""

random.seed(seed)
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
num_blocks_to_consume = num_blocks + 1

token_ids = list(range(num_blocks_to_consume * block_size))

num_blocks_in_first_chain = 2
num_tokens_in_first_chain = block_size * num_blocks_in_first_chain
# First chain takes the first block
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids[:num_tokens_in_first_chain],
allocator=allocator,
)
# There should only be one block allocated at this point
assert allocator.get_num_free_blocks() == (num_blocks -
num_blocks_in_first_chain)

# Set the last accessed time of the first block to 1
blocks_ids = [block.block_id for block in first_chain]
allocator.mark_blocks_as_accessed(blocks_ids, 1)

# Second chain takes the rest of the blocks
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids[num_tokens_in_first_chain:-block_size],
allocator=allocator,
)

# There shouldn't be any blocks left at this point
assert allocator.get_num_free_blocks() == (0)

assert len(first_chain) == num_blocks_in_first_chain
last_block_id = first_chain[-1].block_id
# Free each block in the first chain.
for i, block in enumerate(first_chain):
allocator.free(block)

# Set the last accessed time on all of the blocks in the second chain
# to 2
blocks_ids = [block.block_id for block in second_chain]
allocator.mark_blocks_as_accessed(blocks_ids, 2)

# Free each block in the second chain.
for i, block in enumerate(second_chain):
allocator.free(block)

# Allocate a new block and check that it's the least recently used block
# from the first chain.
new_block = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids[-block_size:],
allocator=allocator,
)

assert new_block[0].block_id == last_block_id

@staticmethod
def create_immutable_chain(
block_size: int,
Expand Down
12 changes: 10 additions & 2 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,18 @@ def clear_copy_on_writes(self) -> Dict[int, List[int]]:
device = Device.GPU
return self._allocators[device].clear_copy_on_writes()

def mark_blocks_as_computed(self) -> None:
def mark_blocks_as_accessed(self, block_ids: List[int],
leiwen83 marked this conversation as resolved.
Show resolved Hide resolved
now: float) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed()
return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
leiwen83 marked this conversation as resolved.
Show resolved Hide resolved

def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed(block_ids)

def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
Expand Down
4 changes: 4 additions & 0 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def all_block_ids(self) -> FrozenSet[int]:
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass

@abstractmethod
def mark_blocks_as_accessed(self) -> None:
pass

@abstractmethod
def mark_blocks_as_computed(self) -> None:
pass
Expand Down
11 changes: 10 additions & 1 deletion vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,16 @@ def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
"""
return self._cow_tracker.clear_cows()

def mark_blocks_as_computed(self) -> None:
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, used in prefix caching.

Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass

def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching.

Since the naive allocator does not implement prefix caching, we do
Expand Down
Loading
Loading