Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PREFIX CACHING FOLLOW UP] A bunch of fixes to block allocator performance when automatic prefix caching is disabled #3357

Merged
merged 20 commits into from
Mar 20, 2024

Conversation

ElizaWszola
Copy link
Contributor

@ElizaWszola ElizaWszola commented Mar 12, 2024

The performance of block allocator went down after implementing automatic prefix caching, even when running with prefix caching disabled. This pr brings back parts of the old code and regains some of the lost performance in the scenario with disabled prefix caching.

Benchmarked with:

python benchmark_throughput_cache.py --backend vllm --model huggyllama/llama-7b --dataset ../data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000

Performance before introducing automatic prefix caching (commit baee28c):
Throughput: 10.37 requests/s, 5062.42 tokens/s
Throughput: 10.46 requests/s, 5102.27 tokens/s
Throughput: 10.47 requests/s, 5107.30 tokens/s
Throughput: 10.48 requests/s, 5113.97 tokens/s
Throughput: 10.53 requests/s, 5137.21 tokens/s
Throughput: 10.54 requests/s, 5145.38 tokens/s
Throughput: 10.56 requests/s, 5153.24 tokens/s
Throughput: 10.57 requests/s, 5157.54 tokens/s
Throughput: 10.63 requests/s, 5187.32 tokens/s
Throughput: 10.65 requests/s, 5198.19 tokens/s

Performance after introducing changes in this PR to commit ce4f5a2:
Throughput: 10.40 requests/s, 5076.05 tokens/s
Throughput: 10.53 requests/s, 5137.97 tokens/s
Throughput: 10.57 requests/s, 5156.04 tokens/s
Throughput: 10.60 requests/s, 5173.07 tokens/s
Throughput: 10.61 requests/s, 5177.02 tokens/s
Throughput: 10.62 requests/s, 5179.91 tokens/s
Throughput: 10.63 requests/s, 5186.06 tokens/s
Throughput: 10.63 requests/s, 5186.63 tokens/s
Throughput: 10.64 requests/s, 5193.72 tokens/s
Throughput: 10.67 requests/s, 5207.76 tokens/s


(OLD)

Benchmark results (10 runs each):

Performance before introducing automatic prefix caching (commit baee28c):
Throughput: 10.15 requests/s, 4909.50 tokens/s
Throughput: 10.17 requests/s, 4918.22 tokens/s
Throughput: 10.20 requests/s, 4936.93 tokens/s
Throughput: 10.23 requests/s, 4949.76 tokens/s
Throughput: 10.22 requests/s, 4945.64 tokens/s
Throughput: 10.27 requests/s, 4967.08 tokens/s
Throughput: 10.28 requests/s, 4971.52 tokens/s
Throughput: 10.29 requests/s, 4980.92 tokens/s
Throughput: 10.29 requests/s, 4976.94 tokens/s
Throughput: 10.30 requests/s, 4982.69 tokens/s

Performance after introducing automatic prefix caching (commit ce4f5a2):
Throughput: 9.91 requests/s, 4795.14 tokens/s
Throughput: 9.98 requests/s, 4830.01 tokens/s
Throughput: 9.99 requests/s, 4832.00 tokens/s
Throughput: 10.00 requests/s, 4839.62 tokens/s
Throughput: 10.03 requests/s, 4851.13 tokens/s
Throughput: 10.06 requests/s, 4868.87 tokens/s
Throughput: 10.07 requests/s, 4873.87 tokens/s
Throughput: 10.07 requests/s, 4872.51 tokens/s
Throughput: 10.08 requests/s, 4876.18 tokens/s
Throughput: 10.08 requests/s, 4877.26 tokens/s

Performance after introducing changes in this PR to commit ce4f5a2:
Throughput: 10.07 requests/s, 4873.42 tokens/s
Throughput: 10.17 requests/s, 4919.84 tokens/s
Throughput: 10.18 requests/s, 4923.71 tokens/s
Throughput: 10.18 requests/s, 4925.56 tokens/s
Throughput: 10.19 requests/s, 4928.09 tokens/s
Throughput: 10.20 requests/s, 4937.20 tokens/s
Throughput: 10.21 requests/s, 4942.21 tokens/s
Throughput: 10.21 requests/s, 4938.38 tokens/s
Throughput: 10.21 requests/s, 4940.22 tokens/s
Throughput: 10.22 requests/s, 4946.95 tokens/s

@zhuohan123
Copy link
Member

cc @cadedaniel

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

  • I am concerned that our test coverage of the block manager is not sufficient to allow for refactors w/o good tests. There's a few branches in this PR that are only for prefix caching, which adds a lot of complexity.
  • Could you comment on what causes the performance degradation / improvement?

@cadedaniel cadedaniel self-assigned this Mar 13, 2024
@zhuohan123 zhuohan123 self-assigned this Mar 14, 2024
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some random small comments. Will review in more detail!

def free(self, block: PhysicalTokenBlock) -> None:
pass

@abstractproperty
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be abstract_method

vllm/core/block_manager.py Outdated Show resolved Hide resolved
vllm/core/block_manager.py Outdated Show resolved Hide resolved
vllm/core/block_manager.py Outdated Show resolved Hide resolved
ElizaWszola and others added 3 commits March 14, 2024 14:23
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
@ElizaWszola
Copy link
Contributor Author

@cadedaniel I can think up some tests to add. Is there anything that you would like to be tested specifically?
As for the performance gap that still exists, I'm not sure about it because the non-cached codepath is currently very similar to what had been there before the original auto prefix commit. I'm still poking around.

@ElizaWszola
Copy link
Contributor Author

Good news, I've found a small bug and redid some of the benchmarks: the performance looks similar to the old one, but I'd be happy if more people can verify.

@ElizaWszola ElizaWszola changed the title A bunch of fixes to block allocator performance when automatic prefix caching is disabled [PREFIX CACHING FOLLOW UP] A bunch of fixes to block allocator performance when automatic prefix caching is disabled Mar 15, 2024
if block.num_hashed_tokens == highest_num_hashed_tokens:
if (block.last_accessed < evicted_block.last_accessed
or block.last_accessed == evicted_block.last_accessed and
block.num_hashed_tokens > evicted_block.num_hashed_tokens):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have also optimized the evictor LRU, but after learning more about evictors, I feel that LRU is unnecessary as it is not as efficient as the random policy.
So, in my opinion, LRU policy should be removed.
cc @cadedaniel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this PR improve LRU evictor efficiency marginally. I'm ok with removing them from this PR, especially when a better way to improve LRU evictor efficiency (bringing it to the level roughly on par with random evictor for the tested cases) is implemented here: #3431

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix and left some small comments. Regarding @cadedaniel's comment on tests, let's discuss more offline together and figure out what tests we need to write.

vllm/core/block_manager.py Outdated Show resolved Hide resolved
vllm/core/block_manager.py Outdated Show resolved Hide resolved
else:
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
elif self.enable_caching:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does prefix caching work with sliding window now? Should we explicitly check somewhere that if we enable caching, sliding window should not be enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prefix caching functionality is simply not used when we have sliding windows. We have specific checks for that in different places in the code. Putting it in a more central place sounds like a better idea though, and less confusing.

vllm/core/block_manager.py Outdated Show resolved Hide resolved
vllm/core/block_manager.py Outdated Show resolved Hide resolved
ElizaWszola and others added 4 commits March 19, 2024 13:06
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
@zhuohan123
Copy link
Member

@ElizaWszola Please let me know when this PR is ready to be merged!

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix!

@zhuohan123 zhuohan123 enabled auto-merge (squash) March 20, 2024 07:11
@zhuohan123 zhuohan123 merged commit 9474e89 into vllm-project:main Mar 20, 2024
28 of 30 checks passed
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
…mance when automatic prefix caching is disabled (vllm-project#3357)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants