mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 11:04:33 +08:00
Move ModelConfig from config/__init__.py to config/model.py (#25252)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
cf278ff3b2
commit
aed16879a9
@ -39,7 +39,8 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype
|
||||
from vllm.config.model import (ConvertOption, RunnerOption,
|
||||
_get_and_verify_dtype)
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
|
||||
@ -14,7 +14,7 @@ from typing import Literal, NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption
|
||||
from vllm.config.model import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import ModelImpl
|
||||
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
@ -111,8 +110,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
# these tests seem to produce leftover memory
|
||||
gpu_memory_utilization=0.80,
|
||||
load_format="dummy",
|
||||
model_impl=ModelImpl.TRANSFORMERS
|
||||
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
|
||||
model_impl="transformers"
|
||||
if model_arch in _TRANSFORMERS_BACKEND_MODELS else "vllm",
|
||||
hf_overrides=hf_overrides_fn,
|
||||
max_num_seqs=model_info.max_num_seqs)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import itertools
|
||||
from collections.abc import Generator
|
||||
from typing import get_args
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -464,7 +465,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(prompt_logprob) == vocab_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
|
||||
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test with LLM engine with different logprobs_mode.
|
||||
@ -493,14 +494,12 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
logprob = logprobs[token_id]
|
||||
if logprobs_mode in (LogprobsMode.RAW_LOGPROBS,
|
||||
LogprobsMode.PROCESSED_LOGPROBS):
|
||||
if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
|
||||
assert logprob.logprob <= 0
|
||||
if logprob.logprob > 0:
|
||||
positive_values = positive_values + 1
|
||||
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||
if logprobs_mode in (LogprobsMode.RAW_LOGITS,
|
||||
LogprobsMode.PROCESSED_LOGITS):
|
||||
if logprobs_mode in ("raw_logits", "processed_logits"):
|
||||
assert positive_values > 0
|
||||
del llm
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
2006
vllm/config/model.py
Normal file
2006
vllm/config/model.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,7 +3,7 @@
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
@ -15,13 +15,9 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import RunnerType
|
||||
else:
|
||||
RunnerType = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft"]
|
||||
PreemptionMode = Literal["swap", "recompute"]
|
||||
SchedulerPolicy = Literal["fcfs", "priority"]
|
||||
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import regex as re
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
@ -45,3 +50,96 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
return field(default=default)
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||
|
||||
|
||||
def contains_object_print(text: str) -> bool:
|
||||
"""
|
||||
Check if the text looks like a printed Python object, e.g.
|
||||
contains any substring matching the pattern: "at 0xFFFFFFF>"
|
||||
We match against 0x followed by 2-16 hex chars (there's
|
||||
a max of 16 on a 64-bit system).
|
||||
|
||||
Args:
|
||||
text (str): The text to check
|
||||
|
||||
Returns:
|
||||
result (bool): `True` if a match is found, `False` otherwise.
|
||||
"""
|
||||
pattern = r'at 0x[a-fA-F0-9]{2,16}>'
|
||||
match = re.search(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def assert_hashable(text: str) -> bool:
|
||||
if not contains_object_print(text):
|
||||
return True
|
||||
raise AssertionError(
|
||||
f"vLLM tried to hash some configs that may have Python objects ids "
|
||||
f"in them. This is a bug, please file an issue. "
|
||||
f"Text being hashed: {text}")
|
||||
|
||||
|
||||
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
"""
|
||||
Get any docstrings placed after attribute assignments in a class body.
|
||||
|
||||
https://davidism.com/mit-license/
|
||||
"""
|
||||
|
||||
def pairwise(iterable):
|
||||
"""
|
||||
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
|
||||
|
||||
Can be removed when Python 3.9 support is dropped.
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
a = next(iterator, None)
|
||||
|
||||
for b in iterator:
|
||||
yield a, b
|
||||
a = b
|
||||
|
||||
try:
|
||||
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
||||
except (OSError, KeyError, TypeError):
|
||||
# HACK: Python 3.13+ workaround - set missing __firstlineno__
|
||||
# Workaround can be removed after we upgrade to pydantic==2.12.0
|
||||
with open(inspect.getfile(cls)) as f:
|
||||
for i, line in enumerate(f):
|
||||
if f"class {cls.__name__}" in line and ":" in line:
|
||||
cls.__firstlineno__ = i + 1
|
||||
break
|
||||
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
||||
|
||||
if not isinstance(cls_node, ast.ClassDef):
|
||||
raise TypeError("Given object was not a class.")
|
||||
|
||||
out = {}
|
||||
|
||||
# Consider each pair of nodes.
|
||||
for a, b in pairwise(cls_node.body):
|
||||
# Must be an assignment then a constant string.
|
||||
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||
or not isinstance(b, ast.Expr)
|
||||
or not isinstance(b.value, ast.Constant)
|
||||
or not isinstance(b.value.value, str)):
|
||||
continue
|
||||
|
||||
doc = inspect.cleandoc(b.value.value)
|
||||
|
||||
# An assignment can have multiple targets (a = b = v), but an
|
||||
# annotated assignment only has one target.
|
||||
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
||||
|
||||
for target in targets:
|
||||
# Must be assigning to a plain name.
|
||||
if not isinstance(target, ast.Name):
|
||||
continue
|
||||
|
||||
out[target.id] = doc
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return next(f for f in fields(cls) if f.name == name).init
|
||||
|
||||
@ -27,11 +27,11 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
EPLBConfig, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
||||
ModelDType, ModelImpl, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
RunnerOption, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, StructuredOutputsConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs)
|
||||
ModelDType, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
StructuredOutputsConfig, TaskOption, TokenizerMode,
|
||||
VllmConfig, get_attr_docs)
|
||||
from vllm.config.multimodal import MMCacheType, MultiModalConfig
|
||||
from vllm.config.parallel import ExpertPlacementStrategy
|
||||
from vllm.config.utils import get_field
|
||||
@ -548,7 +548,6 @@ class EngineArgs:
|
||||
model_group.add_argument("--max-logprobs",
|
||||
**model_kwargs["max_logprobs"])
|
||||
model_group.add_argument("--logprobs-mode",
|
||||
choices=[f.value for f in LogprobsMode],
|
||||
**model_kwargs["logprobs_mode"])
|
||||
model_group.add_argument("--disable-sliding-window",
|
||||
**model_kwargs["disable_sliding_window"])
|
||||
@ -593,9 +592,7 @@ class EngineArgs:
|
||||
**model_kwargs["override_generation_config"])
|
||||
model_group.add_argument("--enable-sleep-mode",
|
||||
**model_kwargs["enable_sleep_mode"])
|
||||
model_group.add_argument("--model-impl",
|
||||
choices=[f.value for f in ModelImpl],
|
||||
**model_kwargs["model_impl"])
|
||||
model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
|
||||
model_group.add_argument("--override-attention-dtype",
|
||||
**model_kwargs["override_attention_dtype"])
|
||||
model_group.add_argument("--logits-processors",
|
||||
|
||||
@ -13,8 +13,7 @@ from torch import nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (ModelConfig, ModelImpl, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -176,8 +175,8 @@ def get_model_architecture(
|
||||
)
|
||||
|
||||
if arch == model_config._get_transformers_backend_cls():
|
||||
assert model_config.model_impl != ModelImpl.VLLM
|
||||
if model_config.model_impl == ModelImpl.AUTO:
|
||||
assert model_config.model_impl != "vllm"
|
||||
if model_config.model_impl == "auto":
|
||||
logger.warning_once(
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
|
||||
@ -19,7 +19,7 @@ from typing import Callable, Optional, TypeVar, Union
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
|
||||
from vllm.config import (ModelConfig, ModelImpl, iter_architecture_defaults,
|
||||
from vllm.config import (ModelConfig, iter_architecture_defaults,
|
||||
try_match_architecture_defaults)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.dynamic_module import (
|
||||
@ -587,7 +587,7 @@ class _ModelRegistry:
|
||||
if model_module is not None:
|
||||
break
|
||||
else:
|
||||
if model_config.model_impl != ModelImpl.TRANSFORMERS:
|
||||
if model_config.model_impl != "transformers":
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
@ -598,7 +598,7 @@ class _ModelRegistry:
|
||||
"'auto_map' (relevant if the model is custom).")
|
||||
|
||||
if not model_module.is_backend_compatible():
|
||||
if model_config.model_impl != ModelImpl.TRANSFORMERS:
|
||||
if model_config.model_impl != "transformers":
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
@ -644,20 +644,20 @@ class _ModelRegistry:
|
||||
raise ValueError("No model architectures are specified")
|
||||
|
||||
# Require transformers impl
|
||||
if model_config.model_impl == ModelImpl.TRANSFORMERS:
|
||||
if model_config.model_impl == "transformers":
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
if arch is not None:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
return (model_info, arch)
|
||||
elif model_config.model_impl == ModelImpl.TERRATORCH:
|
||||
elif model_config.model_impl == "terratorch":
|
||||
model_info = self._try_inspect_model_cls("Terratorch")
|
||||
return (model_info, "Terratorch")
|
||||
|
||||
# Fallback to transformers impl (after resolving convert_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == ModelImpl.AUTO
|
||||
and model_config.model_impl == "auto"
|
||||
and getattr(model_config, "convert_type", "none") == "none"):
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
@ -674,7 +674,7 @@ class _ModelRegistry:
|
||||
|
||||
# Fallback to transformers impl (before resolving runner_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == ModelImpl.AUTO):
|
||||
and model_config.model_impl == "auto"):
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
if arch is not None:
|
||||
@ -695,14 +695,14 @@ class _ModelRegistry:
|
||||
raise ValueError("No model architectures are specified")
|
||||
|
||||
# Require transformers impl
|
||||
if model_config.model_impl == ModelImpl.TRANSFORMERS:
|
||||
if model_config.model_impl == "transformers":
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
if arch is not None:
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
elif model_config.model_impl == ModelImpl.TERRATORCH:
|
||||
elif model_config.model_impl == "terratorch":
|
||||
arch = "Terratorch"
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
@ -710,7 +710,7 @@ class _ModelRegistry:
|
||||
|
||||
# Fallback to transformers impl (after resolving convert_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == ModelImpl.AUTO
|
||||
and model_config.model_impl == "auto"
|
||||
and getattr(model_config, "convert_type", "none") == "none"):
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
@ -727,7 +727,7 @@ class _ModelRegistry:
|
||||
|
||||
# Fallback to transformers impl (before resolving runner_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
and model_config.model_impl == ModelImpl.AUTO):
|
||||
and model_config.model_impl == "auto"):
|
||||
arch = self._try_resolve_transformers(architectures[0],
|
||||
model_config)
|
||||
if arch is not None:
|
||||
|
||||
@ -29,15 +29,12 @@ class TopKTopPSampler(nn.Module):
|
||||
Implementations may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None:
|
||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
|
||||
super().__init__()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
# flashinfer optimization does not apply if intermediate
|
||||
# logprobs/logits after top_k/top_p need to be returned
|
||||
if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS,
|
||||
LogprobsMode.PROCESSED_LOGPROBS
|
||||
if logprobs_mode not in ("processed_logits", "processed_logprobs"
|
||||
) and current_platform.is_cuda():
|
||||
if is_flashinfer_available:
|
||||
flashinfer_version = flashinfer.__version__
|
||||
@ -90,9 +87,9 @@ class TopKTopPSampler(nn.Module):
|
||||
"""
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
@ -115,7 +112,7 @@ class TopKTopPSampler(nn.Module):
|
||||
"PyTorch-native implementation.")
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
assert self.logprobs_mode not in (
|
||||
LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS
|
||||
"processed_logits", "processed_logprobs"
|
||||
), "FlashInfer does not support returning logits/logprobs"
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
|
||||
@ -60,8 +60,7 @@ class Sampler(nn.Module):
|
||||
9. Return the final `SamplerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS):
|
||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
@ -78,9 +77,9 @@ class Sampler(nn.Module):
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS:
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
elif self.logprobs_mode == LogprobsMode.RAW_LOGITS:
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
|
||||
# Use float32 for the logits.
|
||||
@ -156,9 +155,9 @@ class Sampler(nn.Module):
|
||||
if sampling_metadata.all_greedy:
|
||||
processed_logprobs = None
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
processed_logprobs = logits
|
||||
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
processed_logprobs = self.compute_logprobs(logits)
|
||||
return greedy_sampled, processed_logprobs
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user