mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 05:07:13 +08:00
[Config] Clean up SchedulerConfig initialization (#28665)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
96b23b8e3b
commit
511a6b611d
@ -348,9 +348,14 @@ def test_fp32_cache_state(
|
||||
|
||||
|
||||
# Helper functions for the APC tests
|
||||
def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1):
|
||||
def _get_vllm_runner_params(
|
||||
model: str,
|
||||
max_model_len: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
):
|
||||
return {
|
||||
"model_name": model,
|
||||
"enable_chunked_prefill": True,
|
||||
"enable_prefix_caching": False,
|
||||
"max_model_len": max_model_len,
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
|
||||
@ -2256,6 +2256,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
|
||||
scheduler_config = SchedulerConfig(
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
# Must <= max_num_batched_tokens if chunked prefill is disabled
|
||||
max_model_len=SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
# `is_encoder_decoder` should only be used during construction
|
||||
|
||||
@ -47,6 +47,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
|
||||
max_num_batched_tokens=16,
|
||||
max_num_seqs=16,
|
||||
max_model_len=128,
|
||||
enable_chunked_prefill=True,
|
||||
enforce_eager=True,
|
||||
# TODO: enable this once we support it for
|
||||
# prompt logprobs.
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import hashlib
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
@ -12,11 +12,6 @@ from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -33,25 +28,32 @@ SchedulerPolicy = Literal["fcfs", "priority"]
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration."""
|
||||
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
|
||||
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
|
||||
|
||||
runner_type: RunnerType = "generate"
|
||||
"""The runner type to launch for the model."""
|
||||
|
||||
max_num_batched_tokens: int = Field(default=None, ge=1)
|
||||
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
|
||||
"""Maximum number of tokens to be processed in a single iteration.
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
max_num_seqs: int = Field(default=None, ge=1)
|
||||
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
|
||||
"""Maximum number of sequences to be processed in a single iteration.
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
max_model_len: int = Field(default=None, ge=1)
|
||||
"""Maximum length of a sequence (including prompt and generated text). This
|
||||
is primarily set in `ModelConfig` and that value should be manually
|
||||
duplicated here."""
|
||||
max_model_len: int = Field(default=8192, ge=1)
|
||||
"""Maximum length of a sequence (including prompt and generated text).
|
||||
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should duplicate `ModelConfig.max_model_len` via
|
||||
`EngineArgs`."""
|
||||
|
||||
max_num_partial_prefills: int = Field(default=1, ge=1)
|
||||
"""For chunked prefill, the maximum number of sequences that can be
|
||||
@ -76,9 +78,13 @@ class SchedulerConfig:
|
||||
NOTE: This will be replaced by speculative config in the future; it is
|
||||
present to enable correctness tests until then."""
|
||||
|
||||
enable_chunked_prefill: bool = Field(default=None)
|
||||
enable_chunked_prefill: bool = True
|
||||
"""If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens."""
|
||||
on the remaining `max_num_batched_tokens`.
|
||||
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
is_multimodal_model: bool = False
|
||||
"""True if the model is multimodal."""
|
||||
@ -111,9 +117,6 @@ class SchedulerConfig:
|
||||
- "priority" means requests are handled based on given priority (lower
|
||||
value means earlier handling) and time of arrival deciding any ties)."""
|
||||
|
||||
chunked_prefill_enabled: bool = Field(init=False)
|
||||
"""True if chunked prefill is enabled."""
|
||||
|
||||
disable_chunked_mm_input: bool = False
|
||||
"""If set to true and chunked prefill is enabled, we do not want to
|
||||
partially schedule a multimodal item. Only used in V1
|
||||
@ -188,15 +191,7 @@ class SchedulerConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator(
|
||||
"max_num_batched_tokens",
|
||||
"max_num_seqs",
|
||||
"max_model_len",
|
||||
"enable_chunked_prefill",
|
||||
"scheduler_cls",
|
||||
"async_scheduling",
|
||||
mode="wrap",
|
||||
)
|
||||
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||
@ -205,16 +200,9 @@ class SchedulerConfig:
|
||||
return handler(value)
|
||||
|
||||
def __post_init__(self, is_encoder_decoder: bool) -> None:
|
||||
if self.max_model_len is None:
|
||||
self.max_model_len = 8192
|
||||
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 128
|
||||
|
||||
if is_encoder_decoder:
|
||||
# Chunked prefill should be disabled for encoder-decoder models.
|
||||
self.disable_chunked_mm_input = True
|
||||
self.chunked_prefill_enabled = False
|
||||
self.enable_chunked_prefill = False
|
||||
self.long_prefill_token_threshold = 0
|
||||
logger.info(
|
||||
@ -222,37 +210,6 @@ class SchedulerConfig:
|
||||
" prefix caching; disabling both."
|
||||
)
|
||||
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
else:
|
||||
# If max_model_len is too short, use
|
||||
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
|
||||
# for higher throughput.
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
)
|
||||
|
||||
if self.runner_type == "pooling":
|
||||
# Choose specific value for higher throughput
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_num_batched_tokens,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
if self.is_multimodal_model:
|
||||
# The value needs to be at least the number of multimodal tokens
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_num_batched_tokens,
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
# When using default settings,
|
||||
# Ensure max_num_batched_tokens does not exceed model limit.
|
||||
# Some models (e.g., Whisper) have embeddings tied to max length.
|
||||
self.max_num_batched_tokens = min(
|
||||
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
|
||||
)
|
||||
|
||||
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
|
||||
self.encoder_cache_size = self.max_num_batched_tokens
|
||||
|
||||
@ -262,7 +219,6 @@ class SchedulerConfig:
|
||||
self.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
self.chunked_prefill_enabled = self.enable_chunked_prefill
|
||||
if self.max_num_partial_prefills > 1:
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||
@ -276,6 +232,14 @@ class SchedulerConfig:
|
||||
self.long_prefill_token_threshold,
|
||||
)
|
||||
|
||||
@property
|
||||
def chunked_prefill_enabled(self) -> bool:
|
||||
return self.enable_chunked_prefill
|
||||
|
||||
@chunked_prefill_enabled.setter
|
||||
def chunked_prefill_enabled(self, value: bool):
|
||||
self.enable_chunked_prefill = value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _verify_args(self) -> Self:
|
||||
if (
|
||||
|
||||
@ -428,11 +428,11 @@ class EngineArgs:
|
||||
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
|
||||
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
|
||||
kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
|
||||
max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens
|
||||
max_num_batched_tokens: int | None = None
|
||||
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
|
||||
max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
|
||||
long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
|
||||
max_num_seqs: int | None = SchedulerConfig.max_num_seqs
|
||||
max_num_seqs: int | None = None
|
||||
max_logprobs: int = ModelConfig.max_logprobs
|
||||
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
|
||||
disable_log_stats: bool = False
|
||||
@ -485,7 +485,7 @@ class EngineArgs:
|
||||
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
|
||||
ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
|
||||
|
||||
enable_chunked_prefill: bool | None = SchedulerConfig.enable_chunked_prefill
|
||||
enable_chunked_prefill: bool | None = None
|
||||
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
|
||||
|
||||
disable_hybrid_kv_cache_manager: bool = (
|
||||
@ -1738,41 +1738,41 @@ class EngineArgs:
|
||||
)
|
||||
_raise_unsupported_error(feature_name=name)
|
||||
|
||||
def _set_default_args(
|
||||
self, usage_context: UsageContext, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
|
||||
# V1 uses chunked prefills and prefix caching by default
|
||||
# for non-pooling tasks.
|
||||
# For pooling tasks the default is False
|
||||
@classmethod
|
||||
def get_chunked_prefill_prefix_caching_defaults(
|
||||
cls,
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[bool, bool]:
|
||||
if model_config.runner_type != "pooling":
|
||||
self.enable_chunked_prefill = True
|
||||
default_chunked_prefill = True
|
||||
|
||||
if self.enable_prefix_caching is None:
|
||||
# Disable prefix caching default for hybrid models
|
||||
# since the feature is still experimental.
|
||||
if model_config.is_hybrid:
|
||||
self.enable_prefix_caching = False
|
||||
else:
|
||||
self.enable_prefix_caching = True
|
||||
# Disable prefix caching default for hybrid models
|
||||
# since the feature is still experimental.
|
||||
default_prefix_caching = not model_config.is_hybrid
|
||||
else:
|
||||
assert model_config.pooler_config is not None
|
||||
|
||||
pooling_type = model_config.pooler_config.pooling_type
|
||||
is_causal = getattr(model_config.hf_config, "is_causal", True)
|
||||
incremental_prefill_supported = (
|
||||
pooling_type is not None
|
||||
and pooling_type.lower() == "last"
|
||||
and bool(is_causal)
|
||||
and getattr(model_config.hf_config, "is_causal", True)
|
||||
)
|
||||
|
||||
action = "Enabling" if incremental_prefill_supported else "Disabling"
|
||||
default_chunked_prefill = incremental_prefill_supported
|
||||
default_prefix_caching = incremental_prefill_supported
|
||||
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = incremental_prefill_supported
|
||||
logger.info("(%s) chunked prefill by default", action)
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = incremental_prefill_supported
|
||||
logger.info("(%s) prefix caching by default", action)
|
||||
return default_chunked_prefill, default_prefix_caching
|
||||
|
||||
@classmethod
|
||||
def get_batch_defaults(
|
||||
cls,
|
||||
world_size: int,
|
||||
) -> tuple[dict[UsageContext | None, int], dict[UsageContext | None, int]]:
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
default_max_num_batched_tokens: dict[UsageContext | None, int]
|
||||
default_max_num_seqs: dict[UsageContext | None, int]
|
||||
|
||||
# When no user override, set the default values based on the usage
|
||||
# context.
|
||||
@ -1793,8 +1793,6 @@ class EngineArgs:
|
||||
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
|
||||
# throughput, see PR #17885 for more details.
|
||||
# So here we do an extra device name check to prevent such regression.
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
|
||||
# For GPUs like H100 and MI300x, use larger default values.
|
||||
default_max_num_batched_tokens = {
|
||||
@ -1818,22 +1816,26 @@ class EngineArgs:
|
||||
|
||||
# tpu specific default values.
|
||||
if current_platform.is_tpu():
|
||||
default_max_num_batched_tokens_tpu = {
|
||||
UsageContext.LLM_CLASS: {
|
||||
"V6E": 2048,
|
||||
"V5E": 1024,
|
||||
"V5P": 512,
|
||||
},
|
||||
UsageContext.OPENAI_API_SERVER: {
|
||||
"V6E": 1024,
|
||||
"V5E": 512,
|
||||
"V5P": 256,
|
||||
},
|
||||
}
|
||||
chip_name = current_platform.get_device_name()
|
||||
|
||||
if chip_name == "V6E":
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 2048,
|
||||
UsageContext.OPENAI_API_SERVER: 1024,
|
||||
}
|
||||
elif chip_name == "V5E":
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 1024,
|
||||
UsageContext.OPENAI_API_SERVER: 512,
|
||||
}
|
||||
elif chip_name == "V5P":
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 512,
|
||||
UsageContext.OPENAI_API_SERVER: 256,
|
||||
}
|
||||
|
||||
# cpu specific default values.
|
||||
if current_platform.is_cpu():
|
||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 4096 * world_size,
|
||||
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
|
||||
@ -1843,44 +1845,104 @@ class EngineArgs:
|
||||
UsageContext.OPENAI_API_SERVER: 128 * world_size,
|
||||
}
|
||||
|
||||
use_context_value = usage_context.value if usage_context else None
|
||||
if (
|
||||
self.max_num_batched_tokens is None
|
||||
and usage_context in default_max_num_batched_tokens
|
||||
return default_max_num_batched_tokens, default_max_num_seqs
|
||||
|
||||
def _set_default_args(
|
||||
self, usage_context: UsageContext, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
(
|
||||
default_chunked_prefill,
|
||||
default_prefix_caching,
|
||||
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
|
||||
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = default_chunked_prefill
|
||||
|
||||
logger.debug(
|
||||
"%s chunked prefill by default",
|
||||
"Enabling" if default_chunked_prefill else "Disabling",
|
||||
)
|
||||
elif (
|
||||
model_config.runner_type == "pooling"
|
||||
and self.enable_chunked_prefill
|
||||
and not default_chunked_prefill
|
||||
):
|
||||
if current_platform.is_tpu():
|
||||
chip_name = current_platform.get_device_name()
|
||||
if chip_name in default_max_num_batched_tokens_tpu[usage_context]:
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[
|
||||
usage_context
|
||||
][chip_name]
|
||||
else:
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
||||
usage_context
|
||||
]
|
||||
else:
|
||||
if not self.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = model_config.max_model_len
|
||||
else:
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
||||
usage_context
|
||||
]
|
||||
logger.warning(
|
||||
"This model does not officially support chunked prefill. "
|
||||
"Enabling this manually may cause the engine to crash "
|
||||
"or produce incorrect outputs.",
|
||||
)
|
||||
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = default_prefix_caching
|
||||
|
||||
logger.debug(
|
||||
"Setting max_num_batched_tokens to %d for %s usage context.",
|
||||
"%s prefix caching by default",
|
||||
"Enabling" if default_prefix_caching else "Disabling",
|
||||
)
|
||||
elif (
|
||||
model_config.runner_type == "pooling"
|
||||
and self.enable_prefix_caching
|
||||
and not default_prefix_caching
|
||||
):
|
||||
logger.warning(
|
||||
"This model does not officially support prefix caching. "
|
||||
"Enabling this manually may cause the engine to crash "
|
||||
"or produce incorrect outputs.",
|
||||
)
|
||||
|
||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
(
|
||||
default_max_num_batched_tokens,
|
||||
default_max_num_seqs,
|
||||
) = self.get_batch_defaults(world_size)
|
||||
|
||||
orig_max_num_batched_tokens = self.max_num_batched_tokens
|
||||
orig_max_num_seqs = self.max_num_seqs
|
||||
|
||||
if self.max_num_batched_tokens is None:
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
|
||||
usage_context,
|
||||
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = default_max_num_seqs.get(
|
||||
usage_context,
|
||||
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
|
||||
)
|
||||
|
||||
if orig_max_num_batched_tokens is None:
|
||||
if not self.enable_chunked_prefill:
|
||||
# If max_model_len is too short, use the default for higher throughput.
|
||||
self.max_num_batched_tokens = max(
|
||||
model_config.max_model_len,
|
||||
self.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# When using default settings,
|
||||
# Ensure max_num_batched_tokens does not exceed model limit.
|
||||
# Some models (e.g., Whisper) have embeddings tied to max length.
|
||||
self.max_num_batched_tokens = min(
|
||||
self.max_num_seqs * model_config.max_model_len,
|
||||
self.max_num_batched_tokens,
|
||||
use_context_value,
|
||||
)
|
||||
|
||||
if self.max_num_seqs is None and usage_context in default_max_num_seqs:
|
||||
self.max_num_seqs = min(
|
||||
default_max_num_seqs[usage_context],
|
||||
self.max_num_batched_tokens or sys.maxsize,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Setting max_num_seqs to %d for %s usage context.",
|
||||
"Defaulting max_num_batched_tokens to %d for %s usage context.",
|
||||
self.max_num_batched_tokens,
|
||||
usage_context.value if usage_context else None,
|
||||
)
|
||||
|
||||
if orig_max_num_seqs is None:
|
||||
assert self.max_num_batched_tokens is not None # For type checking
|
||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
||||
|
||||
logger.debug(
|
||||
"Defaulting max_num_seqs to %d for %s usage context.",
|
||||
self.max_num_seqs,
|
||||
use_context_value,
|
||||
usage_context.value if usage_context else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@ import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
|
||||
@ -339,10 +338,9 @@ class CpuPlatform(Platform):
|
||||
"prefill and prefix caching to be disabled."
|
||||
)
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -10,7 +10,6 @@ from tpu_info import device
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
@ -186,10 +185,9 @@ class TpuPlatform(Platform):
|
||||
"prefill and prefix caching to be disabled."
|
||||
)
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -9,7 +9,6 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
@ -185,10 +184,9 @@ class XPUPlatform(Platform):
|
||||
"prefill and prefix caching to be disabled."
|
||||
)
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@ -39,12 +39,6 @@ def __dir__() -> list[str]:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# This value is chosen to have a balance between ITL and TTFT. Note it is
|
||||
# not optimized for throughput.
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||
|
||||
# Constants related to forcing the attention backend selection
|
||||
|
||||
# String name of register which may be set in order to
|
||||
@ -60,9 +54,6 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user