Skip to content

Commit 71face8

Browse files
authored
[Bugfix] Fix max_num_batched_tokens for MLA (vllm-project#13620)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent bfbc0b3 commit 71face8

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

vllm/config.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151

5252
logger = init_logger(__name__)
5353

54+
# This value is chosen to have a balance between ITL and TTFT. Note it is
55+
# not optimized for throughput.
56+
_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
5457
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
5558
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
5659

@@ -1526,15 +1529,17 @@ def __post_init__(self) -> None:
15261529
# for now. Have max_num_batched_tokens set to max_model_len
15271530
# so we don't reject sequences on account of a short
15281531
# max_num_batched_tokens.
1529-
self.max_num_batched_tokens = max(self.max_model_len, 2048)
1532+
self.max_num_batched_tokens = max(
1533+
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
15301534
else:
1531-
# This value is chosen to have a balance between ITL
1532-
# and TTFT. Note it is not optimized for throughput.
1533-
self.max_num_batched_tokens = 2048
1535+
self.max_num_batched_tokens = (
1536+
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
15341537
else:
1535-
# If max_model_len is too short, use 2048 as the default value
1538+
# If max_model_len is too short, use
1539+
# _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
15361540
# for higher throughput.
1537-
self.max_num_batched_tokens = max(self.max_model_len, 2048)
1541+
self.max_num_batched_tokens = max(
1542+
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
15381543

15391544
if self.runner_type == "pooling":
15401545
# Choose specific value for higher throughput
@@ -3333,6 +3338,9 @@ def __post_init__(self):
33333338
"caching to be disabled.")
33343339
self.scheduler_config.enable_chunked_prefill = False
33353340
self.scheduler_config.chunked_prefill_enabled = False
3341+
self.scheduler_config.max_num_batched_tokens = max(
3342+
self.scheduler_config.max_model_len,
3343+
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
33363344

33373345
if self.cache_config is not None:
33383346
self.cache_config.enable_prefix_caching = False

0 commit comments

Comments
 (0)