diff --git a/vllm/config/model.py b/vllm/config/model.py index a730aa8ad1b9c..8b26148ae36a0 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import warnings -from collections.abc import Callable +from collections.abc import Callable, Iterator from dataclasses import InitVar, field from functools import cached_property from typing import TYPE_CHECKING, Any, Literal, cast, get_args @@ -1806,7 +1806,7 @@ class ModelConfig: return getattr(self.hf_config, "quantization_config", None) is not None -def get_served_model_name(model: str, served_model_name: str | list[str] | None): +def get_served_model_name(model: str, served_model_name: str | list[str] | None) -> str: """ If the input is a non-empty list, the first model_name in `served_model_name` is taken. @@ -1844,7 +1844,9 @@ _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ ] -def iter_architecture_defaults(): +def iter_architecture_defaults() -> Iterator[ + tuple[str, tuple[RunnerType, ConvertType]] +]: yield from _SUFFIX_TO_DEFAULTS @@ -1877,7 +1879,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = { } -def str_dtype_to_torch_dtype(type: str): +def str_dtype_to_torch_dtype(type: str) -> torch.dtype | None: return _STR_DTYPE_TO_TORCH_DTYPE.get(type) @@ -1891,14 +1893,14 @@ _FLOAT16_NOT_SUPPORTED_MODELS = { } -def _is_valid_dtype(model_type: str, dtype: torch.dtype): +def _is_valid_dtype(model_type: str, dtype: torch.dtype) -> bool: if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103 return False return True -def _check_valid_dtype(model_type: str, dtype: torch.dtype): +def _check_valid_dtype(model_type: str, dtype: torch.dtype) -> bool: if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] raise ValueError( @@ -1913,7 +1915,7 @@ def _find_dtype( config: PretrainedConfig, *, revision: str | None, -): +) -> torch.dtype: # NOTE: getattr(config, "dtype", torch.float32) is not correct # because config.dtype can be None. config_dtype = getattr(config, "dtype", None) @@ -1953,7 +1955,7 @@ def _resolve_auto_dtype( config_dtype: torch.dtype, *, is_pooling_model: bool, -): +) -> torch.dtype: from vllm.platforms import current_platform supported_dtypes = [