[Misc] Make SchedulerConfig.max_model_len init-only (#28733)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-15 17:59:31 +08:00 committed by GitHub
parent 1ec978c209
commit 638e4196d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 22 additions and 45 deletions

View File

@ -40,8 +40,6 @@ NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6] TOP_KS = [1, 2, 6]
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclass @dataclass

View File

@ -33,8 +33,6 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations # Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]

View File

@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.bfloat16] DTYPES = [torch.bfloat16]

View File

@ -42,8 +42,6 @@ MNK_FACTORS = [
] ]
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclasses.dataclass @dataclasses.dataclass

View File

@ -45,8 +45,6 @@ MNK_FACTORS = [
] ]
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def quant_fp8_per_tensor_batches(a): def quant_fp8_per_tensor_batches(a):

View File

@ -81,8 +81,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [
] ]
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def run_moe_test( def run_moe_test(

View File

@ -192,8 +192,6 @@ def pplx_cutlass_moe(
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def _pplx_moe( def _pplx_moe(

View File

@ -81,8 +81,6 @@ TOP_KS = [1, 2, 6]
DTYPES = [torch.float8_e4m3fn, torch.bfloat16] DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def torch_prepare( def torch_prepare(

View File

@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):

View File

@ -29,8 +29,6 @@ if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations # Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]

View File

@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.half, torch.bfloat16] DTYPES = [torch.half, torch.bfloat16]
M = [1, 33, 64, 222] M = [1, 33, 64, 222]

View File

@ -6,7 +6,7 @@ from collections.abc import Callable
from dataclasses import InitVar from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator, model_validator from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from typing_extensions import Self, deprecated from typing_extensions import Self, deprecated
@ -48,13 +48,6 @@ class SchedulerConfig:
In real usage, this should be set in `EngineArgs.create_engine_config`. In real usage, this should be set in `EngineArgs.create_engine_config`.
""" """
max_model_len: int = Field(default=8192, ge=1)
"""Maximum length of a sequence (including prompt and generated text).
The default value here is mainly for convenience when testing.
In real usage, this should duplicate `ModelConfig.max_model_len` via
`EngineArgs`."""
max_num_partial_prefills: int = Field(default=1, ge=1) max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be """For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently.""" partially prefilled concurrently."""
@ -89,6 +82,12 @@ class SchedulerConfig:
is_multimodal_model: bool = False is_multimodal_model: bool = False
"""True if the model is multimodal.""" """True if the model is multimodal."""
max_model_len: InitVar[int] = 8192
"""Maximum length of a sequence (including prompt and generated text).
Note: This is stored in the ModelConfig, and is used only here to
provide fallbacks and validate other attributes."""
is_encoder_decoder: InitVar[bool] = False is_encoder_decoder: InitVar[bool] = False
"""True if the model is an encoder-decoder model. """True if the model is an encoder-decoder model.
@ -199,7 +198,7 @@ class SchedulerConfig:
return value return value
return handler(value) return handler(value)
def __post_init__(self, is_encoder_decoder: bool) -> None: def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None:
if is_encoder_decoder: if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models. # Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True self.disable_chunked_mm_input = True
@ -221,7 +220,7 @@ class SchedulerConfig:
if self.max_num_partial_prefills > 1: if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0: if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04) self.long_prefill_token_threshold = int(max_model_len * 0.04)
logger.info( logger.info(
"Concurrent partial prefills enabled with " "Concurrent partial prefills enabled with "
@ -232,6 +231,8 @@ class SchedulerConfig:
self.long_prefill_token_threshold, self.long_prefill_token_threshold,
) )
self.verify_max_model_len(max_model_len)
@property @property
@deprecated( @deprecated(
"`SchedulerConfig.chunked_prefill_enabled` has been renamed to " "`SchedulerConfig.chunked_prefill_enabled` has been renamed to "
@ -245,15 +246,14 @@ class SchedulerConfig:
def chunked_prefill_enabled(self, value: bool): def chunked_prefill_enabled(self, value: bool):
self.enable_chunked_prefill = value self.enable_chunked_prefill = value
@model_validator(mode="after") def verify_max_model_len(self, max_model_len: int) -> Self:
def _verify_args(self) -> Self:
if ( if (
self.max_num_batched_tokens < self.max_model_len self.max_num_batched_tokens < max_model_len
and not self.enable_chunked_prefill and not self.enable_chunked_prefill
): ):
raise ValueError( raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). " f"smaller than max_model_len ({max_model_len}). "
"This effectively limits the maximum sequence length to " "This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer " "max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or " "sequences. Please increase max_num_batched_tokens or "
@ -267,12 +267,12 @@ class SchedulerConfig:
f"({self.max_num_seqs})." f"({self.max_num_seqs})."
) )
if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
logger.warning( logger.warning(
"max_num_batched_tokens (%d) exceeds max_num_seqs " "max_num_batched_tokens (%d) exceeds max_num_seqs "
"* max_model_len (%d). This may lead to unexpected behavior.", "* max_model_len (%d). This may lead to unexpected behavior.",
self.max_num_batched_tokens, self.max_num_batched_tokens,
self.max_num_seqs * self.max_model_len, self.max_num_seqs * max_model_len,
) )
if self.max_num_partial_prefills > 1: if self.max_num_partial_prefills > 1:
@ -282,11 +282,11 @@ class SchedulerConfig:
"max_num_partial_prefills > 1." "max_num_partial_prefills > 1."
) )
if self.long_prefill_token_threshold > self.max_model_len: if self.long_prefill_token_threshold > max_model_len:
raise ValueError( raise ValueError(
"long_prefill_token_threshold " "long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater " f"({self.long_prefill_token_threshold}) cannot be greater "
f"than the max_model_len ({self.max_model_len})." f"than the max_model_len ({max_model_len})."
) )
if self.max_long_partial_prefills > self.max_num_partial_prefills: if self.max_long_partial_prefills > self.max_num_partial_prefills:

View File

@ -929,7 +929,6 @@ class VllmConfig:
model_config = self.model_config model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len) max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len self.model_config.max_model_len = max_model_len
self.scheduler_config.max_model_len = max_model_len
def try_verify_and_update_config(self): def try_verify_and_update_config(self):
if self.model_config is None: if self.model_config is None:

View File

@ -339,7 +339,7 @@ class CpuPlatform(Platform):
) )
vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len, vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
) )

View File

@ -191,7 +191,7 @@ class TpuPlatform(Platform):
) )
vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len, vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
) )

View File

@ -185,7 +185,7 @@ class XPUPlatform(Platform):
) )
vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len, vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
) )

View File

@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface):
# Scheduling constraints. # Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.enable_kv_cache_events = ( self.enable_kv_cache_events = (
self.kv_events_config is not None self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events and self.kv_events_config.enable_kv_cache_events