Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [V1] TPU support #11936

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

[WIP] [V1] TPU support #11936

wants to merge 7 commits into from

Conversation

alexm-redhat
Copy link
Collaborator

@alexm-redhat alexm-redhat commented Jan 10, 2025

This PR is a rebase and modification of @robertgshaw2-redhat original PR for TPU support in vLLM V1 from 1.5 months ago #10241

Currently, TPU attention kernel has no support for mixing prefills and decodes in the same scheduler iteration. As a result, this PR separates the requests to (1) prefills and (2) decodes, and executes each one of them separately. Google guys are working on a new TPU attention kernel that will allow mixing prefills and decodes, the moment it is ready, we will be able to remove the separation logic and unify the requests (which will also provide better performance).

Notes:

  1. @mgoin verified correctness with GSM8K on a TPU instance
  2. No TP > 1 support yet
  3. Only greedy sampler for now
  4. V1 code had no support for multiple arches (this PR supports for CUDA and TPU), and this results in code duplications that are avoided as much as possible by introducing base classes for worker and model runner.
  5. Not performance optimized yet

Follow up tasks (maybe I missed something):

  1. Add all sampler options
  2. Add prefix caching (currently supported in V0 TPU)
  3. Add prefill chunking
  4. Integrate with Google new super attention kernel to support mixing for prefills and decodes
  5. Optimize

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Jan 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-neuralmagic.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Comment on lines 382 to 223
return PrefillInputData(
request_ids=prefill_request_ids,
prompt_lens=prefill_prompt_lens,
token_ids=prefill_token_ids,
position_ids=prefill_position_ids,
attn_metadata=prefill_attn_metadata,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the PrefillInputData data structure, and make it consistent with gpu_model_runner ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be removed the moment Google provides the new attention kernel that supports chunked prefill.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is the new attention kernel related to this PrefillInputData data structure ?

vllm/v1/worker/tpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/tpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/tpu_model_runner.py Outdated Show resolved Hide resolved
@mgoin
Copy link
Member

mgoin commented Jan 13, 2025

Successfully ran an eval on GSM8k

VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=Qwen/Qwen2.5-1.5B-Instruct,max_model_len=2048,max_num_seqs=512 --tasks gsm8k --num_fewshot 5 --batch_size auto
...
vllm (pretrained=Qwen/Qwen2.5-1.5B-Instruct,max_model_len=2048,max_num_seqs=512), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5989|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.5428|±  |0.0137|

vllm/platforms/tpu.py Outdated Show resolved Hide resolved
vllm/v1/worker/tpu_worker.py Outdated Show resolved Hide resolved
vllm/platforms/tpu.py Outdated Show resolved Hide resolved
vllm/platforms/tpu.py Outdated Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert these changes

vllm/platforms/tpu.py Outdated Show resolved Hide resolved
vllm/platforms/tpu.py Outdated Show resolved Hide resolved
vllm/platforms/tpu.py Show resolved Hide resolved
vllm/platforms/tpu.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Show resolved Hide resolved
Comment on lines 248 to 202
# TODO: Remove prompt_len param here
prefill_attn_metadata.append(
PallasMetadata(
num_prefills=1,
num_prefill_tokens=prompt_len, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping.to(self.device),
multi_modal_placeholder_index_maps=None,
block_tables=None,
context_lens=None,
effective_query_lens=None,
))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you address this TODO?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

assert req_id is not None
req_state = self.requests[req_id]

# TODO: ASSERT NO CHUNKED PREFILL.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implement this TODO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the current assert combo is good enough

scheduler_output.num_scheduled_tokens[req_id])
assert seq_len == req_state.num_tokens

# TODO: Verify if req_id_to_index mapping is needed here!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, it is an old comment

Comment on lines 450 to 452
# TODO: ASSERT NO PREFIX CACHING.
assert req_state.num_computed_tokens == 0
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])

# TODO: ASSERT NO CHUNKED PREFILL.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make these asserts at the initialization level? Why would you need to assert this for each request?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are now inside tpu.py of the platform, and here are just in case something changes in the code and messes something. All of these will change the moment we have chunked prefill attn kernel.

@mergify mergify bot added the ci/build label Jan 16, 2025
Comment on lines 520 to 462
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you build these dummy tensors each time rather than allocating the max in the initializer and taking slices for each run like the gpu_model_runner?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taking slices will result in copies as well, no?

Copy link
Collaborator Author

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin @vanbasten23 thanks for the review comments!

vllm/platforms/tpu.py Outdated Show resolved Hide resolved
vllm/platforms/tpu.py Outdated Show resolved Hide resolved
vllm/platforms/tpu.py Outdated Show resolved Hide resolved
vllm/platforms/tpu.py Outdated Show resolved Hide resolved
vllm/platforms/tpu.py Outdated Show resolved Hide resolved
scheduler_output.num_scheduled_tokens[req_id])
assert seq_len == req_state.num_tokens

# TODO: Verify if req_id_to_index mapping is needed here!
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, it is an old comment

Comment on lines 450 to 452
# TODO: ASSERT NO PREFIX CACHING.
assert req_state.num_computed_tokens == 0
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])

# TODO: ASSERT NO CHUNKED PREFILL.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are now inside tpu.py of the platform, and here are just in case something changes in the code and messes something. All of these will change the moment we have chunked prefill attn kernel.

Comment on lines 520 to 462
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taking slices will result in copies as well, no?

vllm/v1/worker/tpu_worker.py Outdated Show resolved Hide resolved
vllm/v1/worker/tpu_worker.py Show resolved Hide resolved
@mergify mergify bot removed the needs-rebase label Jan 22, 2025
@alexm-redhat alexm-redhat force-pushed the tpu_v1 branch 2 times, most recently from dea6afd to c6f526c Compare January 22, 2025 22:38
Copy link

mergify bot commented Jan 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@@ -89,4 +89,4 @@ repos:
name: Suggestion
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
language: system
verbose: true
verbose: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

@@ -8,15 +8,15 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

@@ -34,4 +34,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/v1
run_mypy vllm/v1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

@alexm-redhat alexm-redhat force-pushed the tpu_v1 branch 2 times, most recently from 90ecdbd to eee6378 Compare January 24, 2025 19:44
@vanbasten23
Copy link

Hi @alexm-redhat , thanks for adding vLLm v1 support for TPU!
One quick question, this vLLM slides mentioned a few key changes in vLLM v1:

  • Simplified scheduler
  • Incremental input preparation
  • Piecewise CUDA graphs
  • Enhanced API server
  • More efficient Prefix caching
  • Fine-grained scheduling for VLMs

could you help mark which changes are included in this PR and which are to be made in the future PRs?
cc @miladm

@@ -212,6 +212,13 @@ def schedule(self) -> "SchedulerOutput":
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()

# If chunked prefill is not enabled, then breakout of the loop
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hack that hurts our development. We should find a way to not affect the scheduler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon I have addressed this issue by adding a chunked prompt support to TPU V1, the PR is updated. Now there is no changes to the scheduler, so it is the same for both GPU and TPU. Thanks for pointing this out!

self.model(token_ids, position_ids, None, kv_caches)

def profile_run(self) -> None:
raise NotImplementedError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to base class ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

vllm/v1/worker/tpu_model_runner.py Outdated Show resolved Hide resolved
Comment on lines 382 to 223
return PrefillInputData(
request_ids=prefill_request_ids,
prompt_lens=prefill_prompt_lens,
token_ids=prefill_token_ids,
position_ids=prefill_position_ids,
attn_metadata=prefill_attn_metadata,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is the new attention kernel related to this PrefillInputData data structure ?

vllm/v1/worker/tpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/tpu_model_runner.py Outdated Show resolved Hide resolved
Copy link

mergify bot commented Jan 28, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 28, 2025
@alexm-redhat
Copy link
Collaborator Author

@liangfu PrefillInputData stores a list of PallasMetadata, each per prompt.

@alexm-redhat
Copy link
Collaborator Author

alexm-redhat commented Jan 28, 2025

@vanbasten23 @miladm @bvrockwell

Reply:

Simplified scheduler => No changes to scheduler in this PR
Incremental input preparation => The input preparation is incremental here (same as for NVIDIA), however, it is not optimized yet (will work on it)
Piecewise CUDA graphs => TPU has support for this?
Enhanced API server => Same as NVIDIA, this PR is not touching API server
More efficient Prefix caching => Not enabled yet (will be next)
Fine-grained scheduling for VLMs => @mgoin follow-up PR will have this optimizations. Need to land this PR so Michael can progress

=> Hope this helps!

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Comment on lines +222 to +240
# TODO: Remove
# def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# # Check if the GPU supports the dtype.
# if torch_dtype == torch.bfloat16: # noqa: SIM102
# if not current_platform.has_device_capability(80):
# capability = current_platform.get_device_capability()
# gpu_name = current_platform.get_device_name()

# if capability is None:
# compute_str = "does not have a compute capability"
# else:
# version_str = capability.as_version_str()
# compute_str = f"has compute capability {version_str}"

# raise ValueError(
# "Bfloat16 is only supported on GPUs with compute capability "
# f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
# "You can use float16 instead by explicitly setting the"
# "`dtype` flag in CLI, for example: --dtype=half.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember to remove

else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
assert self.device_config.device.type == "cuda"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the error message

Suggested change
assert self.device_config.device.type == "cuda"
assert self.device_config.device.type == "cuda",
f"Not supported device type: {self.device_config.device}"

Comment on lines +34 to +38
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these can be removed as unused for now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this should be called base_model_runner.py so as we have more "base" files they are grouped together

Comment on lines +142 to +144
# input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
# req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
# input_tokens[:, prompt_len:] = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove cruft

Comment on lines +204 to +215
# TODO: Remove this
# if num_computed_tokens > 0:
# print("-------------------")
# print("input_tokens.shape = {}".format(input_tokens.shape))
# print("input_positions.shape = {}".format(
# input_positions.shape))
# print("slot_mapping.shape = {}".format(slot_mapping.shape))
# print("block_table.shape = {}".format(block_table.shape))
# print("context_lens.shape = {} data = {}".format(
# context_lens.shape, context_lens))
# print("effective_query_lens.shape = {} data = {}".format(
# effective_query_lens.shape, effective_query_lens))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove cruft or hide behind debug var

Comment on lines +79 to +80
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# use an empty tensor instead of `None` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.

@@ -833,14 +629,15 @@ def load_model(self) -> None:
self.model_memory_usage / float(2**30))

@torch.inference_mode()
def _dummy_run(
def dummy_run(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep the underscore ?

Comment on lines +94 to +97
def _prepare_prompt_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> PromptInputData:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is V0 style (_prepare_prompt_inputs + _prepare_decode_inputs), can we reuse

def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):

from V1 GPU model runner, as it prepare both prefill and decode inputs ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants