diff --git a/vllm/config.py b/vllm/config.py index 6764694f80591..f118004b2f2f7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -51,6 +51,9 @@ else: logger = init_logger(__name__) +# This value is chosen to have a balance between ITL and TTFT. Note it is +# not optimized for throughput. +_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 @@ -1526,15 +1529,17 @@ class SchedulerConfig: # for now. Have max_num_batched_tokens set to max_model_len # so we don't reject sequences on account of a short # max_num_batched_tokens. - self.max_num_batched_tokens = max(self.max_model_len, 2048) + self.max_num_batched_tokens = max( + self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) else: - # This value is chosen to have a balance between ITL - # and TTFT. Note it is not optimized for throughput. - self.max_num_batched_tokens = 2048 + self.max_num_batched_tokens = ( + _DEFAULT_MAX_NUM_BATCHED_TOKENS) else: - # If max_model_len is too short, use 2048 as the default value + # If max_model_len is too short, use + # _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value # for higher throughput. - self.max_num_batched_tokens = max(self.max_model_len, 2048) + self.max_num_batched_tokens = max( + self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.runner_type == "pooling": # Choose specific value for higher throughput @@ -3333,6 +3338,9 @@ class VllmConfig: "caching to be disabled.") self.scheduler_config.enable_chunked_prefill = False self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.max_num_batched_tokens = max( + self.scheduler_config.max_model_len, + _DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.cache_config is not None: self.cache_config.enable_prefix_caching = False