[Config] Clean up SchedulerConfig initialization (#28665)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-14 22:41:02 +08:00 committed by GitHub
parent 96b23b8e3b
commit 511a6b611d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 182 additions and 163 deletions

View File

@ -348,9 +348,14 @@ def test_fp32_cache_state(
# Helper functions for the APC tests
def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1):
def _get_vllm_runner_params(
model: str,
max_model_len: int,
tensor_parallel_size: int = 1,
):
return {
"model_name": model,
"enable_chunked_prefill": True,
"enable_prefix_caching": False,
"max_model_len": max_model_len,
"tensor_parallel_size": tensor_parallel_size,

View File

@ -2256,6 +2256,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
scheduler_config = SchedulerConfig(
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=is_encoder_decoder,
# Must <= max_num_batched_tokens if chunked prefill is disabled
max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
# `is_encoder_decoder` should only be used during construction

View File

@ -47,6 +47,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
max_num_batched_tokens=16,
max_num_seqs=16,
max_model_len=128,
enable_chunked_prefill=True,
enforce_eager=True,
# TODO: enable this once we support it for
# prompt logprobs.

View File

@ -4,7 +4,7 @@
import hashlib
from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
@ -12,11 +12,6 @@ from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import (
DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
@ -33,25 +28,32 @@ SchedulerPolicy = Literal["fcfs", "priority"]
class SchedulerConfig:
"""Scheduler configuration."""
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""
max_num_batched_tokens: int = Field(default=None, ge=1)
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
"""Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_num_seqs: int = Field(default=None, ge=1)
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
"""Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_model_len: int = Field(default=None, ge=1)
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
max_model_len: int = Field(default=8192, ge=1)
"""Maximum length of a sequence (including prompt and generated text).
The default value here is mainly for convenience when testing.
In real usage, this should duplicate `ModelConfig.max_model_len` via
`EngineArgs`."""
max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be
@ -76,9 +78,13 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
enable_chunked_prefill: bool = Field(default=None)
enable_chunked_prefill: bool = True
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
on the remaining `max_num_batched_tokens`.
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
is_multimodal_model: bool = False
"""True if the model is multimodal."""
@ -111,9 +117,6 @@ class SchedulerConfig:
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled: bool = Field(init=False)
"""True if chunked prefill is enabled."""
disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1
@ -188,15 +191,7 @@ class SchedulerConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator(
"max_num_batched_tokens",
"max_num_seqs",
"max_model_len",
"enable_chunked_prefill",
"scheduler_cls",
"async_scheduling",
mode="wrap",
)
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
@ -205,16 +200,9 @@ class SchedulerConfig:
return handler(value)
def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None:
self.max_model_len = 8192
if self.max_num_seqs is None:
self.max_num_seqs = 128
if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True
self.chunked_prefill_enabled = False
self.enable_chunked_prefill = False
self.long_prefill_token_threshold = 0
logger.info(
@ -222,37 +210,6 @@ class SchedulerConfig:
" prefix caching; disabling both."
)
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
else:
# If max_model_len is too short, use
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
)
if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)
# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
)
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens
@ -262,7 +219,6 @@ class SchedulerConfig:
self.max_num_batched_tokens,
)
self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
@ -276,6 +232,14 @@ class SchedulerConfig:
self.long_prefill_token_threshold,
)
@property
def chunked_prefill_enabled(self) -> bool:
return self.enable_chunked_prefill
@chunked_prefill_enabled.setter
def chunked_prefill_enabled(self, value: bool):
self.enable_chunked_prefill = value
@model_validator(mode="after")
def _verify_args(self) -> Self:
if (

View File

@ -428,11 +428,11 @@ class EngineArgs:
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens
max_num_batched_tokens: int | None = None
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
max_num_seqs: int | None = SchedulerConfig.max_num_seqs
max_num_seqs: int | None = None
max_logprobs: int = ModelConfig.max_logprobs
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
disable_log_stats: bool = False
@ -485,7 +485,7 @@ class EngineArgs:
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
enable_chunked_prefill: bool | None = SchedulerConfig.enable_chunked_prefill
enable_chunked_prefill: bool | None = None
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
disable_hybrid_kv_cache_manager: bool = (
@ -1738,41 +1738,41 @@ class EngineArgs:
)
_raise_unsupported_error(feature_name=name)
def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 uses chunked prefills and prefix caching by default
# for non-pooling tasks.
# For pooling tasks the default is False
@classmethod
def get_chunked_prefill_prefix_caching_defaults(
cls,
model_config: ModelConfig,
) -> tuple[bool, bool]:
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True
default_chunked_prefill = True
if self.enable_prefix_caching is None:
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
if model_config.is_hybrid:
self.enable_prefix_caching = False
else:
self.enable_prefix_caching = True
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
default_prefix_caching = not model_config.is_hybrid
else:
assert model_config.pooler_config is not None
pooling_type = model_config.pooler_config.pooling_type
is_causal = getattr(model_config.hf_config, "is_causal", True)
incremental_prefill_supported = (
pooling_type is not None
and pooling_type.lower() == "last"
and bool(is_causal)
and getattr(model_config.hf_config, "is_causal", True)
)
action = "Enabling" if incremental_prefill_supported else "Disabling"
default_chunked_prefill = incremental_prefill_supported
default_prefix_caching = incremental_prefill_supported
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = incremental_prefill_supported
logger.info("(%s) chunked prefill by default", action)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = incremental_prefill_supported
logger.info("(%s) prefix caching by default", action)
return default_chunked_prefill, default_prefix_caching
@classmethod
def get_batch_defaults(
cls,
world_size: int,
) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
from vllm.usage.usage_lib import UsageContext
default_max_num_batched_tokens: dict[UsageContext | None, int]
default_max_num_seqs: dict[UsageContext | None, int]
# When no user override, set the default values based on the usage
# context.
@ -1793,8 +1793,6 @@ class EngineArgs:
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
# So here we do an extra device name check to prevent such regression.
from vllm.usage.usage_lib import UsageContext
if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
# For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens = {
@ -1818,22 +1816,26 @@ class EngineArgs:
# tpu specific default values.
if current_platform.is_tpu():
default_max_num_batched_tokens_tpu = {
UsageContext.LLM_CLASS: {
"V6E": 2048,
"V5E": 1024,
"V5P": 512,
},
UsageContext.OPENAI_API_SERVER: {
"V6E": 1024,
"V5E": 512,
"V5P": 256,
},
}
chip_name = current_platform.get_device_name()
if chip_name == "V6E":
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 2048,
UsageContext.OPENAI_API_SERVER: 1024,
}
elif chip_name == "V5E":
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 1024,
UsageContext.OPENAI_API_SERVER: 512,
}
elif chip_name == "V5P":
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 512,
UsageContext.OPENAI_API_SERVER: 256,
}
# cpu specific default values.
if current_platform.is_cpu():
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 4096 * world_size,
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
@ -1843,44 +1845,104 @@ class EngineArgs:
UsageContext.OPENAI_API_SERVER: 128 * world_size,
}
use_context_value = usage_context.value if usage_context else None
if (
self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens
return default_max_num_batched_tokens, default_max_num_seqs
def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
) -> None:
"""Set Default Arguments for V1 Engine."""
(
default_chunked_prefill,
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill
logger.debug(
"%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_chunked_prefill
and not default_chunked_prefill
):
if current_platform.is_tpu():
chip_name = current_platform.get_device_name()
if chip_name in default_max_num_batched_tokens_tpu[usage_context]:
self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[
usage_context
][chip_name]
else:
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context
]
else:
if not self.enable_chunked_prefill:
self.max_num_batched_tokens = model_config.max_model_len
else:
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context
]
logger.warning(
"This model does not officially support chunked prefill. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = default_prefix_caching
logger.debug(
"Setting max_num_batched_tokens to %d for %s usage context.",
"%s prefix caching by default",
"Enabling" if default_prefix_caching else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_prefix_caching
and not default_prefix_caching
):
logger.warning(
"This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
)
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
(
default_max_num_batched_tokens,
default_max_num_seqs,
) = self.get_batch_defaults(world_size)
orig_max_num_batched_tokens = self.max_num_batched_tokens
orig_max_num_seqs = self.max_num_seqs
if self.max_num_batched_tokens is None:
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
if self.max_num_seqs is None:
self.max_num_seqs = default_max_num_seqs.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
)
if orig_max_num_batched_tokens is None:
if not self.enable_chunked_prefill:
# If max_model_len is too short, use the default for higher throughput.
self.max_num_batched_tokens = max(
model_config.max_model_len,
self.max_num_batched_tokens,
)
# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * model_config.max_model_len,
self.max_num_batched_tokens,
use_context_value,
)
if self.max_num_seqs is None and usage_context in default_max_num_seqs:
self.max_num_seqs = min(
default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize,
)
logger.debug(
"Setting max_num_seqs to %d for %s usage context.",
"Defaulting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens,
usage_context.value if usage_context else None,
)
if orig_max_num_seqs is None:
assert self.max_num_batched_tokens is not None # For type checking
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
logger.debug(
"Defaulting max_num_seqs to %d for %s usage context.",
self.max_num_seqs,
use_context_value,
usage_context.value if usage_context else None,
)

View File

@ -15,7 +15,6 @@ import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import CpuArchEnum, Platform, PlatformEnum
@ -339,10 +338,9 @@ class CpuPlatform(Platform):
"prefill and prefix caching to be disabled."
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
@classmethod

View File

@ -10,7 +10,6 @@ from tpu_info import device
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum
@ -186,10 +185,9 @@ class TpuPlatform(Platform):
"prefill and prefix caching to be disabled."
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
@classmethod

View File

@ -9,7 +9,6 @@ import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum
@ -185,10 +184,9 @@ class XPUPlatform(Platform):
"prefill and prefix caching to be disabled."
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
@classmethod

View File

@ -3,7 +3,7 @@
import uuid
import warnings
from typing import Any, TypeVar
from typing import Any
import torch
@ -39,12 +39,6 @@ def __dir__() -> list[str]:
logger = init_logger(__name__)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
# Constants related to forcing the attention backend selection
# String name of register which may be set in order to
@ -60,9 +54,6 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
T = TypeVar("T")
def random_uuid() -> str:
return str(uuid.uuid4().hex)