diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 62704bbcbbc7..2285709fa7d6 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -40,8 +40,6 @@ NUM_EXPERTS = [8, 64] TOP_KS = [1, 2, 6] vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 @dataclass diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index cd34617ee0fc..88db4b3e537c 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -33,8 +33,6 @@ if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 3799e60f1294..e35ca4caa9db 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0): pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 DTYPES = [torch.bfloat16] diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 5512ccce47b0..c15837f14570 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -42,8 +42,6 @@ MNK_FACTORS = [ ] vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 @dataclasses.dataclass diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 707068b2bbdc..3a681d4603f8 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -45,8 +45,6 @@ MNK_FACTORS = [ ] vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def quant_fp8_per_tensor_batches(a): diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index c27cf2468ede..0550c2d9e212 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -81,8 +81,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [ ] vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def run_moe_test( diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index a2de64974b35..dd4eb4da913b 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -192,8 +192,6 @@ def pplx_cutlass_moe( vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def _pplx_moe( diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 0f0ed3326d15..f671b23d300c 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -81,8 +81,6 @@ TOP_KS = [1, 2, 6] DTYPES = [torch.float8_e4m3fn, torch.bfloat16] vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def torch_prepare( diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index 933cd9dbdeaa..7a467e160b78 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 55f092e7ea69..e9973c1fcc15 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -29,8 +29,6 @@ if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index dabc10a122f7..310091b6a554 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -18,8 +18,6 @@ if current_platform.get_device_capability() < (7, 0): pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() -vllm_config.scheduler_config.max_num_seqs = 128 -vllm_config.scheduler_config.max_model_len = 8192 DTYPES = [torch.half, torch.bfloat16] M = [1, 33, 64, 222] diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 444568994a95..8194295ffedb 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -6,7 +6,7 @@ from collections.abc import Callable from dataclasses import InitVar from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast -from pydantic import Field, field_validator, model_validator +from pydantic import Field, field_validator from pydantic.dataclasses import dataclass from typing_extensions import Self, deprecated @@ -48,13 +48,6 @@ class SchedulerConfig: In real usage, this should be set in `EngineArgs.create_engine_config`. """ - max_model_len: int = Field(default=8192, ge=1) - """Maximum length of a sequence (including prompt and generated text). - - The default value here is mainly for convenience when testing. - In real usage, this should duplicate `ModelConfig.max_model_len` via - `EngineArgs`.""" - max_num_partial_prefills: int = Field(default=1, ge=1) """For chunked prefill, the maximum number of sequences that can be partially prefilled concurrently.""" @@ -89,6 +82,12 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" + max_model_len: InitVar[int] = 8192 + """Maximum length of a sequence (including prompt and generated text). + + Note: This is stored in the ModelConfig, and is used only here to + provide fallbacks and validate other attributes.""" + is_encoder_decoder: InitVar[bool] = False """True if the model is an encoder-decoder model. @@ -199,7 +198,7 @@ class SchedulerConfig: return value return handler(value) - def __post_init__(self, is_encoder_decoder: bool) -> None: + def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None: if is_encoder_decoder: # Chunked prefill should be disabled for encoder-decoder models. self.disable_chunked_mm_input = True @@ -221,7 +220,7 @@ class SchedulerConfig: if self.max_num_partial_prefills > 1: if self.long_prefill_token_threshold == 0: - self.long_prefill_token_threshold = int(self.max_model_len * 0.04) + self.long_prefill_token_threshold = int(max_model_len * 0.04) logger.info( "Concurrent partial prefills enabled with " @@ -232,6 +231,8 @@ class SchedulerConfig: self.long_prefill_token_threshold, ) + self.verify_max_model_len(max_model_len) + @property @deprecated( "`SchedulerConfig.chunked_prefill_enabled` has been renamed to " @@ -245,15 +246,14 @@ class SchedulerConfig: def chunked_prefill_enabled(self, value: bool): self.enable_chunked_prefill = value - @model_validator(mode="after") - def _verify_args(self) -> Self: + def verify_max_model_len(self, max_model_len: int) -> Self: if ( - self.max_num_batched_tokens < self.max_model_len + self.max_num_batched_tokens < max_model_len and not self.enable_chunked_prefill ): raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " - f"smaller than max_model_len ({self.max_model_len}). " + f"smaller than max_model_len ({max_model_len}). " "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " @@ -267,12 +267,12 @@ class SchedulerConfig: f"({self.max_num_seqs})." ) - if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: + if self.max_num_batched_tokens > self.max_num_seqs * max_model_len: logger.warning( "max_num_batched_tokens (%d) exceeds max_num_seqs " "* max_model_len (%d). This may lead to unexpected behavior.", self.max_num_batched_tokens, - self.max_num_seqs * self.max_model_len, + self.max_num_seqs * max_model_len, ) if self.max_num_partial_prefills > 1: @@ -282,11 +282,11 @@ class SchedulerConfig: "max_num_partial_prefills > 1." ) - if self.long_prefill_token_threshold > self.max_model_len: + if self.long_prefill_token_threshold > max_model_len: raise ValueError( "long_prefill_token_threshold " f"({self.long_prefill_token_threshold}) cannot be greater " - f"than the max_model_len ({self.max_model_len})." + f"than the max_model_len ({max_model_len})." ) if self.max_long_partial_prefills > self.max_num_partial_prefills: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 1e6e455210c8..bf9bcd0e8a11 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -929,7 +929,6 @@ class VllmConfig: model_config = self.model_config max_model_len = model_config.get_and_verify_max_len(max_model_len) self.model_config.max_model_len = max_model_len - self.scheduler_config.max_model_len = max_model_len def try_verify_and_update_config(self): if self.model_config is None: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 1da34629472c..ed655912d396 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -339,7 +339,7 @@ class CpuPlatform(Platform): ) vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, + vllm_config.model_config.max_model_len, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index c1218801bc07..944344a22957 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -191,7 +191,7 @@ class TpuPlatform(Platform): ) vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, + vllm_config.model_config.max_model_len, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index ad4beb28bdae..65516827a16d 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -185,7 +185,7 @@ class XPUPlatform(Platform): ) vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, + vllm_config.model_config.max_model_len, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c640c40a455d..bc15979dea62 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -83,7 +83,7 @@ class Scheduler(SchedulerInterface): # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens - self.max_model_len = self.scheduler_config.max_model_len + self.max_model_len = vllm_config.model_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None and self.kv_events_config.enable_kv_cache_events