remove convertor dependency on model&revision, only torch_dtype is classmethod

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
Xingyu Liu 2025-12-22 23:08:08 -08:00
parent 78d47494df
commit 5bde69c2b9
6 changed files with 44 additions and 37 deletions

View File

@ -8,6 +8,9 @@ from pathlib import Path
import pytest
from vllm.config import ModelConfig, ParallelConfig, SpeculativeConfig
from vllm.transformers_utils.model_arch_config_convertor import (
ModelArchConfigConvertorBase,
)
BASE_TRUST_REMOTE_CODE_MODELS = {
"nvidia/Llama-3_3-Nemotron-Super-49B-v1",
@ -57,9 +60,10 @@ def _load_groundtruth(filename: str) -> dict:
def _assert_model_arch_config(
model_arch_config, expected: dict, check_head_size: bool = True
model_config, expected: dict, check_head_size: bool = True
):
"""Assert model_arch_config matches expected values."""
model_arch_config = model_config.model_arch_config
assert model_arch_config.architectures == expected["architectures"]
assert model_arch_config.model_type == expected["model_type"]
assert model_arch_config.text_model_type == expected["text_model_type"]
@ -75,7 +79,11 @@ def _assert_model_arch_config(
assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"]
assert model_arch_config.num_experts == expected["num_experts"]
assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"]
assert str(model_arch_config.torch_dtype) == expected["dtype"]
torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
model_config.hf_config, model_config.model_id, revision=model_config.revision
)
assert str(torch_dtype) == expected["dtype"]
if check_head_size:
assert model_arch_config.head_size == expected["head_size"]
@ -109,7 +117,7 @@ def test_base_model_arch_config(model: str):
model, trust_remote_code=model in BASE_TRUST_REMOTE_CODE_MODELS
)
_assert_model_arch_config(model_config.model_arch_config, expected)
_assert_model_arch_config(model_config, expected)
_assert_model_config_methods(model_config, expected)

View File

@ -14,9 +14,6 @@ from vllm.config.model import ModelConfig, ModelDType, RunnerOption
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from vllm.multimodal.processing import InputProcessingContext
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.transformers_utils.model_arch_config_convertor import (
ModelArchConfigConvertorBase,
)
from .. import ci_envs
from .registry import HF_EXAMPLE_MODELS
@ -486,12 +483,10 @@ def dummy_hf_overrides(
"num_kv_shared_layers": 1,
}
class DummyConfig:
hf_text_config = text_config
model_arch_config = ModelConfig.get_model_arch_config(hf_config, text_config)
# Only set MoE related config when the model has MoE layers.
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
if ModelArchConfigConvertorBase.get_num_experts(text_config) > 0:
if model_arch_config.num_experts > 0:
update_dict.update(
{
"num_experts": num_experts,

View File

@ -16,6 +16,10 @@ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
from vllm.config.model import ModelConfig, get_hf_text_config
from vllm.transformers_utils.model_arch_config_convertor import (
MODEL_ARCH_CONFIG_CONVERTORS,
ModelArchConfigConvertorBase,
)
from vllm.v1.metrics.perf import (
AttentionMetrics,
BaseConfigParser,
@ -33,6 +37,12 @@ class MockModelConfig:
def __init__(self, hf_config, dtype):
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(hf_config)
convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get(
self.hf_config.model_type, ModelArchConfigConvertorBase
)
self.model_arch_config = convertor_cls(
self.hf_config, self.hf_text_config
).convert()
self.dtype = dtype
self.is_attention_free = False

View File

@ -484,7 +484,9 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision
)
self.model_arch_config = self.get_model_arch_config()
self.model_arch_config = self.get_model_arch_config(
self.hf_config, self.hf_text_config
)
architectures = self.architectures
registry = self.registry
@ -602,12 +604,15 @@ class ModelConfig:
self._verify_cuda_graph()
self._verify_bnb_config()
def get_model_arch_config(self) -> ModelArchitectureConfig:
@classmethod
def get_model_arch_config(
cls, hf_config, hf_text_config
) -> ModelArchitectureConfig:
convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get(
self.hf_config.model_type, ModelArchConfigConvertorBase
hf_config.model_type, ModelArchConfigConvertorBase
)
convertor = convertor_cls(self.hf_config, self.hf_text_config)
return convertor.convert(self.model, self.revision)
convertor = convertor_cls(hf_config, hf_text_config)
return convertor.convert()
@field_validator("tokenizer", "max_model_len", mode="wrap")
@classmethod
@ -850,7 +855,7 @@ class ModelConfig:
self.quantization = cast(me_quant.QuantizationMethods, self.quantization)
# Parse quantization method from the HF model config, if available.
quant_cfg = ModelArchConfigConvertorBase.get_quantization_config(self.hf_config)
quant_cfg = self.model_arch_config.quantization_config
if quant_cfg is not None:
quant_method = quant_cfg["quant_method"]

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
@ -51,9 +50,6 @@ class ModelArchitectureConfig:
quantization_config: dict[str, Any] | None
"""Quantization configuration dictionary containing quantization parameters."""
torch_dtype: torch.dtype | str | None
"""PyTorch data type for model weights (e.g., 'float16', 'bfloat16')."""
is_deepseek_mla: bool
"""Whether the model is a DeepSeek MLA model."""

View File

@ -81,9 +81,7 @@ class ModelArchConfigConvertorBase:
return self.hf_text_config.num_attention_heads
@final
@classmethod
def get_num_experts(cls, hf_text_config: PretrainedConfig) -> int:
def get_num_experts(self) -> int:
"""Returns the number of experts in the model."""
num_expert_names = [
"num_experts", # Jamba
@ -91,7 +89,7 @@ class ModelArchConfigConvertorBase:
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = getattr_iter(hf_text_config, num_expert_names, 0)
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
if isinstance(num_experts, list):
# Ernie VL's remote code uses list[int]...
# The values are always the same so we just take the first one.
@ -137,9 +135,7 @@ class ModelArchConfigConvertorBase:
return config_dtype
@final
@classmethod
def _normalize_quantization_config(cls, config: PretrainedConfig):
def _normalize_quantization_config(self, config: PretrainedConfig):
quant_cfg = getattr(config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
@ -176,15 +172,13 @@ class ModelArchConfigConvertorBase:
return quant_cfg
@final
@classmethod
def get_quantization_config(cls, hf_config: PretrainedConfig):
quant_cfg = cls._normalize_quantization_config(hf_config)
def get_quantization_config(self):
quant_cfg = self._normalize_quantization_config(self.hf_config)
if quant_cfg is None and (
text_config := getattr(hf_config, "text_config", None)
text_config := getattr(self.hf_config, "text_config", None)
):
# Check the text config as well for multi-modal models.
quant_cfg = cls._normalize_quantization_config(text_config)
quant_cfg = self._normalize_quantization_config(text_config)
return quant_cfg
def is_deepseek_mla(self) -> bool:
@ -247,7 +241,7 @@ class ModelArchConfigConvertorBase:
derived_max_model_len = tmp_max_len
return derived_max_model_len, max_len_key
def convert(self, model_id: str, revision: str | None) -> ModelArchitectureConfig:
def convert(self) -> ModelArchitectureConfig:
model_arch_config = ModelArchitectureConfig(
architectures=self.get_architectures(),
model_type=self.hf_config.model_type,
@ -258,9 +252,8 @@ class ModelArchConfigConvertorBase:
head_size=self.get_head_size(),
vocab_size=self.get_vocab_size(),
total_num_kv_heads=self.get_total_num_kv_heads(),
num_experts=self.get_num_experts(self.hf_text_config),
quantization_config=self.get_quantization_config(self.hf_config),
torch_dtype=self.get_torch_dtype(self.hf_config, model_id, revision),
num_experts=self.get_num_experts(),
quantization_config=self.get_quantization_config(),
is_deepseek_mla=self.is_deepseek_mla(),
derived_max_model_len_and_key=self.derive_max_model_len_and_key(),
)