[Misc] Clean up duplicated hf overrides (#22311)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-08-06 15:50:25 +08:00 committed by GitHub
parent 134a8ee8fd
commit fa00c5d75b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 103 deletions

View File

@ -1,11 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import partial
from typing import Any
from unittest.mock import patch
import pytest
from transformers import PretrainedConfig
from vllm.config import ModelConfig
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
@ -19,6 +17,7 @@ from vllm.v1.engine.core import EngineCore as V1EngineCore
from ...conftest import VllmRunner
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
from ..utils import dummy_hf_overrides
ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
@ -51,51 +50,6 @@ def create_batched_mm_kwargs(
return mm_kwargs
# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig,
exist_overrides: dict[str, Any]) -> PretrainedConfig:
hf_config.update(exist_overrides)
text_config = hf_config.get_text_config()
# Ensure at least 2 expert per group
# Since `grouped_topk` assumes top-2
n_group = getattr(text_config, 'n_group', None)
num_experts = n_group * 2 if n_group is not None else 2
# we use three layers for Gemma-3n to check
# both normal layer and kv_shared_layer
text_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
})
if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
# e.g.: ibm-granite/granite-speech-3.3-2b
if hasattr(hf_config, "encoder_config"):
hf_config.encoder_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
if hasattr(hf_config, "audio_config"):
hf_config.audio_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"encoder_layers": 1,
})
return hf_config
@pytest.mark.core_model
@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys()))
def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
@ -110,7 +64,8 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
model_id = model_info.default
hf_overrides_fn = partial(hf_overrides,
hf_overrides_fn = partial(dummy_hf_overrides,
model_arch=model_arch,
exist_overrides=model_info.hf_overrides)
model_config = ModelConfig(

View File

@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import partial
from unittest.mock import patch
import pytest
from transformers import PretrainedConfig
from vllm import LLM
from vllm.config import ModelImpl
@ -16,6 +16,7 @@ from vllm.v1.engine.core import EngineCore as V1EngineCore
from ..utils import create_new_process_for_each_test
from .registry import (_TRANSFORMERS_BACKEND_MODELS, AUTO_EXAMPLE_MODELS,
HF_EXAMPLE_MODELS, HfExampleModels)
from .utils import dummy_hf_overrides
@create_new_process_for_each_test()
@ -33,64 +34,15 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
hf_overrides_fn = partial(dummy_hf_overrides,
model_arch=model_arch,
exist_overrides=model_info.hf_overrides)
if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"):
from vllm.model_executor.models.llama4 import Llama4ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM)
# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)
text_config = hf_config.get_text_config()
# Ensure at least 2 expert per group
# Since `grouped_topk` assumes top-2
n_group = getattr(text_config, 'n_group', None)
num_experts = n_group * 2 if n_group is not None else 2
# we use three layers for Gemma-3n to check
# both normal layer and kv_shared_layer
num_hidden_layers = (3 if model_arch
== "Gemma3nForConditionalGeneration" else 1)
text_config.update({
"num_layers": 1,
"num_hidden_layers": num_hidden_layers,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
})
if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
# e.g.: ibm-granite/granite-speech-3.3-2b
if hasattr(hf_config, "encoder_config"):
hf_config.encoder_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
if hasattr(hf_config, "audio_config"):
hf_config.audio_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"encoder_layers": 1,
})
return hf_config
# Avoid calling model.forward()
def _initialize_kv_caches_v0(self) -> None:
self.cache_config.num_gpu_blocks = 0
@ -132,7 +84,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
load_format="dummy",
model_impl=ModelImpl.TRANSFORMERS
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
hf_overrides=hf_overrides,
hf_overrides=hf_overrides_fn,
)

View File

@ -7,6 +7,7 @@ from typing import Any, NamedTuple, Optional, Union
import torch
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, RunnerOption
from vllm.inputs import InputContext
@ -351,3 +352,63 @@ class RerankModelInfo(NamedTuple):
architecture: str = ""
dtype: str = "auto"
enable_test: bool = True
def dummy_hf_overrides(
hf_config: PretrainedConfig,
model_arch: str,
exist_overrides: Optional[dict[str, Any]] = None,
) -> PretrainedConfig:
"""
Dummy HF overrides function used to create dummy model
with only minimum nums of layer.
"""
hf_config.update(exist_overrides or {})
text_config = hf_config.get_text_config()
# Ensure at least 2 expert per group
# Since `grouped_topk` assumes top-2
n_group = getattr(text_config, 'n_group', None)
num_experts = n_group * 2 if n_group is not None else 2
# we use three layers for Gemma-3n to check
# both normal layer and kv_shared_layer
num_hidden_layers = (3 if model_arch == "Gemma3nForConditionalGeneration"
else 1)
text_config.update({
"num_layers": 1,
"num_hidden_layers": num_hidden_layers,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
})
if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
# e.g.: ibm-granite/granite-speech-3.3-2b
if hasattr(hf_config, "encoder_config"):
hf_config.encoder_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})
# e.g.: Qwen/Qwen2-Audio-7B-Instruct
if hasattr(hf_config, "audio_config"):
hf_config.audio_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
"encoder_layers": 1,
})
return hf_config