mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
[Platform] Move platform check to right place (#18470)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
1f3a1200e4
commit
721fb9b181
@ -42,7 +42,10 @@ from vllm.transformers_utils.config import (
|
||||
try_get_generation_config, uses_mrope)
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
|
||||
LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
|
||||
random_uuid, resolve_obj_by_qualname)
|
||||
|
||||
@ -64,12 +67,6 @@ logger = init_logger(__name__)
|
||||
|
||||
ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
|
||||
# This value is chosen to have a balance between ITL and TTFT. Note it is
|
||||
# not optimized for throughput.
|
||||
_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
|
||||
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
"score", "reward", "transcription"]
|
||||
|
||||
@ -2074,28 +2071,28 @@ class SchedulerConfig:
|
||||
# so we don't reject sequences on account of a short
|
||||
# max_num_batched_tokens.
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
else:
|
||||
self.max_num_batched_tokens = (
|
||||
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
else:
|
||||
# If max_model_len is too short, use
|
||||
# _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
|
||||
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
|
||||
# for higher throughput.
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
if self.runner_type == "pooling":
|
||||
# Choose specific value for higher throughput
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_num_batched_tokens,
|
||||
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
if self.is_multimodal_model:
|
||||
# The value needs to be at least the number of multimodal tokens
|
||||
self.max_num_batched_tokens = max(
|
||||
self.max_num_batched_tokens,
|
||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
# When using default settings,
|
||||
@ -4316,18 +4313,6 @@ class VllmConfig:
|
||||
"full_cuda_graph is not supported with "
|
||||
"cascade attention. Disabling cascade attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
if self.model_config and self.model_config.use_mla and \
|
||||
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
self.scheduler_config.enable_chunked_prefill = False
|
||||
self.scheduler_config.chunked_prefill_enabled = False
|
||||
self.scheduler_config.max_num_batched_tokens = max(
|
||||
self.scheduler_config.max_model_len,
|
||||
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import psutil
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
|
||||
|
||||
@ -177,6 +178,16 @@ class CpuPlatform(Platform):
|
||||
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
|
||||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
|
||||
if vllm_config.model_config and vllm_config.model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
logger.warning("Pin memory is not supported on CPU.")
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
@ -80,6 +81,16 @@ class HpuPlatform(Platform):
|
||||
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
if vllm_config.model_config and vllm_config.model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on HPU.")
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
@ -56,6 +57,16 @@ class NeuronPlatform(Platform):
|
||||
vllm_config.cache_config.block_size = \
|
||||
vllm_config.model_config.max_model_len # type: ignore
|
||||
|
||||
if vllm_config.model_config and vllm_config.model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
logger.warning("Pin memory is not supported on Neuron.")
|
||||
|
||||
@ -9,6 +9,7 @@ import vllm.envs as envs
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import Platform, PlatformEnum, _Backend
|
||||
|
||||
@ -161,6 +162,16 @@ class TpuPlatform(Platform):
|
||||
"Forcing --disable_chunked_mm_input.")
|
||||
scheduler_config.disable_chunked_mm_input = True
|
||||
|
||||
if vllm_config.model_config and vllm_config.model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on TPU.")
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
@ -113,6 +114,16 @@ class XPUPlatform(Platform):
|
||||
parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend = "ray"
|
||||
|
||||
if vllm_config.model_config and vllm_config.model_config.use_mla:
|
||||
logger.info(
|
||||
"MLA is enabled on a non-GPU platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = max(
|
||||
vllm_config.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on XPU.")
|
||||
|
||||
@ -77,6 +77,12 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# This value is chosen to have a balance between ITL and TTFT. Note it is
|
||||
# not optimized for throughput.
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||
|
||||
# Exception strings for non-implemented encoder/decoder scenarios
|
||||
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user