[Platform] Move platform check to right place (#18470)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-05-23 03:11:28 +08:00 committed by GitHub
parent 1f3a1200e4
commit 721fb9b181
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 71 additions and 25 deletions

View File

@ -42,7 +42,10 @@ from vllm.transformers_utils.config import (
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname)
@ -64,12 +67,6 @@ logger = init_logger(__name__)
ConfigT = TypeVar("ConfigT", bound=ConfigType)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription"]
@ -2074,28 +2071,28 @@ class SchedulerConfig:
# so we don't reject sequences on account of a short
# max_num_batched_tokens.
self.max_num_batched_tokens = max(
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
else:
self.max_num_batched_tokens = (
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
DEFAULT_MAX_NUM_BATCHED_TOKENS)
else:
# If max_model_len is too short, use
# _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)
# When using default settings,
@ -4316,18 +4313,6 @@ class VllmConfig:
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
if self.model_config and self.model_config.use_mla and \
not (current_platform.is_cuda() or current_platform.is_rocm()):
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False

View File

@ -9,6 +9,7 @@ import psutil
import torch
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
@ -177,6 +178,16 @@ class CpuPlatform(Platform):
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")

View File

@ -7,6 +7,7 @@ import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend
@ -80,6 +81,16 @@ class HpuPlatform(Platform):
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum
@ -56,6 +57,16 @@ class NeuronPlatform(Platform):
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")

View File

@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum, _Backend
@ -161,6 +162,16 @@ class TpuPlatform(Platform):
"Forcing --disable_chunked_mm_input.")
scheduler_config.disable_chunked_mm_input = True
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
import torch
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
@ -113,6 +114,16 @@ class XPUPlatform(Platform):
parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "ray"
if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on XPU.")

View File

@ -77,6 +77,12 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
# Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/features/compatibility_matrix.md