mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 01:34:40 +08:00
[Misc] Clean up duplicated hf overrides (#22311)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
134a8ee8fd
commit
fa00c5d75b
@ -1,11 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
||||
@ -19,6 +17,7 @@ from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
|
||||
from ...conftest import VllmRunner
|
||||
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||
from ..utils import dummy_hf_overrides
|
||||
|
||||
ARCH_TO_SKIP = {
|
||||
"MolmoForCausalLM": "incompatible requirements",
|
||||
@ -51,51 +50,6 @@ def create_batched_mm_kwargs(
|
||||
return mm_kwargs
|
||||
|
||||
|
||||
# Avoid OOM and reduce initialization time by only using 1 layer
|
||||
def hf_overrides(hf_config: PretrainedConfig,
|
||||
exist_overrides: dict[str, Any]) -> PretrainedConfig:
|
||||
hf_config.update(exist_overrides)
|
||||
text_config = hf_config.get_text_config()
|
||||
# Ensure at least 2 expert per group
|
||||
# Since `grouped_topk` assumes top-2
|
||||
n_group = getattr(text_config, 'n_group', None)
|
||||
num_experts = n_group * 2 if n_group is not None else 2
|
||||
# we use three layers for Gemma-3n to check
|
||||
# both normal layer and kv_shared_layer
|
||||
text_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": num_experts,
|
||||
# Otherwise there will not be any expert layers
|
||||
"first_k_dense_replace": 0,
|
||||
# To avoid OOM on DeepSeek-V3
|
||||
"n_routed_experts": num_experts,
|
||||
# For Gemma-3n
|
||||
"num_kv_shared_layers": 1,
|
||||
})
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config.vision_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
})
|
||||
# e.g.: ibm-granite/granite-speech-3.3-2b
|
||||
if hasattr(hf_config, "encoder_config"):
|
||||
hf_config.encoder_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
})
|
||||
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
|
||||
if hasattr(hf_config, "audio_config"):
|
||||
hf_config.audio_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"encoder_layers": 1,
|
||||
})
|
||||
return hf_config
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
||||
def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
|
||||
@ -110,7 +64,8 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
|
||||
|
||||
model_id = model_info.default
|
||||
|
||||
hf_overrides_fn = partial(hf_overrides,
|
||||
hf_overrides_fn = partial(dummy_hf_overrides,
|
||||
model_arch=model_arch,
|
||||
exist_overrides=model_info.hf_overrides)
|
||||
|
||||
model_config = ModelConfig(
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import partial
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import ModelImpl
|
||||
@ -16,6 +16,7 @@ from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
from ..utils import create_new_process_for_each_test
|
||||
from .registry import (_TRANSFORMERS_BACKEND_MODELS, AUTO_EXAMPLE_MODELS,
|
||||
HF_EXAMPLE_MODELS, HfExampleModels)
|
||||
from .utils import dummy_hf_overrides
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@ -33,64 +34,15 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
hf_overrides_fn = partial(dummy_hf_overrides,
|
||||
model_arch=model_arch,
|
||||
exist_overrides=model_info.hf_overrides)
|
||||
|
||||
if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
|
||||
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
|
||||
from vllm.model_executor.models.registry import ModelRegistry
|
||||
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)
|
||||
|
||||
# Avoid OOM and reduce initialization time by only using 1 layer
|
||||
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
hf_config.update(model_info.hf_overrides)
|
||||
|
||||
text_config = hf_config.get_text_config()
|
||||
|
||||
# Ensure at least 2 expert per group
|
||||
# Since `grouped_topk` assumes top-2
|
||||
n_group = getattr(text_config, 'n_group', None)
|
||||
num_experts = n_group * 2 if n_group is not None else 2
|
||||
|
||||
# we use three layers for Gemma-3n to check
|
||||
# both normal layer and kv_shared_layer
|
||||
num_hidden_layers = (3 if model_arch
|
||||
== "Gemma3nForConditionalGeneration" else 1)
|
||||
|
||||
text_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": num_experts,
|
||||
# Otherwise there will not be any expert layers
|
||||
"first_k_dense_replace": 0,
|
||||
# To avoid OOM on DeepSeek-V3
|
||||
"n_routed_experts": num_experts,
|
||||
# For Gemma-3n
|
||||
"num_kv_shared_layers": 1,
|
||||
})
|
||||
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config.vision_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
})
|
||||
|
||||
# e.g.: ibm-granite/granite-speech-3.3-2b
|
||||
if hasattr(hf_config, "encoder_config"):
|
||||
hf_config.encoder_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
})
|
||||
|
||||
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
|
||||
if hasattr(hf_config, "audio_config"):
|
||||
hf_config.audio_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"encoder_layers": 1,
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
# Avoid calling model.forward()
|
||||
def _initialize_kv_caches_v0(self) -> None:
|
||||
self.cache_config.num_gpu_blocks = 0
|
||||
@ -132,7 +84,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
load_format="dummy",
|
||||
model_impl=ModelImpl.TRANSFORMERS
|
||||
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
|
||||
hf_overrides=hf_overrides,
|
||||
hf_overrides=hf_overrides_fn,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, RunnerOption
|
||||
from vllm.inputs import InputContext
|
||||
@ -351,3 +352,63 @@ class RerankModelInfo(NamedTuple):
|
||||
architecture: str = ""
|
||||
dtype: str = "auto"
|
||||
enable_test: bool = True
|
||||
|
||||
|
||||
def dummy_hf_overrides(
|
||||
hf_config: PretrainedConfig,
|
||||
model_arch: str,
|
||||
exist_overrides: Optional[dict[str, Any]] = None,
|
||||
) -> PretrainedConfig:
|
||||
"""
|
||||
Dummy HF overrides function used to create dummy model
|
||||
with only minimum nums of layer.
|
||||
"""
|
||||
hf_config.update(exist_overrides or {})
|
||||
|
||||
text_config = hf_config.get_text_config()
|
||||
|
||||
# Ensure at least 2 expert per group
|
||||
# Since `grouped_topk` assumes top-2
|
||||
n_group = getattr(text_config, 'n_group', None)
|
||||
num_experts = n_group * 2 if n_group is not None else 2
|
||||
|
||||
# we use three layers for Gemma-3n to check
|
||||
# both normal layer and kv_shared_layer
|
||||
num_hidden_layers = (3 if model_arch == "Gemma3nForConditionalGeneration"
|
||||
else 1)
|
||||
text_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"num_experts": num_experts,
|
||||
"num_experts_per_tok": 2,
|
||||
"num_local_experts": num_experts,
|
||||
# Otherwise there will not be any expert layers
|
||||
"first_k_dense_replace": 0,
|
||||
# To avoid OOM on DeepSeek-V3
|
||||
"n_routed_experts": num_experts,
|
||||
# For Gemma-3n
|
||||
"num_kv_shared_layers": 1,
|
||||
})
|
||||
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config.vision_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
})
|
||||
|
||||
# e.g.: ibm-granite/granite-speech-3.3-2b
|
||||
if hasattr(hf_config, "encoder_config"):
|
||||
hf_config.encoder_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
})
|
||||
|
||||
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
|
||||
if hasattr(hf_config, "audio_config"):
|
||||
hf_config.audio_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
"encoder_layers": 1,
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user