[Misc] Make SchedulerConfig.max_model_len init-only (#28733)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-15 17:59:31 +08:00 committed by GitHub
parent 1ec978c209
commit 638e4196d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 22 additions and 45 deletions

View File

@ -40,8 +40,6 @@ NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclass

View File

@ -33,8 +33,6 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]

View File

@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.bfloat16]

View File

@ -42,8 +42,6 @@ MNK_FACTORS = [
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclasses.dataclass

View File

@ -45,8 +45,6 @@ MNK_FACTORS = [
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def quant_fp8_per_tensor_batches(a):

View File

@ -81,8 +81,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [
]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def run_moe_test(

View File

@ -192,8 +192,6 @@ def pplx_cutlass_moe(
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def _pplx_moe(

View File

@ -81,8 +81,6 @@ TOP_KS = [1, 2, 6]
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def torch_prepare(

View File

@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):

View File

@ -29,8 +29,6 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]

View File

@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.half, torch.bfloat16]
M = [1, 33, 64, 222]

View File

@ -6,7 +6,7 @@ from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator, model_validator
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self, deprecated
@ -48,13 +48,6 @@ class SchedulerConfig:
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_model_len: int = Field(default=8192, ge=1)
"""Maximum length of a sequence (including prompt and generated text).
The default value here is mainly for convenience when testing.
In real usage, this should duplicate `ModelConfig.max_model_len` via
`EngineArgs`."""
max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently."""
@ -89,6 +82,12 @@ class SchedulerConfig:
is_multimodal_model: bool = False
"""True if the model is multimodal."""
max_model_len: InitVar[int] = 8192
"""Maximum length of a sequence (including prompt and generated text).
Note: This is stored in the ModelConfig, and is used only here to
provide fallbacks and validate other attributes."""
is_encoder_decoder: InitVar[bool] = False
"""True if the model is an encoder-decoder model.
@ -199,7 +198,7 @@ class SchedulerConfig:
return value
return handler(value)
def __post_init__(self, is_encoder_decoder: bool) -> None:
def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None:
if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True
@ -221,7 +220,7 @@ class SchedulerConfig:
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
self.long_prefill_token_threshold = int(max_model_len * 0.04)
logger.info(
"Concurrent partial prefills enabled with "
@ -232,6 +231,8 @@ class SchedulerConfig:
self.long_prefill_token_threshold,
)
self.verify_max_model_len(max_model_len)
@property
@deprecated(
"`SchedulerConfig.chunked_prefill_enabled` has been renamed to "
@ -245,15 +246,14 @@ class SchedulerConfig:
def chunked_prefill_enabled(self, value: bool):
self.enable_chunked_prefill = value
@model_validator(mode="after")
def _verify_args(self) -> Self:
def verify_max_model_len(self, max_model_len: int) -> Self:
if (
self.max_num_batched_tokens < self.max_model_len
self.max_num_batched_tokens < max_model_len
and not self.enable_chunked_prefill
):
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
f"smaller than max_model_len ({max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
@ -267,12 +267,12 @@ class SchedulerConfig:
f"({self.max_num_seqs})."
)
if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
logger.warning(
"max_num_batched_tokens (%d) exceeds max_num_seqs "
"* max_model_len (%d). This may lead to unexpected behavior.",
self.max_num_batched_tokens,
self.max_num_seqs * self.max_model_len,
self.max_num_seqs * max_model_len,
)
if self.max_num_partial_prefills > 1:
@ -282,11 +282,11 @@ class SchedulerConfig:
"max_num_partial_prefills > 1."
)
if self.long_prefill_token_threshold > self.max_model_len:
if self.long_prefill_token_threshold > max_model_len:
raise ValueError(
"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than the max_model_len ({self.max_model_len})."
f"than the max_model_len ({max_model_len})."
)
if self.max_long_partial_prefills > self.max_num_partial_prefills:

View File

@ -929,7 +929,6 @@ class VllmConfig:
model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len
self.scheduler_config.max_model_len = max_model_len
def try_verify_and_update_config(self):
if self.model_config is None:

View File

@ -339,7 +339,7 @@ class CpuPlatform(Platform):
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)

View File

@ -191,7 +191,7 @@ class TpuPlatform(Platform):
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)

View File

@ -185,7 +185,7 @@ class XPUPlatform(Platform):
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)

View File

@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface):
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
self.max_model_len = vllm_config.model_config.max_model_len
self.enable_kv_cache_events = (
self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events