Move ModelConfig from config/__init__.py to config/model.py (#25252)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-19 17:22:33 +01:00 committed by GitHub
parent cf278ff3b2
commit aed16879a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2160 additions and 2149 deletions

View File

@ -39,7 +39,8 @@ from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype
from vllm.config.model import (ConvertOption, RunnerOption,
_get_and_verify_dtype)
from vllm.connections import global_http_connection
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,

View File

@ -14,7 +14,7 @@ from typing import Literal, NamedTuple, Optional
import pytest
from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption
from vllm.config.model import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config

View File

@ -7,7 +7,6 @@ from unittest.mock import patch
import pytest
from vllm import LLM
from vllm.config import ModelImpl
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
@ -111,8 +110,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# these tests seem to produce leftover memory
gpu_memory_utilization=0.80,
load_format="dummy",
model_impl=ModelImpl.TRANSFORMERS
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
model_impl="transformers"
if model_arch in _TRANSFORMERS_BACKEND_MODELS else "vllm",
hf_overrides=hf_overrides_fn,
max_num_seqs=model_info.max_num_seqs)

View File

@ -3,6 +3,7 @@
import itertools
from collections.abc import Generator
from typing import get_args
import pytest
import torch
@ -464,7 +465,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
assert len(prompt_logprob) == vocab_size
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode.
@ -493,14 +494,12 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
for logprobs in output.logprobs:
for token_id in logprobs:
logprob = logprobs[token_id]
if logprobs_mode in (LogprobsMode.RAW_LOGPROBS,
LogprobsMode.PROCESSED_LOGPROBS):
if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
assert logprob.logprob <= 0
if logprob.logprob > 0:
positive_values = positive_values + 1
total_token_with_logprobs = total_token_with_logprobs + 1
assert total_token_with_logprobs >= len(results[0].outputs)
if logprobs_mode in (LogprobsMode.RAW_LOGITS,
LogprobsMode.PROCESSED_LOGITS):
if logprobs_mode in ("raw_logits", "processed_logits"):
assert positive_values > 0
del llm

File diff suppressed because it is too large Load Diff

2006
vllm/config/model.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,7 @@
import hashlib
from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
from pydantic import SkipValidation, model_validator
from pydantic.dataclasses import dataclass
@ -15,13 +15,9 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS)
if TYPE_CHECKING:
from vllm.config import RunnerType
else:
RunnerType = Any
logger = init_logger(__name__)
RunnerType = Literal["generate", "pooling", "draft"]
PreemptionMode = Literal["swap", "recompute"]
SchedulerPolicy = Literal["fcfs", "priority"]

View File

@ -1,8 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import inspect
import textwrap
from dataclasses import MISSING, Field, field, fields, is_dataclass
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar
import regex as re
if TYPE_CHECKING:
from _typeshed import DataclassInstance
@ -45,3 +50,96 @@ def get_field(cls: ConfigType, name: str) -> Field:
return field(default=default)
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory.")
def contains_object_print(text: str) -> bool:
"""
Check if the text looks like a printed Python object, e.g.
contains any substring matching the pattern: "at 0xFFFFFFF>"
We match against 0x followed by 2-16 hex chars (there's
a max of 16 on a 64-bit system).
Args:
text (str): The text to check
Returns:
result (bool): `True` if a match is found, `False` otherwise.
"""
pattern = r'at 0x[a-fA-F0-9]{2,16}>'
match = re.search(pattern, text)
return match is not None
def assert_hashable(text: str) -> bool:
if not contains_object_print(text):
return True
raise AssertionError(
f"vLLM tried to hash some configs that may have Python objects ids "
f"in them. This is a bug, please file an issue. "
f"Text being hashed: {text}")
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
https://davidism.com/mit-license/
"""
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b
try:
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
except (OSError, KeyError, TypeError):
# HACK: Python 3.13+ workaround - set missing __firstlineno__
# Workaround can be removed after we upgrade to pydantic==2.12.0
with open(inspect.getfile(cls)) as f:
for i, line in enumerate(f):
if f"class {cls.__name__}" in line and ":" in line:
cls.__firstlineno__ = i + 1
break
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
if not isinstance(cls_node, ast.ClassDef):
raise TypeError("Given object was not a class.")
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
def is_init_field(cls: ConfigType, name: str) -> bool:
return next(f for f in fields(cls) if f.name == name).init

View File

@ -27,11 +27,11 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
EPLBConfig, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
ModelDType, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
RunnerOption, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, StructuredOutputsConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs)
ModelDType, ObservabilityConfig, ParallelConfig,
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
StructuredOutputsConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs)
from vllm.config.multimodal import MMCacheType, MultiModalConfig
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.config.utils import get_field
@ -548,7 +548,6 @@ class EngineArgs:
model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode",
choices=[f.value for f in LogprobsMode],
**model_kwargs["logprobs_mode"])
model_group.add_argument("--disable-sliding-window",
**model_kwargs["disable_sliding_window"])
@ -593,9 +592,7 @@ class EngineArgs:
**model_kwargs["override_generation_config"])
model_group.add_argument("--enable-sleep-mode",
**model_kwargs["enable_sleep_mode"])
model_group.add_argument("--model-impl",
choices=[f.value for f in ModelImpl],
**model_kwargs["model_impl"])
model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
model_group.add_argument("--override-attention-dtype",
**model_kwargs["override_attention_dtype"])
model_group.add_argument("--logits-processors",

View File

@ -13,8 +13,7 @@ from torch import nn
from typing_extensions import assert_never
from vllm.attention import Attention
from vllm.config import (ModelConfig, ModelImpl, VllmConfig,
set_current_vllm_config)
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
from vllm.model_executor.layers.quantization.base_config import (
@ -176,8 +175,8 @@ def get_model_architecture(
)
if arch == model_config._get_transformers_backend_cls():
assert model_config.model_impl != ModelImpl.VLLM
if model_config.model_impl == ModelImpl.AUTO:
assert model_config.model_impl != "vllm"
if model_config.model_impl == "auto":
logger.warning_once(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "

View File

@ -19,7 +19,7 @@ from typing import Callable, Optional, TypeVar, Union
import torch.nn as nn
import transformers
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
from vllm.config import (ModelConfig, iter_architecture_defaults,
try_match_architecture_defaults)
from vllm.logger import init_logger
from vllm.transformers_utils.dynamic_module import (
@ -587,7 +587,7 @@ class _ModelRegistry:
if model_module is not None:
break
else:
if model_config.model_impl != ModelImpl.TRANSFORMERS:
if model_config.model_impl != "transformers":
return None
raise ValueError(
@ -598,7 +598,7 @@ class _ModelRegistry:
"'auto_map' (relevant if the model is custom).")
if not model_module.is_backend_compatible():
if model_config.model_impl != ModelImpl.TRANSFORMERS:
if model_config.model_impl != "transformers":
return None
raise ValueError(
@ -644,20 +644,20 @@ class _ModelRegistry:
raise ValueError("No model architectures are specified")
# Require transformers impl
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if model_config.model_impl == "transformers":
arch = self._try_resolve_transformers(architectures[0],
model_config)
if arch is not None:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return (model_info, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
elif model_config.model_impl == "terratorch":
model_info = self._try_inspect_model_cls("Terratorch")
return (model_info, "Terratorch")
# Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures)
and model_config.model_impl == ModelImpl.AUTO
and model_config.model_impl == "auto"
and getattr(model_config, "convert_type", "none") == "none"):
arch = self._try_resolve_transformers(architectures[0],
model_config)
@ -674,7 +674,7 @@ class _ModelRegistry:
# Fallback to transformers impl (before resolving runner_type)
if (all(arch not in self.models for arch in architectures)
and model_config.model_impl == ModelImpl.AUTO):
and model_config.model_impl == "auto"):
arch = self._try_resolve_transformers(architectures[0],
model_config)
if arch is not None:
@ -695,14 +695,14 @@ class _ModelRegistry:
raise ValueError("No model architectures are specified")
# Require transformers impl
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if model_config.model_impl == "transformers":
arch = self._try_resolve_transformers(architectures[0],
model_config)
if arch is not None:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
elif model_config.model_impl == "terratorch":
arch = "Terratorch"
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
@ -710,7 +710,7 @@ class _ModelRegistry:
# Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures)
and model_config.model_impl == ModelImpl.AUTO
and model_config.model_impl == "auto"
and getattr(model_config, "convert_type", "none") == "none"):
arch = self._try_resolve_transformers(architectures[0],
model_config)
@ -727,7 +727,7 @@ class _ModelRegistry:
# Fallback to transformers impl (before resolving runner_type)
if (all(arch not in self.models for arch in architectures)
and model_config.model_impl == ModelImpl.AUTO):
and model_config.model_impl == "auto"):
arch = self._try_resolve_transformers(architectures[0],
model_config)
if arch is not None:

View File

@ -29,15 +29,12 @@ class TopKTopPSampler(nn.Module):
Implementations may update the logits tensor in-place.
"""
def __init__(
self,
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None:
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
super().__init__()
self.logprobs_mode = logprobs_mode
# flashinfer optimization does not apply if intermediate
# logprobs/logits after top_k/top_p need to be returned
if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS,
LogprobsMode.PROCESSED_LOGPROBS
if logprobs_mode not in ("processed_logits", "processed_logprobs"
) and current_platform.is_cuda():
if is_flashinfer_available:
flashinfer_version = flashinfer.__version__
@ -90,9 +87,9 @@ class TopKTopPSampler(nn.Module):
"""
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), logits_to_return
@ -115,7 +112,7 @@ class TopKTopPSampler(nn.Module):
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in (
LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS
"processed_logits", "processed_logprobs"
), "FlashInfer does not support returning logits/logprobs"
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous

View File

@ -60,8 +60,7 @@ class Sampler(nn.Module):
9. Return the final `SamplerOutput`.
"""
def __init__(self,
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS):
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
self.pin_memory = is_pin_memory_available()
@ -78,9 +77,9 @@ class Sampler(nn.Module):
# is used for sampling (after penalties and temperature scaling).
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS:
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == LogprobsMode.RAW_LOGITS:
elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
# Use float32 for the logits.
@ -156,9 +155,9 @@ class Sampler(nn.Module):
if sampling_metadata.all_greedy:
processed_logprobs = None
if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
if self.logprobs_mode == "processed_logits":
processed_logprobs = logits
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
elif self.logprobs_mode == "processed_logprobs":
processed_logprobs = self.compute_logprobs(logits)
return greedy_sampled, processed_logprobs