-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] [V1] TPU support #11936
base: main
Are you sure you want to change the base?
[WIP] [V1] TPU support #11936
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This pull request has merge conflicts that must be resolved before it can be |
vllm/v1/worker/tpu_model_runner.py
Outdated
return PrefillInputData( | ||
request_ids=prefill_request_ids, | ||
prompt_lens=prefill_prompt_lens, | ||
token_ids=prefill_token_ids, | ||
position_ids=prefill_position_ids, | ||
attn_metadata=prefill_attn_metadata, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the PrefillInputData
data structure, and make it consistent with gpu_model_runner ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be removed the moment Google provides the new attention kernel that supports chunked prefill.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is the new attention kernel related to this PrefillInputData
data structure ?
Successfully ran an eval on GSM8k
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert these changes
vllm/v1/worker/tpu_model_runner.py
Outdated
# TODO: Remove prompt_len param here | ||
prefill_attn_metadata.append( | ||
PallasMetadata( | ||
num_prefills=1, | ||
num_prefill_tokens=prompt_len, # NOTE: This is not used. | ||
num_decode_tokens=0, | ||
slot_mapping=slot_mapping.to(self.device), | ||
multi_modal_placeholder_index_maps=None, | ||
block_tables=None, | ||
context_lens=None, | ||
effective_query_lens=None, | ||
)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you address this TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
vllm/v1/worker/tpu_model_runner.py
Outdated
assert req_id is not None | ||
req_state = self.requests[req_id] | ||
|
||
# TODO: ASSERT NO CHUNKED PREFILL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement this TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the current assert combo is good enough
vllm/v1/worker/tpu_model_runner.py
Outdated
scheduler_output.num_scheduled_tokens[req_id]) | ||
assert seq_len == req_state.num_tokens | ||
|
||
# TODO: Verify if req_id_to_index mapping is needed here! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed, it is an old comment
vllm/v1/worker/tpu_model_runner.py
Outdated
# TODO: ASSERT NO PREFIX CACHING. | ||
assert req_state.num_computed_tokens == 0 | ||
seq_len = (req_state.num_computed_tokens + | ||
scheduler_output.num_scheduled_tokens[req_id]) | ||
|
||
# TODO: ASSERT NO CHUNKED PREFILL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you make these asserts at the initialization level? Why would you need to assert this for each request?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are now inside tpu.py of the platform, and here are just in case something changes in the code and messes something. All of these will change the moment we have chunked prefill attn kernel.
vllm/v1/worker/tpu_model_runner.py
Outdated
token_ids = torch.zeros((batch_size, seq_len), | ||
dtype=torch.int32, | ||
device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you build these dummy tensors each time rather than allocating the max in the initializer and taking slices for each run like the gpu_model_runner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
taking slices will result in copies as well, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin @vanbasten23 thanks for the review comments!
vllm/v1/worker/tpu_model_runner.py
Outdated
scheduler_output.num_scheduled_tokens[req_id]) | ||
assert seq_len == req_state.num_tokens | ||
|
||
# TODO: Verify if req_id_to_index mapping is needed here! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed, it is an old comment
vllm/v1/worker/tpu_model_runner.py
Outdated
# TODO: ASSERT NO PREFIX CACHING. | ||
assert req_state.num_computed_tokens == 0 | ||
seq_len = (req_state.num_computed_tokens + | ||
scheduler_output.num_scheduled_tokens[req_id]) | ||
|
||
# TODO: ASSERT NO CHUNKED PREFILL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are now inside tpu.py of the platform, and here are just in case something changes in the code and messes something. All of these will change the moment we have chunked prefill attn kernel.
vllm/v1/worker/tpu_model_runner.py
Outdated
token_ids = torch.zeros((batch_size, seq_len), | ||
dtype=torch.int32, | ||
device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
taking slices will result in copies as well, no?
dea6afd
to
c6f526c
Compare
This pull request has merge conflicts that must be resolved before it can be |
@@ -89,4 +89,4 @@ repos: | |||
name: Suggestion | |||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' | |||
language: system | |||
verbose: true | |||
verbose: true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
@@ -8,15 +8,15 @@ | |||
"The future of AI is", | |||
] | |||
# Create a sampling params object. | |||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | |||
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert
@@ -34,4 +34,4 @@ run_mypy vllm/plugins | |||
run_mypy vllm/prompt_adapter | |||
run_mypy vllm/spec_decode | |||
run_mypy vllm/worker | |||
run_mypy vllm/v1 | |||
run_mypy vllm/v1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
90ecdbd
to
eee6378
Compare
Hi @alexm-redhat , thanks for adding vLLm v1 support for TPU!
could you help mark which changes are included in this PR and which are to be made in the future PRs? |
vllm/v1/core/scheduler.py
Outdated
@@ -212,6 +212,13 @@ def schedule(self) -> "SchedulerOutput": | |||
num_computed_tokens -= self.block_size | |||
num_new_tokens = self.block_size | |||
computed_blocks.pop() | |||
|
|||
# If chunked prefill is not enabled, then breakout of the loop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a hack that hurts our development. We should find a way to not affect the scheduler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon I have addressed this issue by adding a chunked prompt support to TPU V1, the PR is updated. Now there is no changes to the scheduler, so it is the same for both GPU and TPU. Thanks for pointing this out!
vllm/v1/worker/tpu_model_runner.py
Outdated
self.model(token_ids, position_ids, None, kv_caches) | ||
|
||
def profile_run(self) -> None: | ||
raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this to base class ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
vllm/v1/worker/tpu_model_runner.py
Outdated
return PrefillInputData( | ||
request_ids=prefill_request_ids, | ||
prompt_lens=prefill_prompt_lens, | ||
token_ids=prefill_token_ids, | ||
position_ids=prefill_position_ids, | ||
attn_metadata=prefill_attn_metadata, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is the new attention kernel related to this PrefillInputData
data structure ?
This pull request has merge conflicts that must be resolved before it can be |
@liangfu PrefillInputData stores a list of PallasMetadata, each per prompt. |
@vanbasten23 @miladm @bvrockwell Reply: Simplified scheduler => No changes to scheduler in this PR => Hope this helps! |
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
# TODO: Remove | ||
# def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): | ||
# # Check if the GPU supports the dtype. | ||
# if torch_dtype == torch.bfloat16: # noqa: SIM102 | ||
# if not current_platform.has_device_capability(80): | ||
# capability = current_platform.get_device_capability() | ||
# gpu_name = current_platform.get_device_name() | ||
|
||
# if capability is None: | ||
# compute_str = "does not have a compute capability" | ||
# else: | ||
# version_str = capability.as_version_str() | ||
# compute_str = f"has compute capability {version_str}" | ||
|
||
# raise ValueError( | ||
# "Bfloat16 is only supported on GPUs with compute capability " | ||
# f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " | ||
# "You can use float16 instead by explicitly setting the" | ||
# "`dtype` flag in CLI, for example: --dtype=half.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember to remove
else: | ||
raise RuntimeError( | ||
f"Not support device type: {self.device_config.device}") | ||
assert self.device_config.device.type == "cuda" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep the error message
assert self.device_config.device.type == "cuda" | |
assert self.device_config.device.type == "cuda", | |
f"Not supported device type: {self.device_config.device}" |
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. | ||
_ENABLE_TOP_P = False | ||
# FIXME(woosuk): A temporary hack to support `n > 1`. | ||
# This can significantly affect the performance if too large. | ||
_MAX_NUM_SAMPLES = 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these can be removed as unused for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think this should be called base_model_runner.py
so as we have more "base" files they are grouped together
# input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[ | ||
# req_index, num_computed_tokens:padded_seq_len].reshape(1, -1)) | ||
# input_tokens[:, prompt_len:] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove cruft
# TODO: Remove this | ||
# if num_computed_tokens > 0: | ||
# print("-------------------") | ||
# print("input_tokens.shape = {}".format(input_tokens.shape)) | ||
# print("input_positions.shape = {}".format( | ||
# input_positions.shape)) | ||
# print("slot_mapping.shape = {}".format(slot_mapping.shape)) | ||
# print("block_table.shape = {}".format(block_table.shape)) | ||
# print("context_lens.shape = {} data = {}".format( | ||
# context_lens.shape, context_lens)) | ||
# print("effective_query_lens.shape = {} data = {}".format( | ||
# effective_query_lens.shape, effective_query_lens)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove cruft or hide behind debug var
# use an empty tensor instead of `None`` to force Dynamo to pass | ||
# it by reference, rather by specializing on the value ``None``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# use an empty tensor instead of `None`` to force Dynamo to pass | |
# it by reference, rather by specializing on the value ``None``. | |
# use an empty tensor instead of `None` to force Dynamo to pass | |
# it by reference, rather by specializing on the value `None`. |
@@ -833,14 +629,15 @@ def load_model(self) -> None: | |||
self.model_memory_usage / float(2**30)) | |||
|
|||
@torch.inference_mode() | |||
def _dummy_run( | |||
def dummy_run( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep the underscore ?
def _prepare_prompt_inputs( | ||
self, | ||
scheduler_output: "SchedulerOutput", | ||
) -> PromptInputData: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is V0 style (_prepare_prompt_inputs
+ _prepare_decode_inputs
), can we reuse
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
from V1 GPU model runner, as it prepare both prefill and decode inputs ?
This PR is a rebase and modification of @robertgshaw2-redhat original PR for TPU support in vLLM V1 from 1.5 months ago #10241
Currently, TPU attention kernel has no support for mixing prefills and decodes in the same scheduler iteration. As a result, this PR separates the requests to (1) prefills and (2) decodes, and executes each one of them separately. Google guys are working on a new TPU attention kernel that will allow mixing prefills and decodes, the moment it is ready, we will be able to remove the separation logic and unify the requests (which will also provide better performance).
Notes:
Follow up tasks (maybe I missed something):