Merge 74b3a2014a3cd207d526917438ae28fe9bcdccfa into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
ゆり 2025-12-25 00:07:16 +00:00 committed by GitHub
commit d5e26513b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 13 deletions

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable, Iterator
from dataclasses import InitVar, field from dataclasses import InitVar, field
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args from typing import TYPE_CHECKING, Any, Literal, cast, get_args
@ -1806,7 +1806,7 @@ class ModelConfig:
return getattr(self.hf_config, "quantization_config", None) is not None return getattr(self.hf_config, "quantization_config", None) is not None
def get_served_model_name(model: str, served_model_name: str | list[str] | None): def get_served_model_name(model: str, served_model_name: str | list[str] | None) -> str:
""" """
If the input is a non-empty list, the first model_name in If the input is a non-empty list, the first model_name in
`served_model_name` is taken. `served_model_name` is taken.
@ -1844,7 +1844,9 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
] ]
def iter_architecture_defaults(): def iter_architecture_defaults() -> Iterator[
tuple[str, tuple[RunnerType, ConvertType]]
]:
yield from _SUFFIX_TO_DEFAULTS yield from _SUFFIX_TO_DEFAULTS
@ -1877,7 +1879,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
} }
def str_dtype_to_torch_dtype(type: str): def str_dtype_to_torch_dtype(type: str) -> torch.dtype | None:
return _STR_DTYPE_TO_TORCH_DTYPE.get(type) return _STR_DTYPE_TO_TORCH_DTYPE.get(type)
@ -1891,14 +1893,14 @@ _FLOAT16_NOT_SUPPORTED_MODELS = {
} }
def _is_valid_dtype(model_type: str, dtype: torch.dtype): def _is_valid_dtype(model_type: str, dtype: torch.dtype) -> bool:
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103 if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103
return False return False
return True return True
def _check_valid_dtype(model_type: str, dtype: torch.dtype): def _check_valid_dtype(model_type: str, dtype: torch.dtype) -> bool:
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16:
reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type]
raise ValueError( raise ValueError(
@ -1913,7 +1915,7 @@ def _find_dtype(
config: PretrainedConfig, config: PretrainedConfig,
*, *,
revision: str | None, revision: str | None,
): ) -> torch.dtype:
# NOTE: getattr(config, "dtype", torch.float32) is not correct # NOTE: getattr(config, "dtype", torch.float32) is not correct
# because config.dtype can be None. # because config.dtype can be None.
config_dtype = getattr(config, "dtype", None) config_dtype = getattr(config, "dtype", None)
@ -1953,7 +1955,7 @@ def _resolve_auto_dtype(
config_dtype: torch.dtype, config_dtype: torch.dtype,
*, *,
is_pooling_model: bool, is_pooling_model: bool,
): ) -> torch.dtype:
from vllm.platforms import current_platform from vllm.platforms import current_platform
supported_dtypes = [ supported_dtypes = [

View File

@ -329,14 +329,23 @@ def initialize_ray_cluster(
available_gpus = cuda_device_count_stateless() available_gpus = cuda_device_count_stateless()
if parallel_config.world_size > available_gpus: if parallel_config.world_size > available_gpus:
logger.warning( logger.warning(
"Tensor parallel size (%d) exceeds available GPUs (%d). " "World size (%d) exceeds locally visible GPUs (%d). "
"This may result in Ray placement group allocation failures. " "For single-node deployments, this may result in Ray "
"Consider reducing tensor_parallel_size to %d or less, " "placement group allocation failures. For multi-node Ray "
"or ensure your Ray cluster has %d GPUs available.", "clusters, ensure your cluster has %d GPUs available across "
"all nodes. (world_size = tensor_parallel_size=%d × "
"pipeline_parallel_size=%d%s)",
parallel_config.world_size, parallel_config.world_size,
available_gpus, available_gpus,
available_gpus,
parallel_config.world_size, parallel_config.world_size,
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
(
f" × prefill_context_parallel_size="
f"{parallel_config.prefill_context_parallel_size}"
if parallel_config.prefill_context_parallel_size > 1
else ""
),
) )
if ray.is_initialized(): if ray.is_initialized():