[Config] Clean up SchedulerConfig initialization (#28665)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-14 22:41:02 +08:00 committed by GitHub
parent 96b23b8e3b
commit 511a6b611d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 182 additions and 163 deletions

View File

@ -348,9 +348,14 @@ def test_fp32_cache_state(
# Helper functions for the APC tests # Helper functions for the APC tests
def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1): def _get_vllm_runner_params(
model: str,
max_model_len: int,
tensor_parallel_size: int = 1,
):
return { return {
"model_name": model, "model_name": model,
"enable_chunked_prefill": True,
"enable_prefix_caching": False, "enable_prefix_caching": False,
"max_model_len": max_model_len, "max_model_len": max_model_len,
"tensor_parallel_size": tensor_parallel_size, "tensor_parallel_size": tensor_parallel_size,

View File

@ -2256,6 +2256,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
# Must <= max_num_batched_tokens if chunked prefill is disabled
max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
) )
# `is_encoder_decoder` should only be used during construction # `is_encoder_decoder` should only be used during construction

View File

@ -47,6 +47,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
max_num_batched_tokens=16, max_num_batched_tokens=16,
max_num_seqs=16, max_num_seqs=16,
max_model_len=128, max_model_len=128,
enable_chunked_prefill=True,
enforce_eager=True, enforce_eager=True,
# TODO: enable this once we support it for # TODO: enable this once we support it for
# prompt logprobs. # prompt logprobs.

View File

@ -4,7 +4,7 @@
import hashlib import hashlib
from collections.abc import Callable from collections.abc import Callable
from dataclasses import InitVar from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator, model_validator from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
@ -12,11 +12,6 @@ from typing_extensions import Self
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (
DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING: if TYPE_CHECKING:
@ -33,25 +28,32 @@ SchedulerPolicy = Literal["fcfs", "priority"]
class SchedulerConfig: class SchedulerConfig:
"""Scheduler configuration.""" """Scheduler configuration."""
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
runner_type: RunnerType = "generate" runner_type: RunnerType = "generate"
"""The runner type to launch for the model.""" """The runner type to launch for the model."""
max_num_batched_tokens: int = Field(default=None, ge=1) max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
"""Maximum number of tokens to be processed in a single iteration. """Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will The default value here is mainly for convenience when testing.
be set in `EngineArgs.create_engine_config` based on the usage context.""" In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_num_seqs: int = Field(default=None, ge=1) max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
"""Maximum number of sequences to be processed in a single iteration. """Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will The default value here is mainly for convenience when testing.
be set in `EngineArgs.create_engine_config` based on the usage context.""" In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_model_len: int = Field(default=None, ge=1) max_model_len: int = Field(default=8192, ge=1)
"""Maximum length of a sequence (including prompt and generated text). This """Maximum length of a sequence (including prompt and generated text).
is primarily set in `ModelConfig` and that value should be manually
duplicated here.""" The default value here is mainly for convenience when testing.
In real usage, this should duplicate `ModelConfig.max_model_len` via
`EngineArgs`."""
max_num_partial_prefills: int = Field(default=1, ge=1) max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be """For chunked prefill, the maximum number of sequences that can be
@ -76,9 +78,13 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then.""" present to enable correctness tests until then."""
enable_chunked_prefill: bool = Field(default=None) enable_chunked_prefill: bool = True
"""If True, prefill requests can be chunked based """If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.""" on the remaining `max_num_batched_tokens`.
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
is_multimodal_model: bool = False is_multimodal_model: bool = False
"""True if the model is multimodal.""" """True if the model is multimodal."""
@ -111,9 +117,6 @@ class SchedulerConfig:
- "priority" means requests are handled based on given priority (lower - "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties).""" value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled: bool = Field(init=False)
"""True if chunked prefill is enabled."""
disable_chunked_mm_input: bool = False disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to """If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1 partially schedule a multimodal item. Only used in V1
@ -188,15 +191,7 @@ class SchedulerConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str return hash_str
@field_validator( @field_validator("scheduler_cls", "async_scheduling", mode="wrap")
"max_num_batched_tokens",
"max_num_seqs",
"max_model_len",
"enable_chunked_prefill",
"scheduler_cls",
"async_scheduling",
mode="wrap",
)
@classmethod @classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed.""" """Skip validation if the value is `None` when initialisation is delayed."""
@ -205,16 +200,9 @@ class SchedulerConfig:
return handler(value) return handler(value)
def __post_init__(self, is_encoder_decoder: bool) -> None: def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None:
self.max_model_len = 8192
if self.max_num_seqs is None:
self.max_num_seqs = 128
if is_encoder_decoder: if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models. # Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True self.disable_chunked_mm_input = True
self.chunked_prefill_enabled = False
self.enable_chunked_prefill = False self.enable_chunked_prefill = False
self.long_prefill_token_threshold = 0 self.long_prefill_token_threshold = 0
logger.info( logger.info(
@ -222,37 +210,6 @@ class SchedulerConfig:
" prefix caching; disabling both." " prefix caching; disabling both."
) )
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
else:
# If max_model_len is too short, use
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
)
if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)
# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
)
self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens
@ -262,7 +219,6 @@ class SchedulerConfig:
self.max_num_batched_tokens, self.max_num_batched_tokens,
) )
self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.max_num_partial_prefills > 1: if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0: if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04) self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
@ -276,6 +232,14 @@ class SchedulerConfig:
self.long_prefill_token_threshold, self.long_prefill_token_threshold,
) )
@property
def chunked_prefill_enabled(self) -> bool:
return self.enable_chunked_prefill
@chunked_prefill_enabled.setter
def chunked_prefill_enabled(self, value: bool):
self.enable_chunked_prefill = value
@model_validator(mode="after") @model_validator(mode="after")
def _verify_args(self) -> Self: def _verify_args(self) -> Self:
if ( if (

View File

@ -428,11 +428,11 @@ class EngineArgs:
cpu_offload_gb: float = CacheConfig.cpu_offload_gb cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens max_num_batched_tokens: int | None = None
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
max_num_seqs: int | None = SchedulerConfig.max_num_seqs max_num_seqs: int | None = None
max_logprobs: int = ModelConfig.max_logprobs max_logprobs: int = ModelConfig.max_logprobs
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
disable_log_stats: bool = False disable_log_stats: bool = False
@ -485,7 +485,7 @@ class EngineArgs:
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns") ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
enable_chunked_prefill: bool | None = SchedulerConfig.enable_chunked_prefill enable_chunked_prefill: bool | None = None
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
disable_hybrid_kv_cache_manager: bool = ( disable_hybrid_kv_cache_manager: bool = (
@ -1738,41 +1738,41 @@ class EngineArgs:
) )
_raise_unsupported_error(feature_name=name) _raise_unsupported_error(feature_name=name)
def _set_default_args( @classmethod
self, usage_context: UsageContext, model_config: ModelConfig def get_chunked_prefill_prefix_caching_defaults(
) -> None: cls,
"""Set Default Arguments for V1 Engine.""" model_config: ModelConfig,
) -> tuple[bool, bool]:
# V1 uses chunked prefills and prefix caching by default
# for non-pooling tasks.
# For pooling tasks the default is False
if model_config.runner_type != "pooling": if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True default_chunked_prefill = True
if self.enable_prefix_caching is None: # Disable prefix caching default for hybrid models
# Disable prefix caching default for hybrid models # since the feature is still experimental.
# since the feature is still experimental. default_prefix_caching = not model_config.is_hybrid
if model_config.is_hybrid:
self.enable_prefix_caching = False
else:
self.enable_prefix_caching = True
else: else:
assert model_config.pooler_config is not None
pooling_type = model_config.pooler_config.pooling_type pooling_type = model_config.pooler_config.pooling_type
is_causal = getattr(model_config.hf_config, "is_causal", True)
incremental_prefill_supported = ( incremental_prefill_supported = (
pooling_type is not None pooling_type is not None
and pooling_type.lower() == "last" and pooling_type.lower() == "last"
and bool(is_causal) and getattr(model_config.hf_config, "is_causal", True)
) )
action = "Enabling" if incremental_prefill_supported else "Disabling" default_chunked_prefill = incremental_prefill_supported
default_prefix_caching = incremental_prefill_supported
if self.enable_chunked_prefill is None: return default_chunked_prefill, default_prefix_caching
self.enable_chunked_prefill = incremental_prefill_supported
logger.info("(%s) chunked prefill by default", action) @classmethod
if self.enable_prefix_caching is None: def get_batch_defaults(
self.enable_prefix_caching = incremental_prefill_supported cls,
logger.info("(%s) prefix caching by default", action) world_size: int,
) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
from vllm.usage.usage_lib import UsageContext
default_max_num_batched_tokens: dict[UsageContext | None, int]
default_max_num_seqs: dict[UsageContext | None, int]
# When no user override, set the default values based on the usage # When no user override, set the default values based on the usage
# context. # context.
@ -1793,8 +1793,6 @@ class EngineArgs:
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details. # throughput, see PR #17885 for more details.
# So here we do an extra device name check to prevent such regression. # So here we do an extra device name check to prevent such regression.
from vllm.usage.usage_lib import UsageContext
if device_memory >= 70 * GiB_bytes and "a100" not in device_name: if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
# For GPUs like H100 and MI300x, use larger default values. # For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens = { default_max_num_batched_tokens = {
@ -1818,22 +1816,26 @@ class EngineArgs:
# tpu specific default values. # tpu specific default values.
if current_platform.is_tpu(): if current_platform.is_tpu():
default_max_num_batched_tokens_tpu = { chip_name = current_platform.get_device_name()
UsageContext.LLM_CLASS: {
"V6E": 2048, if chip_name == "V6E":
"V5E": 1024, default_max_num_batched_tokens = {
"V5P": 512, UsageContext.LLM_CLASS: 2048,
}, UsageContext.OPENAI_API_SERVER: 1024,
UsageContext.OPENAI_API_SERVER: { }
"V6E": 1024, elif chip_name == "V5E":
"V5E": 512, default_max_num_batched_tokens = {
"V5P": 256, UsageContext.LLM_CLASS: 1024,
}, UsageContext.OPENAI_API_SERVER: 512,
} }
elif chip_name == "V5P":
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 512,
UsageContext.OPENAI_API_SERVER: 256,
}
# cpu specific default values. # cpu specific default values.
if current_platform.is_cpu(): if current_platform.is_cpu():
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
default_max_num_batched_tokens = { default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 4096 * world_size, UsageContext.LLM_CLASS: 4096 * world_size,
UsageContext.OPENAI_API_SERVER: 2048 * world_size, UsageContext.OPENAI_API_SERVER: 2048 * world_size,
@ -1843,44 +1845,104 @@ class EngineArgs:
UsageContext.OPENAI_API_SERVER: 128 * world_size, UsageContext.OPENAI_API_SERVER: 128 * world_size,
} }
use_context_value = usage_context.value if usage_context else None return default_max_num_batched_tokens, default_max_num_seqs
if (
self.max_num_batched_tokens is None def _set_default_args(
and usage_context in default_max_num_batched_tokens self, usage_context: UsageContext, model_config: ModelConfig
) -> None:
"""Set Default Arguments for V1 Engine."""
(
default_chunked_prefill,
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill
logger.debug(
"%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_chunked_prefill
and not default_chunked_prefill
): ):
if current_platform.is_tpu(): logger.warning(
chip_name = current_platform.get_device_name() "This model does not officially support chunked prefill. "
if chip_name in default_max_num_batched_tokens_tpu[usage_context]: "Enabling this manually may cause the engine to crash "
self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[ "or produce incorrect outputs.",
usage_context )
][chip_name]
else: if self.enable_prefix_caching is None:
self.max_num_batched_tokens = default_max_num_batched_tokens[ self.enable_prefix_caching = default_prefix_caching
usage_context
]
else:
if not self.enable_chunked_prefill:
self.max_num_batched_tokens = model_config.max_model_len
else:
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context
]
logger.debug( logger.debug(
"Setting max_num_batched_tokens to %d for %s usage context.", "%s prefix caching by default",
"Enabling" if default_prefix_caching else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_prefix_caching
and not default_prefix_caching
):
logger.warning(
"This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
)
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
(
default_max_num_batched_tokens,
default_max_num_seqs,
) = self.get_batch_defaults(world_size)
orig_max_num_batched_tokens = self.max_num_batched_tokens
orig_max_num_seqs = self.max_num_seqs
if self.max_num_batched_tokens is None:
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
if self.max_num_seqs is None:
self.max_num_seqs = default_max_num_seqs.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
)
if orig_max_num_batched_tokens is None:
if not self.enable_chunked_prefill:
# If max_model_len is too short, use the default for higher throughput.
self.max_num_batched_tokens = max(
model_config.max_model_len,
self.max_num_batched_tokens,
)
# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * model_config.max_model_len,
self.max_num_batched_tokens, self.max_num_batched_tokens,
use_context_value,
)
if self.max_num_seqs is None and usage_context in default_max_num_seqs:
self.max_num_seqs = min(
default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize,
) )
logger.debug( logger.debug(
"Setting max_num_seqs to %d for %s usage context.", "Defaulting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens,
usage_context.value if usage_context else None,
)
if orig_max_num_seqs is None:
assert self.max_num_batched_tokens is not None # For type checking
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
logger.debug(
"Defaulting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, self.max_num_seqs,
use_context_value, usage_context.value if usage_context else None,
) )

View File

@ -15,7 +15,6 @@ import torch
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import CpuArchEnum, Platform, PlatformEnum from .interface import CpuArchEnum, Platform, PlatformEnum
@ -339,10 +338,9 @@ class CpuPlatform(Platform):
"prefill and prefix caching to be disabled." "prefill and prefix caching to be disabled."
) )
vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len, vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
) )
@classmethod @classmethod

View File

@ -10,7 +10,6 @@ from tpu_info import device
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
@ -186,10 +185,9 @@ class TpuPlatform(Platform):
"prefill and prefix caching to be disabled." "prefill and prefix caching to be disabled."
) )
vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len, vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
) )
@classmethod @classmethod

View File

@ -9,7 +9,6 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
@ -185,10 +184,9 @@ class XPUPlatform(Platform):
"prefill and prefix caching to be disabled." "prefill and prefix caching to be disabled."
) )
vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len, vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
) )
@classmethod @classmethod

View File

@ -3,7 +3,7 @@
import uuid import uuid
import warnings import warnings
from typing import Any, TypeVar from typing import Any
import torch import torch
@ -39,12 +39,6 @@ def __dir__() -> list[str]:
logger = init_logger(__name__) logger = init_logger(__name__)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
# Constants related to forcing the attention backend selection # Constants related to forcing the attention backend selection
# String name of register which may be set in order to # String name of register which may be set in order to
@ -60,9 +54,6 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
T = TypeVar("T")
def random_uuid() -> str: def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)