Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][misc] improve free_finished_seq_groups #6865

Merged
merged 2 commits into from
Jul 30, 2024

Conversation

youkaichao
Copy link
Member

free_finished_seq_groups is in the critical path, because the engine calls it after every model step.

the main branch searchs for all the requests to find finished requests, while new finished requests should only appear in running queue.

this simple optimization, turns out to be important, when we process large amount of offline batch inference, e.g. when users give 100k prompts to the LLM engine. In this case, we are iterating all 100k data at every step.

I have heard users complaining that sending the whole dataset to LLM is slow, and he has to chunk the data into small batches (25 requests each) to make it fast. This PR might fix it.

Some micro benchmark to demonstrate the benefit when number of prompts are long and the model becomes small (thus overhead in scheduler becomes significant):

Note: main branch is commit 925de97

$ python benchmarks/benchmark_throughput.py --output-len 16 --input 16 --model facebook/opt-125m --num-prompts 10000

# main branch:
Throughput: 577.63 requests/s, 18484.31 tokens/s

# this PR:
Throughput: 674.46 requests/s, 21582.66 tokens/s

$ python benchmarks/benchmark_throughput.py --output-len 16 --input 16 --model facebook/opt-125m --num-prompts 100000

# main branch:
Throughput: 211.18 requests/s, 6757.85 tokens/s

# this PR:
Throughput: 522.91 requests/s, 16733.13 tokens/s

Conclusion:

When I increase the number of prompts from 10k to 100k, the main branch is slowed down by almost 3x. Improving the free_finished_seq_groups helps a lot, the slow down is about 40%. Apparently there are some other factors, but I think this might be the largest factor in processing with large datasets.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

Here is the flamechart of the main branch, with 100k prompts. It becomes clear that free_finished_seq_groups is the major source of overhead.

flamechart

@youkaichao
Copy link
Member Author

Here is the flame chart of this PR, with various number of prompts:

10k prompts:

flamechart_10k

100k prompts:

flamechart_100k

500k prompts:
flamechart_500k

Even after this PR, there is still a slow down factor that is proportional to the number of waiting requests. The flamechat points to:

waiting_queue = deque([s for s in waiting_queue])

which is indeed very inefficient.

@youkaichao youkaichao added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 27, 2024
Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes a lot of sense, matches behavior I've seem or been reported to me, and should not be controversial. LGTM pending review by @mzusman

@mgoin
Copy link
Collaborator

mgoin commented Jul 29, 2024

@youkaichao I think you must rebase to pass the failing Examples Test

[2024-07-29T19:01:27Z] python3: can't open file '/vllm-workspace/examples/llava_example.py': [Errno 2] No such file or directory

Comment on lines 1067 to 1073
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
if seq_group.is_finished():
self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group)
self.running = remaining
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao how about wrapping this in

 if any(seq_group.is_finished() for seq_group in self.running):

to avoid rebuilding the deque every time, since it requests will be finished relatively rarely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to refactor this into dictionary in the future, so that we can easily delete requests.

@WoosukKwon WoosukKwon self-assigned this Jul 30, 2024
@@ -1058,13 +1061,16 @@ def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)

def free_finished_seq_groups(self) -> None:
Copy link
Collaborator

@WoosukKwon WoosukKwon Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mzusman I actually don't get the current (before this PR) logic here.

If I understand correctly, the order of execution is schedule -> get_and_reset_finished_requests_ids -> model execution -> free_finished_seq_groups. Hence, the requests freed from self.swapped or self.waiting were not actually added to finished_requests_ids before the model execution. This means the finished_requests_ids holds stale information lagged by 1 step. Is my understanding correct?

Copy link
Collaborator

@WoosukKwon WoosukKwon Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, to my understanding,

  1. Actually finished_requests_ids doesn't have to include the requests rejected in self.waiting since they (ignored_seq_groups) are never passed into the model runner.
  2. finished_requests_ids must include the requests rejected in self.swapped in the current step, which means the current code (before this PR) has a bug. However, the bug has not been noticed because this case is very rare.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, That's correct, finished_requests_ids holds information that is lagged by 1 step. finished requests will always be released upon the next step.

AFAIU, self.waiting also includes preempted requests that got rescheduled, preempted requests did previously passed into the model runner and are already registered in the mamba cache. If those requests get aborted then we do add them to the finished_requests_ids and release them through here though.
BTW Just to be sure, by requests rejected, do you mean aborted requests?

So actually, yeah, checking for finished requests in the self.waiting and self.swapped in free_finished_seq_groups is not necessary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mzusman Thanks for the explanation! However, it's a bit unclear to me how it handles the preemption, both recompuation and swapping. Could you elaborate more on that?

Copy link
Contributor

@mzusman mzusman Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon Sure! Upon recomputation we actually keep the Mamba cache as it is and do not evict it (since Mamba cache is quite small), the request id persists during the preemption therefore we can still use it's corresponding cache upon recomputation.
RE swapping - We do not handle swapping at the moment as it occurs fairly rarely and it's quite complicated to implement it for the Mamba cache atm..

Therefore there's no reason to run through self.swapped in search for finished requests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mzusman Thanks for the detailed explanation! It is super helpful to understand the implementation.

I think we should revisit this design decision in the near future. Currently, it's a bit confusing for code readers.

@youkaichao
Copy link
Member Author

@youkaichao I think you must rebase to pass the failing Examples Test

[2024-07-29T19:01:27Z] python3: can't open file '/vllm-workspace/examples/llava_example.py': [Errno 2] No such file or directory

If there's no more comments, I will force-merge to avoid one additional ci cost :)

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao I think we need to understand the code before getting this PR merged.

@youkaichao
Copy link
Member Author

@youkaichao I think we need to understand the code before getting this PR merged.

LGTM. Please take your time. I think the scheduler logic is quite unclear though. Might have lots of bugs there.

vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
@youkaichao
Copy link
Member Author

@WoosukKwon please feel free to take this PR if you want some modifications!

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao I've update the PR based on the discussion with @mzusman. I think the PR is good to go now.

@WoosukKwon WoosukKwon enabled auto-merge (squash) July 30, 2024 20:53
@youkaichao youkaichao disabled auto-merge July 30, 2024 21:32
@youkaichao youkaichao merged commit 6ca8031 into vllm-project:main Jul 30, 2024
55 of 73 checks passed
@youkaichao youkaichao deleted the free_finished_seq_groups branch July 30, 2024 21:32
tjohnson31415 added a commit to tjohnson31415/vllm that referenced this pull request Jul 30, 2024
* upstream/main:
  [Build] Temporarily Disable Kernels and LoRA tests (vllm-project#6961)
  [core][misc] improve free_finished_seq_groups (vllm-project#6865)
  [Kernel] Remove scaled_fp8_quant kernel padding footgun (vllm-project#6842)
  [Bugfix] Fix tensorizer memory profiling bug during testing (vllm-project#6881)
  [OpenVINO] Updated OpenVINO requirements and build docs (vllm-project#6948)
  [Kernel] Squash a few more warnings (vllm-project#6914)
  [BugFix] Fix use of per-request seed with pipeline parallel (vllm-project#6698)
  [Doc] Super tiny fix doc typo (vllm-project#6949)
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants