[Bugfix] Fix max_num_batched_tokens for MLA (#13620)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-02-20 20:45:20 -05:00 committed by GitHub
parent bfbc0b32c6
commit 71face8540
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -51,6 +51,9 @@ else:
logger = init_logger(__name__)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
@ -1526,15 +1529,17 @@ class SchedulerConfig:
# for now. Have max_num_batched_tokens set to max_model_len
# so we don't reject sequences on account of a short
# max_num_batched_tokens.
self.max_num_batched_tokens = max(self.max_model_len, 2048)
self.max_num_batched_tokens = max(
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
else:
# This value is chosen to have a balance between ITL
# and TTFT. Note it is not optimized for throughput.
self.max_num_batched_tokens = 2048
self.max_num_batched_tokens = (
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
else:
# If max_model_len is too short, use 2048 as the default value
# If max_model_len is too short, use
# _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(self.max_model_len, 2048)
self.max_num_batched_tokens = max(
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.runner_type == "pooling":
# Choose specific value for higher throughput
@ -3333,6 +3338,9 @@ class VllmConfig:
"caching to be disabled.")
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False